feat(dpa4): multiple updates for DPA4/SeZM#5503
Conversation
feat(sezm): add activation checkpoint option feat(sezm): add cross node tensor product feat(sezm): add message node tensor product feat(sezm): seperate node lmax with edge lmax feat(sezm): add custom kernel for lmax=5-10 feat(sezm): add so3 grid projection refactor: change default values feat(sezm): tf32 infer
There was a problem hiding this comment.
Pull request overview
Note
Copilot was unable to run its full agentic suite in this review.
This PR expands the PyTorch backend’s neighbor-list strategy support (adding an NVIDIA Toolkit-Ops O(N) option) and updates SeZM/DPA4 configuration, docs, and tests to reflect new grid features and inference-time env controls.
Changes:
- Add
NvNeighborListstrategy (Toolkit-Ops /nvalchemi-toolkit-ops) plus dedicated unit tests and broadennlist_backenddispatch/equivalence tests to cover multiple strategies. - Extend SeZM/DPA4 descriptor argument schema (e.g.,
extra_node_l,kmax,ffn_so3_grid, grid branch controls, and validation-timetf32_infer) and wire env-var translation in the PT trainer. - Refactor/replace several SeZM grid/nonlinearity components and update example inputs + DPA4 docs accordingly; relax certain numeric tolerances in tests.
Reviewed changes
Copilot reviewed 47 out of 47 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| source/tests/pt/model/test_sezm_model.py | Adjust tolerances for compiled/eager comparisons (GPU). |
| source/tests/pt/model/test_nv_nlist.py | New unit tests for NvNeighborList edge topology/geometry equivalence and error behavior. |
| source/tests/pt/model/test_nlist_backend.py | Generalize backend strategy tests (vesin/nv) and improve dispatch coverage. |
| source/tests/pt/model/test_descriptor_sezm_s2_equivariance.py | Removes S2 equivariance tests (file deleted). |
| source/tests/pt/model/test_descriptor_sezm.py | Add tests for new zonal gather and special Wigner blocks; adjust Wigner tests and remove one deserialize regression test. |
| examples/water/dpa4/lora_ft.json | Update example descriptor/optimizer settings to newer compact DPA4 config. |
| examples/water/dpa4/input_multitask.json | Update multitask example config to newer compact DPA4 config. |
| examples/water/dpa4/input_dens.json | Update dens example config to newer compact DPA4 config. |
| examples/water/dpa4/input.json | Update baseline example config and add validating inference flags. |
| examples/water/dpa4/input-zbl.json | Update ZBL example config to newer compact DPA4 config. |
| examples/water/dpa4/input-spin.json | Update spin example config to newer compact DPA4 config. |
| examples/water/dpa4/README.md | Align README text with updated “compact/DPA4-Neo-style” examples. |
| doc/model/dpa4.md | Major DPA4/SeZM documentation expansion and updates, including env var behavior. |
| deepmd/utils/argcheck.py | Add get_argument, extend DPA4/SeZM schema and validation options, adjust defaults/variants. |
| deepmd/pt/utils/nv_nlist.py | New NVIDIA Toolkit-Ops neighbor-list strategy implementation. |
| deepmd/pt/train/training.py | Translate validating.* options into env vars before model construction. |
| deepmd/pt/model/model/init.py | Default descriptor.type to dpa4 in SeZM scaffolds. |
| deepmd/pt/model/descriptor/sezm_nn/triton/kernels_generic.py | Remove old Triton kernels (file deleted). |
| deepmd/pt/model/descriptor/sezm_nn/triton/kernels_edge_geometry_rbf.py | Remove old Triton kernels (file deleted). |
| deepmd/pt/model/descriptor/sezm_nn/triton/dispatch.py | Remove old Triton dispatch (file deleted). |
| deepmd/pt/model/descriptor/sezm_nn/triton/constants.py | Remove old Triton constants/flags (file deleted). |
| deepmd/pt/model/descriptor/sezm_nn/triton/autograd.py | Remove old Triton autograd wrapper/API (file deleted). |
| deepmd/pt/model/descriptor/sezm_nn/triton/init.py | Switch Triton exports to new compile-composable rotation API. |
| deepmd/pt/model/descriptor/sezm_nn/so2.py | Extend SO(2) conv with grid nets/branches + inference-time Triton rotation toggle. |
| deepmd/pt/model/descriptor/sezm_nn/projection.py | New projection utilities (S2 + SO3 projectors and grid resolution helpers). |
| deepmd/pt/model/descriptor/sezm_nn/indexing.py | Add build_gie_zonal_index helper for zonal coupling gather. |
| deepmd/pt/model/descriptor/sezm_nn/ffn.py | Refactor FFN nonlinearity onto new grid-net abstractions and add SO3-grid option. |
| deepmd/pt/model/descriptor/sezm_nn/embedding.py | Use new zonal index helper and allow passing precomputed zonal coupling. |
| deepmd/pt/model/descriptor/sezm_nn/edge_cache.py | Remove geometry/RBF Triton path and simplify edge cache construction. |
| deepmd/pt/model/descriptor/sezm_nn/cute/init.py | New CuTe accelerated SO(2) rotation package exports. |
| deepmd/pt/model/descriptor/sezm_nn/block.py | Add node/message grid branches, node-lmax decoupling, and inference checkpointing hooks. |
| deepmd/pt/model/descriptor/sezm_nn/activation.py | Remove grid-projector/grid-activation implementations (moved/refactored). |
| deepmd/pt/model/descriptor/sezm_nn/init.py | Re-export new grid/projector utilities and indexing helpers. |
| deepmd/pt/model/atomic_model/sezm_atomic_model.py | Use node_l_schedule if present for dens-head reconstruction. |
| deepmd/pt/infer/deep_eval.py | Add nv strategy handling, refactor strategy eval path, and broaden backend validation. |
| deepmd/pt/entrypoints/freeze_pt2.py | Patch Inductor config (triton.max_tiles) during AOTI packaging. |
| backend/find_pytorch.py | Add optional dependency for Toolkit-Ops neighbor list on Python >= 3.11. |
Comments suppressed due to low confidence (4)
deepmd/utils/argcheck.py:1
- The
extra_checkcan raiseTypeErrorfor non-bool, non-sequence inputs (e.g.,len(1)), and it also accepts any length-3 sequence without verifying element types arebool. Make the check robust by first verifyingisinstance(x, list)(orSequence) and then validatinglen(x) == 3andall(isinstance(v, bool) for v in x).
deepmd/pt/utils/nv_nlist.py:1 NvNeighborListis a GPU Toolkit-Ops path, butbuild()does not enforce CUDA tensors. If called with CPU tensors (or in a CPU-only environment), this will likely fail insidenvalchemiopswith a less actionable error. Consider adding an explicit check likecoord.device.type == 'cuda'(and similarly foratype/box) and raising a clearValueErrorwhen not on CUDA.
deepmd/pt/utils/nv_nlist.py:1- This builds a flattened
validmask of sizetotal_atoms * max_neighborsand then materializesedge_idxvianonzero, which can be very memory/latency heavy for large systems (the primary target for this strategy). Consider constructing the flattened edge indices without forming the full 2D mask (e.g., usingrepeat_interleave(num_neighbors)to builddstindices andarange(sum(num_neighbors))-based offsets to buildslotindices), so the intermediate scales with the number of real edges rather than the dense capacity.
source/tests/pt/model/test_descriptor_sezm.py:1 - A deserialize/backward-compatibility regression test was removed in this area, and the PR also deletes
test_descriptor_sezm_s2_equivariance.pyentirely. Since this PR introduces substantial new grid/projection behavior and new serialization-relevant config keys (e.g.,extra_node_l, grid branches, SO3-grid options), it would be good to retain equivalent coverage: add/restore tests that (1) validate descriptors deserialize correctly when older config fields are missing and (2) assert key equivariance/invariance properties for the new grid-net paths.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
📝 WalkthroughWalkthroughAdds SeZM grid/projector/net stack, replaces Triton rotation plumbing with new Triton/CuTe rotation operators, adds NV Toolkit-Ops neighbor-list backend and selection, adjusts compile/export/inference TF32/Triton options, extends descriptor/config/docs, and updates/introduces comprehensive tests. ChangesSeZM grid, descriptor, and config surface
Descriptor and grid implementation
Rotation kernels, Triton/CuTe, and triton package surface
Neighbor-list backend, model compile/export, and runtime
Sequence Diagram(s)sequenceDiagram
participant DeepEval
participant SeZMModel
participant NvNeighborList
DeepEval->>SeZMModel: _setup_nlist_backend(nlist_backend, auto)
SeZMModel->>NvNeighborList: NvNeighborList.build(coord, atype, box, rcut, sel)
DeepEval->>SeZMModel: _eval_lower_strategy(... using _nlist_builder)
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Possibly related PRs
Suggested reviewers
✨ Finishing Touches🧪 Generate unit tests (beta)
|
There was a problem hiding this comment.
Actionable comments posted: 11
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
examples/water/dpa4/input.json (1)
1-124:⚠️ Potential issue | 🟠 Major | ⚡ Quick winUse
input_torch.jsonfor this PyTorch example.This file is a PyTorch training config (
model.type: "DPA4"). Keeping it asinput.jsonconflicts with the repo rule for TensorFlowinput.jsonusage and creates command ambiguity for users/docs.As per coding guidelines,
**/input.json: TensorFlow backend training configuration should useinput.jsonformat and be invoked withdp train input.json.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@examples/water/dpa4/input.json` around lines 1 - 124, The config is a PyTorch training config (model.type: "DPA4") but is named as the TensorFlow-style input.json, causing ambiguity; rename this file to input_torch.json (and update any references) so PyTorch examples use input_torch.json, ensure any documentation or scripts that reference this example (e.g., training invocation) point to input_torch.json instead of input.json, and verify consumers expect a PyTorch config when parsing model.type == "DPA4".Source: Coding guidelines
🧹 Nitpick comments (1)
source/tests/pt/model/test_descriptor_sezm_triton.py (1)
76-177: ⚡ Quick winThe new dense rotation path still has no coverage.
Every fixture builds
coeff_index = build_m_major_index(lmax, 1, ...), sorotate_to_local/rotate_backalways auto-select the block-diagonal kernels. The dense kernels, inverse-index path, and dense registered backwards introduced inso2_rotation.pycan regress without any test failing. Please add at least one non-mmax==1case or direct*_denseassertions.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@source/tests/pt/model/test_descriptor_sezm_triton.py` around lines 76 - 177, Tests always build coeff_index via build_m_major_index(lmax, 1, ...) so rotate_to_local and rotate_back always pick the block-diagonal (mmax==1) path; add a test case that constructs a coeff_index with mmax>1 (or directly call the *_dense variants) so the dense/inverse-index code paths and their registered backwards in so2_rotation.py are exercised; update test_rotate_to_local_matches_reference and test_rotate_back_matches_reference (or add a new subTest) to call build_m_major_index(lmax, mmax>1, ...) or to invoke rotate_to_local_dense / rotate_back_dense equivalents and assert outputs and gradients match the *_reference functions (and still apply the mask checks for dense behavior).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/pt/infer/deep_eval.py`:
- Around line 309-343: The nvalchemi ('nv') backend selection in the
_setup_nlist_backend logic (symbols: nlist_backend, builder, NvNeighborList,
VesinNeighborList, is_nv_available, is_vesin_torch_available) must be gated on
the actual inference device (DEVICE from deepmd/pt/utils/env.py) rather than
host CUDA availability; update both the explicit "nv" branch and the "auto"
branch to check DEVICE.type == "cuda" (or the device of the input tensors)
before choosing NvNeighborList and, if DEVICE.type != "cuda" but nlist_backend
== "nv" was requested, raise a clear ValueError explaining that nv requires
CUDA-device inference; keep the existing checks for vesin (using
is_vesin_torch_available) but ensure vesin selection also verifies the
input/DEVICE compatibility if needed.
In `@deepmd/pt/model/descriptor/sezm_nn/__init__.py`:
- Around line 45-51: The module removed package-level re-exports (notably
SwiGLUS2Activation) which breaks downstream imports; restore a deprecated
re-export shim in __init__.py that re-exports the original symbol name (e.g.,
SwiGLUS2Activation) while importing/forwarding to the new implementation
location or wrapper and emit a DeprecationWarning when the shim is used; apply
the same pattern for the other removed symbols referenced in the file (the
exports around the later ranges) so callers keep compatibility for one release
before fully removing them.
In `@deepmd/pt/model/descriptor/sezm_nn/cute/so2_rotation.py`:
- Around line 833-839: The CuTe fast-path guard (_cute_usable) must also verify
kernel channel-width constraints: reject CuTe when channel count C would produce
invalid launch geometry or tails; update _cute_usable to compute the channel
dimension C (from the relevant tensor shape used by so2_rotation) and return
False unless C >= _TN and C % _TN == 0 (i.e., divisible by the kernel tile size
_TN), in addition to the existing checks (SEZM_CUTE_AVAILABLE, is_cuda, dtype);
apply the same additional gating logic to the other CuTe guard sites referenced
(the checks around lines 875-879 and 912-916) so they all fall back to the eager
path when channel widths are unsupported.
In `@deepmd/pt/model/descriptor/sezm_nn/so2.py`:
- Around line 1327-1343: The S2GridNet for the edge-local S2 branch
(instantiated as node_wise_grid_product) is using the default "packed"
coefficient layout while x_local and x_dst_local are reduced m-major tensors;
change the S2GridNet constructor to pass layout="m_major" so the projector
matches the m-major ordering; also update the other analogous S2GridNet
instantiation used for the edge-local branch later in the file (the second
S2GridNet for the same local branch) to use layout="m_major" as well.
In `@deepmd/pt/model/descriptor/sezm_nn/triton/so2_rotation.py`:
- Around line 1168-1191: The dispatch helper _block_layout_lmax currently only
checks shapes and may misidentify non-canonical reduced layouts as the m-major
mmax=1 block layout; update _block_layout_lmax to also validate that coeff_index
contains the canonical ordering produced by build_m_major_index(lmax, 1) (i.e.,
compare coeff_index values to the expected index tensor using value-safe
operations) and return -1 if it does not match so the code falls back to the
dense path or raises; apply the same value-validation wherever you currently use
this shape-only detection (the rotate_to_local / rotate_back block dispatch and
the explicit *_block helpers) so block kernels are only used for truly canonical
layouts.
- Line 649: The loop variable named `l` used in the `for ... in
tl.static_range(0, LMAX + 1):` constructs triggers Ruff E741; rename the
lone-letter variable to a non-ambiguous identifier (e.g. `ell` or `l_idx`) in
every occurrence in this module (all loops using `tl.static_range(0, LMAX + 1)`
including the other three occurrences flagged) and update any inner references
accordingly so the code compiles and passes `ruff check .`.
In `@deepmd/pt/model/descriptor/sezm.py`:
- Around line 1491-1524: The concatenation in _build_gie_zonal_coupling mixes
frames because mp_coupling comes from edge_cache.Dt_full (which was built with
training-time random-γ) while extra_coupling is recomputed from raw edge_vec
without that γ; fix by reusing the same augmented quaternion/frame used for
Dt_full when computing extra_coupling. Concretely, obtain the edge
quaternion/frame that matches Dt_full (e.g. edge_cache.edge_quat or
edge_cache.random_gamma-applied edge_vec) instead of recomputing from raw
edge_vec, then call gie_zonal_wigner_calc.forward_zonal with that quaternion so
extra_coupling and mp_coupling share the same local frame before concatenation;
update references in _build_gie_zonal_coupling, mp_coupling, extra_coupling,
build_edge_quaternion and gie_zonal_wigner_calc.forward_zonal accordingly.
- Line 624: The list comprehension assigning self.ebed_dims uses the ambiguous
one-letter variable name "l" which triggers Ruff E741; update the comprehension
in the expression self.ebed_dims = [get_so3_dim_of_lmax(l) for l in
self.l_schedule] to use a clear, unambiguous variable name (e.g., "ell" or
"l_val") and update the call to get_so3_dim_of_lmax accordingly so it becomes
self.ebed_dims = [get_so3_dim_of_lmax(ell) for ell in self.l_schedule], ensuring
only the variable name is changed and no other behavior is altered.
- Around line 619-623: The kmax validation currently compares to the raw
constructor arg int(lmax) instead of the resolved schedule max; update the check
so self.kmax is validated against the resolved effective lmax (self.lmax) after
l_schedule is applied — i.e., ensure l_schedule resolution runs before
validating self.kmax and replace the comparison using int(lmax) with a
comparison to self.lmax (or int(self.lmax) if needed) in the initialization code
around self.kmax and self.lmax.
- Around line 846-848: The zip calls over the schedules should enforce equal
lengths: update the loop zip(self.l_schedule, self.node_l_schedule,
self.m_schedule) inside the SEZM descriptor and the zip(self.m_schedule,
self.l_schedule) used in the any(m > l for m, l in zip(...)) check to include
strict=True (e.g., zip(..., strict=True)) so mismatched schedule lengths raise
immediately; locate these by referencing the attributes self.l_schedule,
self.node_l_schedule, self.m_schedule and the any(...) check and add the
strict=True argument to each zip call.
In `@deepmd/pt/train/training.py`:
- Around line 184-192: The code uses os.environ.setdefault for
DP_COMPILE_INFER/DP_TF32_INFER (based on validating_params) which makes these
flags process-global and sticky; change this to temporarily set the env vars
only while constructing the SeZMModel instance and then restore the previous
environment values immediately after construction so other Trainer instances are
not affected—specifically, around the SeZMModel creation in training.py capture
os.environ.get for DP_COMPILE_INFER and DP_TF32_INFER, set/override them
according to validating_params, instantiate SeZMModel, and finally restore the
saved values (or delete if previously absent).
---
Outside diff comments:
In `@examples/water/dpa4/input.json`:
- Around line 1-124: The config is a PyTorch training config (model.type:
"DPA4") but is named as the TensorFlow-style input.json, causing ambiguity;
rename this file to input_torch.json (and update any references) so PyTorch
examples use input_torch.json, ensure any documentation or scripts that
reference this example (e.g., training invocation) point to input_torch.json
instead of input.json, and verify consumers expect a PyTorch config when parsing
model.type == "DPA4".
---
Nitpick comments:
In `@source/tests/pt/model/test_descriptor_sezm_triton.py`:
- Around line 76-177: Tests always build coeff_index via
build_m_major_index(lmax, 1, ...) so rotate_to_local and rotate_back always pick
the block-diagonal (mmax==1) path; add a test case that constructs a coeff_index
with mmax>1 (or directly call the *_dense variants) so the dense/inverse-index
code paths and their registered backwards in so2_rotation.py are exercised;
update test_rotate_to_local_matches_reference and
test_rotate_back_matches_reference (or add a new subTest) to call
build_m_major_index(lmax, mmax>1, ...) or to invoke rotate_to_local_dense /
rotate_back_dense equivalents and assert outputs and gradients match the
*_reference functions (and still apply the mask checks for dense behavior).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 4fc847a9-e9f4-477b-a673-749e1a3a07e5
📒 Files selected for processing (47)
backend/find_pytorch.pydeepmd/pt/entrypoints/freeze_pt2.pydeepmd/pt/infer/deep_eval.pydeepmd/pt/model/atomic_model/sezm_atomic_model.pydeepmd/pt/model/descriptor/sezm.pydeepmd/pt/model/descriptor/sezm_nn/__init__.pydeepmd/pt/model/descriptor/sezm_nn/activation.pydeepmd/pt/model/descriptor/sezm_nn/block.pydeepmd/pt/model/descriptor/sezm_nn/cute/__init__.pydeepmd/pt/model/descriptor/sezm_nn/cute/so2_rotation.pydeepmd/pt/model/descriptor/sezm_nn/edge_cache.pydeepmd/pt/model/descriptor/sezm_nn/embedding.pydeepmd/pt/model/descriptor/sezm_nn/ffn.pydeepmd/pt/model/descriptor/sezm_nn/grid_net.pydeepmd/pt/model/descriptor/sezm_nn/indexing.pydeepmd/pt/model/descriptor/sezm_nn/projection.pydeepmd/pt/model/descriptor/sezm_nn/so2.pydeepmd/pt/model/descriptor/sezm_nn/triton/__init__.pydeepmd/pt/model/descriptor/sezm_nn/triton/autograd.pydeepmd/pt/model/descriptor/sezm_nn/triton/constants.pydeepmd/pt/model/descriptor/sezm_nn/triton/custom_ops.pydeepmd/pt/model/descriptor/sezm_nn/triton/dispatch.pydeepmd/pt/model/descriptor/sezm_nn/triton/kernels_edge_geometry_rbf.pydeepmd/pt/model/descriptor/sezm_nn/triton/kernels_generic.pydeepmd/pt/model/descriptor/sezm_nn/triton/kernels_small.pydeepmd/pt/model/descriptor/sezm_nn/triton/so2_rotation.pydeepmd/pt/model/descriptor/sezm_nn/wignerd.pydeepmd/pt/model/model/__init__.pydeepmd/pt/model/model/sezm_model.pydeepmd/pt/train/training.pydeepmd/pt/utils/nv_nlist.pydeepmd/utils/argcheck.pydoc/model/dpa4.mdexamples/water/dpa4/README.mdexamples/water/dpa4/input-spin.jsonexamples/water/dpa4/input-zbl.jsonexamples/water/dpa4/input.jsonexamples/water/dpa4/input_dens.jsonexamples/water/dpa4/input_multitask.jsonexamples/water/dpa4/lora_ft.jsonsource/tests/pt/model/test_descriptor_sezm.pysource/tests/pt/model/test_descriptor_sezm_grid_projection.pysource/tests/pt/model/test_descriptor_sezm_s2_equivariance.pysource/tests/pt/model/test_descriptor_sezm_triton.pysource/tests/pt/model/test_nlist_backend.pysource/tests/pt/model/test_nv_nlist.pysource/tests/pt/model/test_sezm_model.py
💤 Files with no reviewable changes (8)
- deepmd/pt/model/descriptor/sezm_nn/triton/constants.py
- deepmd/pt/model/descriptor/sezm_nn/triton/kernels_small.py
- deepmd/pt/model/descriptor/sezm_nn/triton/dispatch.py
- deepmd/pt/model/descriptor/sezm_nn/triton/kernels_generic.py
- deepmd/pt/model/descriptor/sezm_nn/triton/kernels_edge_geometry_rbf.py
- source/tests/pt/model/test_descriptor_sezm_s2_equivariance.py
- deepmd/pt/model/descriptor/sezm_nn/triton/autograd.py
- deepmd/pt/model/descriptor/sezm_nn/triton/custom_ops.py
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #5503 +/- ##
==========================================
+ Coverage 81.41% 81.46% +0.04%
==========================================
Files 871 870 -1
Lines 96952 97617 +665
Branches 4242 4242
==========================================
+ Hits 78938 79519 +581
- Misses 16711 16796 +85
+ Partials 1303 1302 -1 ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/pt/model/descriptor/sezm_nn/triton/so2_rotation.py`:
- Around line 1192-1201: The check in _canonical_block_index_values currently
returns True for FakeTensor (fake_mode) and meta tensors, allowing
_block_layout_lmax to pick the block-diagonal/triton path during torch.compile
tracing even when the real coeff_index is noncanonical; change the logic so
fake/meta tensors do not unconditionally return True — instead detect fake/meta
and conservatively return False (or defer decision) so the non-block codepath is
used during compile-time tracing; update references in
_canonical_block_index_values, _block_layout_lmax and any callsites relying on
coeff_index.fake_mode or coeff_index.device.type == "meta" to ensure the block
kernels are only chosen when coeff_index is explicitly canonical (e.g., by
actually comparing to build_m_major_index for real tensors).
In `@deepmd/pt/model/descriptor/sezm.py`:
- Line 1687: The generator expression uses a single-letter variable name `l`
which triggers Ruff E741; update the expression in the condition that reads "if
any(m > l for m, l in zip(self.m_schedule, self.l_schedule, strict=True))" to
use a clearer binding (e.g. `l_val` or `sched_l`) instead of `l`, and update the
comparison accordingly so the check uses `m > l_val` while keeping the same
zip(self.m_schedule, self.l_schedule, strict=True) semantics.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 7213fdd3-1a94-4027-8ff2-372f1789b514
📒 Files selected for processing (11)
deepmd/pt/infer/deep_eval.pydeepmd/pt/model/descriptor/sezm.pydeepmd/pt/model/descriptor/sezm_nn/block.pydeepmd/pt/model/descriptor/sezm_nn/cute/so2_rotation.pydeepmd/pt/model/descriptor/sezm_nn/edge_cache.pydeepmd/pt/model/descriptor/sezm_nn/so2.pydeepmd/pt/model/descriptor/sezm_nn/triton/so2_rotation.pydeepmd/pt/train/training.pysource/tests/pt/model/test_descriptor_sezm_grid_projection.pysource/tests/pt/model/test_descriptor_sezm_triton.pysource/tests/pt/model/test_nv_nlist.py
🚧 Files skipped from review as they are similar to previous changes (5)
- source/tests/pt/model/test_nv_nlist.py
- source/tests/pt/model/test_descriptor_sezm_grid_projection.py
- deepmd/pt/model/descriptor/sezm_nn/cute/so2_rotation.py
- deepmd/pt/model/descriptor/sezm_nn/so2.py
- deepmd/pt/model/descriptor/sezm_nn/block.py
| def _canonical_block_index_values(coeff_index: Tensor, lmax: int) -> bool: | ||
| """Return whether real ``coeff_index`` values match the m-major mmax=1 layout.""" | ||
| if getattr(coeff_index, "fake_mode", None) is not None: | ||
| return True | ||
| if coeff_index.device.type == "meta": | ||
| return True | ||
| try: | ||
| expected = build_m_major_index(int(lmax), 1, device=coeff_index.device) | ||
| return torch.equal(coeff_index, expected) | ||
| except Exception: # pragma: no cover - exotic tensor subclasses |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
FILE="deepmd/pt/model/descriptor/sezm_nn/triton/so2_rotation.py"
echo "== Lines 1160-1235 in $FILE =="
sed -n '1160,1235p' "$FILE" | cat -n
echo "== Searching for rotate_to_local / rotate_back dispatch =="
rg -n "def (rotate_to_local|rotate_back)\b|rotate_to_local\(|rotate_back\(" "$FILE"
echo "== Show rotate_to_local / rotate_back implementations =="
rg -n "def rotate_to_local\b|def rotate_back\b" -n "$FILE" -S
# Print a window around each definition
python3 - <<'PY'
import re, subprocess, os, sys
file="deepmd/pt/model/descriptor/sezm_nn/triton/so2_rotation.py"
text=open(file,'r',encoding='utf-8').read().splitlines()
for i,line in enumerate(text,1):
if re.match(r'\s*def rotate_to_local\b', line) or re.match(r'\s*def rotate_back\b', line):
start=max(1,i-40); end=min(len(text), i+120)
print("\n--- Window around line", i, "---")
for j in range(start,end+1):
print(f"{j:5d}: {text[j-1]}")
PY
echo "== Inspect _canonical_block_index_values definition fully =="
rg -n "def _canonical_block_index_values\b" "$FILE"
python3 - <<'PY'
file="deepmd/pt/model/descriptor/sezm_nn/triton/so2_rotation.py"
lines=open(file,'r',encoding='utf-8').read().splitlines()
start=None
for i,l in enumerate(lines,1):
if "def _canonical_block_index_values" in l:
start=i; break
assert start
end=min(len(lines), start+120)
for j in range(start,end+1):
print(f"{j:5d}: {lines[j-1]}")
PY
echo "== Find build_m_major_index usage for canonical reconstruction =="
rg -n "build_m_major_index\(" deepmd/pt/model/descriptor/sezm_nn/triton/so2_rotation.py
echo "== Search for fake_mode/meta tensor handling elsewhere in dispatch =="
rg -n "fake_mode|device\.type == \"meta\"|device\.type==\"meta\"" deepmd/pt/model/descriptor/sezm_nn/triton/so2_rotation.py
echo "== Find block-vs-dense dispatch points around custom-op boundary (custom_op / library / Triton) =="
rg -n "custom_op|torch\.library|Triton|triton|kernel" deepmd/pt/model/descriptor/sezm_nn/triton/so2_rotation.py
echo "== Test coverage: locate referenced test file and check compile coverage =="
fd -a "test_descriptor_sezm_triton.py" -t f . || true
for f in $(fd -a "test_descriptor_sezm_triton.py" -t f . || true); do
echo "\n== Contents around tests in $f =="
sed -n '1,260p' "$f" | cat -n
echo "\n== Search for compile / torch.compile usage in $f =="
rg -n "torch\.compile|compile\(" "$f" || true
done
echo "== Ensure ruff guidelines? Just check that file is .py =="
test -f "$FILE" && echo "OK: $FILE exists"Repository: deepmodeling/deepmd-kit
Length of output: 36084
🏁 Script executed:
#!/bin/bash
set -euo pipefail
FILE="deepmd/pt/model/descriptor/sezm_nn/triton/so2_rotation.py"
echo "== Dense/reference: rotate_to_local_reference and rotate_back_reference =="
rg -n "def rotate_to_local_reference|def rotate_back_reference" "$FILE"
python3 - <<'PY'
import re
file="deepmd/pt/model/descriptor/sezm_nn/triton/so2_rotation.py"
lines=open(file,'r',encoding='utf-8').read().splitlines()
# print windows around the two defs
targets=["def rotate_to_local_reference","def rotate_back_reference"]
for t in targets:
for i,l in enumerate(lines,1):
if l.startswith(t):
start=max(1,i-30); end=min(len(lines), i+160)
print("\n--- Window around", t, "line", i, "---")
for j in range(start,end+1):
print(f"{j:5d}: {lines[j-1]}")
break
PY
echo "== Search where coeff_index is actually used in dense (non-block) kernels =="
rg -n "coeff_index" "$FILE" | head -n 80
echo "== Inspect custom_op registrations for fake/meta behavior for dense vs block ops =="
# show around custom_op definitions
sed -n '1420,1585p' "$FILE" | cat -n
echo "== Inspect custom_op forward implementations to see if block ops ignore coeff_index =="
# Show where _rotate_to_local_op and _block_to_local_op are defined
sed -n '1455,1565p' "$FILE" | cat -nRepository: deepmodeling/deepmd-kit
Length of output: 34178
🏁 Script executed:
#!/bin/bash
set -euo pipefail
FILE="deepmd/pt/model/descriptor/sezm_nn/triton/so2_rotation.py"
echo "== Inspect _use_triton helper =="
sed -n '1340,1420p' "$FILE" | cat -n
echo "== Inspect dense _rotate_to_local_impl / _rotate_back_impl use of coeff_index =="
sed -n '1340,1460p' "$FILE" | cat -n
echo "== Inspect block _block_rotate_to_local_impl / _block_rotate_back_impl inputs (no coeff_index) =="
sed -n '1460,1580p' "$FILE" | cat -n
echo "== Inspect inverse/_inverse_index or coeff_index-dependent ops (to ensure dense path actually uses values) =="
rg -n "def _inverse_index|inv_index|index_select|coeff_index\[" "$FILE" | head -n 80Repository: deepmodeling/deepmd-kit
Length of output: 13548
Fix torch.compile fake/meta tracing misdispatch to block kernels for noncanonical coeff_index
In deepmd/pt/model/descriptor/sezm_nn/triton/so2_rotation.py (~1168-1201, ~1614-1687), _canonical_block_index_values() returns True for FakeTensor (fake_mode) and meta tensors, so _block_layout_lmax() can select the block-diagonal path under torch.compile based on shape/dim_full alone; the block custom ops do not take coeff_index and instead rebuild m-major coeff from lmax, so noncanonical same-length inputs can produce incorrect results. Eager tests cover noncanonical dispatch, but compiled noncanonical coverage is missing (only canonical coeff_index is used in test_torch_compile_composability).
🧰 Tools
🪛 Ruff (0.15.15)
[warning] 1201-1201: Do not catch blind exception: Exception
(BLE001)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@deepmd/pt/model/descriptor/sezm_nn/triton/so2_rotation.py` around lines 1192
- 1201, The check in _canonical_block_index_values currently returns True for
FakeTensor (fake_mode) and meta tensors, allowing _block_layout_lmax to pick the
block-diagonal/triton path during torch.compile tracing even when the real
coeff_index is noncanonical; change the logic so fake/meta tensors do not
unconditionally return True — instead detect fake/meta and conservatively return
False (or defer decision) so the non-block codepath is used during compile-time
tracing; update references in _canonical_block_index_values, _block_layout_lmax
and any callsites relying on coeff_index.fake_mode or coeff_index.device.type ==
"meta" to ensure the block kernels are only chosen when coeff_index is
explicitly canonical (e.g., by actually comparing to build_m_major_index for
real tensors).
| if any(x < 0 for x in self.m_schedule): | ||
| raise ValueError("`m_schedule` entries must be non-negative") | ||
| if any(m > l for m, l in zip(self.m_schedule, self.l_schedule)): | ||
| if any(m > l for m, l in zip(self.m_schedule, self.l_schedule, strict=True)): |
There was a problem hiding this comment.
Rename the zipped l binding before Ruff blocks CI.
Line 1687 reintroduces Ruff E741, so this file still fails the required lint step.
Suggested fix
- if any(m > l for m, l in zip(self.m_schedule, self.l_schedule, strict=True)):
+ if any(
+ m > degree
+ for m, degree in zip(self.m_schedule, self.l_schedule, strict=True)
+ ):As per coding guidelines, **/*.py: Install linter and run ruff check . before committing changes or the CI will fail.
🧰 Tools
🪛 Ruff (0.15.15)
[error] 1687-1687: Ambiguous variable name: l
(E741)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@deepmd/pt/model/descriptor/sezm.py` at line 1687, The generator expression
uses a single-letter variable name `l` which triggers Ruff E741; update the
expression in the condition that reads "if any(m > l for m, l in
zip(self.m_schedule, self.l_schedule, strict=True))" to use a clearer binding
(e.g. `l_val` or `sched_l`) instead of `l`, and update the comparison
accordingly so the check uses `m > l_val` while keeping the same
zip(self.m_schedule, self.l_schedule, strict=True) semantics.
Sources: Coding guidelines, Linters/SAST tools
Summary by CodeRabbit
New Features
Performance
Bug Fixes
Documentation