Add MultiDot op and rewrites for optimal contraction#2060
Add MultiDot op and rewrites for optimal contraction#2060jessegrabowski wants to merge 5 commits intopymc-devs:mainfrom
Conversation
| rval = args[0] | ||
| for a in args[1:]: | ||
| rval = ptm.dot(rval, a) | ||
| rval = ptm.matmul(rval, a) |
There was a problem hiding this comment.
remove this helper altogether, it's not standard name in numpy/scipy right?
There was a problem hiding this comment.
My gut doesn't love this approach.
Adding multi_dot in the IR is going to make us miss / compliate regular dot graphs.
The flattening by default may break the original associativity that may have been optimal in lack of statically known information.
My suggestion: After specialize, have a single GraphRewrite that collects nested matmuls and "re-associates" them if it can prove the new order is strictly better than the old one. It doesn't need an OpFromGraph imo.
Something like (bot generated):
class ReassociateMatmulChain(GraphRewriter):
"""Post-specialize: find matmul chains and reassociate if provably cheaper."""
def apply(self, fgraph):
visited = set()
for node in fgraph.toposort():
if node in visited or not _is_matmul_node(node):
continue
# 1. Extend chain through single-client intermediates only.
# This should ignore expand_dims / squeeze (maybe even transposes somehow?)
inputs, chain_nodes = self._extend_chain(node, fgraph, visited)
visited.update(chain_nodes)
if len(inputs) < 3:
continue
# 2. Symbolic shapes for every input. Each is a tuple of dim
# expressions (batch dims..., m_i, k_i) built from static shape
# where available and shape_of(var) otherwise.
shapes = [_symbolic_shape(x, fgraph) for x in inputs]
# 2b. Canonicalize all dim entries via a single shape-unification pass.
shapes = _unify_shapes(shapes, fgraph.shape_feature)
# 3. DP over parenthesizations.
# dp[i, j] = (cost_expr, split_k, result_shape) for chain[i..j]
n = len(inputs)
dp = {(i, i): (_zero(), None, shapes[i]) for i in range(n)}
for length in range(2, n + 1):
for i in range(n - length + 1):
j = i + length - 1
best = None
for k in range(i, j):
lc, _, ls = dp[i, k]
rc, _, rs = dp[k + 1, j]
step = _contract_cost(ls, rs)
total = lc + rc + step
result = _matmul_result_shape(ls, rs)
if best is None or _provably_less(total, best[0]):
best = (total, k, result)
dp[i, j] = best
new_cost, *_ = dp[0, n - 1]
old_cost = _current_order_cost(chain_nodes, shapes)
# 4. Only replace when provably strictly cheaper.
if not _provably_less(new_cost, old_cost):
continue
# _build_tree should return all nodes so we can add them to `seen`
new_out = _build_tree(inputs, dp, 0, n - 1) # plain matmul nodes
copy_stack_trace(chain_nodes[-1].outputs[0], new_out)
fgraph.replace(chain_nodes[-1].outputs[0], new_out,
reason="reassoc_matmul")Helpers — batch-aware shape & cost
def _matmul_result_shape(left, right):
"""left = (*bl, m, k), right = (*br, k, n) -> (*broadcast(bl, br), m, n).
Align batch dims from the right; missing dims on the shorter side are
treated as literal 1. After _unify_shapes, each aligned pair is either
(1, x), (x, 1), or (x, x) — pick the non-literal-1 side.
"""
batch = []
for da, db in zip_longest_right(left[:-2], right[:-2], fill=ONE):
if _is_literal_one(da):
batch.append(db)
elif _is_literal_one(db):
batch.append(da)
else:
assert _same_symbol(da, db) # unification guarantees this
batch.append(da)
return (*batch, left[-2], right[-1])
def _contract_cost(left, right):
"""FLOPs of (left @ right). Batch broadcast enters as a multiplier."""
result = _matmul_result_shape(left, right)
m, k, n = left[-2], left[-1], right[-1]
return _prod(result[:-2]) * m * k * n
def _unify_shapes(shapes, shape_feature):
"""Canonicalize dim entries for a matmul chain using all known equalities.
Three sources of equality feed in:
1. Contracting dims (matmul semantics): shapes[i][-1] == shapes[i+1][-2]
for every adjacent pair. Applies to *every* chain, adjacent only.
2. Batch dims required equal at runtime: for any pair (i, j), align
their batch dims from the right. Dims that are both non-literal-1
MUST be equal (broadcasting rule) — unify them for costing.
Applies to non-adjacent pairs too, transitively.
3. ShapeFeature same_shape classes: if the fgraph's ShapeFeature
already knows two shape entries are equal (from earlier rewrites
or op-level declarations), use it directly — no need to re-derive.
This is why the helper takes `shape_feature` rather than just
looking at the raw shape graphs.
Strategy: union-find over all dim entries in the chain. Add edges
from (1), (2), (3). Pick a representative per class preferring
literal ints > static-shape ints > shape_of symbols. Rewrite every
shape tuple with representatives.
TODO: ideally ShapeFeature itself carries the edges from (1) and (2)
(matmul's Op declares "my input ks are equal"; blockwise declares
"my batch dims broadcast"), so `same_shape` works everywhere and this
helper collapses to "read canonical reps from ShapeFeature." For now
we do it locally, but the long-term home is ShapeFeature.
"""Proving a < b symbolically
def _provably_less(a, b):
"""Expand a, b into sum-of-monomials in positive dim symbols.
Use the invariant: every dim symbol >= 1 (matmul dims are positive).
Return True iff we can match each monomial of `a` to a DISTINCT
monomial of `b` s.t. the a-term is dominated term-wise by its b-term,
and b has at least one unmatched monomial (strict). Otherwise False.
False means 'not provable'; it does NOT claim b <= a."""Term-wise dominance: monomial c·x1^a1·x2^a2… is dominated by
d·x1^b1·x2^b2… when c ≤ d and every ai ≤ bi, given all xi ≥ 1.
Cheap, and catches the common wins (a factor of m·k·p dominates k·p for
any positive m). Won't decide genuinely shape-dependent ties — that's
fine; bail and keep the original order.
513b760 to
edcb675
Compare
I've wanted this for a while. Adds a
MultiDotOpthat we can track with rewrites. We look for sequences of matrix multiplicates in the graph and fuse them into a MultiDot during canonicalization. For example:A @ B @ C -> MultiDot(A, B, C).By default, MulitDot is just an OpFromGraph that does simple left-to-right matrix multiplication. So
MultiDot(A, B, C) -> A @ B @ Cduring inlining. If all shapes of A, B, C are statically known, however, we solve the dynamic programming problem to figure out the optimal ordering of matmuls. For details see the wiki here: https://en.wikipedia.org/wiki/Matrix_chain_multiplicationWe could probably try to do something more heroic, but I think this is a good start.