Skip to content

feat(dpa4): multiple updates for DPA4/SeZM#5503

Open
OutisLi wants to merge 4 commits into
deepmodeling:masterfrom
OutisLi:pr/dpa4
Open

feat(dpa4): multiple updates for DPA4/SeZM#5503
OutisLi wants to merge 4 commits into
deepmodeling:masterfrom
OutisLi:pr/dpa4

Conversation

@OutisLi
Copy link
Copy Markdown
Collaborator

@OutisLi OutisLi commented Jun 7, 2026

Summary by CodeRabbit

  • New Features

    • O(N) neighbor-list support (NV toolkit) for large periodic CUDA systems; automatic backend selection.
    • New S2/SO3 grid nonlinearities and projectors with configurable grid branches, kmax, and node-degree scheduling.
    • Optional CuTe/CUDA-accelerated fused rotation kernels for faster SO(2)/Wigner ops.
    • Inference TF32 and compile controls exposed via environment variables.
  • Performance

    • Improved compile/inference lowering and caching; eval-time activation checkpointing to reduce memory.
  • Bug Fixes

    • Fixed neighbor-list backend selection and validation logic.
  • Documentation

    • Expanded DPA4/SeZM docs and inference env var guidance.

OutisLi added 3 commits June 7, 2026 20:23
feat(sezm): add activation checkpoint option

feat(sezm): add cross node tensor product

feat(sezm): add message node tensor product

feat(sezm): seperate node lmax with edge lmax

feat(sezm): add custom kernel for lmax=5-10

feat(sezm): add so3 grid projection

refactor: change default values

feat(sezm): tf32 infer
Copilot AI review requested due to automatic review settings June 7, 2026 12:24
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Note

Copilot was unable to run its full agentic suite in this review.

This PR expands the PyTorch backend’s neighbor-list strategy support (adding an NVIDIA Toolkit-Ops O(N) option) and updates SeZM/DPA4 configuration, docs, and tests to reflect new grid features and inference-time env controls.

Changes:

  • Add NvNeighborList strategy (Toolkit-Ops / nvalchemi-toolkit-ops) plus dedicated unit tests and broaden nlist_backend dispatch/equivalence tests to cover multiple strategies.
  • Extend SeZM/DPA4 descriptor argument schema (e.g., extra_node_l, kmax, ffn_so3_grid, grid branch controls, and validation-time tf32_infer) and wire env-var translation in the PT trainer.
  • Refactor/replace several SeZM grid/nonlinearity components and update example inputs + DPA4 docs accordingly; relax certain numeric tolerances in tests.

Reviewed changes

Copilot reviewed 47 out of 47 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
source/tests/pt/model/test_sezm_model.py Adjust tolerances for compiled/eager comparisons (GPU).
source/tests/pt/model/test_nv_nlist.py New unit tests for NvNeighborList edge topology/geometry equivalence and error behavior.
source/tests/pt/model/test_nlist_backend.py Generalize backend strategy tests (vesin/nv) and improve dispatch coverage.
source/tests/pt/model/test_descriptor_sezm_s2_equivariance.py Removes S2 equivariance tests (file deleted).
source/tests/pt/model/test_descriptor_sezm.py Add tests for new zonal gather and special Wigner blocks; adjust Wigner tests and remove one deserialize regression test.
examples/water/dpa4/lora_ft.json Update example descriptor/optimizer settings to newer compact DPA4 config.
examples/water/dpa4/input_multitask.json Update multitask example config to newer compact DPA4 config.
examples/water/dpa4/input_dens.json Update dens example config to newer compact DPA4 config.
examples/water/dpa4/input.json Update baseline example config and add validating inference flags.
examples/water/dpa4/input-zbl.json Update ZBL example config to newer compact DPA4 config.
examples/water/dpa4/input-spin.json Update spin example config to newer compact DPA4 config.
examples/water/dpa4/README.md Align README text with updated “compact/DPA4-Neo-style” examples.
doc/model/dpa4.md Major DPA4/SeZM documentation expansion and updates, including env var behavior.
deepmd/utils/argcheck.py Add get_argument, extend DPA4/SeZM schema and validation options, adjust defaults/variants.
deepmd/pt/utils/nv_nlist.py New NVIDIA Toolkit-Ops neighbor-list strategy implementation.
deepmd/pt/train/training.py Translate validating.* options into env vars before model construction.
deepmd/pt/model/model/init.py Default descriptor.type to dpa4 in SeZM scaffolds.
deepmd/pt/model/descriptor/sezm_nn/triton/kernels_generic.py Remove old Triton kernels (file deleted).
deepmd/pt/model/descriptor/sezm_nn/triton/kernels_edge_geometry_rbf.py Remove old Triton kernels (file deleted).
deepmd/pt/model/descriptor/sezm_nn/triton/dispatch.py Remove old Triton dispatch (file deleted).
deepmd/pt/model/descriptor/sezm_nn/triton/constants.py Remove old Triton constants/flags (file deleted).
deepmd/pt/model/descriptor/sezm_nn/triton/autograd.py Remove old Triton autograd wrapper/API (file deleted).
deepmd/pt/model/descriptor/sezm_nn/triton/init.py Switch Triton exports to new compile-composable rotation API.
deepmd/pt/model/descriptor/sezm_nn/so2.py Extend SO(2) conv with grid nets/branches + inference-time Triton rotation toggle.
deepmd/pt/model/descriptor/sezm_nn/projection.py New projection utilities (S2 + SO3 projectors and grid resolution helpers).
deepmd/pt/model/descriptor/sezm_nn/indexing.py Add build_gie_zonal_index helper for zonal coupling gather.
deepmd/pt/model/descriptor/sezm_nn/ffn.py Refactor FFN nonlinearity onto new grid-net abstractions and add SO3-grid option.
deepmd/pt/model/descriptor/sezm_nn/embedding.py Use new zonal index helper and allow passing precomputed zonal coupling.
deepmd/pt/model/descriptor/sezm_nn/edge_cache.py Remove geometry/RBF Triton path and simplify edge cache construction.
deepmd/pt/model/descriptor/sezm_nn/cute/init.py New CuTe accelerated SO(2) rotation package exports.
deepmd/pt/model/descriptor/sezm_nn/block.py Add node/message grid branches, node-lmax decoupling, and inference checkpointing hooks.
deepmd/pt/model/descriptor/sezm_nn/activation.py Remove grid-projector/grid-activation implementations (moved/refactored).
deepmd/pt/model/descriptor/sezm_nn/init.py Re-export new grid/projector utilities and indexing helpers.
deepmd/pt/model/atomic_model/sezm_atomic_model.py Use node_l_schedule if present for dens-head reconstruction.
deepmd/pt/infer/deep_eval.py Add nv strategy handling, refactor strategy eval path, and broaden backend validation.
deepmd/pt/entrypoints/freeze_pt2.py Patch Inductor config (triton.max_tiles) during AOTI packaging.
backend/find_pytorch.py Add optional dependency for Toolkit-Ops neighbor list on Python >= 3.11.
Comments suppressed due to low confidence (4)

deepmd/utils/argcheck.py:1

  • The extra_check can raise TypeError for non-bool, non-sequence inputs (e.g., len(1)), and it also accepts any length-3 sequence without verifying element types are bool. Make the check robust by first verifying isinstance(x, list) (or Sequence) and then validating len(x) == 3 and all(isinstance(v, bool) for v in x).
    deepmd/pt/utils/nv_nlist.py:1
  • NvNeighborList is a GPU Toolkit-Ops path, but build() does not enforce CUDA tensors. If called with CPU tensors (or in a CPU-only environment), this will likely fail inside nvalchemiops with a less actionable error. Consider adding an explicit check like coord.device.type == 'cuda' (and similarly for atype/box) and raising a clear ValueError when not on CUDA.
    deepmd/pt/utils/nv_nlist.py:1
  • This builds a flattened valid mask of size total_atoms * max_neighbors and then materializes edge_idx via nonzero, which can be very memory/latency heavy for large systems (the primary target for this strategy). Consider constructing the flattened edge indices without forming the full 2D mask (e.g., using repeat_interleave(num_neighbors) to build dst indices and arange(sum(num_neighbors))-based offsets to build slot indices), so the intermediate scales with the number of real edges rather than the dense capacity.
    source/tests/pt/model/test_descriptor_sezm.py:1
  • A deserialize/backward-compatibility regression test was removed in this area, and the PR also deletes test_descriptor_sezm_s2_equivariance.py entirely. Since this PR introduces substantial new grid/projection behavior and new serialization-relevant config keys (e.g., extra_node_l, grid branches, SO3-grid options), it would be good to retain equivalent coverage: add/restore tests that (1) validate descriptors deserialize correctly when older config fields are missing and (2) assert key equivariance/invariance properties for the new grid-net paths.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread deepmd/pt/model/descriptor/sezm_nn/projection.py
Comment thread deepmd/pt/infer/deep_eval.py
Comment thread deepmd/pt/model/descriptor/sezm_nn/block.py
Comment thread source/tests/pt/model/test_nv_nlist.py Fixed
@OutisLi OutisLi added bug Core CUDA Test CUDA Trigger test CUDA workflow labels Jun 7, 2026
@github-actions github-actions Bot removed the Test CUDA Trigger test CUDA workflow label Jun 7, 2026
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Jun 7, 2026

Review Change Stack

📝 Walkthrough

Walkthrough

Adds SeZM grid/projector/net stack, replaces Triton rotation plumbing with new Triton/CuTe rotation operators, adds NV Toolkit-Ops neighbor-list backend and selection, adjusts compile/export/inference TF32/Triton options, extends descriptor/config/docs, and updates/introduces comprehensive tests.

Changes

SeZM grid, descriptor, and config surface

Layer / File(s) Summary
Descriptor config, defaults, and docs
deepmd/utils/argcheck.py, deepmd/pt/model/model/__init__.py, doc/model/dpa4.md, examples/water/dpa4/*, backend/find_pytorch.py
Adds kmax, extra_node_l, grid_mlp/grid_branch, ffn_so3_grid, and related schema/docs; defaults SeZM descriptor type to dpa4; updates PyPI requirement to include nvalchemi-toolkit-ops with Python version marker.
Training env and trainer wiring
deepmd/pt/train/training.py
Adds scoped env defaults for inference compile/TF32 flags and applies them around model construction.

Descriptor and grid implementation

Layer / File(s) Summary
SeZM descriptor and node-schedule
deepmd/pt/model/descriptor/sezm.py, deepmd/pt/model/atomic_model/sezm_atomic_model.py
Adds extra_node_l, kmax, node-level node_l_schedule, zonal coupling support, and propagates node-schedule into interaction blocks and dens head sizing.
GIE / embedding / indexing
deepmd/pt/model/descriptor/sezm_nn/indexing.py, deepmd/pt/model/descriptor/sezm_nn/embedding.py
Adds build_gie_zonal_index and uses it to precompute/register zonal indexing buffers; embedding forward accepts optional zonal_coupling.
Grid projector and net stack
deepmd/pt/model/descriptor/sezm_nn/projection.py, deepmd/pt/model/descriptor/sezm_nn/grid_net.py, deepmd/pt/model/descriptor/sezm_nn/ffn.py
New BaseGridProjector, S2/SO3 projectors, resolve helpers; BaseGridNet, S2GridNet, SO3GridNet, GridMLP, GridBranch, FrameExpand/Contract; EquivariantFFN wired to grid-net path with kmax/grid_branch options.
Interaction blocks and SO(2) paths
deepmd/pt/model/descriptor/sezm_nn/block.py, deepmd/pt/model/descriptor/sezm_nn/so2.py
Blocks accept node_lmax/kmax, grid-branch flags, node/message/FFN grid toggles; adds eval-time activation checkpoint wrappers; SO2Convolution integrates S2/SO3 grid-product branches and uses rotate_to_local/rotate_back calls when inference rotation is enabled.
Edge cache updates
deepmd/pt/model/descriptor/sezm_nn/edge_cache.py
Removes Triton fused geometry/RBF path; always uses gather-based RBF path; adds edge_quat to cache and plumbing.
Wigner D improvements
deepmd/pt/model/descriptor/sezm_nn/wignerd.py
Adds specialized small-order kernels up to l=10, shifts polynomial path to l>=11, and adds forward_zonal for zonal coupling construction.
Activation tweaks and CuTe entry
deepmd/pt/model/descriptor/sezm_nn/activation.py, deepmd/pt/model/descriptor/sezm_nn/cute/__init__.py
Docstring and SwiGLU gate change; CuTe package entry added.

Rotation kernels, Triton/CuTe, and triton package surface

Layer / File(s) Summary
New Triton rotation and projector ops
deepmd/pt/model/descriptor/sezm_nn/triton/so2_rotation.py, deepmd/pt/model/descriptor/sezm_nn/triton/__init__.py
Adds dense and block-diagonal Triton-backed rotate_to_local / rotate_back implementations, exposes custom ops, and simplifies triton package exports to rotation APIs only.
CuTe rotation kernels
deepmd/pt/model/descriptor/sezm_nn/cute/so2_rotation.py
Adds experimental CuTe fused kernels and custom-op registrations for fast rotate_to_local / rotate_back with autograd support.
Removed legacy Triton autograd kernels
deepmd/pt/model/descriptor/sezm_nn/triton/* (autograd/custom_ops/constants/dispatch/kernels_*)
Removes older Triton autograd wrappers and many Triton kernel modules in favor of the new rotation modules.
SO(2) test updates
source/tests/pt/model/test_descriptor_sezm_triton.py
Rewrites tests to focus on block-diagonal rotation correctness and dispatch behavior.

Neighbor-list backend, model compile/export, and runtime

Layer / File(s) Summary
NV neighbor-list implementation
deepmd/pt/utils/nv_nlist.py
Adds NvNeighborList using nvalchemi Toolkit-Ops, matrix→extended conversion, dynamic capacity growth, and compiled truncation helper; also availability and method selection helpers.
Model compile/export and inference changes
deepmd/pt/model/model/sezm_model.py, deepmd/pt/entrypoints/freeze_pt2.py, deepmd/pt/infer/deep_eval.py
Adds shared compiled-graph cache and structure-key helpers; changes inference lowering to AOT path with Inductor patch (sets triton.max_tiles=1); adds SEZM_NV_NLIST_THRESHOLD and NV path selection; export wraps aoti package call in Inductor config patch.
DeepEval neighbor-list backend selection and eval wiring
deepmd/pt/infer/deep_eval.py
Adds "nv" backend option, unified _nlist_builder selection, availability/validation checks, and unified _eval_lower_strategy O(N) evaluation path.
Freeze/export PT2 adjustment
deepmd/pt/entrypoints/freeze_pt2.py
Patches Inductor config around the AOT inference packaging to set triton.max_tiles=1.
Tests for nlist backends and NV nlist
source/tests/pt/model/test_nlist_backend.py, source/tests/pt/model/test_nv_nlist.py, source/tests/pt/model/test_sezm_model.py
Tests updated/added to validate backend dispatch, equivalence with native dense builder, NV neighbor-list correctness, and compile-vs-eager tolerance adjustments.

Sequence Diagram(s)

sequenceDiagram
  participant DeepEval
  participant SeZMModel
  participant NvNeighborList
  DeepEval->>SeZMModel: _setup_nlist_backend(nlist_backend, auto)
  SeZMModel->>NvNeighborList: NvNeighborList.build(coord, atype, box, rcut, sel)
  DeepEval->>SeZMModel: _eval_lower_strategy(... using _nlist_builder)
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Possibly related PRs

Suggested reviewers

  • wanghan-iapcm
  • iProzd
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 11

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
examples/water/dpa4/input.json (1)

1-124: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Use input_torch.json for this PyTorch example.

This file is a PyTorch training config (model.type: "DPA4"). Keeping it as input.json conflicts with the repo rule for TensorFlow input.json usage and creates command ambiguity for users/docs.

As per coding guidelines, **/input.json: TensorFlow backend training configuration should use input.json format and be invoked with dp train input.json.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/water/dpa4/input.json` around lines 1 - 124, The config is a PyTorch
training config (model.type: "DPA4") but is named as the TensorFlow-style
input.json, causing ambiguity; rename this file to input_torch.json (and update
any references) so PyTorch examples use input_torch.json, ensure any
documentation or scripts that reference this example (e.g., training invocation)
point to input_torch.json instead of input.json, and verify consumers expect a
PyTorch config when parsing model.type == "DPA4".

Source: Coding guidelines

🧹 Nitpick comments (1)
source/tests/pt/model/test_descriptor_sezm_triton.py (1)

76-177: ⚡ Quick win

The new dense rotation path still has no coverage.

Every fixture builds coeff_index = build_m_major_index(lmax, 1, ...), so rotate_to_local / rotate_back always auto-select the block-diagonal kernels. The dense kernels, inverse-index path, and dense registered backwards introduced in so2_rotation.py can regress without any test failing. Please add at least one non-mmax==1 case or direct *_dense assertions.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@source/tests/pt/model/test_descriptor_sezm_triton.py` around lines 76 - 177,
Tests always build coeff_index via build_m_major_index(lmax, 1, ...) so
rotate_to_local and rotate_back always pick the block-diagonal (mmax==1) path;
add a test case that constructs a coeff_index with mmax>1 (or directly call the
*_dense variants) so the dense/inverse-index code paths and their registered
backwards in so2_rotation.py are exercised; update
test_rotate_to_local_matches_reference and test_rotate_back_matches_reference
(or add a new subTest) to call build_m_major_index(lmax, mmax>1, ...) or to
invoke rotate_to_local_dense / rotate_back_dense equivalents and assert outputs
and gradients match the *_reference functions (and still apply the mask checks
for dense behavior).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/pt/infer/deep_eval.py`:
- Around line 309-343: The nvalchemi ('nv') backend selection in the
_setup_nlist_backend logic (symbols: nlist_backend, builder, NvNeighborList,
VesinNeighborList, is_nv_available, is_vesin_torch_available) must be gated on
the actual inference device (DEVICE from deepmd/pt/utils/env.py) rather than
host CUDA availability; update both the explicit "nv" branch and the "auto"
branch to check DEVICE.type == "cuda" (or the device of the input tensors)
before choosing NvNeighborList and, if DEVICE.type != "cuda" but nlist_backend
== "nv" was requested, raise a clear ValueError explaining that nv requires
CUDA-device inference; keep the existing checks for vesin (using
is_vesin_torch_available) but ensure vesin selection also verifies the
input/DEVICE compatibility if needed.

In `@deepmd/pt/model/descriptor/sezm_nn/__init__.py`:
- Around line 45-51: The module removed package-level re-exports (notably
SwiGLUS2Activation) which breaks downstream imports; restore a deprecated
re-export shim in __init__.py that re-exports the original symbol name (e.g.,
SwiGLUS2Activation) while importing/forwarding to the new implementation
location or wrapper and emit a DeprecationWarning when the shim is used; apply
the same pattern for the other removed symbols referenced in the file (the
exports around the later ranges) so callers keep compatibility for one release
before fully removing them.

In `@deepmd/pt/model/descriptor/sezm_nn/cute/so2_rotation.py`:
- Around line 833-839: The CuTe fast-path guard (_cute_usable) must also verify
kernel channel-width constraints: reject CuTe when channel count C would produce
invalid launch geometry or tails; update _cute_usable to compute the channel
dimension C (from the relevant tensor shape used by so2_rotation) and return
False unless C >= _TN and C % _TN == 0 (i.e., divisible by the kernel tile size
_TN), in addition to the existing checks (SEZM_CUTE_AVAILABLE, is_cuda, dtype);
apply the same additional gating logic to the other CuTe guard sites referenced
(the checks around lines 875-879 and 912-916) so they all fall back to the eager
path when channel widths are unsupported.

In `@deepmd/pt/model/descriptor/sezm_nn/so2.py`:
- Around line 1327-1343: The S2GridNet for the edge-local S2 branch
(instantiated as node_wise_grid_product) is using the default "packed"
coefficient layout while x_local and x_dst_local are reduced m-major tensors;
change the S2GridNet constructor to pass layout="m_major" so the projector
matches the m-major ordering; also update the other analogous S2GridNet
instantiation used for the edge-local branch later in the file (the second
S2GridNet for the same local branch) to use layout="m_major" as well.

In `@deepmd/pt/model/descriptor/sezm_nn/triton/so2_rotation.py`:
- Around line 1168-1191: The dispatch helper _block_layout_lmax currently only
checks shapes and may misidentify non-canonical reduced layouts as the m-major
mmax=1 block layout; update _block_layout_lmax to also validate that coeff_index
contains the canonical ordering produced by build_m_major_index(lmax, 1) (i.e.,
compare coeff_index values to the expected index tensor using value-safe
operations) and return -1 if it does not match so the code falls back to the
dense path or raises; apply the same value-validation wherever you currently use
this shape-only detection (the rotate_to_local / rotate_back block dispatch and
the explicit *_block helpers) so block kernels are only used for truly canonical
layouts.
- Line 649: The loop variable named `l` used in the `for ... in
tl.static_range(0, LMAX + 1):` constructs triggers Ruff E741; rename the
lone-letter variable to a non-ambiguous identifier (e.g. `ell` or `l_idx`) in
every occurrence in this module (all loops using `tl.static_range(0, LMAX + 1)`
including the other three occurrences flagged) and update any inner references
accordingly so the code compiles and passes `ruff check .`.

In `@deepmd/pt/model/descriptor/sezm.py`:
- Around line 1491-1524: The concatenation in _build_gie_zonal_coupling mixes
frames because mp_coupling comes from edge_cache.Dt_full (which was built with
training-time random-γ) while extra_coupling is recomputed from raw edge_vec
without that γ; fix by reusing the same augmented quaternion/frame used for
Dt_full when computing extra_coupling. Concretely, obtain the edge
quaternion/frame that matches Dt_full (e.g. edge_cache.edge_quat or
edge_cache.random_gamma-applied edge_vec) instead of recomputing from raw
edge_vec, then call gie_zonal_wigner_calc.forward_zonal with that quaternion so
extra_coupling and mp_coupling share the same local frame before concatenation;
update references in _build_gie_zonal_coupling, mp_coupling, extra_coupling,
build_edge_quaternion and gie_zonal_wigner_calc.forward_zonal accordingly.
- Line 624: The list comprehension assigning self.ebed_dims uses the ambiguous
one-letter variable name "l" which triggers Ruff E741; update the comprehension
in the expression self.ebed_dims = [get_so3_dim_of_lmax(l) for l in
self.l_schedule] to use a clear, unambiguous variable name (e.g., "ell" or
"l_val") and update the call to get_so3_dim_of_lmax accordingly so it becomes
self.ebed_dims = [get_so3_dim_of_lmax(ell) for ell in self.l_schedule], ensuring
only the variable name is changed and no other behavior is altered.
- Around line 619-623: The kmax validation currently compares to the raw
constructor arg int(lmax) instead of the resolved schedule max; update the check
so self.kmax is validated against the resolved effective lmax (self.lmax) after
l_schedule is applied — i.e., ensure l_schedule resolution runs before
validating self.kmax and replace the comparison using int(lmax) with a
comparison to self.lmax (or int(self.lmax) if needed) in the initialization code
around self.kmax and self.lmax.
- Around line 846-848: The zip calls over the schedules should enforce equal
lengths: update the loop zip(self.l_schedule, self.node_l_schedule,
self.m_schedule) inside the SEZM descriptor and the zip(self.m_schedule,
self.l_schedule) used in the any(m > l for m, l in zip(...)) check to include
strict=True (e.g., zip(..., strict=True)) so mismatched schedule lengths raise
immediately; locate these by referencing the attributes self.l_schedule,
self.node_l_schedule, self.m_schedule and the any(...) check and add the
strict=True argument to each zip call.

In `@deepmd/pt/train/training.py`:
- Around line 184-192: The code uses os.environ.setdefault for
DP_COMPILE_INFER/DP_TF32_INFER (based on validating_params) which makes these
flags process-global and sticky; change this to temporarily set the env vars
only while constructing the SeZMModel instance and then restore the previous
environment values immediately after construction so other Trainer instances are
not affected—specifically, around the SeZMModel creation in training.py capture
os.environ.get for DP_COMPILE_INFER and DP_TF32_INFER, set/override them
according to validating_params, instantiate SeZMModel, and finally restore the
saved values (or delete if previously absent).

---

Outside diff comments:
In `@examples/water/dpa4/input.json`:
- Around line 1-124: The config is a PyTorch training config (model.type:
"DPA4") but is named as the TensorFlow-style input.json, causing ambiguity;
rename this file to input_torch.json (and update any references) so PyTorch
examples use input_torch.json, ensure any documentation or scripts that
reference this example (e.g., training invocation) point to input_torch.json
instead of input.json, and verify consumers expect a PyTorch config when parsing
model.type == "DPA4".

---

Nitpick comments:
In `@source/tests/pt/model/test_descriptor_sezm_triton.py`:
- Around line 76-177: Tests always build coeff_index via
build_m_major_index(lmax, 1, ...) so rotate_to_local and rotate_back always pick
the block-diagonal (mmax==1) path; add a test case that constructs a coeff_index
with mmax>1 (or directly call the *_dense variants) so the dense/inverse-index
code paths and their registered backwards in so2_rotation.py are exercised;
update test_rotate_to_local_matches_reference and
test_rotate_back_matches_reference (or add a new subTest) to call
build_m_major_index(lmax, mmax>1, ...) or to invoke rotate_to_local_dense /
rotate_back_dense equivalents and assert outputs and gradients match the
*_reference functions (and still apply the mask checks for dense behavior).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 4fc847a9-e9f4-477b-a673-749e1a3a07e5

📥 Commits

Reviewing files that changed from the base of the PR and between 99c1ece and 905fe7f.

📒 Files selected for processing (47)
  • backend/find_pytorch.py
  • deepmd/pt/entrypoints/freeze_pt2.py
  • deepmd/pt/infer/deep_eval.py
  • deepmd/pt/model/atomic_model/sezm_atomic_model.py
  • deepmd/pt/model/descriptor/sezm.py
  • deepmd/pt/model/descriptor/sezm_nn/__init__.py
  • deepmd/pt/model/descriptor/sezm_nn/activation.py
  • deepmd/pt/model/descriptor/sezm_nn/block.py
  • deepmd/pt/model/descriptor/sezm_nn/cute/__init__.py
  • deepmd/pt/model/descriptor/sezm_nn/cute/so2_rotation.py
  • deepmd/pt/model/descriptor/sezm_nn/edge_cache.py
  • deepmd/pt/model/descriptor/sezm_nn/embedding.py
  • deepmd/pt/model/descriptor/sezm_nn/ffn.py
  • deepmd/pt/model/descriptor/sezm_nn/grid_net.py
  • deepmd/pt/model/descriptor/sezm_nn/indexing.py
  • deepmd/pt/model/descriptor/sezm_nn/projection.py
  • deepmd/pt/model/descriptor/sezm_nn/so2.py
  • deepmd/pt/model/descriptor/sezm_nn/triton/__init__.py
  • deepmd/pt/model/descriptor/sezm_nn/triton/autograd.py
  • deepmd/pt/model/descriptor/sezm_nn/triton/constants.py
  • deepmd/pt/model/descriptor/sezm_nn/triton/custom_ops.py
  • deepmd/pt/model/descriptor/sezm_nn/triton/dispatch.py
  • deepmd/pt/model/descriptor/sezm_nn/triton/kernels_edge_geometry_rbf.py
  • deepmd/pt/model/descriptor/sezm_nn/triton/kernels_generic.py
  • deepmd/pt/model/descriptor/sezm_nn/triton/kernels_small.py
  • deepmd/pt/model/descriptor/sezm_nn/triton/so2_rotation.py
  • deepmd/pt/model/descriptor/sezm_nn/wignerd.py
  • deepmd/pt/model/model/__init__.py
  • deepmd/pt/model/model/sezm_model.py
  • deepmd/pt/train/training.py
  • deepmd/pt/utils/nv_nlist.py
  • deepmd/utils/argcheck.py
  • doc/model/dpa4.md
  • examples/water/dpa4/README.md
  • examples/water/dpa4/input-spin.json
  • examples/water/dpa4/input-zbl.json
  • examples/water/dpa4/input.json
  • examples/water/dpa4/input_dens.json
  • examples/water/dpa4/input_multitask.json
  • examples/water/dpa4/lora_ft.json
  • source/tests/pt/model/test_descriptor_sezm.py
  • source/tests/pt/model/test_descriptor_sezm_grid_projection.py
  • source/tests/pt/model/test_descriptor_sezm_s2_equivariance.py
  • source/tests/pt/model/test_descriptor_sezm_triton.py
  • source/tests/pt/model/test_nlist_backend.py
  • source/tests/pt/model/test_nv_nlist.py
  • source/tests/pt/model/test_sezm_model.py
💤 Files with no reviewable changes (8)
  • deepmd/pt/model/descriptor/sezm_nn/triton/constants.py
  • deepmd/pt/model/descriptor/sezm_nn/triton/kernels_small.py
  • deepmd/pt/model/descriptor/sezm_nn/triton/dispatch.py
  • deepmd/pt/model/descriptor/sezm_nn/triton/kernels_generic.py
  • deepmd/pt/model/descriptor/sezm_nn/triton/kernels_edge_geometry_rbf.py
  • source/tests/pt/model/test_descriptor_sezm_s2_equivariance.py
  • deepmd/pt/model/descriptor/sezm_nn/triton/autograd.py
  • deepmd/pt/model/descriptor/sezm_nn/triton/custom_ops.py

Comment thread deepmd/pt/infer/deep_eval.py
Comment thread deepmd/pt/model/descriptor/sezm_nn/__init__.py
Comment thread deepmd/pt/model/descriptor/sezm_nn/cute/so2_rotation.py Outdated
Comment thread deepmd/pt/model/descriptor/sezm_nn/so2.py
Comment thread deepmd/pt/model/descriptor/sezm_nn/triton/so2_rotation.py
Comment thread deepmd/pt/model/descriptor/sezm.py
Comment thread deepmd/pt/model/descriptor/sezm.py
Comment thread deepmd/pt/model/descriptor/sezm.py
Comment thread deepmd/pt/model/descriptor/sezm.py
Comment thread deepmd/pt/train/training.py Outdated
@codecov
Copy link
Copy Markdown

codecov Bot commented Jun 7, 2026

Codecov Report

❌ Patch coverage is 48.71442% with 1117 lines in your changes missing coverage. Please review.
✅ Project coverage is 81.46%. Comparing base (99c1ece) to head (7c54170).

Files with missing lines Patch % Lines
...d/pt/model/descriptor/sezm_nn/cute/so2_rotation.py 0.00% 450 Missing ⚠️
...pt/model/descriptor/sezm_nn/triton/so2_rotation.py 22.82% 382 Missing ⚠️
deepmd/pt/utils/nv_nlist.py 16.66% 100 Missing ⚠️
deepmd/pt/model/descriptor/sezm_nn/grid_net.py 79.01% 51 Missing ⚠️
deepmd/pt/model/descriptor/sezm_nn/projection.py 82.25% 44 Missing ⚠️
deepmd/pt/model/descriptor/sezm_nn/wignerd.py 89.79% 20 Missing ⚠️
deepmd/pt/model/model/sezm_model.py 84.80% 19 Missing ⚠️
deepmd/pt/model/descriptor/sezm_nn/so2.py 74.00% 13 Missing ⚠️
deepmd/pt/infer/deep_eval.py 72.41% 8 Missing ⚠️
deepmd/pt/model/descriptor/sezm.py 87.69% 8 Missing ⚠️
... and 6 more
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5503      +/-   ##
==========================================
+ Coverage   81.41%   81.46%   +0.04%     
==========================================
  Files         871      870       -1     
  Lines       96952    97617     +665     
  Branches     4242     4242              
==========================================
+ Hits        78938    79519     +581     
- Misses      16711    16796      +85     
+ Partials     1303     1302       -1     

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

@OutisLi OutisLi added the Test CUDA Trigger test CUDA workflow label Jun 7, 2026
@github-actions github-actions Bot removed the Test CUDA Trigger test CUDA workflow label Jun 7, 2026
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/pt/model/descriptor/sezm_nn/triton/so2_rotation.py`:
- Around line 1192-1201: The check in _canonical_block_index_values currently
returns True for FakeTensor (fake_mode) and meta tensors, allowing
_block_layout_lmax to pick the block-diagonal/triton path during torch.compile
tracing even when the real coeff_index is noncanonical; change the logic so
fake/meta tensors do not unconditionally return True — instead detect fake/meta
and conservatively return False (or defer decision) so the non-block codepath is
used during compile-time tracing; update references in
_canonical_block_index_values, _block_layout_lmax and any callsites relying on
coeff_index.fake_mode or coeff_index.device.type == "meta" to ensure the block
kernels are only chosen when coeff_index is explicitly canonical (e.g., by
actually comparing to build_m_major_index for real tensors).

In `@deepmd/pt/model/descriptor/sezm.py`:
- Line 1687: The generator expression uses a single-letter variable name `l`
which triggers Ruff E741; update the expression in the condition that reads "if
any(m > l for m, l in zip(self.m_schedule, self.l_schedule, strict=True))" to
use a clearer binding (e.g. `l_val` or `sched_l`) instead of `l`, and update the
comparison accordingly so the check uses `m > l_val` while keeping the same
zip(self.m_schedule, self.l_schedule, strict=True) semantics.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 7213fdd3-1a94-4027-8ff2-372f1789b514

📥 Commits

Reviewing files that changed from the base of the PR and between 905fe7f and 7c54170.

📒 Files selected for processing (11)
  • deepmd/pt/infer/deep_eval.py
  • deepmd/pt/model/descriptor/sezm.py
  • deepmd/pt/model/descriptor/sezm_nn/block.py
  • deepmd/pt/model/descriptor/sezm_nn/cute/so2_rotation.py
  • deepmd/pt/model/descriptor/sezm_nn/edge_cache.py
  • deepmd/pt/model/descriptor/sezm_nn/so2.py
  • deepmd/pt/model/descriptor/sezm_nn/triton/so2_rotation.py
  • deepmd/pt/train/training.py
  • source/tests/pt/model/test_descriptor_sezm_grid_projection.py
  • source/tests/pt/model/test_descriptor_sezm_triton.py
  • source/tests/pt/model/test_nv_nlist.py
🚧 Files skipped from review as they are similar to previous changes (5)
  • source/tests/pt/model/test_nv_nlist.py
  • source/tests/pt/model/test_descriptor_sezm_grid_projection.py
  • deepmd/pt/model/descriptor/sezm_nn/cute/so2_rotation.py
  • deepmd/pt/model/descriptor/sezm_nn/so2.py
  • deepmd/pt/model/descriptor/sezm_nn/block.py

Comment on lines +1192 to +1201
def _canonical_block_index_values(coeff_index: Tensor, lmax: int) -> bool:
"""Return whether real ``coeff_index`` values match the m-major mmax=1 layout."""
if getattr(coeff_index, "fake_mode", None) is not None:
return True
if coeff_index.device.type == "meta":
return True
try:
expected = build_m_major_index(int(lmax), 1, device=coeff_index.device)
return torch.equal(coeff_index, expected)
except Exception: # pragma: no cover - exotic tensor subclasses
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

FILE="deepmd/pt/model/descriptor/sezm_nn/triton/so2_rotation.py"

echo "== Lines 1160-1235 in $FILE =="
sed -n '1160,1235p' "$FILE" | cat -n

echo "== Searching for rotate_to_local / rotate_back dispatch =="
rg -n "def (rotate_to_local|rotate_back)\b|rotate_to_local\(|rotate_back\(" "$FILE"

echo "== Show rotate_to_local / rotate_back implementations =="
rg -n "def rotate_to_local\b|def rotate_back\b" -n "$FILE" -S
# Print a window around each definition
python3 - <<'PY'
import re, subprocess, os, sys
file="deepmd/pt/model/descriptor/sezm_nn/triton/so2_rotation.py"
text=open(file,'r',encoding='utf-8').read().splitlines()
for i,line in enumerate(text,1):
    if re.match(r'\s*def rotate_to_local\b', line) or re.match(r'\s*def rotate_back\b', line):
        start=max(1,i-40); end=min(len(text), i+120)
        print("\n--- Window around line", i, "---")
        for j in range(start,end+1):
            print(f"{j:5d}: {text[j-1]}")
PY

echo "== Inspect _canonical_block_index_values definition fully =="
rg -n "def _canonical_block_index_values\b" "$FILE"
python3 - <<'PY'
file="deepmd/pt/model/descriptor/sezm_nn/triton/so2_rotation.py"
lines=open(file,'r',encoding='utf-8').read().splitlines()
start=None
for i,l in enumerate(lines,1):
    if "def _canonical_block_index_values" in l:
        start=i; break
assert start
end=min(len(lines), start+120)
for j in range(start,end+1):
    print(f"{j:5d}: {lines[j-1]}")
