Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions devito/operations/interpolators.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,18 +166,20 @@ class Injection(UnevaluatedSparseOperation):

__rargs__ = ('field', 'expr', 'implicit_dims') + UnevaluatedSparseOperation.__rargs__

def __new__(cls, field, expr, implicit_dims, interpolator):
def __new__(cls, field, expr, increment, implicit_dims, interpolator):
obj = super().__new__(cls, interpolator)

# TODO: unused now, but will be necessary to compute the adjoint
obj.field = field
obj.expr = expr
obj.increment = increment
obj.implicit_dims = implicit_dims

return obj

def operation(self, **kwargs):
return self.interpolator._inject(expr=self.expr, field=self.field,
increment=self.increment,
implicit_dims=self.implicit_dims)

def __repr__(self):
Expand Down Expand Up @@ -372,7 +374,7 @@ def interpolate(self, expr, increment=False, self_subs=None, implicit_dims=None)

@check_radius
@check_coords
def inject(self, field, expr, implicit_dims=None):
def inject(self, field, expr, increment=True, implicit_dims=None):
"""
Generate equations injecting an arbitrary expression into a field.

Expand All @@ -387,7 +389,7 @@ def inject(self, field, expr, implicit_dims=None):
injection expression, but that should be honored when constructing
the operator.
"""
return Injection(field, expr, implicit_dims, self)
return Injection(field, expr, increment, implicit_dims, self)

def _interpolate(self, expr, increment=False, self_subs=None, implicit_dims=None):
"""
Expand Down Expand Up @@ -439,7 +441,7 @@ def _interpolate(self, expr, increment=False, self_subs=None, implicit_dims=None

return temps + summands + last

def _inject(self, field, expr, implicit_dims=None):
def _inject(self, field, expr, increment=True, implicit_dims=None):
"""
Generate equations injecting an arbitrary expression into a field.

Expand Down Expand Up @@ -489,9 +491,10 @@ def _inject(self, field, expr, implicit_dims=None):
pos_only=variables, subdomain=subdomain)

# Substitute coordinate base symbols into the interpolation coefficients
eqns = [Inc(_field.xreplace(idx_subs),
(self._weights(subdomain=subdomain) * _expr).xreplace(idx_subs),
implicit_dims=implicit_dims)
ecls = Inc if increment else Eq
eqns = [ecls(_field.xreplace(idx_subs),
(self._weights(subdomain=subdomain) * _expr).xreplace(idx_subs),
implicit_dims=implicit_dims)
for (_field, _expr) in zip(fields, _exprs, strict=True)]

return temps + eqns
Expand Down
6 changes: 3 additions & 3 deletions examples/seismic/self_adjoint/sa_03_iso_correctness.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -578,9 +578,9 @@
"output_type": "stream",
"text": [
"Operator `IsoFwdOperator` ran in 0.04 s\n",
"No source type defined, returning uninitiallized (zero) source\n",
"No source type defined, returning uninitialized (zero) source\n",
"Operator `IsoAdjOperator` ran in 0.03 s\n",
"No source type defined, returning uninitiallized (zero) source\n",
"No source type defined, returning uninitialized (zero) source\n",
"Operator `IsoAdjOperator` ran in 0.03 s\n"
]
},
Expand Down Expand Up @@ -639,7 +639,7 @@
"output_type": "stream",
"text": [
"Operator `IsoFwdOperator` ran in 0.03 s\n",
"No source type defined, returning uninitiallized (zero) source\n",
"No source type defined, returning uninitialized (zero) source\n",
"Operator `IsoAdjOperator` ran in 0.03 s\n"
]
},
Expand Down
2 changes: 1 addition & 1 deletion examples/seismic/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def src(self):
def new_src(self, name='src', src_type='self', coordinates=None):
coords = coordinates or self.src_positions
if self.src_type is None or src_type is None:
warning("No source type defined, returning uninitiallized (zero) source")
warning("No source type defined, returning uninitialized (zero) source")
src = PointSource(name=name, grid=self.grid,
time_range=self.time_axis, npoint=self.nsrc,
coordinates=coords,
Expand Down
21 changes: 21 additions & 0 deletions tests/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,27 @@ def test_inject(shape, coords, result, npoints=19):
assert np.allclose(a.data[indices], result, rtol=1.e-5)


@pytest.mark.parametrize('shape, coords', [
((11, 11), [(.1, .9), (.4, .4)]),
((11, 11, 11), [(.1, .9), (.4, .4), (.4, .4)])
])
def test_inject_no_incr(shape, coords, npoints=9):
a = unit_box(shape=shape)
a.data[:] = 2.
p = points(a.grid, coords, npoints=npoints)

p.data[:] = 3.
expr = p.inject(a, p, increment=False)
op = Operator(expr, subs=a.grid.spacing_map)

op(a=a)

indices = [slice(4, 5, 1) for _ in coords]
indices[0] = slice(1, -1, 1)
# Should be 3 at the points
assert np.allclose(a.data[indices], 3, rtol=1.e-5)


@pytest.mark.parametrize('shape, coords, nexpr, result', [
((11, 11), [(.05, .95), (.45, .45)], 1, 1.),
((11, 11), [(.05, .95), (.45, .45)], 2, 1.),
Expand Down
Loading