System for algebraic reasoning about linear alegbra#2032
System for algebraic reasoning about linear alegbra#2032jessegrabowski wants to merge 50 commits intopymc-devs:mainfrom
Conversation
cece59a to
9f6e7f5
Compare
|
I like the top message (didn't look at the code). Two notes: 1Missing some discussion on how to preserve information across rewrites. ShapeFeature combines shape information when you replace a->b and you knew something of a and something of b (or just one of them). This is specially relevant for constant folding (e.g, But then, when do you decide to ask/start checking assumption? Because this can also be useless work (e.g. a graph without any linalg stuff). Ordering is hard 2Why do you assume that you can do all the reasoning you need to from op and inputs? It's a small note and I don't think a requirement/restriction in your proposal. But eg checking for tridiagonal matrix creation (which we do) requires checking 2/3 nested set_subtensor nodes. |
|
Very disappointed that German romanticism is out of scope But seriously, this looks amazing. I'm not really capable of reviewing this but I am very excited for this one. |
| try: | ||
| val = get_underlying_scalar_constant_value(node.inputs[0]) | ||
| if val == 0: | ||
| return [FactState.TRUE] |
There was a problem hiding this comment.
Do we want FactState=False? when it is not zero but still known constant?
| return output_grads | ||
|
|
||
|
|
||
| def specify_assumptions( |
There was a problem hiding this comment.
Shorter name? pytensor.assume?
(Also this need not be in tensor module, seems more generic that it)
There was a problem hiding this comment.
I think it could be shorter too, the other arguments a user has to give to it I think will make it clear what it does, so the name can be on the short side. Outside of tensors/linear algebra stuff, what kinda examples are you thinking of @ricardoV94?
There was a problem hiding this comment.
xtensor / scalar modules can also have assumptions, no reason it needs to be tied to tensor/
| return true_if(eye_is_identity(node)) | ||
|
|
||
|
|
||
| @register_assumption(ORTHOGONAL, MatrixInverse) |
There was a problem hiding this comment.
I have a helper in assumptions/blockwise.py that pushes all the assumptions through blockwise. Thoughts?
lucianopaz
left a comment
There was a problem hiding this comment.
I really like the goal of the PR. I couldn't finish reviewing but I'll just leave the comments I've written down so far. I'll try to come back to this later and write something more decent.
| ) -> list[FactState]: | ||
| """Determine the *key* fact for every output of *node*. | ||
|
|
||
| Resolution order: |
There was a problem hiding this comment.
This is the opposite resolution order than the one used in blockwise. Which should be preferred?
There was a problem hiding this comment.
in assumptions/blockwise.py or in the actual blockwise Op?
| state = FactState(state) | ||
| cache_key = (var, key) | ||
| old = self.user_facts.get(cache_key, FactState.UNKNOWN) | ||
| new = FactState.join(old, state) |
There was a problem hiding this comment.
If old is CONFLICT, wont join also return a conflict? It wont' overwrite the old fact. Is this intended?
There was a problem hiding this comment.
Related to below, I also hate the CONFLICT state. If we switch to True / Unknown, this is moot. But in this case this was intended, I was thinking about conflict as "nan" behaves in float
| """Return ``True`` iff the assumption is definitively TRUE for ``var``.""" | ||
| return bool(self.get(var, key)) | ||
|
|
||
| def set_user_fact(self, var: Any, key: AssumptionKey, state: FactState) -> None: |
There was a problem hiding this comment.
I found the name confusing. I thought that it would have set the state without joining it with the existing state. I noticed the docstring after I looked at the following method, replace_user_fact
There was a problem hiding this comment.
yeah this is all over-engineered. I'm going to think about how to slim it down.
| @register_assumption(DIAGONAL, DimShuffle) | ||
| def _dimshuffle(op, feature, fgraph, node, input_states): | ||
| if not input_states[0]: | ||
| return [FactState.UNKNOWN] |
There was a problem hiding this comment.
Why not return a input_states[0] directly? I'm having a hard time understanding the need for the FactState.FALSE. It looks like everything is either true or unknown.
There was a problem hiding this comment.
Yeah I'm starting to agree. I included it because it what all the big boy logical systems do (wolfram, maple, etc). But for our purposes here, I can't think of a situation where knowing something is false gives us any useful information
|
|
||
| @register_assumption(DIAGONAL, Dot) | ||
| def _dot(op, feature, fgraph, node, input_states): | ||
| return true_if(all(input_states)) |
There was a problem hiding this comment.
I just came back to this after having looked at the orthogonal assumptions module. How would you handle the case where Q @ Q.T = eye? In other words, if the two inputs are orthogonal and one is the transpose of the other. Their product would produce an identity matrix. Would you be able to get the assumption from a different set (orthogonality) while working through diagonality?
There was a problem hiding this comment.
Good question. The system as it exists don't have good tooling for handling cross-facts like that. Need to pause and ponder.
There was a problem hiding this comment.
In the latest commit I showed how this would work. The system can actually handle cross-facts, but it will be important to explain the rules:
- You register fact functions according to the output. So for your example of Q @ Q.T = eye, we register it as a diagonal fact.
- The fact function takes the feature as an input, so you can always query against other facts, like
feature.check(Q, ORTHOGONAL)
So the fact function looks like:
@register_assumption(DIAGONAL, Dot)
def _dot_orthogonal_xxt(op, feature, fgraph, node, input_states):
"""x @ x.T is diagonal (identity) when x is orthogonal."""
a, b = node.inputs
b_owner = b.owner
if (
feature.check(a, ORTHOGONAL)
and b_owner is not None
and isinstance(b_owner.op, DimShuffle)
and b_owner.op.is_matrix_transpose
and b_owner.inputs[0] is a
):
return [FactState.TRUE]
return [FactState.UNKNOWN]
8ac904d to
5767b43
Compare
|
This proposal looks great! I'm totally in favour. Related to non-goals, are there any formal properties we would like to maintain about our system? Obviously users can introduce invalid rewrites, but in terms of the core implementation, any invariants and properties we want to maintain? |
1d87d8e to
cc01e9d
Compare
fa48d6d to
4676b07
Compare
I think we want to be super defensive. I am already pushing it bit with this one: I included it to be provocative and generate discussion. In most realistic cases it's true that |
4016443 to
a7394bd
Compare
a7394bd to
894eb8e
Compare
This PR is a proposal for a typing system for linear algebra primitives. The purpose is to enable graph-wide reasoning about the kinds of matrices, so that we can rewrite to efficient computational forms.
Current State
We currently have several linear algebra rewrites and plan to add more. These are tracked in #573. This is important because linear algebra is 1) ubiquotous, 2) expensive, and 3) inscrutiable. Pytensor's static graph representation and rewrite system is well positioned to provide users help writing the best possible programs involving heavy linear algebra, if only we can figure out what is going on.
Consider the motivating case of
solve(A, b), whenA = pt.diag(pt.arange(100_000)). This is an O(n^3) operation that will call out to specialized routines. But there is no need for this. Since A is diagonal, we can write this as an elementwise divisionb / pt.extract_diag(A).How do we decide to do this? We:
SolveOpWhat seems diagonal? We assume an input is diagonal if it was created by
pt.eyeorpt.diag. Users cannot specify themselves whether input data is diagonal. If anOpget inbetween the known "diagonalish"Opand theSolve, we cannot detect diagonality. For example, we cannot rewritesolve(A * 3, b), because now the first input isElemwise(Mul)(A, 3). Multiplication is diagonal-preserving (because it is zero-preserving), but since the known diagonal op is now buried inside the Elemwise, we're out of luck.Proposal
My proposal is to reason about algebraic properties of matrices the same way we reason about shapes. For shapes, we attach a
ShapeFeaturetoFunctionGraphs. EachOphas aninfer_shapemethod that explains how the static shape propagates. Likewise, I propose anAssumptionFeature.Opsdo not haveinfer_assumptionmethods. That would be too messy. Instead, we have a centralASSUMPTION_INFER_REGISTRYwith keys(Op, AssumptionKey )and valuesInferFactFn.AssumptionKeyis just a marker class corresponding to an algebraic fact about a matrix, likeDIAGONAL,LOWER_TRIANGULAR,ORTHOGONAL,POSITIVE,SEMIDEFINITE, and so on.InferFactFnhas the following signature:Like other symbolic operations, the
InferFactFunctiontakes an Op (plus global information about the graph it lives in,fgraphandassumption_feature), and information about its inputs (the list ofFactStates) and returns a list of information about its outputs.A
FactStateis a three-valued logic for assumption inference. The possible values areUNKNOWN,TRUE, orFALSE. A fourth state,CONFLICT, exists, but should never arise.All facts about all Ops are assumed to be
UNKNOWNunless we can prove otherwise. Proof comes from each Op's registeredInferFactFunctions. TheAssumptionFeatureis responsible for gathering all the rules of fact propagation. An example of a simple fact is that allEyeOpsareDIAGONAL, provided it is 1) square and 2) offset of zero:In a program, we can use the
AssumptionFeature.get(x, AssumptionKey)to query about the state of aVariable. Here, we ask "is x diagonal?". Obviously it is:Where this becomes powerful is by accumulating
InferFactFunctions. ManyOpspreserve diagonality, likeCholeskyorInverse. We can use information about the inputs to theseOpsto propogate fact information through the graph:Now we don't lose information about
xin deeper graphs, and are free to do more rewrites:Of course can also reason conditionally. An
IncSubtensormight be diagonal-preserving if we can prove that we're setting a value on the diagonal of the matrix. Otherwise we fall back toUNKNOWN:Facts can also imply other facts.
DIAGONALmatrices are also symmetrical. These general relationships can be registered and encoded as well. Continuing the example above:Finally, users can specify facts about matrices using
pt.specify_assumptions, the same way they are able to specify shapes.Benefits for rewriting:
FactStatesare trivial to check. We can check any fact about any Variable in 5 lines:We can reason globally about the graph
As noted above, information flows through the graph. As long as we have good coverage of fact rules for Ops, we can make statements about Variables at all levels of computation
InferFactFunctions are lightweight and easy to writeIt is trivial to add new
InferFactFunctions. LLMs can bang them out. They require only local reasoning about theOpand its immediate inputs.FactStatesallow non-trivial combinations of rewritesOne example I hit while working on this was a rewrite for
DirectSolveLyapunovgiven diagonal inputs. Because there is a rule thatkronpreserves diagonality if both inputs are diagonal, the chain of rewrites fromsolve_discrete_lyapunov(diag(a), Q) -> Q.ravel() / (1 - outer(a, a).ravel())is discovered by the rewrite system via:rewrite_kron_diag_to_diag_outer: kron(diag(a), diag(b)) → diag(outer(a, b).ravel())rewrite_solve_diag_to_division: solve(diag(x), b) → b / xThe key is that after the first rewrite produces a diagonal matrix (via alloc_diag), the assumption system recognizes it as diagonal (via AllocDiag being registered with DIAGONAL), and then the second rewrite kicks in.
Non-Goals
The purpose of this system is not to introduce an complete, closed algebra over all types. That is impossible. The goal is also not to complete the project of German romanticism. That is also impossible.
Assuming). Assumptions live on FunctionGraphs. There is never a need to to deal with global context, logical combinations, relational assumptions.FactStateof theNode. We check during rewrites and that's it.LinearAlgebra[IsDefinite]. No symbolic conditionals as part of the core API.