Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def skipif(items, whole_module=False):
accepted.update({'device', 'device-C', 'device-openmp', 'device-openacc',
'device-aomp', 'cpu64-icc', 'cpu64-icx', 'cpu64-nvc',
'noadvisor', 'cpu64-arm', 'cpu64-icpx', 'chkpnt'})
accepted.update({'nodevice', 'noomp'})
accepted.update({'nodevice', 'noomp', 'nointel'})
unknown = sorted(set(items) - accepted)
if unknown:
raise ValueError(f"Illegal skipif argument(s) `{unknown}`")
Expand Down Expand Up @@ -93,6 +93,11 @@ def skipif(items, whole_module=False):
if i == 'noomp' and 'openmp' not in configuration['language']:
skipit = "Must use openmp"
break
# Skip if not using an Intel compiler
if i == 'nointel' and \
not isinstance(configuration['compiler'], (IntelCompiler, OneapiCompiler)):
skipit = "Must use an Intel compiler"
break
# Skip if it won't run on Arm
if i == 'cpu64-arm' and isinstance(configuration['platform'], Arm):
skipit = "Arm doesn't support x86-specific instructions"
Expand Down
14 changes: 13 additions & 1 deletion devito/passes/iet/languages/openmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from sympy import And, Ne, Not

from devito.arch import AMDGPUX, INTELGPUX, NVIDIAX, PVC
from devito.arch.compiler import CustomCompiler, GNUCompiler, NvidiaCompiler
from devito.arch.compiler import (
CustomCompiler, GNUCompiler, IntelCompiler, NvidiaCompiler, OneapiCompiler
)
from devito.ir import (
Call, Conditional, DeviceCall, FindSymbols, List, ParallelBlock, PointerCast, Pragma,
Prodder, While
Expand Down Expand Up @@ -276,6 +278,16 @@ def _support_complex_reduction(cls, compiler):
# Gcc doesn't supports complex reduction
return not isinstance(compiler, GNUCompiler)

@classmethod
def _support_nested_parallelism(cls, compiler):
# In case we have a CustomCompiler
if isinstance(compiler, CustomCompiler):
compiler = compiler._base()
if isinstance(compiler, (IntelCompiler, OneapiCompiler)): # noqa: SIM103
return True
else:
return False


class Ompizer(AbstractOmpizer):
langbb = OmpBB
Expand Down
16 changes: 15 additions & 1 deletion devito/passes/iet/parpragma.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ def _support_array_reduction(cls, compiler):
def _support_complex_reduction(cls, compiler):
return False

@classmethod
def _support_nested_parallelism(cls, compiler):
return False

@property
def simd_reg_nbytes(self):
return self.platform.simd_reg_nbytes
Expand Down Expand Up @@ -342,6 +346,15 @@ def _make_parregion(self, partree, parrays):
def _make_guard(self, parregion):
return parregion

def _support_uindices(self, uindices):
if not uindices:
# No secondary indices, so we can apply nested parallelism
return True
else:
# Compiler supports nested parallelism with multiple indices
# such as for(int i = 0, j=1; ...)
return self._support_nested_parallelism(self.compiler)

def _make_nested_partree(self, partree):
# Apply heuristic
if self.nhyperthreads <= self.nested:
Expand All @@ -366,7 +379,8 @@ def _make_nested_partree(self, partree):
# within a block)
candidates = []
for i in inner:
if self.key(i) and any((j.dim.root is i.dim.root) for j in outer):
if self.key(i) and any((j.dim.root is i.dim.root) for j in outer) and \
self._support_uindices(i.uindices):
candidates.append(i)
elif candidates:
# If there's at least one candidate but `i` doesn't honor the
Expand Down
41 changes: 41 additions & 0 deletions tests/test_dle.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
PrecomputedSparseTimeFunction, ReduceMax, ReduceMin, ReduceMinMax, SpaceDimension,
SparseTimeFunction, SubDimension, TimeFunction, configuration, cos, dimensions, info
)
from devito.arch.compiler import IntelCompiler, OneapiCompiler
from devito.exceptions import InvalidArgument
from devito.ir.iet import (
Expression, FindNodes, IsPerfectIteration, Iteration, retrieve_iteration_tree
Expand Down Expand Up @@ -1297,6 +1298,7 @@ def test_collapsing(self):
('omp parallel for collapse(2) schedule(dynamic,1) '
'num_threads(nthreads_nested)')

@skipif('nointel')
def test_multiple_subnests_v0(self):
grid = Grid(shape=(3, 3, 3))
x, y, z = grid.dimensions
Expand Down Expand Up @@ -1329,6 +1331,7 @@ def test_multiple_subnests_v0(self):
('omp parallel for collapse(2) schedule(dynamic,1) '
'num_threads(nthreads_nested)')

@skipif('nointel')
def test_multiple_subnests_v1(self):
"""
Unlike ``test_multiple_subnestes_v0``, now we use the ``cire-rotate=True``
Expand Down Expand Up @@ -1461,3 +1464,41 @@ def test_collapsing_w_wo_halo(self, exprs, collapsed, scheduling):

assert iterations[1].pragmas[0].ccode.value ==\
"".join([ompfor_string, scheduling_string])

def test_nested_parallelism_support(self):
grid = Grid(shape=(10, 10, 10))

f = Function(name='f', grid=grid, space_order=4)
v = TimeFunction(name="v", grid=grid, space_order=4)
v1 = TimeFunction(name="v1", grid=grid, space_order=4)

f.data_with_halo[:] = 0.5
v.data_with_halo[:] = 1.
v1.data_with_halo[:] = 1.

eqn = Eq(v.forward, (v.dx * (1 + 2*f) * f).dx)
op = Operator(eqn, opt=('advanced', {'openmp': True,
'par-collapse-ncores': 1,
'par-nested': 0}))

bns, _ = assert_blocking(op, {'x0_blk0'})
trees = retrieve_iteration_tree(bns['x0_blk0'])
assert len(trees) == 2

# Check omp pargams
assert trees[0][0].pragmas[0].ccode.value == \
'omp for collapse(2) schedule(dynamic,1)'
if isinstance(configuration['compiler'], (IntelCompiler, OneapiCompiler)):
# Supports nested parallelism
assert trees[0][2].pragmas[0].ccode.value == \
'omp parallel for collapse(2) schedule(dynamic,1)'\
' num_threads(nthreads_nested)'
assert trees[1][2].pragmas[0].ccode.value == \
trees[0][2].pragmas[0].ccode.value
else:
# Most compiler don't support nested parallelism
assert not trees[0][2].pragmas
assert not trees[1][2].pragmas

# Should compile properly
op.cfunction # noqa: B018
Loading