From 7de87308b143493ae9508e9633a26399e5df3206 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Mon, 4 May 2026 17:55:19 +0100 Subject: [PATCH 01/10] compiler: Avoid useless instrumentation --- devito/passes/iet/instrument.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/devito/passes/iet/instrument.py b/devito/passes/iet/instrument.py index 4bc75b88b0..9bd1ce2134 100644 --- a/devito/passes/iet/instrument.py +++ b/devito/passes/iet/instrument.py @@ -140,7 +140,7 @@ def sync_sections(iet, langbb=None, profiler=None, **kwargs): symbols = FindSymbols().visit(tl) queues = [i for i in symbols if isinstance(i, langbb.AsyncQueue)] - unnecessary = any(FindNodes(BusyWait).visit(tl)) + unnecessary = any(FindNodes((BusyWait, RemainderCall)).visit(tl)) if queues and not unnecessary: waits = tuple(sync(i) for i in queues) mapper[tl] = tl._rebuild(body=tl.body + waits) From 40c63b530a55056551116a533b88294bc4572079 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Tue, 5 May 2026 15:15:08 +0100 Subject: [PATCH 02/10] compiler: Patch MPI with overlap2 --- devito/mpi/halo_scheme.py | 1 - devito/passes/iet/mpi.py | 2 +- tests/test_mpi.py | 29 +++++++++++++++++++++++++++++ 3 files changed, 30 insertions(+), 2 deletions(-) diff --git a/devito/mpi/halo_scheme.py b/devito/mpi/halo_scheme.py index fc41ba7ce2..33bcb12a9d 100644 --- a/devito/mpi/halo_scheme.py +++ b/devito/mpi/halo_scheme.py @@ -7,7 +7,6 @@ import sympy from sympy import Max, Min -from devito import configuration from devito.data import CENTER, CORE, LEFT, OWNED, RIGHT from devito.ir.support import Forward, Scope from devito.symbolics.manipulation import _uxreplace_registry diff --git a/devito/passes/iet/mpi.py b/devito/passes/iet/mpi.py index 3a3d354905..25550e1915 100644 --- a/devito/passes/iet/mpi.py +++ b/devito/passes/iet/mpi.py @@ -510,7 +510,7 @@ def _is_mergeable(hsf0, hsf1, scope): return False # Ensure `hsf0` and `hsf1` are compatible - if hsf0.dimensions != hsf1.dimensions or \ + if not hsf0.dimensions.issubset(hsf1.dimensions) or \ not hsf0.functions & hsf1.functions: return False diff --git a/tests/test_mpi.py b/tests/test_mpi.py index 7746e06155..3d5525afe4 100644 --- a/tests/test_mpi.py +++ b/tests/test_mpi.py @@ -2221,6 +2221,35 @@ def test_halo_inner_dim(self, mode): assert np.isclose(norm(e), 23484.863, rtol=0, atol=1e-1) + @pytest.mark.parallel(mode=[(1, 'overlap2')]) + def test_merge_subset_comms_w_overlap(self, mode): + grid = Grid(shape=(50, 50, 50)) + + f = Function(name='f', grid=grid, space_order=8) + g = f.func(name='g') + u = TimeFunction(name='u', grid=grid, time_order=2, space_order=8, + save=Buffer(2)) + + eqns = [ + Eq(f, u.dx), + Eq(g, u.dy), + Eq(u.forward, u + u.backward + u.laplace + f + g.dx), + ] + + op = Operator(eqns) + + op.cfunction + + # Check generated code -- expected one halo exchange for `u` before + # the first set of loops within `tloop`, and one halo exchange for `g` + # before the second set of loops + tloop = get_time_loop(op) + body = tloop.nodes[0].body[0].body + halo_update0 = body[0].body[0] + assert isinstance(halo_update0, HaloUpdateList) + halo_update1 = body[1].body[0] + assert isinstance(halo_update1, HaloUpdateList) + class TestOperatorAdvanced: From 7d6e930fbb71f743cf0a2aed2f3508ffa054ed8a Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Tue, 5 May 2026 15:58:49 +0100 Subject: [PATCH 03/10] compiler: Remove unused DualHaloExchangeBuilder --- devito/mpi/routines.py | 43 ++++++------------------------------------ 1 file changed, 6 insertions(+), 37 deletions(-) diff --git a/devito/mpi/routines.py b/devito/mpi/routines.py index 74c788fa0e..f861bf9e75 100644 --- a/devito/mpi/routines.py +++ b/devito/mpi/routines.py @@ -957,36 +957,6 @@ def _call_remainder(self, remainder): return remainder -class DualHaloExchangeBuilder(Overlap2HaloExchangeBuilder): - - """ - "Dual" of Overlap2HaloExchangeBuilder, as the "remainder" is now the first - thing getting computed. - - Generates: - - remainder() - haloupdate() - compute_core() - halowait() - """ - - def _make_body(self, callcompute, remainder, haloupdates, halowaits): - body = [] - - assert remainder is not None - body.append(self._call_remainder(remainder)) - - body.append(HaloUpdateList(body=haloupdates)) - - assert callcompute is not None - body.append(callcompute) - - body.append(HaloWaitList(body=halowaits)) - - return List(body=body) - - class FullHaloExchangeBuilder(Overlap2HaloExchangeBuilder): """ @@ -1058,7 +1028,6 @@ def _call_poke(self, poke): 'overlap': OverlapHaloExchangeBuilder, 'overlap2': Overlap2HaloExchangeBuilder, 'full': FullHaloExchangeBuilder, - 'dual': DualHaloExchangeBuilder } @@ -1117,6 +1086,12 @@ def __init__(self, arguments, **kwargs): super().__init__('MPI_Irecv', arguments) +class AllreduceCall(Call): + + def __init__(self, arguments, **kwargs): + super().__init__('MPI_Allreduce', arguments, **kwargs) + + class MPICall(Call): @property @@ -1426,12 +1401,6 @@ def _arg_values(self, args=None, **kwargs): return values -class AllreduceCall(Call): - - def __init__(self, arguments, **kwargs): - super().__init__('MPI_Allreduce', arguments, **kwargs) - - class ReductionBuilder: """ From 1477afe814b636b962483d6ad581a953c13082a4 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Wed, 6 May 2026 10:06:44 +0100 Subject: [PATCH 04/10] compiler: HaloScheme.build -> self._rebuild --- devito/ir/iet/visitors.py | 2 +- devito/mpi/halo_scheme.py | 30 +++++++++++++++++++----------- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 2744da24d6..1eae6433e3 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -1494,7 +1494,7 @@ def visit_ParallelTree(self, o): def visit_HaloSpot(self, o): hs = o.halo_scheme fmapper = {self.mapper.get(k, k): v for k, v in hs.fmapper.items()} - halo_scheme = hs.build(fmapper, hs.honored) + halo_scheme = hs._rebuild(fmapper=fmapper) body = self._visit(o.body) return o._rebuild(halo_scheme=halo_scheme, body=body) diff --git a/devito/mpi/halo_scheme.py b/devito/mpi/halo_scheme.py index 33bcb12a9d..0b638c9140 100644 --- a/devito/mpi/halo_scheme.py +++ b/devito/mpi/halo_scheme.py @@ -136,11 +136,9 @@ def __init__(self, exprs, ispace): # Derive the halo exchanges self._mapper = frozendict(classify(exprs, ispace)) - # Track the IterationSpace offsets induced by SubDomains/SubDimensions. - # These should be honored in the derivation of the `omapper` + # Track the IterationSpace offsets induced by SubDomains/SubDimensions, + # which are honored in the derivation of the `omapper` self._honored = {} - # SubDimensions are not necessarily included directly in - # ispace.dimensions and hence we need to first utilize the `_defines` method dims = set().union(*[d._defines for d in ispace.dimensions if d._defines & self.dimensions]) subdims = [d for d in dims if d.is_Sub and not d.local] @@ -164,11 +162,21 @@ def __len__(self): def __hash__(self): return hash((self._mapper.__hash__(), self.honored.__hash__())) - @classmethod - def build(cls, fmapper, honored): + def _rebuild(self, fmapper=None, honored=None): + """ + Rebuild a HaloScheme from the provided `fmapper` and `honored`. Reuse + `self`'s values for the missing arguments. + """ obj = object.__new__(HaloScheme) + + if fmapper is None: + fmapper = self._mapper + if honored is None: + honored = self._honored + obj._mapper = frozendict(fmapper) obj._honored = frozendict(honored) + return obj @classmethod @@ -222,7 +230,7 @@ def union(self, halo_schemes): for d, v in i.honored.items(): honored[d] = honored.get(d, frozenset()) | v - return HaloScheme.build(fmapper, honored) + return i._rebuild(fmapper=fmapper, honored=honored) @property def honored(self): @@ -482,7 +490,7 @@ def project(self, functions): to the provided `functions`. """ fmapper = {f: v for f, v in self.fmapper.items() if f in as_tuple(functions)} - return HaloScheme.build(fmapper, self.honored) + return self._rebuild(fmapper=fmapper) def drop(self, functions): """ @@ -490,7 +498,7 @@ def drop(self, functions): corresponding to the provided `functions`. """ fmapper = {f: v for f, v in self.fmapper.items() if f not in as_tuple(functions)} - return HaloScheme.build(fmapper, self.honored) + return self._rebuild(fmapper=fmapper) def add(self, f, hse): """ @@ -502,7 +510,7 @@ def add(self, f, hse): if f in fmapper: hse = fmapper[f].union(hse) fmapper[f] = hse - return HaloScheme.build(fmapper, self.honored) + return self._rebuild(fmapper=fmapper) def merge(self, hs): """ @@ -511,7 +519,7 @@ def merge(self, hs): fmapper = dict(self.fmapper) for f, hse in hs.fmapper.items(): fmapper[f] = fmapper.get(f, hse).merge(hse) - return HaloScheme.build(fmapper, self.honored) + return self._rebuild(fmapper=fmapper) def classify(exprs, ispace): From cdd3b2ab7a52af6de73ff7c0f7f40055d3cfeb32 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Wed, 6 May 2026 10:32:14 +0100 Subject: [PATCH 05/10] compiler: Refactor _mark_overlappable --- devito/passes/iet/mpi.py | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/devito/passes/iet/mpi.py b/devito/passes/iet/mpi.py index 25550e1915..165fc90fd2 100644 --- a/devito/passes/iet/mpi.py +++ b/devito/passes/iet/mpi.py @@ -11,7 +11,9 @@ from devito.mpi.reduction_scheme import DistReduce from devito.mpi.routines import HaloExchangeBuilder, ReductionBuilder from devito.passes.iet.engine import iet_pass +from devito.symbolics import VectorAccess from devito.tools import generator +from devito.types import TensorMove __all__ = ['mpiize'] @@ -276,15 +278,21 @@ def _mark_overlappable(iet): # Analysis found = [] for hs in FindNodes(HaloSpot).visit(iet): + # Heuristic: avoid comp/comm overlap for sparse Iteration nests + iters = FindNodes(Iteration).visit(hs) + if any(i.dim._defines & set(hs.halo_scheme.distributed_aindices) and + not i.is_Affine for i in iters): + continue + + # Check legality. Comp/comm overlaps is legal only if the OWNED regions + # can grow arbitrarily, which means all of the dependencies must be + # carried along a non-halo Dimension expressions = FindNodes(Expression).visit(hs) if not expressions: continue scope = Scope(i.expr for i in expressions) - # Comp/comm overlaps is legal only if the OWNED regions can grow - # arbitrarily, which means all of the dependencies must be carried - # along a non-halo Dimension for dep in scope.d_all_gen(): if dep.function in hs.functions: cause = dep.cause & hs.dimensions @@ -295,20 +303,9 @@ def _mark_overlappable(iet): # ... = ... f[x, y-1] ... # for y # f[x, y] = ... - test = False break else: - test = True - - # Heuristic: avoid comp/comm overlap for sparse Iteration nests - if test: - for i in FindNodes(Iteration).visit(hs): - if i.dim._defines & set(hs.halo_scheme.distributed_aindices) and \ - not i.is_Affine: - test = False - break - - if test: + # All good! found.append(hs) # Transform the IET replacing HaloSpots with OverlappableHaloSpots From 86d0e63a1da9ba46db9380469c11086595a10b52 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Wed, 6 May 2026 11:51:17 +0100 Subject: [PATCH 06/10] compiler: Fix HaloScheme.omapper taking alignment into account --- devito/mpi/halo_scheme.py | 64 +++++++++++++++++++++++++----- devito/passes/iet/mpi.py | 29 +++++++++++--- devito/symbolics/extended_sympy.py | 10 +++++ devito/types/parallel.py | 6 +++ 4 files changed, 93 insertions(+), 16 deletions(-) diff --git a/devito/mpi/halo_scheme.py b/devito/mpi/halo_scheme.py index 0b638c9140..132b917c77 100644 --- a/devito/mpi/halo_scheme.py +++ b/devito/mpi/halo_scheme.py @@ -9,6 +9,7 @@ from devito.data import CENTER, CORE, LEFT, OWNED, RIGHT from devito.ir.support import Forward, Scope +from devito.symbolics import IntDiv from devito.symbolics.manipulation import _uxreplace_registry from devito.tools import ( EnrichedTuple, Reconstructable, Tag, as_tuple, filter_ordered, filter_sorted, flatten, @@ -147,6 +148,10 @@ def __init__(self, exprs, ispace): self._honored[i.root] = frozenset([(ltk, rtk)]) self._honored = frozendict(self._honored) + # Further constraints on the `omapper` derivation. At construction time + # there's none, but lowering passes may change this + self._alignment = None + def __repr__(self): fnames = ",".join(i.name for i in set(self._mapper)) return f"HaloScheme<{fnames}>" @@ -162,7 +167,7 @@ def __len__(self): def __hash__(self): return hash((self._mapper.__hash__(), self.honored.__hash__())) - def _rebuild(self, fmapper=None, honored=None): + def _rebuild(self, fmapper=None, honored=None, alignment=None): """ Rebuild a HaloScheme from the provided `fmapper` and `honored`. Reuse `self`'s values for the missing arguments. @@ -176,6 +181,7 @@ def _rebuild(self, fmapper=None, honored=None): obj._mapper = frozendict(fmapper) obj._honored = frozendict(honored) + obj._alignment = alignment or self._alignment return obj @@ -248,10 +254,14 @@ def is_void(self): @cached_property def omapper(self): """ - Logical decomposition of the DOMAIN region into OWNED and CORE sub-regions. + Logical decomposition of the DOMAIN region into OWNED and CORE sub-regions, + "cumulative" over all DiscreteFunctions in the HaloScheme. + + The computed OMapper takes into account: - This is "cumulative" over all DiscreteFunctions in the HaloScheme; it also - takes into account IterationSpace offsets induced by SubDomains/SubDimensions. + * The offsets induced by SubDomains/SubDimensions ("thickness"); + * Any data alignment requirement of the underlying expressions + (`_alignment` attribute). Examples -------- @@ -373,28 +383,62 @@ def omapper(self): if s is CENTER: where.append((d, CORE, s)) - mapper[d] = (d.symbolic_min + osl, - d.symbolic_max - osr) + + mapper[d] = ( + d.symbolic_min + osl, + d.symbolic_max - osr + ) + if nl != 0: mapper[nl] = (Max(nl - osl, 0),) if nr != 0: mapper[nr] = (Max(nr - osr, 0),) else: where.append((d, OWNED, s)) + if s is LEFT: - mapper[d] = (d.symbolic_min, - Min(d.symbolic_min + osl - 1, d.symbolic_max - nr)) + mapper[d] = ( + d.symbolic_min, + Min(d.symbolic_min + osl - 1, d.symbolic_max - nr) + ) + if nl != 0: mapper[nl] = (nl,) mapper[nr] = (0,) else: - mapper[d] = (Max(d.symbolic_max - osr + 1, d.symbolic_min + nl), - d.symbolic_max) + mapper[d] = ( + Max(d.symbolic_max - osr + 1, d.symbolic_min + nl), + d.symbolic_max + ) + if nr != 0: mapper[nl] = (0,) mapper[nr] = (nr,) + processed.append((tuple(where), frozendict(mapper))) + # Apply the alignment constraints, if any + # First, get the fastest varying (contiguous) Dimension, which is the + # one that matters for alignment + if self._alignment: + fvds = {f.dimensions[-1] for f in self.fmapper} + if len(fvds) != 1: + raise HaloSchemeException( + "Unexpected contiguous Dimensions found while computing the " + f"`omapper`: {fvds}" + ) + fvd = fvds.pop() + + for i, (where, mapper) in enumerate(list(processed)): + try: + m, M = mapper[fvd] + except KeyError: + continue + + aligned_m = IntDiv(m, self._alignment) * self._alignment + + processed[i] = (where, frozendict({**mapper, fvd: (aligned_m, M)})) + _, core = processed.pop(0) owned = processed diff --git a/devito/passes/iet/mpi.py b/devito/passes/iet/mpi.py index 165fc90fd2..bc1536d154 100644 --- a/devito/passes/iet/mpi.py +++ b/devito/passes/iet/mpi.py @@ -11,7 +11,7 @@ from devito.mpi.reduction_scheme import DistReduce from devito.mpi.routines import HaloExchangeBuilder, ReductionBuilder from devito.passes.iet.engine import iet_pass -from devito.symbolics import VectorAccess +from devito.symbolics import VectorAccess, search from devito.tools import generator from devito.types import TensorMove @@ -287,11 +287,11 @@ def _mark_overlappable(iet): # Check legality. Comp/comm overlaps is legal only if the OWNED regions # can grow arbitrarily, which means all of the dependencies must be # carried along a non-halo Dimension - expressions = FindNodes(Expression).visit(hs) - if not expressions: + exprs = FindNodes(Expression).visit(hs) + if not exprs: continue - scope = Scope(i.expr for i in expressions) + scope = Scope([n.expr for n in exprs]) for dep in scope.d_all_gen(): if dep.function in hs.functions: @@ -305,11 +305,28 @@ def _mark_overlappable(iet): # f[x, y] = ... break else: - # All good! + # All good -- we can perform comp/comm overlap! found.append(hs) + # The underlying `exprs` might have data alignment constraints due to the + # presence of objects such as VectorAccess or TensorMove, which expect the + # starting address of the data to be aligned to a certain value. Comp/comm + # overlap creates multiple iteration spaces (for the core and owned + # regions), which might break the alignment contract if we don't play safe + # -- imposing these regions start at a carefully rounded-down point, at the + # cost of potentially performing a bit of redundant compute + mapper = {} + for hs in found: + exprs = [n.expr for n in FindNodes(Expression).visit(hs)] + objs = search(exprs, (VectorAccess, TensorMove)) + alignment = max([i._expected_alignment for i in objs], default=None) + + hsf = hs.halo_scheme._rebuild(alignment=alignment) + hs1 = hs._rebuild(halo_scheme=hsf) + + mapper[hs] = OverlappableHaloSpot(**hs1.args) + # Transform the IET replacing HaloSpots with OverlappableHaloSpots - mapper = {hs: OverlappableHaloSpot(**hs.args) for hs in found} iet = Transformer(mapper, nested=True).visit(iet) return iet diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index 46d2c5a43a..cd71e4ad4a 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -119,6 +119,10 @@ def __new__(cls, lhs, rhs, params=None): elif rhs == 1 or rhs is None: return lhs + if is_integer(lhs) and is_integer(rhs): + # Both sides are plain integers -- perform the division right away + return lhs // rhs + if not is_integer(rhs): # Perhaps it's a symbolic RHS -- but we wanna be sure it's of type int if not hasattr(rhs, 'dtype'): @@ -890,6 +894,12 @@ class VectorAccess(Expr, Pickable, BasicWrapperMixin): Represent a vector access operation at high-level. """ + _expected_alignment = 16 + """ + The expected alignment in bytes for the accessed vector. This must be + honored by the compiler for correctness. + """ + def __new__(cls, *args, **kwargs): return Expr.__new__(cls, *args) diff --git a/devito/types/parallel.py b/devito/types/parallel.py index 9670c766a8..4b769780b2 100644 --- a/devito/types/parallel.py +++ b/devito/types/parallel.py @@ -421,6 +421,12 @@ class TensorMove(Expr, Reserved, Terminal): __rargs__ = ('base', 'tid0', 'coords') + _expected_alignment = 16 + """ + The expected alignment in bytes for the accessed vector. This must be + honored by the compiler for correctness. + """ + def __new__(cls, base, tid0, coords, **kwargs): return super().__new__(cls, base, tid0, coords) From 077d7a73dfc7f4311d39466ffa8e49e3bde620e9 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Thu, 7 May 2026 14:49:56 +0100 Subject: [PATCH 07/10] compiler: Fix MPI alignment by using alignment_elems --- devito/mpi/halo_scheme.py | 2 + devito/passes/iet/mpi.py | 2 +- devito/symbolics/extended_sympy.py | 71 +++++++++++++++++++++++------- devito/types/parallel.py | 20 ++------- 4 files changed, 61 insertions(+), 34 deletions(-) diff --git a/devito/mpi/halo_scheme.py b/devito/mpi/halo_scheme.py index 132b917c77..d342cf6cb3 100644 --- a/devito/mpi/halo_scheme.py +++ b/devito/mpi/halo_scheme.py @@ -150,6 +150,8 @@ def __init__(self, exprs, ispace): # Further constraints on the `omapper` derivation. At construction time # there's none, but lowering passes may change this + # * `_alignment` may be a positive integer representing the alignment + # requirement, in number of *elements*, of the underlying expressions self._alignment = None def __repr__(self): diff --git a/devito/passes/iet/mpi.py b/devito/passes/iet/mpi.py index bc1536d154..2de99ee002 100644 --- a/devito/passes/iet/mpi.py +++ b/devito/passes/iet/mpi.py @@ -319,7 +319,7 @@ def _mark_overlappable(iet): for hs in found: exprs = [n.expr for n in FindNodes(Expression).visit(hs)] objs = search(exprs, (VectorAccess, TensorMove)) - alignment = max([i._expected_alignment for i in objs], default=None) + alignment = max([i._expected_alignment_elems for i in objs], default=None) hsf = hs.halo_scheme._rebuild(alignment=alignment) hs1 = hs._rebuild(halo_scheme=hsf) diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index cd71e4ad4a..2f5761cd10 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -3,6 +3,7 @@ """ import re from contextlib import suppress +from functools import cached_property import numpy as np import sympy @@ -25,7 +26,7 @@ 'MathFunction', 'InlineIf', 'Reserved', 'ReservedWord', 'Keyword', 'String', 'Macro', 'Class', 'MacroArgument', 'RoundUp', 'Deref', 'Namespace', 'Rvalue', 'Null', 'SizeOf', 'rfunc', 'BasicWrapperMixin', - 'ValueLimit', 'VectorAccess'] + 'ValueLimit', 'AlignedAccess', 'VectorAccess'] class CondEq(sympy.Eq): @@ -888,16 +889,47 @@ def __str__(self): __repr__ = __str__ -class VectorAccess(Expr, Pickable, BasicWrapperMixin): +class AlignedAccess(Expr, Reserved, BasicWrapperMixin): """ - Represent a vector access operation at high-level. + Abstract base class for an aligned access operation, that is an access to a + memory location that is guaranteed to be aligned to a certain byte + boundary. """ - _expected_alignment = 16 + @property + def _expected_alignment(self): + """ + The expected alignment in bytes for the underlying LOAD/STORE operation. + + To be implemented by subclasses. + """ + raise NotImplementedError + + @property + def _expected_alignment_elems(self): + """ + The expected alignment in number of elements for the underlying + LOAD/STORE operation. + """ + return self._expected_alignment // self.dtype().itemsize + + @property + def base(self): + return self.args[0] + + func = Reserved._rebuild + + @cacheit + def sort_key(self, order=None): + # Ensure that the AlignedAccess is sorted as the base + return self.base.sort_key(order=order) + + +class VectorAccess(AlignedAccess): + """ - The expected alignment in bytes for the accessed vector. This must be - honored by the compiler for correctness. + Represent a vector access operation at high-level. """ def __new__(cls, *args, **kwargs): @@ -908,21 +940,28 @@ def __str__(self): __repr__ = __str__ - func = Pickable._rebuild - - @property - def base(self): - return self.args[0] + @cached_property + def _expected_alignment(self): + """ + The expected alignment in bytes for the underlying LOAD/STORE operation. + """ + mapper = { + # dtype==float => lowered with float4 => 4*4=16 bytes alignment; + np.float32: 16, + # dtype==half => lowered with float2 => 2*4=8 bytes alignment; + np.float16: 8 + } + try: + return mapper[self.function.dtype] + except KeyError: + raise ValueError( + f"Unsupported dtype `{self.function.dtype}` for VectorAccess" + ) @property def indices(self): return self.base.indices - @cacheit - def sort_key(self, order=None): - # Ensure that the VectorAccess is sorted as the base - return self.base.sort_key(order=order) - # Some other utility objects Null = Macro('NULL') diff --git a/devito/types/parallel.py b/devito/types/parallel.py index 4b769780b2..09498df986 100644 --- a/devito/types/parallel.py +++ b/devito/types/parallel.py @@ -11,11 +11,10 @@ from functools import cached_property import numpy as np -from sympy import Expr from devito.exceptions import InvalidArgument from devito.parameters import configuration -from devito.symbolics import Reserved, Terminal, search +from devito.symbolics import AlignedAccess, Terminal, search from devito.tools import as_list, as_tuple, is_integer from devito.types.array import Array, ArrayObject from devito.types.basic import Scalar, Symbol @@ -403,7 +402,7 @@ def __init_finalize__(self, *args, **kwargs): super().__init_finalize__(*args, **kwargs) -class TensorMove(Expr, Reserved, Terminal): +class TensorMove(AlignedAccess, Terminal): """ Represent the LOAD/STORE of a multi-dimensional block of data from/to a higher @@ -423,17 +422,12 @@ class TensorMove(Expr, Reserved, Terminal): _expected_alignment = 16 """ - The expected alignment in bytes for the accessed vector. This must be - honored by the compiler for correctness. + The expected alignment in bytes for the underlying LOAD/STORE operation. """ def __new__(cls, base, tid0, coords, **kwargs): return super().__new__(cls, base, tid0, coords) - @property - def base(self): - return self.args[0] - @property def tid0(self): return self.args[1] @@ -442,10 +436,6 @@ def tid0(self): def coords(self): return self.args[2] - @property - def function(self): - return self.base.function - @cached_property def indexed(self): return self.function[self.coords] @@ -454,9 +444,5 @@ def indexed(self): def ndim(self): return self.function.ndim - func = Reserved._rebuild - def _ccode(self, printer): return str(self) - - _sympystr = _ccode From cb03b523851bf2b3bc661cebe0cd7a8cbfc7c698 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Fri, 8 May 2026 13:31:06 +0100 Subject: [PATCH 08/10] compiler: Purge MPI dual mode --- devito/mpi/halo_scheme.py | 30 +++++++++++------------------- 1 file changed, 11 insertions(+), 19 deletions(-) diff --git a/devito/mpi/halo_scheme.py b/devito/mpi/halo_scheme.py index d342cf6cb3..3b4bcb5bc8 100644 --- a/devito/mpi/halo_scheme.py +++ b/devito/mpi/halo_scheme.py @@ -573,12 +573,6 @@ def classify(exprs, ispace): Produce the mapper `Function -> HaloSchemeEntry`, which describes the necessary halo exchanges in the given Scope. """ - - # Some MPI modes require pulling the `loc_indices` from the reads, others - # from the writes. It essentially depends on whether the halo exchange is - # performed before (reads) or after (writes) the OWNED region is computed - loc_indices_from_reads = configuration['mpi'] not in ('dual',) - scope = Scope(exprs) mapper = {} @@ -618,7 +612,7 @@ def classify(exprs, ispace): else: v[(d, LEFT)] = STENCIL v[(d, RIGHT)] = STENCIL - elif loc_indices_from_reads: + else: v[(d, i[d])] = NONE # Does `i` actually require a halo exchange? @@ -626,7 +620,9 @@ def classify(exprs, ispace): continue # Derive diagonal halo exchanges from the previous analysis - combs = list(product([LEFT, CENTER, RIGHT], repeat=len(f._dist_dimensions))) + combs = list( + product([LEFT, CENTER, RIGHT], repeat=len(f._dist_dimensions)) + ) combs.remove((CENTER,)*len(f._dist_dimensions)) for c in combs: key = (f._dist_dimensions, c) @@ -651,13 +647,6 @@ def classify(exprs, ispace): if not halo_labels: continue - # Augment `halo_labels` with `loc_indices`-related information if necessary - if not loc_indices_from_reads: - for i in scope.writes.get(f, []): - for d in i.findices: - if not f.grid.is_distributed(d): - halo_labels[(d, i[d])].add(NONE) - # Separate halo-exchange Dimensions from `loc_indices` raw_loc_indices, halos = defaultdict(list), [] for (d, s), hl in halo_labels.items(): @@ -666,15 +655,18 @@ def classify(exprs, ispace): if not hl: continue elif len(hl) > 1: - raise HaloSchemeException("Inconsistency found while building a halo " - f"scheme for `{f}` along Dimension `{d}`") + raise HaloSchemeException( + "Inconsistency found while building a halo scheme for " + f"`{f}` along Dimension `{d}`") elif hl.pop() is STENCIL: halos.append(Halo(d, s)) elif d._defines & set(ispace.itdims): raw_loc_indices[d].append(s) - loc_indices, loc_dirs = process_loc_indices(raw_loc_indices, - ispace.directions) + loc_indices, loc_dirs = process_loc_indices( + raw_loc_indices, ispace.directions + ) + mapper[f] = HaloSchemeEntry(loc_indices, loc_dirs, halos, dims) return mapper From 9a56cdf5c34650064bc206105384822a78275382 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Fri, 8 May 2026 14:06:20 +0100 Subject: [PATCH 09/10] compiler: Fix CustomTopology --- devito/mpi/distributed.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/devito/mpi/distributed.py b/devito/mpi/distributed.py index 7daeaeadd9..5af1a6f421 100644 --- a/devito/mpi/distributed.py +++ b/devito/mpi/distributed.py @@ -261,7 +261,7 @@ def nprocs_local(self): @property def topology(self): - return DimensionTuple(*self._topology, getters=self.dimensions) + return self._topology @property def topology_logical(self): @@ -353,7 +353,9 @@ def __init__(self, shape, dimensions, input_comm=None, topology=None): self._topology = compute_dims(self._input_comm.size, len(shape)) else: # A custom topology may contain integers or the wildcard '*' - self._topology = CustomTopology(topology, self._input_comm) + self._topology = CustomTopology( + topology, self._input_comm, getters=dimensions + ) if self._input_comm is not input_comm: # By default, Devito arranges processes into a cartesian topology. @@ -896,7 +898,7 @@ def _arg_values(self, *args, **kwargs): return self._arg_defaults() -class CustomTopology(tuple): +class CustomTopology(DimensionTuple): """ The CustomTopology class provides a mechanism to describe parametric domain @@ -954,7 +956,7 @@ class CustomTopology(tuple): 'xy': ('*', '*', 1), } - def __new__(cls, items, input_comm): + def __new__(cls, items, input_comm, **kwargs): # Keep track of nstars and already defined decompositions nstars = items.count('*') @@ -992,11 +994,15 @@ def __new__(cls, items, input_comm): # Final check that topology matches the communicator size assert np.prod(processed) == input_comm.size - obj = super().__new__(cls, processed) + obj = super().__new__(cls, *processed, **kwargs) obj.logical = items return obj + def __repr__(self): + return (f"CustomTopology(logical={self.logical}, " + f"physical={super().__repr__()})") + def compute_dims(nprocs, ndim): # We don't do anything clever here. In fact, we do something very basic -- From 7d316b5971ab509f61276c324a455338012f69a0 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Mon, 11 May 2026 14:59:06 +0100 Subject: [PATCH 10/10] compiler: Tweak for ruff happiness --- devito/symbolics/extended_sympy.py | 4 ++-- tests/test_mpi.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index 2f5761cd10..4523fbd2f4 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -953,8 +953,8 @@ def _expected_alignment(self): } try: return mapper[self.function.dtype] - except KeyError: - raise ValueError( + except KeyError as e: + raise ValueError from e( f"Unsupported dtype `{self.function.dtype}` for VectorAccess" ) diff --git a/tests/test_mpi.py b/tests/test_mpi.py index 3d5525afe4..7329cc31e3 100644 --- a/tests/test_mpi.py +++ b/tests/test_mpi.py @@ -2238,7 +2238,7 @@ def test_merge_subset_comms_w_overlap(self, mode): op = Operator(eqns) - op.cfunction + _ = op.cfunction # Check generated code -- expected one halo exchange for `u` before # the first set of loops within `tloop`, and one halo exchange for `g`