internal use of join/split dims #2055
Conversation
…nstead of reshape to reduce rounds of rewriting
|
We need joindims/splitdims implementations (C/Jax/Numba/etc...) now that we want to push this as the canonical Op. We also need to extend any shape related rewrites that apply to these ops to them. Finally in cases we can't avoid reshape, we should merge it. join/split_dims(reshape(x)) -> reshape (single op). And we need to convert reshapes that are obviously join/split_dims into those ops. |
b9aeaf0 to
cc652c3
Compare
I drafted something for joindims implementations, will you check if this is what you mean/correct places? In the function in reshape.py, joindims/splitdims will fall back to reshape if it's not the special cases. In canonicalizing joindims/splitdims, we will only implement the special cases - correct? |
| def jax_funcify_JoinDims(op, node, **kwargs): | ||
| start_axis = op.start_axis | ||
| n_axes = op.n_axes | ||
| def join_dims(x): |
There was a problem hiding this comment.
I think you can implement all as reshape like op.perform does. We already have the rewrites for these special cases
There was a problem hiding this comment.
i just realized we are talking about the broader joindims/splitdims! this whole time i was thinking about the local join dims and local split dims in rewriting/reshape cuz that's what i was looking at for the previous pr. ok the task makes more sense, might come back with more questions
|
are joindims and splitdims supposed to stay in tensor/reshape.py? reshape and other shape related functions are currently in tensor/shape.py, and their jax/numba link files follow the same name. i put the jax/numba function for joindims and splitdims in their corresponding shape.py - is this ok/will it be confusing? |
Yeah, we want to move more stuff into reshape.py, we started just with the newer ones |
so should their jax/numba files be reshape.py too? |
|
Yeah, wherever it's currently defined (mirrored structure) |
@ricardoV94 the issue in the comment above is closed now, anything to do there? |
|
@mbaldourw I think that issue only addresses the np.asarray part, we still need the contiguous thing as long as numba/numba#3353 is open. We should link to that. The asarray part doesn't really cost us anything, and dropping would mean older numba versions would fail |
TODO:
This is a follow-up of #1843. After adding the special cases to join/split dims, the next step is to identify places that reshape can be replaced by join/split dims to reduce rounds of rewriting and enhance efficiency.
Will be making slow and small progress one at a time!