From 56bf10c7644aa34296dae8e6c014885cdf31cc87 Mon Sep 17 00:00:00 2001
From: Jean-Matthieu Etancelin <jean-matthieu.etancelin@univ-pau.fr>
Date: Thu, 6 Jun 2019 17:48:40 +0200
Subject: [PATCH] Fix spatial filtering operator for mpi

---
 hysop/operator/base/spatial_filtering.py | 73 ++++++++++++------------
 1 file changed, 38 insertions(+), 35 deletions(-)

diff --git a/hysop/operator/base/spatial_filtering.py b/hysop/operator/base/spatial_filtering.py
index cb23b75f5..ff79da9c0 100644
--- a/hysop/operator/base/spatial_filtering.py
+++ b/hysop/operator/base/spatial_filtering.py
@@ -1,3 +1,5 @@
+# coding: utf-8
+
 """
 @file advection.py
 RestrictionFilter operator generator.
@@ -26,7 +28,7 @@ class SpatialFilterBase(object):
     """
     Common base implementation for lowpass spatial filtering: small grid -> coarse grid
     """
-    
+
     def __init__(self, input_field, output_field,
                        input_topo,  output_topo,
                        **kwds):
@@ -53,7 +55,7 @@ class SpatialFilterBase(object):
         self.dtype = find_common_dtype(Fin.dtype, Fout.dtype)
         self.iratio     = None # will be set in get_field_requirements
         self.grid_ratio = None # will be set in discretize
-    
+
     @debug
     def discretize(self):
         if self.discretized:
@@ -65,7 +67,7 @@ class SpatialFilterBase(object):
         self.dFin  = dFin
         self.dFout = dFout
         self.grid_ratio = grid_ratio
-    
+
     @classmethod
     def supports_multiple_field_topologies(cls):
         return True
@@ -73,7 +75,7 @@ class SpatialFilterBase(object):
     @classmethod
     def supports_mpi(cls):
         return True
-    
+
     def get_preserved_input_fields(self):
         return {self.Fin}
 
@@ -89,7 +91,7 @@ class RestrictionFilterBase(SpatialFilterBase):
             Fin_dx = Fin_topo.space_step
         except AttributeError:
             Fin_dx = Fin_topo.mesh.space_step
-        
+
         Fout_topo, Fout_requirements = requirements.get_output_requirement(self.Fout)
         try:
             Fout_dx = Fout_topo.space_step
@@ -103,10 +105,10 @@ class RestrictionFilterBase(SpatialFilterBase):
         iratio = ratio.astype(npw.int32)
         msg='Grid ratio is not an integer on at least one axis: {}'.format(ratio)
         assert (ratio==iratio).all(), msg
-        
+
         self.iratio = tuple(iratio.tolist())
         return requirements
-    
+
 
 class InterpolationFilterBase(SpatialFilterBase):
     @debug
@@ -119,7 +121,7 @@ class InterpolationFilterBase(SpatialFilterBase):
             Fin_dx = Fin_topo.space_step
         except AttributeError:
             Fin_dx = Fin_topo.mesh.space_step
-        
+
         Fout_topo, Fout_requirements = requirements.get_output_requirement(self.Fout)
         try:
             Fout_dx = Fout_topo.space_step
@@ -133,7 +135,7 @@ class InterpolationFilterBase(SpatialFilterBase):
         iratio = ratio.astype(npw.int32)
         msg='Grid ratio is not an integer on at least one axis: {}'.format(ratio)
         assert (ratio==iratio).all(), msg
-        
+
         self.iratio = tuple(iratio.tolist())
         return requirements
 
@@ -144,12 +146,12 @@ class SpectralRestrictionFilterBase(RestrictionFilterBase, SpectralOperatorBase)
     using the spectral method.
     """
     @debug
-    def __init__(self, plot_input_energy=None, 
-                       plot_output_energy=None, 
-                       **kwds): 
+    def __init__(self, plot_input_energy=None,
+                       plot_output_energy=None,
+                       **kwds):
         """
         Initialize a SpectralRestrictionFilterBase.
-        
+
         Parameters
         ----------
         plot_input_energy: IOParams, optional, defaults to None
@@ -166,7 +168,7 @@ class SpectralRestrictionFilterBase(RestrictionFilterBase, SpectralOperatorBase)
         """
         check_instance(plot_input_energy, IOParams, allow_none=True)
         check_instance(plot_output_energy, IOParams, allow_none=True)
-        
+
         super(SpectralRestrictionFilterBase, self).__init__(**kwds)
 
         Fin, Fout = self.Fin, self.Fout
@@ -181,7 +183,7 @@ class SpectralRestrictionFilterBase(RestrictionFilterBase, SpectralOperatorBase)
         # build spectral transforms
         tg_fine   = self.new_transform_group(mem_tag='FINE')
         tg_coarse = self.new_transform_group(mem_tag='COARSE')
-        
+
         Ft = tg_fine.require_forward_transform(Fin, custom_output_buffer='auto', plot_energy=plot_input_energy)
         Bt = tg_coarse.require_backward_transform(Fout, custom_input_buffer='B0', plot_energy=plot_output_energy)
 
@@ -200,10 +202,10 @@ class SpectralRestrictionFilterBase(RestrictionFilterBase, SpectralOperatorBase)
         msg = 'Compute resolution of coarse mesh {}::{} is greater than compute resolution of fine mesh {}::{}.'
         msg=msg.format(self.Fin.name, dFin.compute_resolution, self.Fout.name, dFout.compute_resolution)
         assert (dFin.compute_resolution >= dFout.compute_resolution).all(), msg
-     
+
     def setup(self, work):
         super(SpectralRestrictionFilterBase, self).setup(work)
-        self.FIN     = self.Ft.output_buffer 
+        self.FIN     = self.Ft.output_buffer
         self.FOUT    = self.Bt.input_buffer
         self.fslices = self._generate_filter_slices()
         self.scaling = self._compute_scaling_coefficient()
@@ -214,7 +216,7 @@ class SpectralRestrictionFilterBase(RestrictionFilterBase, SpectralOperatorBase)
 
         transforms = tuple(self.Ft.transforms[i] for i in self.Ft.output_axes)
         for (N,n,tr) in zip(self.FIN.shape, self.FOUT.shape, transforms):
-            assert len(src_slices) == len(dst_slices)    
+            assert len(src_slices) == len(dst_slices)
             assert n <= N
             if SpectralTransformUtils.is_C2C(tr):
                 left_src_slices  = [l[:] for l in src_slices]
@@ -225,7 +227,7 @@ class SpectralRestrictionFilterBase(RestrictionFilterBase, SpectralOperatorBase)
                     lslc.append(lsrc)
                     rslc.append(rsrc)
                 src_slices = left_src_slices + right_src_slices
-                
+
                 left_dst_slices  = [l[:] for l in dst_slices]
                 right_dst_slices = [l[:] for l in dst_slices]
                 ldst = slice(0, (n+1)//2, 1)
@@ -243,7 +245,7 @@ class SpectralRestrictionFilterBase(RestrictionFilterBase, SpectralOperatorBase)
         src_slices = tuple( tuple(_) for _ in src_slices )
         dst_slices = tuple( tuple(_) for _ in dst_slices )
         return (src_slices, dst_slices)
-    
+
     def _compute_scaling_coefficient(self):
         # scaling can depend on the fft backend so we bruteforce it
         # in every backend
@@ -256,7 +258,7 @@ class RemeshRestrictionFilterBase(RestrictionFilterBase):
     Base implementation for lowpass spatial filtering: small grid -> coarse grid
     using remeshing kernels.
     """
-    
+
     __default_method = {
         Remesh: Remesh.L2_1,
     }
@@ -293,29 +295,31 @@ class RemeshRestrictionFilterBase(RestrictionFilterBase):
             assert remesh_kernel.n % 2 == 0, 'Odd remeshing kernel moments.'
         min_ghosts = int(remesh_kernel.n//2)+1
         return min_ghosts
-    
+
     @debug
     def get_field_requirements(self):
         requirements = super(RemeshRestrictionFilterBase, self).get_field_requirements()
         iratio = self.iratio
         remesh_ghosts    = self.remesh_ghosts(self.remesh_kernel)
-        fine_grid_ghosts = iratio*remesh_ghosts - 1
+        fine_grid_ghosts = tuple(np.multiply(iratio, remesh_ghosts) - 1)
+        Fin_topo, Fin_requirements = requirements.get_input_requirement(self.Fin)
         Fin_requirements.min_ghosts = fine_grid_ghosts
-        
+
         self.remesh_ghosts    = remesh_ghosts
         self.fine_grid_ghosts = fine_grid_ghosts
 
         return requirements
 
     def compute_weights(self, iratio, product=True):
-        assert (iratio>=1).all()
+        iratio_np = np.asarray(iratio)
+        assert (iratio_np>=1).all()
         remesh_kernel = self.remesh_kernel
         p = remesh_kernel.n//2 + 1
-        shape = 2*p*iratio-1
+        shape = 2*p*iratio_np-1
         weights = npw.zeros(dtype=npw.float64, shape=shape)
         nz_weights = {}
         for idx in npw.ndindex(*shape):
-            X = (npw.asarray(idx, dtype=npw.float64)+1) / iratio - p
+            X = (npw.asarray(idx, dtype=npw.float64)+1) / iratio_np - p
             if product:
                 W = npw.prod(remesh_kernel(X))
             else:
@@ -326,11 +330,11 @@ class RemeshRestrictionFilterBase(RestrictionFilterBase):
             if (W!=0):
                 nz_weights[idx] = W
         Ws = weights.sum()
-        weights = weights / Ws 
+        weights = weights / Ws
         nz_weights = {k: v/Ws for (k,v) in nz_weights.iteritems()}
 
         assert abs(weights.sum() - 1.0) < 1e-8, weights.sum()
-        assert abs(npw.sum(nz_weights.values()) - 1.0) < 1e-8, npw.sum(nz_weights.values()) 
+        assert abs(npw.sum(nz_weights.values()) - 1.0) < 1e-8, npw.sum(nz_weights.values())
 
         self.weights    = weights
         self.nz_weights = nz_weights
@@ -341,12 +345,12 @@ class RemeshRestrictionFilterBase(RestrictionFilterBase):
             return
         super(RemeshRestrictionFilterBase, self).discretize()
         dFin, dFout  = self.dFin, self.dFout
-    
+
         grid_ratio = self.grid_ratio
         self.compute_weights(grid_ratio)
-        
+
         remesh_ghosts    = self.remesh_ghosts
-        fine_grid_ghosts = grid_ratio*remesh_ghosts - 1
+        fine_grid_ghosts = np.multiply(grid_ratio, remesh_ghosts) - 1
         fin  = dFin.sdata[dFin.local_slices(ghosts=fine_grid_ghosts)]
         fout = dFout.compute_buffers[0]
 
@@ -358,7 +362,7 @@ class SubgridRestrictionFilterBase(RestrictionFilterBase):
     Base implementation for lowpass spatial filtering: small grid -> coarse grid
     using subgrid
     """
-    
+
     @debug
     def discretize(self):
         if self.discretized:
@@ -368,7 +372,7 @@ class SubgridRestrictionFilterBase(RestrictionFilterBase):
 
         grid_ratio = self.grid_ratio
         view = tuple(slice(None,None,r) for r in grid_ratio)
-        
+
         fin  = dFin.compute_buffers[0][view]
         fout = dFout.compute_buffers[0]
 
@@ -435,4 +439,3 @@ class PolynomialRestrictionFilterBase(PolynomialInterpolationMethod, Restriction
         self.fin  = dFin.sdata[dFin.local_slices(ghosts=ghosts)].handle
         self.fout = dFout.sdata[dFout.compute_slices].handle
         self.iter_shape = self.dFout.compute_resolution
-
-- 
GitLab