Skip to content

internal use of join/split dims #2055

Open
mengxingbw wants to merge 4 commits intopymc-devs:mainfrom
mengxingbw:use-joinsplit
Open

internal use of join/split dims #2055
mengxingbw wants to merge 4 commits intopymc-devs:mainfrom
mengxingbw:use-joinsplit

Conversation

@mengxingbw
Copy link
Copy Markdown

@mengxingbw mengxingbw commented Apr 17, 2026

TODO:

  • joindims/splitdims COp
  • numba test_reshape
  • jax test_reshape

This is a follow-up of #1843. After adding the special cases to join/split dims, the next step is to identify places that reshape can be replaced by join/split dims to reduce rounds of rewriting and enhance efficiency.
Will be making slow and small progress one at a time!

…nstead of reshape to reduce rounds of rewriting
@ricardoV94
Copy link
Copy Markdown
Member

We need joindims/splitdims implementations (C/Jax/Numba/etc...) now that we want to push this as the canonical Op.

We also need to extend any shape related rewrites that apply to these ops to them.

Finally in cases we can't avoid reshape, we should merge it. join/split_dims(reshape(x)) -> reshape (single op).

And we need to convert reshapes that are obviously join/split_dims into those ops.

@mbaldourw
Copy link
Copy Markdown
Contributor

We need joindims/splitdims implementations (C/Jax/Numba/etc...) now that we want to push this as the canonical Op.

I drafted something for joindims implementations, will you check if this is what you mean/correct places?

In the function in reshape.py, joindims/splitdims will fall back to reshape if it's not the special cases. In canonicalizing joindims/splitdims, we will only implement the special cases - correct?

Comment thread pytensor/link/jax/dispatch/shape.py Outdated
def jax_funcify_JoinDims(op, node, **kwargs):
start_axis = op.start_axis
n_axes = op.n_axes
def join_dims(x):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can implement all as reshape like op.perform does. We already have the rewrites for these special cases

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i just realized we are talking about the broader joindims/splitdims! this whole time i was thinking about the local join dims and local split dims in rewriting/reshape cuz that's what i was looking at for the previous pr. ok the task makes more sense, might come back with more questions

@mbaldourw
Copy link
Copy Markdown
Contributor

are joindims and splitdims supposed to stay in tensor/reshape.py? reshape and other shape related functions are currently in tensor/shape.py, and their jax/numba link files follow the same name. i put the jax/numba function for joindims and splitdims in their corresponding shape.py - is this ok/will it be confusing?

@ricardoV94
Copy link
Copy Markdown
Member

are joindims and splitdims supposed to stay in tensor/reshape.py? reshape and other shape related functions are currently in tensor/shape.py, and their jax/numba link files follow the same name. i put the jax/numba function for joindims and splitdims in their corresponding shape.py - is this ok/will it be confusing?

Yeah, we want to move more stuff into reshape.py, we started just with the newer ones

@mbaldourw
Copy link
Copy Markdown
Contributor

Yeah, we want to move more stuff into reshape.py, we started just with the newer ones

so should their jax/numba files be reshape.py too?

@ricardoV94
Copy link
Copy Markdown
Member

Yeah, wherever it's currently defined (mirrored structure)

@mbaldourw
Copy link
Copy Markdown
Contributor

     @numba_basic.numba_njit
    def reshape(x, shape):
        # TODO: Use this until https://github.com/numba/numba/issues/7353 is closed.
        return np.reshape(
            np.ascontiguousarray(np.asarray(x)),
            numba_ndarray.to_fixed_tuple(shape, ndim),
        )

@ricardoV94 the issue in the comment above is closed now, anything to do there?
when i use np.reshape do i need to wrap it in side np.ascontiguousarray(np.asarray())?

@ricardoV94
Copy link
Copy Markdown
Member

@mbaldourw I think that issue only addresses the np.asarray part, we still need the contiguous thing as long as numba/numba#3353 is open.

We should link to that. The asarray part doesn't really cost us anything, and dropping would mean older numba versions would fail

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants