From 56bf10c7644aa34296dae8e6c014885cdf31cc87 Mon Sep 17 00:00:00 2001 From: Jean-Matthieu Etancelin <jean-matthieu.etancelin@univ-pau.fr> Date: Thu, 6 Jun 2019 17:48:40 +0200 Subject: [PATCH] Fix spatial filtering operator for mpi --- hysop/operator/base/spatial_filtering.py | 73 ++++++++++++------------ 1 file changed, 38 insertions(+), 35 deletions(-) diff --git a/hysop/operator/base/spatial_filtering.py b/hysop/operator/base/spatial_filtering.py index cb23b75f5..ff79da9c0 100644 --- a/hysop/operator/base/spatial_filtering.py +++ b/hysop/operator/base/spatial_filtering.py @@ -1,3 +1,5 @@ +# coding: utf-8 + """ @file advection.py RestrictionFilter operator generator. @@ -26,7 +28,7 @@ class SpatialFilterBase(object): """ Common base implementation for lowpass spatial filtering: small grid -> coarse grid """ - + def __init__(self, input_field, output_field, input_topo, output_topo, **kwds): @@ -53,7 +55,7 @@ class SpatialFilterBase(object): self.dtype = find_common_dtype(Fin.dtype, Fout.dtype) self.iratio = None # will be set in get_field_requirements self.grid_ratio = None # will be set in discretize - + @debug def discretize(self): if self.discretized: @@ -65,7 +67,7 @@ class SpatialFilterBase(object): self.dFin = dFin self.dFout = dFout self.grid_ratio = grid_ratio - + @classmethod def supports_multiple_field_topologies(cls): return True @@ -73,7 +75,7 @@ class SpatialFilterBase(object): @classmethod def supports_mpi(cls): return True - + def get_preserved_input_fields(self): return {self.Fin} @@ -89,7 +91,7 @@ class RestrictionFilterBase(SpatialFilterBase): Fin_dx = Fin_topo.space_step except AttributeError: Fin_dx = Fin_topo.mesh.space_step - + Fout_topo, Fout_requirements = requirements.get_output_requirement(self.Fout) try: Fout_dx = Fout_topo.space_step @@ -103,10 +105,10 @@ class RestrictionFilterBase(SpatialFilterBase): iratio = ratio.astype(npw.int32) msg='Grid ratio is not an integer on at least one axis: {}'.format(ratio) assert (ratio==iratio).all(), msg - + self.iratio = tuple(iratio.tolist()) return requirements - + class InterpolationFilterBase(SpatialFilterBase): @debug @@ -119,7 +121,7 @@ class InterpolationFilterBase(SpatialFilterBase): Fin_dx = Fin_topo.space_step except AttributeError: Fin_dx = Fin_topo.mesh.space_step - + Fout_topo, Fout_requirements = requirements.get_output_requirement(self.Fout) try: Fout_dx = Fout_topo.space_step @@ -133,7 +135,7 @@ class InterpolationFilterBase(SpatialFilterBase): iratio = ratio.astype(npw.int32) msg='Grid ratio is not an integer on at least one axis: {}'.format(ratio) assert (ratio==iratio).all(), msg - + self.iratio = tuple(iratio.tolist()) return requirements @@ -144,12 +146,12 @@ class SpectralRestrictionFilterBase(RestrictionFilterBase, SpectralOperatorBase) using the spectral method. """ @debug - def __init__(self, plot_input_energy=None, - plot_output_energy=None, - **kwds): + def __init__(self, plot_input_energy=None, + plot_output_energy=None, + **kwds): """ Initialize a SpectralRestrictionFilterBase. - + Parameters ---------- plot_input_energy: IOParams, optional, defaults to None @@ -166,7 +168,7 @@ class SpectralRestrictionFilterBase(RestrictionFilterBase, SpectralOperatorBase) """ check_instance(plot_input_energy, IOParams, allow_none=True) check_instance(plot_output_energy, IOParams, allow_none=True) - + super(SpectralRestrictionFilterBase, self).__init__(**kwds) Fin, Fout = self.Fin, self.Fout @@ -181,7 +183,7 @@ class SpectralRestrictionFilterBase(RestrictionFilterBase, SpectralOperatorBase) # build spectral transforms tg_fine = self.new_transform_group(mem_tag='FINE') tg_coarse = self.new_transform_group(mem_tag='COARSE') - + Ft = tg_fine.require_forward_transform(Fin, custom_output_buffer='auto', plot_energy=plot_input_energy) Bt = tg_coarse.require_backward_transform(Fout, custom_input_buffer='B0', plot_energy=plot_output_energy) @@ -200,10 +202,10 @@ class SpectralRestrictionFilterBase(RestrictionFilterBase, SpectralOperatorBase) msg = 'Compute resolution of coarse mesh {}::{} is greater than compute resolution of fine mesh {}::{}.' msg=msg.format(self.Fin.name, dFin.compute_resolution, self.Fout.name, dFout.compute_resolution) assert (dFin.compute_resolution >= dFout.compute_resolution).all(), msg - + def setup(self, work): super(SpectralRestrictionFilterBase, self).setup(work) - self.FIN = self.Ft.output_buffer + self.FIN = self.Ft.output_buffer self.FOUT = self.Bt.input_buffer self.fslices = self._generate_filter_slices() self.scaling = self._compute_scaling_coefficient() @@ -214,7 +216,7 @@ class SpectralRestrictionFilterBase(RestrictionFilterBase, SpectralOperatorBase) transforms = tuple(self.Ft.transforms[i] for i in self.Ft.output_axes) for (N,n,tr) in zip(self.FIN.shape, self.FOUT.shape, transforms): - assert len(src_slices) == len(dst_slices) + assert len(src_slices) == len(dst_slices) assert n <= N if SpectralTransformUtils.is_C2C(tr): left_src_slices = [l[:] for l in src_slices] @@ -225,7 +227,7 @@ class SpectralRestrictionFilterBase(RestrictionFilterBase, SpectralOperatorBase) lslc.append(lsrc) rslc.append(rsrc) src_slices = left_src_slices + right_src_slices - + left_dst_slices = [l[:] for l in dst_slices] right_dst_slices = [l[:] for l in dst_slices] ldst = slice(0, (n+1)//2, 1) @@ -243,7 +245,7 @@ class SpectralRestrictionFilterBase(RestrictionFilterBase, SpectralOperatorBase) src_slices = tuple( tuple(_) for _ in src_slices ) dst_slices = tuple( tuple(_) for _ in dst_slices ) return (src_slices, dst_slices) - + def _compute_scaling_coefficient(self): # scaling can depend on the fft backend so we bruteforce it # in every backend @@ -256,7 +258,7 @@ class RemeshRestrictionFilterBase(RestrictionFilterBase): Base implementation for lowpass spatial filtering: small grid -> coarse grid using remeshing kernels. """ - + __default_method = { Remesh: Remesh.L2_1, } @@ -293,29 +295,31 @@ class RemeshRestrictionFilterBase(RestrictionFilterBase): assert remesh_kernel.n % 2 == 0, 'Odd remeshing kernel moments.' min_ghosts = int(remesh_kernel.n//2)+1 return min_ghosts - + @debug def get_field_requirements(self): requirements = super(RemeshRestrictionFilterBase, self).get_field_requirements() iratio = self.iratio remesh_ghosts = self.remesh_ghosts(self.remesh_kernel) - fine_grid_ghosts = iratio*remesh_ghosts - 1 + fine_grid_ghosts = tuple(np.multiply(iratio, remesh_ghosts) - 1) + Fin_topo, Fin_requirements = requirements.get_input_requirement(self.Fin) Fin_requirements.min_ghosts = fine_grid_ghosts - + self.remesh_ghosts = remesh_ghosts self.fine_grid_ghosts = fine_grid_ghosts return requirements def compute_weights(self, iratio, product=True): - assert (iratio>=1).all() + iratio_np = np.asarray(iratio) + assert (iratio_np>=1).all() remesh_kernel = self.remesh_kernel p = remesh_kernel.n//2 + 1 - shape = 2*p*iratio-1 + shape = 2*p*iratio_np-1 weights = npw.zeros(dtype=npw.float64, shape=shape) nz_weights = {} for idx in npw.ndindex(*shape): - X = (npw.asarray(idx, dtype=npw.float64)+1) / iratio - p + X = (npw.asarray(idx, dtype=npw.float64)+1) / iratio_np - p if product: W = npw.prod(remesh_kernel(X)) else: @@ -326,11 +330,11 @@ class RemeshRestrictionFilterBase(RestrictionFilterBase): if (W!=0): nz_weights[idx] = W Ws = weights.sum() - weights = weights / Ws + weights = weights / Ws nz_weights = {k: v/Ws for (k,v) in nz_weights.iteritems()} assert abs(weights.sum() - 1.0) < 1e-8, weights.sum() - assert abs(npw.sum(nz_weights.values()) - 1.0) < 1e-8, npw.sum(nz_weights.values()) + assert abs(npw.sum(nz_weights.values()) - 1.0) < 1e-8, npw.sum(nz_weights.values()) self.weights = weights self.nz_weights = nz_weights @@ -341,12 +345,12 @@ class RemeshRestrictionFilterBase(RestrictionFilterBase): return super(RemeshRestrictionFilterBase, self).discretize() dFin, dFout = self.dFin, self.dFout - + grid_ratio = self.grid_ratio self.compute_weights(grid_ratio) - + remesh_ghosts = self.remesh_ghosts - fine_grid_ghosts = grid_ratio*remesh_ghosts - 1 + fine_grid_ghosts = np.multiply(grid_ratio, remesh_ghosts) - 1 fin = dFin.sdata[dFin.local_slices(ghosts=fine_grid_ghosts)] fout = dFout.compute_buffers[0] @@ -358,7 +362,7 @@ class SubgridRestrictionFilterBase(RestrictionFilterBase): Base implementation for lowpass spatial filtering: small grid -> coarse grid using subgrid """ - + @debug def discretize(self): if self.discretized: @@ -368,7 +372,7 @@ class SubgridRestrictionFilterBase(RestrictionFilterBase): grid_ratio = self.grid_ratio view = tuple(slice(None,None,r) for r in grid_ratio) - + fin = dFin.compute_buffers[0][view] fout = dFout.compute_buffers[0] @@ -435,4 +439,3 @@ class PolynomialRestrictionFilterBase(PolynomialInterpolationMethod, Restriction self.fin = dFin.sdata[dFin.local_slices(ghosts=ghosts)].handle self.fout = dFout.sdata[dFout.compute_slices].handle self.iter_shape = self.dFout.compute_resolution - -- GitLab