PY

echo "== Find build_m_major_index usage for canonical reconstruction =="
rg -n "build_m_major_index\(" deepmd/pt/model/descriptor/sezm_nn/triton/so2_rotation.py

echo "== Search for fake_mode/meta tensor handling elsewhere in dispatch =="
rg -n "fake_mode|device\.type == \"meta\"|device\.type==\"meta\"" deepmd/pt/model/descriptor/sezm_nn/triton/so2_rotation.py

echo "== Find block-vs-dense dispatch points around custom-op boundary (custom_op / library / Triton) =="
rg -n "custom_op|torch\.library|Triton|triton|kernel" deepmd/pt/model/descriptor/sezm_nn/triton/so2_rotation.py

echo "== Test coverage: locate referenced test file and check compile coverage =="
fd -a "test_descriptor_sezm_triton.py" -t f . || true
for f in $(fd -a "test_descriptor_sezm_triton.py" -t f . || true); do
  echo "\n== Contents around tests in $f =="
  sed -n '1,260p' "$f" | cat -n
  echo "\n== Search for compile / torch.compile usage in $f =="
  rg -n "torch\.compile|compile\(" "$f" || true
done

echo "== Ensure ruff guidelines? Just check that file is .py =="
test -f "$FILE" && echo "OK: $FILE exists"

Repository: deepmodeling/deepmd-kit

Length of output: 36084


🏁 Script executed:

#!/bin/bash
set -euo pipefail

FILE="deepmd/pt/model/descriptor/sezm_nn/triton/so2_rotation.py"

echo "== Dense/reference: rotate_to_local_reference and rotate_back_reference =="
rg -n "def rotate_to_local_reference|def rotate_back_reference" "$FILE"
python3 - <<'PY'
import re
file="deepmd/pt/model/descriptor/sezm_nn/triton/so2_rotation.py"
lines=open(file,'r',encoding='utf-8').read().splitlines()
# print windows around the two defs
targets=["def rotate_to_local_reference","def rotate_back_reference"]
for t in targets:
    for i,l in enumerate(lines,1):
        if l.startswith(t):
            start=max(1,i-30); end=min(len(lines), i+160)
            print("\n--- Window around", t, "line", i, "---")
            for j in range(start,end+1):
                print(f"{j:5d}: {lines[j-1]}")
            break
PY

echo "== Search where coeff_index is actually used in dense (non-block) kernels =="
rg -n "coeff_index" "$FILE" | head -n 80

echo "== Inspect custom_op registrations for fake/meta behavior for dense vs block ops =="
# show around custom_op definitions
sed -n '1420,1585p' "$FILE" | cat -n

echo "== Inspect custom_op forward implementations to see if block ops ignore coeff_index =="
# Show where _rotate_to_local_op and _block_to_local_op are defined
sed -n '1455,1565p' "$FILE" | cat -n

Repository: deepmodeling/deepmd-kit

Length of output: 34178


🏁 Script executed:

#!/bin/bash
set -euo pipefail
FILE="deepmd/pt/model/descriptor/sezm_nn/triton/so2_rotation.py"

echo "== Inspect _use_triton helper =="
sed -n '1340,1420p' "$FILE" | cat -n

echo "== Inspect dense _rotate_to_local_impl / _rotate_back_impl use of coeff_index =="
sed -n '1340,1460p' "$FILE" | cat -n

echo "== Inspect block _block_rotate_to_local_impl / _block_rotate_back_impl inputs (no coeff_index) =="
sed -n '1460,1580p' "$FILE" | cat -n

echo "== Inspect inverse/_inverse_index or coeff_index-dependent ops (to ensure dense path actually uses values) =="
rg -n "def _inverse_index|inv_index|index_select|coeff_index\[" "$FILE" | head -n 80

Repository: deepmodeling/deepmd-kit

Length of output: 13548


Fix torch.compile fake/meta tracing misdispatch to block kernels for noncanonical coeff_index

In deepmd/pt/model/descriptor/sezm_nn/triton/so2_rotation.py (~1168-1201, ~1614-1687), _canonical_block_index_values() returns True for FakeTensor (fake_mode) and meta tensors, so _block_layout_lmax() can select the block-diagonal path under torch.compile based on shape/dim_full alone; the block custom ops do not take coeff_index and instead rebuild m-major coeff from lmax, so noncanonical same-length inputs can produce incorrect results. Eager tests cover noncanonical dispatch, but compiled noncanonical coverage is missing (only canonical coeff_index is used in test_torch_compile_composability).

🧰 Tools
🪛 Ruff (0.15.15)

[warning] 1201-1201: Do not catch blind exception: Exception

(BLE001)

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt/model/descriptor/sezm_nn/triton/so2_rotation.py` around lines 1192
- 1201, The check in _canonical_block_index_values currently returns True for
FakeTensor (fake_mode) and meta tensors, allowing _block_layout_lmax to pick the
block-diagonal/triton path during torch.compile tracing even when the real
coeff_index is noncanonical; change the logic so fake/meta tensors do not
unconditionally return True — instead detect fake/meta and conservatively return
False (or defer decision) so the non-block codepath is used during compile-time
tracing; update references in _canonical_block_index_values, _block_layout_lmax
and any callsites relying on coeff_index.fake_mode or coeff_index.device.type ==
"meta" to ensure the block kernels are only chosen when coeff_index is
explicitly canonical (e.g., by actually comparing to build_m_major_index for
real tensors).

if any(x < 0 for x in self.m_schedule):
raise ValueError("`m_schedule` entries must be non-negative")
if any(m > l for m, l in zip(self.m_schedule, self.l_schedule)):
if any(m > l for m, l in zip(self.m_schedule, self.l_schedule, strict=True)):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Rename the zipped l binding before Ruff blocks CI.

Line 1687 reintroduces Ruff E741, so this file still fails the required lint step.

Suggested fix
-        if any(m > l for m, l in zip(self.m_schedule, self.l_schedule, strict=True)):
+        if any(
+            m > degree
+            for m, degree in zip(self.m_schedule, self.l_schedule, strict=True)
+        ):

As per coding guidelines, **/*.py: Install linter and run ruff check . before committing changes or the CI will fail.

🧰 Tools
🪛 Ruff (0.15.15)

[error] 1687-1687: Ambiguous variable name: l

(E741)

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt/model/descriptor/sezm.py` at line 1687, The generator expression
uses a single-letter variable name `l` which triggers Ruff E741; update the
expression in the condition that reads "if any(m > l for m, l in
zip(self.m_schedule, self.l_schedule, strict=True))" to use a clearer binding
(e.g. `l_val` or `sched_l`) instead of `l`, and update the comparison
accordingly so the check uses `m > l_val` while keeping the same
zip(self.m_schedule, self.l_schedule, strict=True) semantics.

Sources: Coding guidelines, Linters/SAST tools

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants