diff --git a/devito/ir/clusters/algorithms.py b/devito/ir/clusters/algorithms.py index f2bbce6b3f..b26fc4c04e 100644 --- a/devito/ir/clusters/algorithms.py +++ b/devito/ir/clusters/algorithms.py @@ -231,8 +231,15 @@ def _break_for_parallelism(self, scope, dim, timestamp): if any(dep.is_carried(i) for i in candidates): test0 = dep.is_flow and dep.is_lex_negative test1 = dep.is_anti and dep.is_lex_positive + if test0: + # If the same access pair is not a flow under logical distance, + # the dep is a buffer/modulo-aliasing artifact and fission is OK + ldist = dep.source.distance(dep.sink, logical=True) + real_flow = (ldist > 0) or \ + (ldist == 0 and dep.sink.lex_ge(dep.source)) + if not real_flow: + test0 = real_flow if test0 or test1: - # Would break a data dependence return False test = test or (bool(dep.cause & candidates) and not dep.is_lex_equal) diff --git a/devito/ir/support/basic.py b/devito/ir/support/basic.py index 7a841be671..ea2abf2cf7 100644 --- a/devito/ir/support/basic.py +++ b/devito/ir/support/basic.py @@ -1105,6 +1105,7 @@ def d_flow_gen(self): continue distance = dependence.distance + try: is_flow = distance > 0 or (r.lex_ge(w) and distance == 0) except TypeError: diff --git a/tests/test_dse.py b/tests/test_dse.py index 8399261563..b2aee06f51 100644 --- a/tests/test_dse.py +++ b/tests/test_dse.py @@ -9,10 +9,10 @@ get_params, skipif ) from devito import ( # noqa - NODE, Abs, ConditionalDimension, Constant, DefaultDimension, Derivative, Dimension, - Eq, Function, Ge, Grid, Inc, Lt, Operator, SparseTimeFunction, SubDimension, - TimeFunction, configuration, cos, dimensions, div, exp, first_derivative, floor, grad, - norm, sin, solve, sqrt, switchconfig, transpose + NODE, Abs, Buffer, ConditionalDimension, Constant, DefaultDimension, Derivative, + Dimension, Eq, Function, Ge, Grid, Inc, Lt, Operator, SparseTimeFunction, + SubDimension, TimeFunction, configuration, cos, dimensions, div, exp, + first_derivative, floor, grad, norm, sin, solve, sqrt, switchconfig, transpose ) from devito.exceptions import InvalidArgument, InvalidOperator from devito.ir import ( @@ -58,6 +58,26 @@ def test_scheduling_after_rewrite(): assert all(trees[1].root.dim is tree.root.dim for tree in trees[1:]) +def test_scheduling_no_deriv(): + grid = Grid((11, 11, 11)) + x, y, z = grid.dimensions + + image_vs = Function(name='image_vs', grid=grid, space_order=1, staggered=NODE) + p_back_xy = TimeFunction(name='p_back_xy', grid=grid, staggered=(x, y), + space_order=4, time_order=1, save=Buffer(1)) + + eqns = [Eq(image_vs, p_back_xy + image_vs), + Eq(p_back_xy.backward, p_back_xy)] + + op = Operator(eqns) + + assert_structure( + op, + ['t,x0_blk0,y0_blk0,x,y,z', 't,x1_blk0,y1_blk0,x,y,z'], + 'tx0_blk0y0_blk0xyzx1_blk0y1_blk0xyz' + ) + + @pytest.mark.parametrize('expr,expected', [ ('2*fa[x] + fb[x]', '2*fa[x] + fb[x]'), ('fa[x]**2', 'fa[x]*fa[x]'),