Skip to content

Add MultiDot op and rewrites for optimal contraction#2060

Open
jessegrabowski wants to merge 5 commits intopymc-devs:mainfrom
jessegrabowski:multi-dot-via-contraction
Open

Add MultiDot op and rewrites for optimal contraction#2060
jessegrabowski wants to merge 5 commits intopymc-devs:mainfrom
jessegrabowski:multi-dot-via-contraction

Conversation

@jessegrabowski
Copy link
Copy Markdown
Member

I've wanted this for a while. Adds a MultiDot Op that we can track with rewrites. We look for sequences of matrix multiplicates in the graph and fuse them into a MultiDot during canonicalization. For example: A @ B @ C -> MultiDot(A, B, C).

By default, MulitDot is just an OpFromGraph that does simple left-to-right matrix multiplication. So MultiDot(A, B, C) -> A @ B @ C during inlining. If all shapes of A, B, C are statically known, however, we solve the dynamic programming problem to figure out the optimal ordering of matmuls. For details see the wiki here: https://en.wikipedia.org/wiki/Matrix_chain_multiplication

We could probably try to do something more heroic, but I think this is a good start.

@jessegrabowski jessegrabowski added enhancement New feature or request NumPy compatibility linalg Linear algebra labels Apr 19, 2026
rval = args[0]
for a in args[1:]:
rval = ptm.dot(rval, a)
rval = ptm.matmul(rval, a)
Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 Apr 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this helper altogether, it's not standard name in numpy/scipy right?

Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My gut doesn't love this approach.

Adding multi_dot in the IR is going to make us miss / compliate regular dot graphs.

The flattening by default may break the original associativity that may have been optimal in lack of statically known information.

My suggestion: After specialize, have a single GraphRewrite that collects nested matmuls and "re-associates" them if it can prove the new order is strictly better than the old one. It doesn't need an OpFromGraph imo.

Something like (bot generated):

class ReassociateMatmulChain(GraphRewriter):
    """Post-specialize: find matmul chains and reassociate if provably cheaper."""

    def apply(self, fgraph):
        visited = set()
        for node in fgraph.toposort():
            if node in visited or not _is_matmul_node(node):
                continue

            # 1. Extend chain through single-client intermediates only.
            # This should ignore expand_dims / squeeze (maybe even transposes somehow?)
            inputs, chain_nodes = self._extend_chain(node, fgraph, visited)
            visited.update(chain_nodes)
            if len(inputs) < 3:
                continue

            # 2. Symbolic shapes for every input. Each is a tuple of dim
            #    expressions (batch dims..., m_i, k_i) built from static shape
            #    where available and shape_of(var) otherwise.
            shapes = [_symbolic_shape(x, fgraph) for x in inputs]

            # 2b. Canonicalize all dim entries via a single shape-unification pass.
            shapes = _unify_shapes(shapes, fgraph.shape_feature)

            # 3. DP over parenthesizations.
            #    dp[i, j] = (cost_expr, split_k, result_shape) for chain[i..j]
            n = len(inputs)
            dp = {(i, i): (_zero(), None, shapes[i]) for i in range(n)}
            for length in range(2, n + 1):
                for i in range(n - length + 1):
                    j = i + length - 1
                    best = None
                    for k in range(i, j):
                        lc, _, ls = dp[i, k]
                        rc, _, rs = dp[k + 1, j]
                        step = _contract_cost(ls, rs)
                        total = lc + rc + step
                        result = _matmul_result_shape(ls, rs)
                        if best is None or _provably_less(total, best[0]):
                            best = (total, k, result)
                    dp[i, j] = best

            new_cost, *_ = dp[0, n - 1]
            old_cost = _current_order_cost(chain_nodes, shapes)

            # 4. Only replace when provably strictly cheaper.
            if not _provably_less(new_cost, old_cost):
                continue

            # _build_tree should return all nodes so we can add them to `seen`
            new_out = _build_tree(inputs, dp, 0, n - 1)  # plain matmul nodes
            copy_stack_trace(chain_nodes[-1].outputs[0], new_out)
            fgraph.replace(chain_nodes[-1].outputs[0], new_out,
                           reason="reassoc_matmul")

Helpers — batch-aware shape & cost

def _matmul_result_shape(left, right):
    """left = (*bl, m, k), right = (*br, k, n) -> (*broadcast(bl, br), m, n).

    Align batch dims from the right; missing dims on the shorter side are
    treated as literal 1. After _unify_shapes, each aligned pair is either
    (1, x), (x, 1), or (x, x) — pick the non-literal-1 side.
    """
    batch = []
    for da, db in zip_longest_right(left[:-2], right[:-2], fill=ONE):
        if _is_literal_one(da):
            batch.append(db)
        elif _is_literal_one(db):
            batch.append(da)
        else:
            assert _same_symbol(da, db)   # unification guarantees this
            batch.append(da)
    return (*batch, left[-2], right[-1])

def _contract_cost(left, right):
    """FLOPs of (left @ right). Batch broadcast enters as a multiplier."""
    result = _matmul_result_shape(left, right)
    m, k, n = left[-2], left[-1], right[-1]
    return _prod(result[:-2]) * m * k * n

def _unify_shapes(shapes, shape_feature):
    """Canonicalize dim entries for a matmul chain using all known equalities.

    Three sources of equality feed in:

    1. Contracting dims (matmul semantics): shapes[i][-1] == shapes[i+1][-2]
       for every adjacent pair. Applies to *every* chain, adjacent only.

    2. Batch dims required equal at runtime: for any pair (i, j), align
       their batch dims from the right. Dims that are both non-literal-1
       MUST be equal (broadcasting rule) — unify them for costing.
       Applies to non-adjacent pairs too, transitively.

    3. ShapeFeature same_shape classes: if the fgraph's ShapeFeature
       already knows two shape entries are equal (from earlier rewrites
       or op-level declarations), use it directly — no need to re-derive.
       This is why the helper takes `shape_feature` rather than just
       looking at the raw shape graphs.

    Strategy: union-find over all dim entries in the chain. Add edges
    from (1), (2), (3). Pick a representative per class preferring
    literal ints > static-shape ints > shape_of symbols. Rewrite every
    shape tuple with representatives.

    TODO: ideally ShapeFeature itself carries the edges from (1) and (2)
    (matmul's Op declares "my input ks are equal"; blockwise declares
    "my batch dims broadcast"), so `same_shape` works everywhere and this
    helper collapses to "read canonical reps from ShapeFeature." For now
    we do it locally, but the long-term home is ShapeFeature.
    """

Proving a < b symbolically

def _provably_less(a, b):
    """Expand a, b into sum-of-monomials in positive dim symbols.
       Use the invariant: every dim symbol >= 1 (matmul dims are positive).
       Return True iff we can match each monomial of `a` to a DISTINCT
       monomial of `b` s.t. the a-term is dominated term-wise by its b-term,
       and b has at least one unmatched monomial (strict). Otherwise False.
       False means 'not provable'; it does NOT claim b <= a."""

Term-wise dominance: monomial c·x1^a1·x2^a2… is dominated by
d·x1^b1·x2^b2… when c ≤ d and every ai ≤ bi, given all xi ≥ 1.
Cheap, and catches the common wins (a factor of m·k·p dominates k·p for
any positive m). Won't decide genuinely shape-dependent ties — that's
fine; bail and keep the original order.

@jessegrabowski jessegrabowski force-pushed the multi-dot-via-contraction branch from 513b760 to edcb675 Compare May 1, 2026 04:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request linalg Linear algebra NumPy compatibility

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants