From 28f7b2ac89467444a36977293484d3d3c4cade90 Mon Sep 17 00:00:00 2001
From: JM Etancelin <jean-matthieu.etancelin@univ-pau.fr>
Date: Mon, 30 Nov 2020 15:22:52 +0100
Subject: [PATCH] Fix tasks overlapping issue with redistributeInter

---
 hysop/core/graph/computational_graph.py |  4 ++--
 hysop/core/mpi/redistribute.py          | 25 ++-----------------------
 hysop/domain/domain.py                  | 10 +++++++++-
 3 files changed, 13 insertions(+), 26 deletions(-)

diff --git a/hysop/core/graph/computational_graph.py b/hysop/core/graph/computational_graph.py
index f48f3febf..16af5b5f0 100644
--- a/hysop/core/graph/computational_graph.py
+++ b/hysop/core/graph/computational_graph.py
@@ -536,12 +536,12 @@ class ComputationalGraph(ComputationalGraphNode):
         ops, already_printed = [], []
         for (i, pn, n, op) in __recurse_nodes(self.nodes, nprefix=self.name+"."):
             for k in self._profiler.all_data.keys():
-                if n == '.'.join(k.split('.')[:-1]) and not n in already_printed:
+                if n == '.'.join(k.split('.')[:-1]) and not k in already_printed:
                     strdata = (i, str(op.mpi_params.task_id), pn, type(op).__name__, k.split('.')[-1])
                     values = tuple(ff.format(self._profiler.all_data[k][_])
                                    for ff, _ in zip(("{:.5g}",  "{}", "{:.5g}", "{:.5g}", "{:.5g}", "{}"), (2, 1, 3, 4, 5, 0)))
                     ops += multiline_split(strdata + values, maxlen, split_sep, replace, newline_prefix)
-                    already_printed.append(n)
+                    already_printed.append(k)
 
         isize = max(strlen(s[0]) for s in ops)
         tasksize = max(max(strlen(s[1]) for s in ops), 6)
diff --git a/hysop/core/mpi/redistribute.py b/hysop/core/mpi/redistribute.py
index b5e82a817..1f0f7c9df 100644
--- a/hysop/core/mpi/redistribute.py
+++ b/hysop/core/mpi/redistribute.py
@@ -572,29 +572,8 @@ class RedistributeInter(RedistributeOperatorBase):
         if _is_target:
             assert all(target_topo.mesh.local_resolution == ofield.resolution)
 
-        # Compute if there is an overlap
-        src_overlap, dst_overlap = -1, -1
-        if _is_source:
-            src_overlap = source_topo.mpi_params.comm.allreduce(1 if _is_source and _is_target else 0)
-        if _is_target:
-            dst_overlap = target_topo.mpi_params.comm.allreduce(1 if _is_source and _is_target else 0)
-        if domain.task_rank() == 0 and (src_overlap == -1 or dst_overlap == -1):
-            rcv_overlap = domain.parent_comm.sendrecv(
-                src_overlap if dst_overlap == -1 else dst_overlap,
-                sendtag=self._other_task_id, recvtag=first_not_None((source_topo, target_topo)).mpi_params.task_id,
-                dest=domain.task_root_in_parent(self._other_task_id),
-                source=domain.task_root_in_parent(self._other_task_id))
-            src_overlap, dst_overlap = [rcv_overlap if _ == -1 else _ for _ in (src_overlap, dst_overlap)]
-        # ... then broadcast
-        if _is_source:
-            dst_overlap = source_topo.mpi_params.comm.bcast(dst_overlap, root=0)
-        if _is_target:
-            src_overlap = target_topo.mpi_params.comm.bcast(src_overlap, root=0)
-        assert (src_overlap+dst_overlap) % 2 == 0
-        self._has_overlap = src_overlap+dst_overlap > 0
-
         # Create bridges and store comm types and indices
-        if self._has_overlap:
+        if not domain.tasks_overlapping(source_id, target_id) is None:
             self.bridge = BridgeOverlap(source=source_topo, target=target_topo,
                                         source_id=source_id, target_id=target_id,
                                         dtype=self.dtype, order=get_mpi_order(first_not_None((ifield, ofield)).sdata))
@@ -638,7 +617,7 @@ class RedistributeInter(RedistributeOperatorBase):
     @op_apply
     def apply(self, **kwds):
         comm = self.bridge.comm
-        rank = self.mpi_params.rank
+        rank = comm.Get_rank()
         types = self._comm_types
         indices = self._comm_indices
         dFin, dFout = self.dFin, self.dFout
diff --git a/hysop/domain/domain.py b/hysop/domain/domain.py
index 0fb99eba5..09a423f65 100644
--- a/hysop/domain/domain.py
+++ b/hysop/domain/domain.py
@@ -190,6 +190,9 @@ class DomainView(TaggedObjectView):
         """Equivalent to self.long_description()"""
         return self.long_description()
 
+    def tasks_overlapping(self, ta, tb):
+        return self._domain._overlapping_map[ta][tb]
+
     domain = property(_get_domain)
     dim = property(_get_dim)
     proc_tasks = property(_get_proc_tasks)
@@ -339,6 +342,8 @@ class Domain(RegisteredObject):
 
         # Create intercommunicators from current task to others
         task_intercomm = {}
+        # task overlapping map : gives the largest task of two overlapping tasks
+        overlapping_map = dict([(_, dict([(__, None) for __ in all_tasks if _ != __])) for _ in all_tasks])
         # For all tasks the current rank is involved in
         my_tasks = tuple((_ for _ in all_tasks if is_task_matters(_, proc_tasks[parent_rank])))
         for tsource in my_tasks:
@@ -356,10 +361,12 @@ class Domain(RegisteredObject):
                     if any([all([t in _ for _ in proc_tasks]) for t in all_tasks]):
                         # TODO: review if nested tasks with differents ranks
                         # for the moment : ensure ranks are identical throw all local tasks
-                        assert all([all([t.values()[0] == _ for _ in t.values()]) for t in all_task_ranks])
+                        #assert all([all([t.values()[0] == _ for _ in t.values()]) for t in all_task_ranks]), all_task_ranks
                         # If nested tasks, we use the largest task communicator
                         largest_task = [t for t in all_tasks if all([t in _ for _ in proc_tasks])][0]
                         intercomm = task_comm[largest_task]
+                        overlapping_map[tsource][tdest] = largest_task
+                        overlapping_map[tdest][tsource] = largest_task
                     else:
                         raise NotImplementedError()
                 task_intercomm[tdest] = intercomm
@@ -387,6 +394,7 @@ class Domain(RegisteredObject):
         self._proc_tasks = proc_tasks
         self._registered_topologies = {}
         self._frame = SymbolicFrame(dim=dim)
+        self._overlapping_map = overlapping_map
 
     def register_topology(self, topo):
         """Register a new topology on this domain.
-- 
GitLab