Skip to content

System for algebraic reasoning about linear alegbra#2032

Open
jessegrabowski wants to merge 50 commits intopymc-devs:mainfrom
jessegrabowski:assumption-system
Open

System for algebraic reasoning about linear alegbra#2032
jessegrabowski wants to merge 50 commits intopymc-devs:mainfrom
jessegrabowski:assumption-system

Conversation

@jessegrabowski
Copy link
Copy Markdown
Member

@jessegrabowski jessegrabowski commented Apr 7, 2026

This PR is a proposal for a typing system for linear algebra primitives. The purpose is to enable graph-wide reasoning about the kinds of matrices, so that we can rewrite to efficient computational forms.

Current State

We currently have several linear algebra rewrites and plan to add more. These are tracked in #573. This is important because linear algebra is 1) ubiquotous, 2) expensive, and 3) inscrutiable. Pytensor's static graph representation and rewrite system is well positioned to provide users help writing the best possible programs involving heavy linear algebra, if only we can figure out what is going on.

Consider the motivating case of solve(A, b), when A = pt.diag(pt.arange(100_000)). This is an O(n^3) operation that will call out to specialized routines. But there is no need for this. Since A is diagonal, we can write this as an elementwise division b / pt.extract_diag(A).

How do we decide to do this? We:

  • Track the Solve Op
  • Check if either input "seems diagonal"
  • If so, do the rewrite

What seems diagonal? We assume an input is diagonal if it was created by pt.eye or pt.diag. Users cannot specify themselves whether input data is diagonal. If an Op get inbetween the known "diagonalish" Op and the Solve, we cannot detect diagonality. For example, we cannot rewrite solve(A * 3, b), because now the first input is Elemwise(Mul)(A, 3). Multiplication is diagonal-preserving (because it is zero-preserving), but since the known diagonal op is now buried inside the Elemwise, we're out of luck.

Proposal

My proposal is to reason about algebraic properties of matrices the same way we reason about shapes. For shapes, we attach a ShapeFeature to FunctionGraphs. Each Op has an infer_shape method that explains how the static shape propagates. Likewise, I propose an AssumptionFeature. Ops do not have infer_assumption methods. That would be too messy. Instead, we have a central ASSUMPTION_INFER_REGISTRY with keys (Op, AssumptionKey ) and values InferFactFn.

  • An AssumptionKey is just a marker class corresponding to an algebraic fact about a matrix, like DIAGONAL, LOWER_TRIANGULAR, ORTHOGONAL, POSITIVE, SEMIDEFINITE, and so on.
  • An InferFactFn has the following signature:
def infer_diagonal(op: Op, assumption_feature: AssumptionFeature, fgraph: FunctionGraph, node: Apply, input_facts: list[FactState]) -> list[FactState]:

Like other symbolic operations, the InferFactFunction takes an Op (plus global information about the graph it lives in, fgraph and assumption_feature), and information about its inputs (the list of FactStates) and returns a list of information about its outputs.

A FactState is a three-valued logic for assumption inference. The possible values are UNKNOWN, TRUE, or FALSE. A fourth state, CONFLICT, exists, but should never arise.

All facts about all Ops are assumed to be UNKNOWN unless we can prove otherwise. Proof comes from each Op's registered InferFactFunctions. The AssumptionFeature is responsible for gathering all the rules of fact propagation. An example of a simple fact is that all Eye Ops are DIAGONAL, provided it is 1) square and 2) offset of zero:

def true_if(cond: bool) -> list[FactState]:
    """``[TRUE]`` when *cond* holds, ``[UNKNOWN]`` otherwise."""
    return [FactState.TRUE] if cond else [FactState.UNKNOWN]

def eye_is_identity(node) -> bool:
    """True when an :class:`Eye` node produces the identity matrix (square, k == 0)."""
    n, m, k = node.inputs
    if not (isinstance(k, Constant) and k.data.item() == 0):
        return False
    if n is m:
        return True
    if isinstance(n, Constant) and isinstance(m, Constant):
        return n.data.item() == m.data.item()
    return False

@register_assumption(DIAGONAL, Eye)
def _eye(op, feature, fgraph, node, input_states):
    return true_if(eye_is_identity(node))

In a program, we can use the AssumptionFeature.get(x, AssumptionKey) to query about the state of a Variable. Here, we ask "is x diagonal?". Obviously it is:

import pytensor.tensor as pt
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor.assumptions import AssumptionFeature, DIAGONAL
x = pt.eye(5)
fg = FunctionGraph([], [x])
af = AssumptionFeature()
fg.attach_feature(af)

af.get(x, DIAGONAL) # <FactState.TRUE: 1>

Where this becomes powerful is by accumulating InferFactFunctions. Many Ops preserve diagonality, like Cholesky or Inverse. We can use information about the inputs to these Ops to propogate fact information through the graph:

@register_assumption(DIAGONAL, Cholesky)
def _cholesky(op, feature, fgraph, node, input_states):
    return true_if(input_states[0])

@register_assumption(DIAGONAL, MatrixInverse)
def _inv(op, feature, fgraph, node, input_states):
    return true_if(input_states[0])

Now we don't lose information about x in deeper graphs, and are free to do more rewrites:

x = pt.eye(5)
y = pt.linalg.cholesky(x)
z = pt.linalg.inv(y)

fg = FunctionGraph([z], [x])
af = AssumptionFeature()
fg.attach_feature(af)

af.get(y, DIAGONAL) # <FactState.TRUE: 1>
af.get(z, DIAGONAL) # <FactState.TRUE: 1>

Of course can also reason conditionally. An IncSubtensor might be diagonal-preserving if we can prove that we're setting a value on the diagonal of the matrix. Otherwise we fall back to UNKNOWN:

from pytensor.tensor.subtensor import IncSubtensor

x = pt.eye(5)
i, j = pt.iscalars('i', 'j')
y = x[i, j].inc(3) # Cannot prove i != j at runtime, so UNKNOWN
z = x[i, i].inc(3) # i == i provable, so this is diagonal-preserving
fg = FunctionGraph([y, z], [x])
af = AssumptionFeature()
fg.attach_feature(af)

af.get(y, DIAGONAL) #<FactState.UNKNOWN: 0>
af.get(z, DIAGONAL) #<FactState.TRUE: 1>

Facts can also imply other facts. DIAGONAL matrices are also symmetrical. These general relationships can be registered and encoded as well. Continuing the example above:

from pytensor.tensor.assumptions import SYMMETRIC
af.get(z, SYMMETRIC) # <FactState.TRUE: 1>

Finally, users can specify facts about matrices using pt.specify_assumptions, the same way they are able to specify shapes.

x = pt.specify_assumptions(pt.dmatrix('x'), diagonal=True)
fg = FunctionGraph([x], [x])
af = AssumptionFeature()
fg.attach_feature(af)

af.get(x, DIAGONAL) # <FactState.TRUE: 1>

Benefits for rewriting:

  • FactStates are trivial to check. We can check any fact about any Variable in 5 lines:
def _is_diagonal(var, fgraph):
    """Check if *var* is diagonal using the AssumptionFeature."""

    af = getattr(fgraph, "assumption_feature", None)
    if af is None:
        af = AssumptionFeature()
        fgraph.attach_feature(af)

    return af.check(var, DIAGONAL)
  • We can reason globally about the graph
    As noted above, information flows through the graph. As long as we have good coverage of fact rules for Ops, we can make statements about Variables at all levels of computation

  • InferFactFunctions are lightweight and easy to write
    It is trivial to add new InferFactFunctions. LLMs can bang them out. They require only local reasoning about the Op and its immediate inputs.

  • FactStates allow non-trivial combinations of rewrites
    One example I hit while working on this was a rewrite for DirectSolveLyapunov given diagonal inputs. Because there is a rule that kron preserves diagonality if both inputs are diagonal, the chain of rewrites from solve_discrete_lyapunov(diag(a), Q) -> Q.ravel() / (1 - outer(a, a).ravel()) is discovered by the rewrite system via:

  • rewrite_kron_diag_to_diag_outer: kron(diag(a), diag(b)) → diag(outer(a, b).ravel())

  • rewrite_solve_diag_to_division: solve(diag(x), b) → b / x

The key is that after the first rewrite produces a diagonal matrix (via alloc_diag), the assumption system recognizes it as diagonal (via AllocDiag being registered with DIAGONAL), and then the second rewrite kicks in.

Non-Goals

The purpose of this system is not to introduce an complete, closed algebra over all types. That is impossible. The goal is also not to complete the project of German romanticism. That is also impossible.

  • There should never be any "global state" (no Wolfram Assuming). Assumptions live on FunctionGraphs. There is never a need to to deal with global context, logical combinations, relational assumptions.
  • We are not trying to be smart or figure out as much as possible. On the contrary, we want to be very dumb. The Assumptions should be maximally conservative, and fall back to UNKNOWN whenever there is runtime ambiguity.
  • Assumption logic lives separately from all other machinery. No other part of Pytensor needs to know anything about them. Evaluation logic does not check assumptions. Perform methods don't dispatch on the FactState of the Node. We check during rewrites and that's it.
  • We do not aspire to provide Maple-style conditional dispatches. If we don't know, we just don't know. There is no LinearAlgebra[IsDefinite]. No symbolic conditionals as part of the core API.
  • We are not and cannot be a theorm prover. We do not present assumptions to the user as a tool for this. Assumptions are primarily an inner-api, and rewrite-focused.

@ricardoV94
Copy link
Copy Markdown
Member

ricardoV94 commented Apr 7, 2026

I like the top message (didn't look at the code). Two notes:

1

Missing some discussion on how to preserve information across rewrites. ShapeFeature combines shape information when you replace a->b and you knew something of a and something of b (or just one of them).

This is specially relevant for constant folding (e.g, eye(int(1e6))) where it's much cheaper/obvious before vs after

But then, when do you decide to ask/start checking assumption? Because this can also be useless work (e.g. a graph without any linalg stuff).

Ordering is hard

2

Why do you assume that you can do all the reasoning you need to from op and inputs? It's a small note and I don't think a requirement/restriction in your proposal.

But eg checking for tridiagonal matrix creation (which we do) requires checking 2/3 nested set_subtensor nodes.

@maresb
Copy link
Copy Markdown
Contributor

maresb commented Apr 7, 2026

Very disappointed that German romanticism is out of scope

But seriously, this looks amazing. I'm not really capable of reviewing this but I am very excited for this one.

Comment thread pytensor/tensor/assumptions/diagonal.py Outdated
try:
val = get_underlying_scalar_constant_value(node.inputs[0])
if val == 0:
return [FactState.TRUE]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want FactState=False? when it is not zero but still known constant?

Comment thread pytensor/tensor/assumptions/specify.py Outdated
Comment thread pytensor/tensor/assumptions/specify.py Outdated
return output_grads


def specify_assumptions(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shorter name? pytensor.assume?

(Also this need not be in tensor module, seems more generic that it)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it could be shorter too, the other arguments a user has to give to it I think will make it clear what it does, so the name can be on the short side. Outside of tensors/linear algebra stuff, what kinda examples are you thinking of @ricardoV94?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xtensor / scalar modules can also have assumptions, no reason it needs to be tied to tensor/

Comment thread pytensor/tensor/assumptions/utils.py Outdated
Comment thread pytensor/tensor/assumptions/utils.py Outdated
Comment thread pytensor/tensor/rewriting/linalg.py Outdated
Comment thread tests/tensor/rewriting/test_assumptions.py Outdated
return true_if(eye_is_identity(node))


@register_assumption(ORTHOGONAL, MatrixInverse)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use BlockwiseOf?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a helper in assumptions/blockwise.py that pushes all the assumptions through blockwise. Thoughts?

Copy link
Copy Markdown
Member

@lucianopaz lucianopaz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like the goal of the PR. I couldn't finish reviewing but I'll just leave the comments I've written down so far. I'll try to come back to this later and write something more decent.

Comment thread pytensor/tensor/assumptions/core.py Outdated
) -> list[FactState]:
"""Determine the *key* fact for every output of *node*.

Resolution order:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the opposite resolution order than the one used in blockwise. Which should be preferred?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in assumptions/blockwise.py or in the actual blockwise Op?

state = FactState(state)
cache_key = (var, key)
old = self.user_facts.get(cache_key, FactState.UNKNOWN)
new = FactState.join(old, state)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If old is CONFLICT, wont join also return a conflict? It wont' overwrite the old fact. Is this intended?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Related to below, I also hate the CONFLICT state. If we switch to True / Unknown, this is moot. But in this case this was intended, I was thinking about conflict as "nan" behaves in float

"""Return ``True`` iff the assumption is definitively TRUE for ``var``."""
return bool(self.get(var, key))

def set_user_fact(self, var: Any, key: AssumptionKey, state: FactState) -> None:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found the name confusing. I thought that it would have set the state without joining it with the existing state. I noticed the docstring after I looked at the following method, replace_user_fact

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah this is all over-engineered. I'm going to think about how to slim it down.

@register_assumption(DIAGONAL, DimShuffle)
def _dimshuffle(op, feature, fgraph, node, input_states):
if not input_states[0]:
return [FactState.UNKNOWN]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not return a input_states[0] directly? I'm having a hard time understanding the need for the FactState.FALSE. It looks like everything is either true or unknown.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I'm starting to agree. I included it because it what all the big boy logical systems do (wolfram, maple, etc). But for our purposes here, I can't think of a situation where knowing something is false gives us any useful information

Comment thread pytensor/tensor/assumptions/orthogonal.py
Comment thread pytensor/tensor/assumptions/diagonal.py Outdated

@register_assumption(DIAGONAL, Dot)
def _dot(op, feature, fgraph, node, input_states):
return true_if(all(input_states))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just came back to this after having looked at the orthogonal assumptions module. How would you handle the case where Q @ Q.T = eye? In other words, if the two inputs are orthogonal and one is the transpose of the other. Their product would produce an identity matrix. Would you be able to get the assumption from a different set (orthogonality) while working through diagonality?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. The system as it exists don't have good tooling for handling cross-facts like that. Need to pause and ponder.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the latest commit I showed how this would work. The system can actually handle cross-facts, but it will be important to explain the rules:

  • You register fact functions according to the output. So for your example of Q @ Q.T = eye, we register it as a diagonal fact.
  • The fact function takes the feature as an input, so you can always query against other facts, like feature.check(Q, ORTHOGONAL)

So the fact function looks like:

@register_assumption(DIAGONAL, Dot)
def _dot_orthogonal_xxt(op, feature, fgraph, node, input_states):
    """x @ x.T is diagonal (identity) when x is orthogonal."""
    a, b = node.inputs
    b_owner = b.owner
    if (
        feature.check(a, ORTHOGONAL)
        and b_owner is not None
        and isinstance(b_owner.op, DimShuffle)
        and b_owner.op.is_matrix_transpose
        and b_owner.inputs[0] is a
    ):
        return [FactState.TRUE]

    return [FactState.UNKNOWN]

@jessegrabowski jessegrabowski force-pushed the assumption-system branch 2 times, most recently from 8ac904d to 5767b43 Compare April 11, 2026 21:40
@zaxtax
Copy link
Copy Markdown
Contributor

zaxtax commented Apr 15, 2026

This proposal looks great! I'm totally in favour. Related to non-goals, are there any formal properties we would like to maintain about our system? Obviously users can introduce invalid rewrites, but in terms of the core implementation, any invariants and properties we want to maintain?

@jessegrabowski jessegrabowski force-pushed the assumption-system branch 2 times, most recently from 1d87d8e to cc01e9d Compare April 17, 2026 03:42
@jessegrabowski jessegrabowski force-pushed the assumption-system branch 2 times, most recently from fa48d6d to 4676b07 Compare April 18, 2026 22:08
@jessegrabowski
Copy link
Copy Markdown
Member Author

jessegrabowski commented Apr 18, 2026

Obviously users can introduce invalid rewrites, but in terms of the core implementation, any invariants and properties we want to maintain?

I think we want to be super defensive. I am already pushing it bit with this one:

def _dot(op, feature, fgraph, node, input_states):

I included it to be provocative and generate discussion. In most realistic cases it's true that X @ X.T will be positive definite, but we can never prove that X is full rank at graph construction time (well, unless it's data we can inspect...). We have a concept of "shape_unsafe" rewrites that might create invalid graphs if things shift underneath the rewrite. If we want to be more aggressive we would need to introduce something like that. My issue here is that "shape_unsafe" assumes that users know about shape unsafe and know to turn it off in what cases. I think that's a pretty heroic assumption.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants