System for algebraic reasoning about linear alegbra#2032
Conversation
cece59a to
9f6e7f5
Compare
|
I like the top message (didn't look at the code). Two notes: 1Missing some discussion on how to preserve information across rewrites. ShapeFeature combines shape information when you replace a->b and you knew something of a and something of b (or just one of them). This is specially relevant for constant folding (e.g, But then, when do you decide to ask/start checking assumption? Because this can also be useless work (e.g. a graph without any linalg stuff). Ordering is hard 2Why do you assume that you can do all the reasoning you need to from op and inputs? It's a small note and I don't think a requirement/restriction in your proposal. But eg checking for tridiagonal matrix creation (which we do) requires checking 2/3 nested set_subtensor nodes. |
|
Very disappointed that German romanticism is out of scope But seriously, this looks amazing. I'm not really capable of reviewing this but I am very excited for this one. |
| try: | ||
| val = get_underlying_scalar_constant_value(node.inputs[0]) | ||
| if val == 0: | ||
| return [FactState.TRUE] |
There was a problem hiding this comment.
Do we want FactState=False? when it is not zero but still known constant?
| return output_grads | ||
|
|
||
|
|
||
| def specify_assumptions( |
There was a problem hiding this comment.
Shorter name? pytensor.assume?
(Also this need not be in tensor module, seems more generic that it)
There was a problem hiding this comment.
I think it could be shorter too, the other arguments a user has to give to it I think will make it clear what it does, so the name can be on the short side. Outside of tensors/linear algebra stuff, what kinda examples are you thinking of @ricardoV94?
There was a problem hiding this comment.
xtensor / scalar modules can also have assumptions, no reason it needs to be tied to tensor/
| return true_if(eye_is_identity(node)) | ||
|
|
||
|
|
||
| @register_assumption(ORTHOGONAL, MatrixInverse) |
There was a problem hiding this comment.
I have a helper in assumptions/blockwise.py that pushes all the assumptions through blockwise. Thoughts?
lucianopaz
left a comment
There was a problem hiding this comment.
I really like the goal of the PR. I couldn't finish reviewing but I'll just leave the comments I've written down so far. I'll try to come back to this later and write something more decent.
| ) -> list[FactState]: | ||
| """Determine the *key* fact for every output of *node*. | ||
|
|
||
| Resolution order: |
There was a problem hiding this comment.
This is the opposite resolution order than the one used in blockwise. Which should be preferred?
There was a problem hiding this comment.
in assumptions/blockwise.py or in the actual blockwise Op?
| state = FactState(state) | ||
| cache_key = (var, key) | ||
| old = self.user_facts.get(cache_key, FactState.UNKNOWN) | ||
| new = FactState.join(old, state) |
There was a problem hiding this comment.
If old is CONFLICT, wont join also return a conflict? It wont' overwrite the old fact. Is this intended?
There was a problem hiding this comment.
Related to below, I also hate the CONFLICT state. If we switch to True / Unknown, this is moot. But in this case this was intended, I was thinking about conflict as "nan" behaves in float
|
|
||
| @register_assumption(DIAGONAL, Dot) | ||
| def _dot(op, feature, fgraph, node, input_states): | ||
| return true_if(all(input_states)) |
There was a problem hiding this comment.
I just came back to this after having looked at the orthogonal assumptions module. How would you handle the case where Q @ Q.T = eye? In other words, if the two inputs are orthogonal and one is the transpose of the other. Their product would produce an identity matrix. Would you be able to get the assumption from a different set (orthogonality) while working through diagonality?
There was a problem hiding this comment.
Good question. The system as it exists don't have good tooling for handling cross-facts like that. Need to pause and ponder.
There was a problem hiding this comment.
In the latest commit I showed how this would work. The system can actually handle cross-facts, but it will be important to explain the rules:
- You register fact functions according to the output. So for your example of Q @ Q.T = eye, we register it as a diagonal fact.
- The fact function takes the feature as an input, so you can always query against other facts, like
feature.check(Q, ORTHOGONAL)
So the fact function looks like:
@register_assumption(DIAGONAL, Dot)
def _dot_orthogonal_xxt(op, feature, fgraph, node, input_states):
"""x @ x.T is diagonal (identity) when x is orthogonal."""
a, b = node.inputs
b_owner = b.owner
if (
feature.check(a, ORTHOGONAL)
and b_owner is not None
and isinstance(b_owner.op, DimShuffle)
and b_owner.op.is_matrix_transpose
and b_owner.inputs[0] is a
):
return [FactState.TRUE]
return [FactState.UNKNOWN]
8ac904d to
5767b43
Compare
|
This proposal looks great! I'm totally in favour. Related to non-goals, are there any formal properties we would like to maintain about our system? Obviously users can introduce invalid rewrites, but in terms of the core implementation, any invariants and properties we want to maintain? |
1d87d8e to
cc01e9d
Compare
fa48d6d to
4676b07
Compare
I think we want to be super defensive. I am already pushing it bit with this one: I included it to be provocative and generate discussion. In most realistic cases it's true that |
4016443 to
a7394bd
Compare
769df10 to
54758e8
Compare
7d6b5b3 to
95cceda
Compare
`_is_provably_positive(var, strict=...)` proves `var > 0` (strict) or `var >= 0` (non-strict); `_is_provably_non_negative` becomes a thin `strict=False` wrapper. The `uint`-dtype, `Shape`/`Shape_i`, and `Cast` cases prove non-negativity but not strict positivity, so they are recognized only when `strict=False`. Results cache on `tag.is_positive` or `tag.is_non_negative` per the flag.
d94d905 to
6d4187c
Compare
|
I added a import pytensor
import pytensor.tensor as pt
from pytensor.assumptions import assume
X = pt.matrix("X")
ls = pt.scalar("ls")
eta = pt.scalar("eta")
v = pt.vector("v")
sq = pt.sum(X**2, axis=1)
D2 = sq[:, None] + sq[None, :] - 2 * (X @ X.T)
K = eta**2 * pt.exp(-0.5 * D2 / ls**2) + 1e-6 * pt.eye(X.shape[0])
# domain knowledge the structural engine can't derive: a valid kernel is PD
K = assume(K, positive_definite=True)
n = X.shape[0]
logp = (
-0.5 * (v @ pt.linalg.inv(K) @ v)
- 0.5 * pt.log(pt.linalg.det(K))
- 0.5 * n * pt.log(2 * pt.pi)
)
f = pytensor.function([X, ls, eta, v], logp)
f.dprint(print_assumptions=True)Output: We can see the gram matrix (id R) was automatically tagged as positive definite |
d481ac9 to
f4bb66f
Compare
|
I reorganized the tests in the way that I like (file |
f4bb66f to
7e999ff
Compare
| def eye_identity_rule(key, op, feature, fgraph, node, input_states) -> list[FactState]: | ||
| """Rule body: TRUE when an :class:`Eye` node produces the identity matrix (square, k == 0).""" | ||
| n, m, k = node.inputs | ||
| if not (isinstance(k, TensorConstant) and k.data.item() == 0): |
There was a problem hiding this comment.
Similarly, if it's a constant you can prove True/False, only if not is it unknown?
There was a problem hiding this comment.
Yes, but the case was a bit more subtle because this utility was used by diagonal, posdef, and triangular. The rule only returns False as written for posdef and diagonal, even if k !=0 we are stilll triangular. I split the helpers.
920e5d7 to
adbcf32
Compare
f6aa532 to
b4c1a3c
Compare
Per-Op inference rules for `DIAGONAL`, `LOWER_TRIANGULAR`, `UPPER_TRIANGULAR`, `SYMMETRIC`, `POSITIVE_DEFINITE`, and `ORTHOGONAL`, queryable from rewrites via `check_assumption(fgraph, var, key)`. Facts are inferred lazily and cached on a `FunctionGraph` feature attached on first use. `pt.assume(x, diagonal=True, ...)` asserts facts directly via a `SpecifyAssumptions` no-op wrapper. Implication graph wires `DIAGONAL ⇒ LOWER_/UPPER_TRIANGULAR, SYMMETRIC` and `POSITIVE_DEFINITE ⇒ SYMMETRIC`. Migrate `cholesky_ldotlt`, `cholesky_of_diag`, `inv_of_diag_to_diag_reciprocal`, `det_of_diag`, `psd_solve_to_chol_solve`, `useless_symmetric_transpose`, and `inv_to_solve` to `check_assumption`, keeping the original tag fallbacks alongside for PyMC compatibility. Add `paired_triangular_solves_to_cho_solve` to fuse the Cholesky→solve pair the migrated PSD path leaves behind. Co-authored-by: Ricardo Vieira <ricardo.vieira1994@gmail.com>
`SpecifyShape`, `Reshape`, `JoinDims`, `SplitDims`, `Alloc`, `DimShuffle`, and `IncSubtensor` only rearrange batch axes when they leave the trailing two (core) axes intact. A per-matrix property then carries through, so each new rule forwards the input fact in that case and reports UNKNOWN otherwise. Co-authored-by: Ricardo Vieira <ricardo.vieira1994@gmail.com>
Co-authored-by: Ricardo Vieira <ricardo.vieira1994@gmail.com>
98e9b5a to
6c8cb31
Compare
This PR is a proposal for a typing system for linear algebra primitives. The purpose is to enable graph-wide reasoning about the kinds of matrices, so that we can rewrite to efficient computational forms.
Current State
We currently have several linear algebra rewrites and plan to add more. These are tracked in #573. This is important because linear algebra is 1) ubiquotous, 2) expensive, and 3) inscrutiable. Pytensor's static graph representation and rewrite system is well positioned to provide users help writing the best possible programs involving heavy linear algebra, if only we can figure out what is going on.
Consider the motivating case of
solve(A, b), whenA = pt.diag(pt.arange(100_000)). This is an O(n^3) operation that will call out to specialized routines. But there is no need for this. Since A is diagonal, we can write this as an elementwise divisionb / pt.extract_diag(A).How do we decide to do this? We:
SolveOpWhat seems diagonal? We assume an input is diagonal if it was created by
pt.eyeorpt.diag. Users cannot specify themselves whether input data is diagonal. If anOpget inbetween the known "diagonalish"Opand theSolve, we cannot detect diagonality. For example, we cannot rewritesolve(A * 3, b), because now the first input isElemwise(Mul)(A, 3). Multiplication is diagonal-preserving (because it is zero-preserving), but since the known diagonal op is now buried inside the Elemwise, we're out of luck.Proposal
My proposal is to reason about algebraic properties of matrices the same way we reason about shapes. For shapes, we attach a
ShapeFeaturetoFunctionGraphs. EachOphas aninfer_shapemethod that explains how the static shape propagates. Likewise, I propose anAssumptionFeature.Opsdo not haveinfer_assumptionmethods. That would be too messy. Instead, we have a centralASSUMPTION_INFER_REGISTRYwith keys(Op, AssumptionKey )and valuesInferFactFn.AssumptionKeyis just a marker class corresponding to an algebraic fact about a matrix, likeDIAGONAL,LOWER_TRIANGULAR,ORTHOGONAL,POSITIVE,SEMIDEFINITE, and so on.InferFactFnhas the following signature:Like other symbolic operations, the
InferFactFunctiontakes an Op (plus global information about the graph it lives in,fgraphandassumption_feature), and information about its inputs (the list ofFactStates) and returns a list of information about its outputs.A
FactStateis a three-valued logic for assumption inference. The possible values areUNKNOWN,TRUE, orFALSE. A fourth state,CONFLICT, exists, but should never arise.All facts about all Ops are assumed to be
UNKNOWNunless we can prove otherwise. Proof comes from each Op's registeredInferFactFunctions. TheAssumptionFeatureis responsible for gathering all the rules of fact propagation. An example of a simple fact is that allEyeOpsareDIAGONAL, provided it is 1) square and 2) offset of zero:In a program, we can use the
AssumptionFeature.get(x, AssumptionKey)to query about the state of aVariable. Here, we ask "is x diagonal?". Obviously it is:Where this becomes powerful is by accumulating
InferFactFunctions. ManyOpspreserve diagonality, likeCholeskyorInverse. We can use information about the inputs to theseOpsto propogate fact information through the graph:Now we don't lose information about
xin deeper graphs, and are free to do more rewrites:Of course can also reason conditionally. An
IncSubtensormight be diagonal-preserving if we can prove that we're setting a value on the diagonal of the matrix. Otherwise we fall back toUNKNOWN:Facts can also imply other facts.
DIAGONALmatrices are also symmetrical. These general relationships can be registered and encoded as well. Continuing the example above:Finally, users can specify facts about matrices using
pt.specify_assumptions, the same way they are able to specify shapes.Benefits for rewriting:
FactStatesare trivial to check. We can check any fact about any Variable in 5 lines:We can reason globally about the graph
As noted above, information flows through the graph. As long as we have good coverage of fact rules for Ops, we can make statements about Variables at all levels of computation
InferFactFunctions are lightweight and easy to writeIt is trivial to add new
InferFactFunctions. LLMs can bang them out. They require only local reasoning about theOpand its immediate inputs.FactStatesallow non-trivial combinations of rewritesOne example I hit while working on this was a rewrite for
DirectSolveLyapunovgiven diagonal inputs. Because there is a rule thatkronpreserves diagonality if both inputs are diagonal, the chain of rewrites fromsolve_discrete_lyapunov(diag(a), Q) -> Q.ravel() / (1 - outer(a, a).ravel())is discovered by the rewrite system via:rewrite_kron_diag_to_diag_outer: kron(diag(a), diag(b)) → diag(outer(a, b).ravel())rewrite_solve_diag_to_division: solve(diag(x), b) → b / xThe key is that after the first rewrite produces a diagonal matrix (via alloc_diag), the assumption system recognizes it as diagonal (via AllocDiag being registered with DIAGONAL), and then the second rewrite kicks in.
Non-Goals
The purpose of this system is not to introduce an complete, closed algebra over all types. That is impossible. The goal is also not to complete the project of German romanticism. That is also impossible.
Assuming). Assumptions live on FunctionGraphs. There is never a need to to deal with global context, logical combinations, relational assumptions.FactStateof theNode. We check during rewrites and that's it.LinearAlgebra[IsDefinite]. No symbolic conditionals as part of the core API.