Skip to content
Snippets Groups Projects
Commit 56bf10c7 authored by EXT Jean-Matthieu Etancelin's avatar EXT Jean-Matthieu Etancelin
Browse files

Fix spatial filtering operator for mpi

parent 797f2edd
No related branches found
No related tags found
1 merge request!16MPI operators
Pipeline #24694 failed
# coding: utf-8
"""
@file advection.py
RestrictionFilter operator generator.
......@@ -26,7 +28,7 @@ class SpatialFilterBase(object):
"""
Common base implementation for lowpass spatial filtering: small grid -> coarse grid
"""
def __init__(self, input_field, output_field,
input_topo, output_topo,
**kwds):
......@@ -53,7 +55,7 @@ class SpatialFilterBase(object):
self.dtype = find_common_dtype(Fin.dtype, Fout.dtype)
self.iratio = None # will be set in get_field_requirements
self.grid_ratio = None # will be set in discretize
@debug
def discretize(self):
if self.discretized:
......@@ -65,7 +67,7 @@ class SpatialFilterBase(object):
self.dFin = dFin
self.dFout = dFout
self.grid_ratio = grid_ratio
@classmethod
def supports_multiple_field_topologies(cls):
return True
......@@ -73,7 +75,7 @@ class SpatialFilterBase(object):
@classmethod
def supports_mpi(cls):
return True
def get_preserved_input_fields(self):
return {self.Fin}
......@@ -89,7 +91,7 @@ class RestrictionFilterBase(SpatialFilterBase):
Fin_dx = Fin_topo.space_step
except AttributeError:
Fin_dx = Fin_topo.mesh.space_step
Fout_topo, Fout_requirements = requirements.get_output_requirement(self.Fout)
try:
Fout_dx = Fout_topo.space_step
......@@ -103,10 +105,10 @@ class RestrictionFilterBase(SpatialFilterBase):
iratio = ratio.astype(npw.int32)
msg='Grid ratio is not an integer on at least one axis: {}'.format(ratio)
assert (ratio==iratio).all(), msg
self.iratio = tuple(iratio.tolist())
return requirements
class InterpolationFilterBase(SpatialFilterBase):
@debug
......@@ -119,7 +121,7 @@ class InterpolationFilterBase(SpatialFilterBase):
Fin_dx = Fin_topo.space_step
except AttributeError:
Fin_dx = Fin_topo.mesh.space_step
Fout_topo, Fout_requirements = requirements.get_output_requirement(self.Fout)
try:
Fout_dx = Fout_topo.space_step
......@@ -133,7 +135,7 @@ class InterpolationFilterBase(SpatialFilterBase):
iratio = ratio.astype(npw.int32)
msg='Grid ratio is not an integer on at least one axis: {}'.format(ratio)
assert (ratio==iratio).all(), msg
self.iratio = tuple(iratio.tolist())
return requirements
......@@ -144,12 +146,12 @@ class SpectralRestrictionFilterBase(RestrictionFilterBase, SpectralOperatorBase)
using the spectral method.
"""
@debug
def __init__(self, plot_input_energy=None,
plot_output_energy=None,
**kwds):
def __init__(self, plot_input_energy=None,
plot_output_energy=None,
**kwds):
"""
Initialize a SpectralRestrictionFilterBase.
Parameters
----------
plot_input_energy: IOParams, optional, defaults to None
......@@ -166,7 +168,7 @@ class SpectralRestrictionFilterBase(RestrictionFilterBase, SpectralOperatorBase)
"""
check_instance(plot_input_energy, IOParams, allow_none=True)
check_instance(plot_output_energy, IOParams, allow_none=True)
super(SpectralRestrictionFilterBase, self).__init__(**kwds)
Fin, Fout = self.Fin, self.Fout
......@@ -181,7 +183,7 @@ class SpectralRestrictionFilterBase(RestrictionFilterBase, SpectralOperatorBase)
# build spectral transforms
tg_fine = self.new_transform_group(mem_tag='FINE')
tg_coarse = self.new_transform_group(mem_tag='COARSE')
Ft = tg_fine.require_forward_transform(Fin, custom_output_buffer='auto', plot_energy=plot_input_energy)
Bt = tg_coarse.require_backward_transform(Fout, custom_input_buffer='B0', plot_energy=plot_output_energy)
......@@ -200,10 +202,10 @@ class SpectralRestrictionFilterBase(RestrictionFilterBase, SpectralOperatorBase)
msg = 'Compute resolution of coarse mesh {}::{} is greater than compute resolution of fine mesh {}::{}.'
msg=msg.format(self.Fin.name, dFin.compute_resolution, self.Fout.name, dFout.compute_resolution)
assert (dFin.compute_resolution >= dFout.compute_resolution).all(), msg
def setup(self, work):
super(SpectralRestrictionFilterBase, self).setup(work)
self.FIN = self.Ft.output_buffer
self.FIN = self.Ft.output_buffer
self.FOUT = self.Bt.input_buffer
self.fslices = self._generate_filter_slices()
self.scaling = self._compute_scaling_coefficient()
......@@ -214,7 +216,7 @@ class SpectralRestrictionFilterBase(RestrictionFilterBase, SpectralOperatorBase)
transforms = tuple(self.Ft.transforms[i] for i in self.Ft.output_axes)
for (N,n,tr) in zip(self.FIN.shape, self.FOUT.shape, transforms):
assert len(src_slices) == len(dst_slices)
assert len(src_slices) == len(dst_slices)
assert n <= N
if SpectralTransformUtils.is_C2C(tr):
left_src_slices = [l[:] for l in src_slices]
......@@ -225,7 +227,7 @@ class SpectralRestrictionFilterBase(RestrictionFilterBase, SpectralOperatorBase)
lslc.append(lsrc)
rslc.append(rsrc)
src_slices = left_src_slices + right_src_slices
left_dst_slices = [l[:] for l in dst_slices]
right_dst_slices = [l[:] for l in dst_slices]
ldst = slice(0, (n+1)//2, 1)
......@@ -243,7 +245,7 @@ class SpectralRestrictionFilterBase(RestrictionFilterBase, SpectralOperatorBase)
src_slices = tuple( tuple(_) for _ in src_slices )
dst_slices = tuple( tuple(_) for _ in dst_slices )
return (src_slices, dst_slices)
def _compute_scaling_coefficient(self):
# scaling can depend on the fft backend so we bruteforce it
# in every backend
......@@ -256,7 +258,7 @@ class RemeshRestrictionFilterBase(RestrictionFilterBase):
Base implementation for lowpass spatial filtering: small grid -> coarse grid
using remeshing kernels.
"""
__default_method = {
Remesh: Remesh.L2_1,
}
......@@ -293,29 +295,31 @@ class RemeshRestrictionFilterBase(RestrictionFilterBase):
assert remesh_kernel.n % 2 == 0, 'Odd remeshing kernel moments.'
min_ghosts = int(remesh_kernel.n//2)+1
return min_ghosts
@debug
def get_field_requirements(self):
requirements = super(RemeshRestrictionFilterBase, self).get_field_requirements()
iratio = self.iratio
remesh_ghosts = self.remesh_ghosts(self.remesh_kernel)
fine_grid_ghosts = iratio*remesh_ghosts - 1
fine_grid_ghosts = tuple(np.multiply(iratio, remesh_ghosts) - 1)
Fin_topo, Fin_requirements = requirements.get_input_requirement(self.Fin)
Fin_requirements.min_ghosts = fine_grid_ghosts
self.remesh_ghosts = remesh_ghosts
self.fine_grid_ghosts = fine_grid_ghosts
return requirements
def compute_weights(self, iratio, product=True):
assert (iratio>=1).all()
iratio_np = np.asarray(iratio)
assert (iratio_np>=1).all()
remesh_kernel = self.remesh_kernel
p = remesh_kernel.n//2 + 1
shape = 2*p*iratio-1
shape = 2*p*iratio_np-1
weights = npw.zeros(dtype=npw.float64, shape=shape)
nz_weights = {}
for idx in npw.ndindex(*shape):
X = (npw.asarray(idx, dtype=npw.float64)+1) / iratio - p
X = (npw.asarray(idx, dtype=npw.float64)+1) / iratio_np - p
if product:
W = npw.prod(remesh_kernel(X))
else:
......@@ -326,11 +330,11 @@ class RemeshRestrictionFilterBase(RestrictionFilterBase):
if (W!=0):
nz_weights[idx] = W
Ws = weights.sum()
weights = weights / Ws
weights = weights / Ws
nz_weights = {k: v/Ws for (k,v) in nz_weights.iteritems()}
assert abs(weights.sum() - 1.0) < 1e-8, weights.sum()
assert abs(npw.sum(nz_weights.values()) - 1.0) < 1e-8, npw.sum(nz_weights.values())
assert abs(npw.sum(nz_weights.values()) - 1.0) < 1e-8, npw.sum(nz_weights.values())
self.weights = weights
self.nz_weights = nz_weights
......@@ -341,12 +345,12 @@ class RemeshRestrictionFilterBase(RestrictionFilterBase):
return
super(RemeshRestrictionFilterBase, self).discretize()
dFin, dFout = self.dFin, self.dFout
grid_ratio = self.grid_ratio
self.compute_weights(grid_ratio)
remesh_ghosts = self.remesh_ghosts
fine_grid_ghosts = grid_ratio*remesh_ghosts - 1
fine_grid_ghosts = np.multiply(grid_ratio, remesh_ghosts) - 1
fin = dFin.sdata[dFin.local_slices(ghosts=fine_grid_ghosts)]
fout = dFout.compute_buffers[0]
......@@ -358,7 +362,7 @@ class SubgridRestrictionFilterBase(RestrictionFilterBase):
Base implementation for lowpass spatial filtering: small grid -> coarse grid
using subgrid
"""
@debug
def discretize(self):
if self.discretized:
......@@ -368,7 +372,7 @@ class SubgridRestrictionFilterBase(RestrictionFilterBase):
grid_ratio = self.grid_ratio
view = tuple(slice(None,None,r) for r in grid_ratio)
fin = dFin.compute_buffers[0][view]
fout = dFout.compute_buffers[0]
......@@ -435,4 +439,3 @@ class PolynomialRestrictionFilterBase(PolynomialInterpolationMethod, Restriction
self.fin = dFin.sdata[dFin.local_slices(ghosts=ghosts)].handle
self.fout = dFout.sdata[dFout.compute_slices].handle
self.iter_shape = self.dFout.compute_resolution
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment