From d68438718100fcb67af93593e31bfc9be7376b17 Mon Sep 17 00:00:00 2001
From: JM Etancelin <jean-matthieu.etancelin@univ-pau.fr>
Date: Thu, 4 Jun 2020 17:26:52 +0200
Subject: [PATCH] fix mpi bridge

---
 hysop/core/mpi/bridge.py     | 13 +++++++++----
 hysop/core/mpi/topo_tools.py | 21 ++++++++-------------
 hysop/problem.py             |  5 ++++-
 3 files changed, 21 insertions(+), 18 deletions(-)

diff --git a/hysop/core/mpi/bridge.py b/hysop/core/mpi/bridge.py
index fdc200f5b..901e32886 100644
--- a/hysop/core/mpi/bridge.py
+++ b/hysop/core/mpi/bridge.py
@@ -376,10 +376,15 @@ class BridgeOverlap(Bridge):
     def _build_send_recv_dict(self):
         # Compute local intersections : i.e. find which grid points
         # are on both source and target mesh.
-        indices_source = TopoTools.gather_global_indices_overlap(
-            self._source, self.comm, self.domain)
-        indices_target = TopoTools.gather_global_indices_overlap(
-            self._target, self.comm, self.domain)
+        # Filter out the empty slices (due to none topologies)
+        indices_source = dict([(rk, sl) for rk, sl in
+                               TopoTools.gather_global_indices_overlap(
+                                   self._source, self.comm, self.domain).iteritems()
+                               if not all([_ == slice(0, 0) for _ in sl])])
+        indices_target = dict([(rk, sl) for rk, sl in
+                               TopoTools.gather_global_indices_overlap(
+                                   self._target, self.comm, self.domain).iteritems()
+                               if not all([_ == slice(0, 0) for _ in sl])])
 
         # From now on, we have indices_source[rk] = global indices (slice)
         # of grid points of the source on process number rk in parent.
diff --git a/hysop/core/mpi/topo_tools.py b/hysop/core/mpi/topo_tools.py
index 31e605dc7..6867528c0 100644
--- a/hysop/core/mpi/topo_tools.py
+++ b/hysop/core/mpi/topo_tools.py
@@ -74,9 +74,7 @@ class TopoTools(object):
         if root is None:
             comm.Allgather([iglob[:, rank], MPI.INT], [iglob_res, MPI.INT])
         else:
-            comm.Gather([iglob[:, rank], MPI.INT], [iglob_res, MPI.INT],
-                        root=root)
-
+            comm.Gather([iglob[:, rank], MPI.INT], [iglob_res, MPI.INT], root=root)
         if toslice:
             return Utils.array_to_dict(iglob_res)
         else:
@@ -133,19 +131,18 @@ class TopoTools(object):
             size = comm.Get_size()
             rank = comm.Get_rank()
             dimension = dom.dim
-            iglob = npw.integer_zeros((size, dimension * 2))
-            iglob_res = npw.integer_zeros((size, dimension * 2))
+            iglob = npw.integer_zeros((dimension * 2, size), order='F')
+            iglob_res = npw.integer_zeros((dimension * 2, size), order='F')
+            iglob[0::2, rank] = 0
             iglob[1::2, rank] = -1
             if root is None:
-                comm.Allgather([iglob[rank, :], MPI.INT], [iglob_res, MPI.INT])
+                comm.Allgather([iglob[:, rank], MPI.INT], [iglob_res, MPI.INT])
             else:
-                comm.Gather([iglob[rank, :], MPI.INT], [iglob_res, MPI.INT],
-                            root=root)
+                comm.Gather([iglob[:, rank], MPI.INT], [iglob_res, MPI.INT], root=root)
             if toslice:
-                return Utils.array_to_dict(iglob_res.T)
+                return Utils.array_to_dict(iglob_res)
             else:
-                return iglob_res.T
-
+                return iglob_res
         else:
             return TopoTools.gather_global_indices(topo, toslice, root, comm)
 
@@ -286,8 +283,6 @@ class TopoTools(object):
         basetype = dtype_to_mpi_type(dtype)
         subtype = basetype.Create_subarray(shape, subshape, substart, order=order)
         subtype.Commit()
-        # print 'MPI_Create_subarray(shape={}, subshape={}, substart={}, order={})'.format(
-        # shape, subshape, substart, 'C' if order is MPI.ORDER_C else 'F')
         return subtype
 
     @staticmethod
diff --git a/hysop/problem.py b/hysop/problem.py
index 7970a071f..840bf45e2 100644
--- a/hysop/problem.py
+++ b/hysop/problem.py
@@ -100,7 +100,7 @@ class Problem(ComputationalGraph):
                         msg += '\n If this is required, override check_unique_clenv().'
                         raise RuntimeError(msg)
 
-    def initialize_field(self, field, **kwds):
+    def initialize_field(self, field, mpi_params=None, **kwds):
         """Initialize a field on all its input and output topologies."""
         initialized = set()
         for op in self.nodes:
@@ -113,6 +113,9 @@ class Problem(ComputationalGraph):
                     if all((df in initialized) for df in dfield.discrete_fields()):
                         # all contained scalar fields were already initialized
                         continue
+                    elif mpi_params and mpi_params.task_id != dfield.topology.task_id:
+                        # Topology task does not matches given mpi_params task
+                        continue
                     else:
                         components = ()
                         for (component, scalar_dfield) in dfield.nd_iter():
-- 
GitLab