Arm backend: Cleanup dim-order and permute handling#19278
Arm backend: Cleanup dim-order and permute handling#19278AdrianLundell merged 1 commit intopytorch:mainfrom
Conversation
- Replace u55 permute dimension check with a u55-only pass decomposing large permutes. This pass checks for support by compiling targeted permutes using Vela to ensure alignment between Executorch and Vela. - Remove passes and testing not required anymore after dim-order update. - Remove all outdated mention of dim-order in the arm backend. Signed-off-by: Adrian Lundell <adrian.lundell@arm.com> Change-Id: I098db4539179cb223b5c76683720e68c1bbecb8f
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19278
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ You can merge normally! (3 Unrelated Failures)As of commit a606e58 with merge base 69989b7 ( BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
No buck2 changes required according to codex |
There was a problem hiding this comment.
Pull request overview
This PR removes the old Arm backend dim-order/memory-format machinery and shifts permute handling toward explicit graph rewrites, including a new U55-specific large-permute decomposition pass. It mainly refactors TOSA lowering/serialization paths, operator support checks, and the associated Arm backend tests.
Changes:
- Replace dim-order-based TOSA shape/constant handling with direct shape normalization and remove the old transpose/memory-format infrastructure.
- Add
DecomposePermuteForU55Passand update U55 permute/view/select expectations and tests around the new behavior. - Delete legacy passes/operators/tests that were only needed for the previous dim-order approach.
Reviewed changes
Copilot reviewed 33 out of 33 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
backends/arm/tosa/utils.py |
Renames shape helper to normalize_symint. |
backends/arm/tosa/mapping.py |
Drops dim-order from extracted tensor metadata and TosaArg. |
backends/arm/tosa/dialect/ops/transpose.py |
Deletes fake TOSA transpose dialect op. |
backends/arm/tosa/dialect/__init__.py |
Removes transpose dialect registration. |
backends/arm/test/tester/arm_tester.py |
Updates tester helper for new extract_tensor_meta signature. |
backends/arm/test/passes/test_to_tosa_memory_format.py |
Deletes tests for removed memory-format pass. |
backends/arm/test/passes/test_decompose_int16_activation_conv_pass.py |
Deletes tests for removed int16 conv decomposition pass. |
backends/arm/test/ops/test_view.py |
Renames test data sets and removes old U55 non-delegation cases. |
backends/arm/test/ops/test_select.py |
Removes old U55 non-delegation coverage. |
backends/arm/test/ops/test_permute.py |
Expands U55 permute coverage for large-shape decomposition path. |
backends/arm/test/misc/test_transpose_counts.py |
Updates expected transpose count for grouped conv channels-last case. |
backends/arm/process_node.py |
Removes dim-order-aware tensor serialization and uses normalized shapes directly. |
backends/arm/operators/op_while.py |
Stops consulting output dim-order when creating dummy loop outputs. |
backends/arm/operators/op_tosa_transpose.py |
Deletes backend visitor for removed TOSA transpose op. |
backends/arm/operators/op_tosa_shapes.py |
Serializes shape constants without dim-order remapping. |
backends/arm/operators/op_sum.py |
Uses raw reduction axis directly. |
backends/arm/operators/op_permute.py |
Removes dim-order permutation remapping logic. |
backends/arm/operators/op_cat.py |
Uses raw concat dimension directly. |
backends/arm/operators/op_any.py |
Uses raw reduction axis directly. |
backends/arm/operators/op_amin.py |
Uses raw reduction axis directly. |
backends/arm/operators/op_amax.py |
Uses raw reduction axis directly. |
backends/arm/operators/__init__.py |
Unregisters removed transpose visitor. |
backends/arm/operator_support/tosa_supported_operators.py |
Removes old U55 transpose/view support checks from factory. |
backends/arm/operator_support/ethos_u55_support.py |
Deletes legacy U55 view/permute support-check implementations. |
backends/arm/operator_support/convolution_support.py |
Simplifies transpose-conv U55 shape handling. |
backends/arm/_passes/to_tosa_memory_format_pass.py |
Deletes old dim-order/memory-format pass. |
backends/arm/_passes/insert_data_layout_casts_pass.py |
Removes dependency on deleted backend transpose op. |
backends/arm/_passes/decompose_permute_for_u55_pass.py |
Adds new U55 large-permute decomposition/probing pass. |
backends/arm/_passes/decompose_int16_activation_conv_pass.py |
Deletes old int16 activation conv decomposition pass. |
backends/arm/_passes/arm_pass_utils.py |
Removes output dim-order helper. |
backends/arm/_passes/arm_pass_manager.py |
Inserts new U55 permute pass and reorders slice rewriting. |
backends/arm/_passes/annotate_output_dim_order_pass.py |
Deletes old output dim-order annotation pass. |
backends/arm/_passes/__init__.py |
Updates exported pass list for removed/added passes. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| dtype = map_dtype(val.dtype) | ||
| shape = tuple(val.size()) | ||
|
|
||
| dim_order = tuple(range(len(shape))) | ||
| return (dtype, shape, dim_order) | ||
| return (dtype, shape) |
| # This is a quick check to avoid the overhead of the Vela compilation in 99% of cases. | ||
| if not self._violates_u55_worst_case_constraint(input_shape): | ||
| return super().call_operator(op, args, kwargs, meta) | ||
|
|
| assert isinstance(tensor, torch.Tensor), ( | ||
| f"Expected lifted tensor constant '{node.name}' to be a torch.Tensor, got " | ||
| f"{type(tensor).__name__}" | ||
| ) |
| """Extract dtype, shape, and dimension order from FX metadata. | ||
|
|
||
| Args: | ||
| meta (dict): FX node ``meta`` containing a ``val`` FakeTensor (or tuple). | ||
|
|
||
| Returns: | ||
| tuple[ts.DType, tuple[int, ...], tuple[int, ...]]: Tuple containing | ||
| tensor dtype, shape, and dimension order. | ||
| tensor dtype and shape. |
| permutation and dtype to check wheter it is supported. | ||
| """ | ||
|
|
||
| if dtype not in (torch.int8, torch.bool, torch.int16): |
cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell