Skip to content

Add MultiDot op and rewrites for optimal contraction#2060

Open
jessegrabowski wants to merge 7 commits into
pymc-devs:mainfrom
jessegrabowski:multi-dot-via-contraction
Open

Add MultiDot op and rewrites for optimal contraction#2060
jessegrabowski wants to merge 7 commits into
pymc-devs:mainfrom
jessegrabowski:multi-dot-via-contraction

Conversation

@jessegrabowski
Copy link
Copy Markdown
Member

I've wanted this for a while. Adds a MultiDot Op that we can track with rewrites. We look for sequences of matrix multiplicates in the graph and fuse them into a MultiDot during canonicalization. For example: A @ B @ C -> MultiDot(A, B, C).

By default, MulitDot is just an OpFromGraph that does simple left-to-right matrix multiplication. So MultiDot(A, B, C) -> A @ B @ C during inlining. If all shapes of A, B, C are statically known, however, we solve the dynamic programming problem to figure out the optimal ordering of matmuls. For details see the wiki here: https://en.wikipedia.org/wiki/Matrix_chain_multiplication

We could probably try to do something more heroic, but I think this is a good start.

@jessegrabowski jessegrabowski added enhancement New feature or request NumPy compatibility linalg Linear algebra labels Apr 19, 2026
Comment thread pytensor/tensor/linalg/products.py
Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My gut doesn't love this approach.

Adding multi_dot in the IR is going to make us miss / compliate regular dot graphs.

The flattening by default may break the original associativity that may have been optimal in lack of statically known information.

My suggestion: After specialize, have a single GraphRewrite that collects nested matmuls and "re-associates" them if it can prove the new order is strictly better than the old one. It doesn't need an OpFromGraph imo.

Something like (bot generated):

class ReassociateMatmulChain(GraphRewriter):
    """Post-specialize: find matmul chains and reassociate if provably cheaper."""

    def apply(self, fgraph):
        visited = set()
        for node in fgraph.toposort():
            if node in visited or not _is_matmul_node(node):
                continue

            # 1. Extend chain through single-client intermediates only.
            # This should ignore expand_dims / squeeze (maybe even transposes somehow?)
            inputs, chain_nodes = self._extend_chain(node, fgraph, visited)
            visited.update(chain_nodes)
            if len(inputs) < 3:
                continue

            # 2. Symbolic shapes for every input. Each is a tuple of dim
            #    expressions (batch dims..., m_i, k_i) built from static shape
            #    where available and shape_of(var) otherwise.
            shapes = [_symbolic_shape(x, fgraph) for x in inputs]

            # 2b. Canonicalize all dim entries via a single shape-unification pass.
            shapes = _unify_shapes(shapes, fgraph.shape_feature)

            # 3. DP over parenthesizations.
            #    dp[i, j] = (cost_expr, split_k, result_shape) for chain[i..j]
            n = len(inputs)
            dp = {(i, i): (_zero(), None, shapes[i]) for i in range(n)}
            for length in range(2, n + 1):
                for i in range(n - length + 1):
                    j = i + length - 1
                    best = None
                    for k in range(i, j):
                        lc, _, ls = dp[i, k]
                        rc, _, rs = dp[k + 1, j]
                        step = _contract_cost(ls, rs)
                        total = lc + rc + step
                        result = _matmul_result_shape(ls, rs)
                        if best is None or _provably_less(total, best[0]):
                            best = (total, k, result)
                    dp[i, j] = best

            new_cost, *_ = dp[0, n - 1]
            old_cost = _current_order_cost(chain_nodes, shapes)

            # 4. Only replace when provably strictly cheaper.
            if not _provably_less(new_cost, old_cost):
                continue

            # _build_tree should return all nodes so we can add them to `seen`
            new_out = _build_tree(inputs, dp, 0, n - 1)  # plain matmul nodes
            copy_stack_trace(chain_nodes[-1].outputs[0], new_out)
            fgraph.replace(chain_nodes[-1].outputs[0], new_out,
                           reason="reassoc_matmul")

Helpers — batch-aware shape & cost

def _matmul_result_shape(left, right):
    """left = (*bl, m, k), right = (*br, k, n) -> (*broadcast(bl, br), m, n).

    Align batch dims from the right; missing dims on the shorter side are
    treated as literal 1. After _unify_shapes, each aligned pair is either
    (1, x), (x, 1), or (x, x) — pick the non-literal-1 side.
    """
    batch = []
    for da, db in zip_longest_right(left[:-2], right[:-2], fill=ONE):
        if _is_literal_one(da):
            batch.append(db)
        elif _is_literal_one(db):
            batch.append(da)
        else:
            assert _same_symbol(da, db)   # unification guarantees this
            batch.append(da)
    return (*batch, left[-2], right[-1])

def _contract_cost(left, right):
    """FLOPs of (left @ right). Batch broadcast enters as a multiplier."""
    result = _matmul_result_shape(left, right)
    m, k, n = left[-2], left[-1], right[-1]
    return _prod(result[:-2]) * m * k * n

def _unify_shapes(shapes, shape_feature):
    """Canonicalize dim entries for a matmul chain using all known equalities.

    Three sources of equality feed in:

    1. Contracting dims (matmul semantics): shapes[i][-1] == shapes[i+1][-2]
       for every adjacent pair. Applies to *every* chain, adjacent only.

    2. Batch dims required equal at runtime: for any pair (i, j), align
       their batch dims from the right. Dims that are both non-literal-1
       MUST be equal (broadcasting rule) — unify them for costing.
       Applies to non-adjacent pairs too, transitively.

    3. ShapeFeature same_shape classes: if the fgraph's ShapeFeature
       already knows two shape entries are equal (from earlier rewrites
       or op-level declarations), use it directly — no need to re-derive.
       This is why the helper takes `shape_feature` rather than just
       looking at the raw shape graphs.

    Strategy: union-find over all dim entries in the chain. Add edges
    from (1), (2), (3). Pick a representative per class preferring
    literal ints > static-shape ints > shape_of symbols. Rewrite every
    shape tuple with representatives.

    TODO: ideally ShapeFeature itself carries the edges from (1) and (2)
    (matmul's Op declares "my input ks are equal"; blockwise declares
    "my batch dims broadcast"), so `same_shape` works everywhere and this
    helper collapses to "read canonical reps from ShapeFeature." For now
    we do it locally, but the long-term home is ShapeFeature.
    """

Proving a < b symbolically

def _provably_less(a, b):
    """Expand a, b into sum-of-monomials in positive dim symbols.
       Use the invariant: every dim symbol >= 1 (matmul dims are positive).
       Return True iff we can match each monomial of `a` to a DISTINCT
       monomial of `b` s.t. the a-term is dominated term-wise by its b-term,
       and b has at least one unmatched monomial (strict). Otherwise False.
       False means 'not provable'; it does NOT claim b <= a."""

Term-wise dominance: monomial c·x1^a1·x2^a2… is dominated by
d·x1^b1·x2^b2… when c ≤ d and every ai ≤ bi, given all xi ≥ 1.
Cheap, and catches the common wins (a factor of m·k·p dominates k·p for
any positive m). Won't decide genuinely shape-dependent ties — that's
fine; bail and keep the original order.

@jessegrabowski jessegrabowski force-pushed the multi-dot-via-contraction branch 2 times, most recently from edcb675 to 7f53de5 Compare May 2, 2026 06:09
@jessegrabowski
Copy link
Copy Markdown
Member Author

@ricardoV94 took another cut at this, lmk if it looks better to you

@ricardoV94
Copy link
Copy Markdown
Member

@ricardoV94 took another cut at this, lmk if it looks better to you

does it look better to you?

@jessegrabowski
Copy link
Copy Markdown
Member Author

it's definitely thinking about the problem in a much different way. You can see from my first pass how I think about pytensor in a very Op-oriented way. Here we have a pure graph reasoning implementation. I would never have come up with the monomial stuff by myself, so it's hard for me to assess if it's doing the right thing from a design level. It's also being very cute by trying to work 100% with shapes instead of passing around variables, so you end up with stuff like the broadcast checkers. I think you can argue those are bloat, but only if I refactor it again to work with variables (so we get access to .broadcastable and whatnot)

@ricardoV94
Copy link
Copy Markdown
Member

ricardoV94 commented May 2, 2026

it's definitely thinking about the problem in a much different way. You can see from my first pass how I think about pytensor in a very Op-oriented way. Here we have a pure graph reasoning implementation.

Could be lack of familiarity with graph level rewrites? Maybe would be a similar weirdness if you were implementing fusion rewriter from scratch. Start with an eager local thing that just tries to expand one node at a time -> Composite A -> Compasite A + 1. And then I come and suggest you analyze the graph all at once and break it into all the Composites you can see, try to use some advanced data structure to verify convexity cheaply

@ricardoV94
Copy link
Copy Markdown
Member

Or is just not the right answer for this problem ...

@jessegrabowski
Copy link
Copy Markdown
Member Author

I think it's nicer, and I think you're right that I've mostly only been working with node rewriters so. I don't love that it's a +1000 line PR for a feature that won't be relevant much of the time. On the plus side though, it's contained to a single file so easy to iterate on if it ends up sucking. I want run it through the ASV benchmarks to make sure it doesn't drag on graphs overall.

@ricardoV94
Copy link
Copy Markdown
Member

Curious if it fires anywhere in the CI or pymc model catalogue

@ricardoV94
Copy link
Copy Markdown
Member

ricardoV94 commented May 2, 2026

If it's rare but the bail out is fast, that's also ok

@jessegrabowski jessegrabowski requested a review from ricardoV94 May 3, 2026 23:10
Comment on lines +178 to +182
if not (0 <= x < len(shape)):
raise _BailOutError(
f"DimShuffle.new_order references index {x} outside operand shape "
f"of length {len(shape)}; lift cannot legally apply."
)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah idk, i got a bit lazy letting all the bot slop through

Comment thread pytensor/tensor/rewriting/linalg/reassociate_matmul.py Outdated
return False, False


def _is_chain_link(node: Apply) -> bool:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inline function (fine to still be a function) in `find_chan_top

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's used in _decompose_operand, _find_chain_top, and ReassociateMatmulChain. I think this one is defensible as a helper.

Comment thread pytensor/tensor/rewriting/linalg/reassociate_matmul.py Outdated
Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pretty nice, some nits. I'll check if this fires anywhere in the pymc catalogue out of curiosity

@jessegrabowski
Copy link
Copy Markdown
Member Author

pretty nice, some nits. I'll check if this fires anywhere in the pymc catalogue out of curiosity

Curious what you find. If we have a SEM example it might? It's pretty unusual for a GLM to do a fat chains of matmuls

@ricardoV94
Copy link
Copy Markdown
Member

ricardoV94 commented May 4, 2026

Get's called in 5 models (but two are just for demoing), in none does it apply:

Model Chain Outcome
car_scotland_lipcancer.py (56×56) @ (56×56) @ (56×56) Unchanged: already optimal
factor_analysis_ppca.py (10×2) @ (2×?) @ (?×250), (?×?) @ (?×10) @ (10×250) Unknown: dim missing, fell back to left-to-right
gp_kron_latent.py (50×50) @ (50×30) @ (30×30), (30×50) @ (50×50) @ (50×30) Unchanged: already optimal
smc2_gaussians_n4.py (1×4) @ (4×4) @ (4×1) ×2 Unchanged: already optimal
smc2_gaussians_n80.py (1×80) @ (80×80) @ (80×1) ×2 Unchanged: already optimal

@jessegrabowski
Copy link
Copy Markdown
Member Author

I imagine this will be the most common case tbh. But statespace cares!

@jessegrabowski jessegrabowski force-pushed the multi-dot-via-contraction branch from e435b76 to 9d1a298 Compare May 6, 2026 03:39
@jessegrabowski jessegrabowski requested a review from ricardoV94 May 6, 2026 03:44
@ricardoV94
Copy link
Copy Markdown
Member

For statespace it will help because the package always puts the dot in the same order, but they are not necessarily in the optimal from the get go. Regardless of this rewrite, do we have the best default order there (say if static shapes never trigger)?

@jessegrabowski
Copy link
Copy Markdown
Member Author

Regardless of this rewrite, do we have the best default order there (say if static shapes never trigger)?

Yes, in principle the right answer is known and I could just go in and put parenthesis. But I want my software to magically do it for me :(

Also in expressions like these it gets uglier and uglier to actually do that in code. These types of huge dot chains also appear in optimal control problems. Again the shapes of all those objects are known statically ahead of time (even if pytensor doesn't know that) so one could simply optimize it himself.

@ricardoV94
Copy link
Copy Markdown
Member

I believe it's useful. I'm a bit bummed that we haven't seen it be useful yet. Do you have any STS example where it triggers (not one purposedly built now to prove it).

Not a blocker regardless as long as this rewrite is cheap to bailout (which I believe it will be).

@jessegrabowski
Copy link
Copy Markdown
Member Author

Ok so a place we really, really should be getting gains is in low-rank projects, like inducing-point approximation for GP. Here's some code due to @bwengals :

          Kuf = self.kernel(Z, X_train)  # (M, N)                                                                                                                                                                                  
          Kus = self.kernel(Z, X_new)  # (M, N*)                                                                                                                                                                                   
          Kss_diag = self.kernel.diag(X_new)                                                                                                                                                                                       
                                                                                                                                                                                                                                   
          # Sigma = Kuu + Kuf @ Kuf.T / sigma^2                                                                                                                                                                                    
          Sigma = Kuu + Kuf @ Kuf.T / sigma2                                                                                                                                                                                       
          Sigma_inv = pt.linalg.inv(Sigma)                                                                                                                                                                                         
                                                                                                                                                                                                                                   
          mu_train = self.mean(X_train)                                                                                                                                                                                            
          alpha = Sigma_inv @ Kuf @ (y_train - mu_train) / sigma2                                                                                                                                                                  
                                                                                                                                                                                                                                   
          fmean = self.mean(X_new) + Kus.T @ alpha                                                                                                                                                                                 
                                                                                                                                                                                                                                   
          # fvar = Kss - Kus.T @ (Kuu^{-1} - Sigma^{-1}) @ Kus                                                                                                                                                                     
          Kuu_inv = pt.linalg.inv(Kuu)                                                                                                                                                                                             
          diff_inv = Kuu_inv - Sigma_inv                                                                                                                                                                                           
          fvar = Kss_diag - pt.sum(Kus * (diff_inv @ Kus), axis=0)     

Bot analysis:

  VFE GP — rewrite breaks before chain detection runs.                                                                                                                                                                             
   
  The candidate chain is Sigma_inv @ Kuf @ (y - mu). With M=50 inducing, N=2000 training:                                                                                                                                          
  - User-written left-fold cost: M²N + MN ≈ 5M flops.               
  - Optimal right-fold (Sigma_inv @ (Kuf @ res)): MN + M² ≈ 227k flops. ~23× cheaper.                                                                                                                                              
                                                                                     
  But pytensor's inv(A) @ B → solve(A, B) rewrite fires before reassociate_matmul_chain (it lives in canonicalize/specialize, position < 2). It rewrites Sigma_inv @ Kuf into Solve[(50,50), (50,2000)], and then what's left is   
  Solve @ Dot — not a chain of Dot/Dot22/BatchedDot nodes. Our _is_chain_link doesn't recognize Solve as a chain link, so the chain is invisible by the time we run. The graph ends up at the slow plan: solve over the full (50,  
  2000) RHS, then dot with (2000, 1). That's a real ~23× opportunity our rewrite misses.                                                                                                                                           
                                                                                                                                                                                                                                   
  This isn't fixable by extending the matmul rewrite alone. Two real options:                                                                                                                                                      
  - A separate Solve-aware reordering pass: recognize solve(A, B) @ C and rewrite to solve(A, B @ C) when C's trailing dim is narrow enough that the cost flips.
  - Move reassociate_matmul_chain earlier so it sees the chain before inv→solve. But this would mean running before BLAS specialization, which would force changes to the chain-link set (no Dot22/Gemm yet) and probably break    
  other invariants.                                       

@ricardoV94
Copy link
Copy Markdown
Member

ricardoV94 commented May 6, 2026

Solve / Dot seems like a hard ordering problem, you could re-canonicalize as MatrixInverse here and be sure to always reintroduce Solve when done.

The BLAS thing, may be time to pull the plug, should def be a very late stage rewrite (after specialize, before fusion), it hurts us everytime we want to work with matmuls (all the time), and I don't see why it should/must be eager at all

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request linalg Linear algebra NumPy compatibility

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants