From dccb0373b8ad4cfd005a65e0d1b6ef652a6d8d54 Mon Sep 17 00:00:00 2001
From: Jean-Matthieu Etancelin <jean-matthieu.etancelin@univ-pau.fr>
Date: Mon, 16 Sep 2024 15:44:27 +0200
Subject: [PATCH] Improve graph building for several tasks

---
 hysop/core/graph/graph_builder.py        | 42 +++++++++++++++++++-----
 hysop/core/mpi/redistribute.py           | 19 ++++++-----
 hysop/domain/box.py                      |  4 +--
 hysop/domain/domain.py                   | 11 +++++++
 hysop/fields/cartesian_discrete_field.py |  4 ++-
 hysop/topology/cartesian_topology.py     | 39 ++++++++++++++++++----
 hysop_examples/examples/tasks/tasks.py   |  3 +-
 7 files changed, 93 insertions(+), 29 deletions(-)

diff --git a/hysop/core/graph/graph_builder.py b/hysop/core/graph/graph_builder.py
index 927370e90..8675df701 100644
--- a/hysop/core/graph/graph_builder.py
+++ b/hysop/core/graph/graph_builder.py
@@ -377,7 +377,7 @@ class GraphBuilder:
                             (
                                 " on an unknown topology"
                                 if (otopo is None)
-                                else f".{otopo.pretty_tag}"
+                                else f".{otopo.pretty_tag}  t{otopo.task_id}"
                             ),
                         )
                     )
@@ -452,9 +452,11 @@ class GraphBuilder:
             needed_names = {
                 _ if not hasattr(_, "name") else _.name for _ in needed_elems.keys()
             } - self._intertasks_exchanged
-            mgs = "  >[IT] Current task {} parameters and fields : {}"
-            gprint(mgs.format("can communicate", ", ".join(available_names)))
-            gprint(mgs.format("needs", ", ".join(needed_names)))
+            mgs = "  >[IT] Current task ({}) {} parameters and fields : {}"
+            gprint(
+                mgs.format(current_task, "can communicate", ", ".join(available_names))
+            )
+            gprint(mgs.format(current_task, "needs", ", ".join(needed_names)))
 
             # Inter-task matching is performed on root process
             available_names = {_: None for _ in available_names}  # value is dest task
@@ -464,7 +466,16 @@ class GraphBuilder:
                 # loop over other tasks
                 for ot in (_ for _ in domain.all_tasks if _ != current_task):
                     if domain.task_root_in_parent(ot) == domain.parent_rank:
-                        ot_needs = needed_names.keys()
+                        ot_needs = []
+                        for _n in needed_names.keys():
+                            _ntopo = needed_elems[_name_to_key(_n, needed_elems)]
+                            if (
+                                hasattr(_ntopo, "task_id")
+                                and _ntopo.task_id == current_task
+                            ):
+                                continue
+                            else:
+                                ot_needs.append(_n)
                         ot_provide = can_provide = [
                             _ for _ in ot_needs if _ in available_names
                         ]
@@ -520,6 +531,17 @@ class GraphBuilder:
             }
             needed_names = tcomm.bcast(needed_names, root=0)
             available_names = tcomm.bcast(available_names, root=0)
+            if domain.task_rank() != 0:
+                needed_names = dict(
+                    (k, v)
+                    for k, v in needed_names.items()
+                    if not v in domain.current_task_list()
+                )
+                available_names = dict(
+                    (k, v)
+                    for k, v in available_names.items()
+                    if not v in domain.current_task_list()
+                )
             gprint(
                 f"  >[IT] Inter-tasks will send:to {available_names} and recieve:from {needed_names}"
             )
@@ -619,7 +641,7 @@ class GraphBuilder:
                                 it_redistribute_kwargs[_]
                                 for _ in ("source_topo", "target_topo", "other_task_id")
                             )
-                        )
+                        ), str(it_redistribute_kwargs)
                         if op.fake_init:
                             op.__init__(**it_redistribute_kwargs)
                             # Recompute fields requirements since no fields were given in first fake operator creation
@@ -630,11 +652,12 @@ class GraphBuilder:
                                 target_node.nodes.index(first_op), op
                             )
                             gprint(
-                                "\n >Handling node {}::{}: {} {}".format(
+                                "\n >Handling node {}::{}: {} {} :: {}".format(
                                     self.target_node.name,
                                     node_id,
                                     op.name,
                                     op.__class__,
+                                    it_redistribute_kwargs,
                                 )
                             )
                             subgraph, node_ops, node_vertices, from_subgraph = (
@@ -1389,6 +1412,7 @@ class GraphBuilder:
 
             is_root = target_dfield_requirements is not None
             dim = target_topo.domain.dim
+            tid = target_topo.task_id
 
             check_instance(
                 target_dfield_requirements, DiscreteFieldRequirements, allow_none=True
@@ -1423,7 +1447,7 @@ class GraphBuilder:
                         # adapt to this first operator
                         assert target_topo not in dtopology_states
                         istate = dtopology_states.setdefault(
-                            target_topo, CartesianTopologyState(dim)
+                            target_topo, CartesianTopologyState(dim, tid)
                         )
                         if target_dfield_requirements:
                             allowed_axes = target_dfield_requirements.axes
@@ -1447,7 +1471,7 @@ class GraphBuilder:
                             gprint2(f"       >Initial state set to {istate}")
                     else:
                         istate = dtopology_states.setdefault(
-                            target_topo, CartesianTopologyState(dim)
+                            target_topo, CartesianTopologyState(dim, tid)
                         )
                         gprint2(f"       >Input state is {istate}")
 
diff --git a/hysop/core/mpi/redistribute.py b/hysop/core/mpi/redistribute.py
index 81819333e..38e15aea3 100644
--- a/hysop/core/mpi/redistribute.py
+++ b/hysop/core/mpi/redistribute.py
@@ -493,14 +493,17 @@ class RedistributeInter(RedistributeOperatorBase):
         ) == set(self.input_fields.keys())
 
         if input_topology_states:
-            (ref_field, ref_state) = next(iter(input_topology_states.items()))
+            ref_field, _ = next(iter(input_topology_states.items()))
             ref_topo = self.input_fields[ref_field]
-        else:
-            ref_state = self.output_fields[output_field].topology_state
+        ref_state = self.output_fields[output_field].topology_state
 
         for ifield, istate in input_topology_states.items():
             itopo = self.input_fields[ifield]
-            if not istate.match(ref_state):
+            if not (
+                istate.dim == ref_state.dim
+                and istate.axes == ref_state.axes
+                and istate.memory_order == ref_state.memory_order
+            ):
                 msg = "\nInput topology state for field {} defined on topology {} does "
                 msg += "not match reference input topology state {} defined on topology {} "
                 msg += "for operator {}.\n"
@@ -742,7 +745,7 @@ class RedistributeInter(RedistributeOperatorBase):
             if not ifield.backend.kind is Backend.HOST:
                 self._need_copy_before = True
                 self._dFin_data = ifield.backend.host_array_backend.empty_like(
-                    ifield
+                    ifield.buffers[0]
                 ).handle
                 self._dFin_data[...] = 0.0
             else:
@@ -751,7 +754,7 @@ class RedistributeInter(RedistributeOperatorBase):
             if not ofield.backend.kind is Backend.HOST:
                 self._need_copy_after = True
                 self._dFout_data = ofield.backend.host_array_backend.empty_like(
-                    ofield
+                    ofield.buffers[0]
                 ).handle
                 self._dFout_data[...] = 0.0
             else:
@@ -789,8 +792,8 @@ class RedistributeInter(RedistributeOperatorBase):
                     _memcpy(
                         self.dFout.sdata,
                         self._dFout_data,
-                        target_indices=indices[rk],
-                        source_indices=indices[rk],
+                        target_indices=indices[self._target_id][rk],
+                        source_indices=indices[self._target_id][rk],
                         skind=Backend.HOST,
                         tkind=Backend.OPENCL,
                     )
diff --git a/hysop/domain/box.py b/hysop/domain/box.py
index daa930989..13acd2330 100644
--- a/hysop/domain/box.py
+++ b/hysop/domain/box.py
@@ -22,7 +22,7 @@
 import warnings
 import numpy as np
 
-from hysop.constants import BoxBoundaryCondition, HYSOP_REAL
+from hysop.constants import BoxBoundaryCondition, HYSOP_REAL, HYSOP_DEFAULT_TASK_ID
 from hysop.domain.domain import Domain, DomainView
 from hysop.tools.decorators import debug
 from hysop.tools.numpywrappers import npw
@@ -287,7 +287,7 @@ class Box(BoxView, Domain):
 
         npw.set_readonly(length, origin, lboundaries, rboundaries)
 
-        topology_state = CartesianTopologyState(dim)
+        topology_state = CartesianTopologyState(dim, HYSOP_DEFAULT_TASK_ID)
 
         obj = super().__new__(
             cls,
diff --git a/hysop/domain/domain.py b/hysop/domain/domain.py
index 5ea6f9459..b088ae323 100644
--- a/hysop/domain/domain.py
+++ b/hysop/domain/domain.py
@@ -143,6 +143,17 @@ class DomainView(TaggedObjectView, metaclass=ABCMeta):
         except IndexError:
             return t
 
+    def current_task_list(self):
+        """Get task number of the current mpi process.
+        Return always a tuple ot taks id"""
+        t = self.task_on_proc(self._domain._parent_rank)
+        if isinstance(t, list) or isinstance(t, tuple):
+            return t
+        else:
+            return [
+                t,
+            ]
+
     def get_task_comm(self, task_id=None):
         """
         Return the communicator that owns the current process.
diff --git a/hysop/fields/cartesian_discrete_field.py b/hysop/fields/cartesian_discrete_field.py
index f56992e69..bdec567ac 100644
--- a/hysop/fields/cartesian_discrete_field.py
+++ b/hysop/fields/cartesian_discrete_field.py
@@ -1687,7 +1687,9 @@ class CartesianDiscreteScalarField(
         msg = "Multi-component fields have been deprecated (see DiscreteTensorField)."
         assert field.nb_components == 1, msg
 
-        init_state = init_topology_state or CartesianTopologyState(field.dim)
+        init_state = init_topology_state or CartesianTopologyState(
+            field.dim, topology.mpi_params.task_id
+        )
         obj = super().__new__(
             cls,
             field=field,
diff --git a/hysop/topology/cartesian_topology.py b/hysop/topology/cartesian_topology.py
index 624e95696..0db36b7b2 100644
--- a/hysop/topology/cartesian_topology.py
+++ b/hysop/topology/cartesian_topology.py
@@ -59,14 +59,18 @@ class CartesianTopologyState(TopologyState):
     contained in the linked Cartesian topology.
     """
 
-    __slots__ = ("_is_read_only", "_dim", "_axes", "_memory_order")
+    __slots__ = ("_is_read_only", "_dim", "_axes", "_memory_order", "_task_id")
 
     @debug
-    def __new__(cls, dim, axes=None, memory_order=None, is_read_only=False, **kwds):
+    def __new__(
+        cls, dim, task_id, axes=None, memory_order=None, is_read_only=False, **kwds
+    ):
         return super().__new__(cls, is_read_only=is_read_only, **kwds)
 
     @debug
-    def __init__(self, dim, axes=None, memory_order=None, is_read_only=False, **kwds):
+    def __init__(
+        self, dim, task_id, axes=None, memory_order=None, is_read_only=False, **kwds
+    ):
         """
         Initialize a CartesianState to given parameters.
 
@@ -86,6 +90,7 @@ class CartesianTopologyState(TopologyState):
         """
         super().__init__(is_read_only=is_read_only, **kwds)
         self._dim = int(dim)
+        self._task_id = int(task_id)
         self._set_axes(axes)
         self._set_memory_order(memory_order)
 
@@ -122,6 +127,10 @@ class CartesianTopologyState(TopologyState):
         """Return the dimension of the underlying topology domain."""
         return self._dim
 
+    def _get_task_id(self):
+        """Return the task identifier of the underlying topology domain."""
+        return self._task_id
+
     def _get_tstate(self):
         """Return the TranspositionState corresponding to current permutation axes."""
         return TranspositionState.axes_to_tstate(self._axes)
@@ -161,6 +170,7 @@ class CartesianTopologyState(TopologyState):
         axes = first_not_None(axes, self.axes)
         return CartesianTopologyState(
             dim=self.dim,
+            task_id=self.task_id,
             axes=axes,
             memory_order=memory_order,
             is_read_only=is_read_only,
@@ -168,18 +178,20 @@ class CartesianTopologyState(TopologyState):
 
     def short_description(self):
         """Return a short description of this CartesianTopologyState."""
-        s = "{}[order={}, axes=({}), ro={}]"
+        s = "{}[order={}, axes=({}), ro={}, task={}]"
         return s.format(
             self.full_tag,
             self.memory_order,
             ",".join(str(a) for a in self.axes),
             "1" if self.is_read_only else "0",
+            self.task_id,
         )
 
     def long_description(self):
         """Return a long description of this CartesianTopologyState."""
         s = """{}
                *dim:       {}
+               *task_id:   {}
                *order:     {}
                *axes:      ({})
                *tstate:    {}
@@ -188,6 +200,7 @@ class CartesianTopologyState(TopologyState):
         return s.format(
             self.full_tag,
             self.dim,
+            self.task_id,
             self.memory_order,
             ",".join([str(a) for a in self.axes]),
             self.tstate,
@@ -200,15 +213,23 @@ class CartesianTopologyState(TopologyState):
             return NotImplemented
         match = super().match(other, invert=False)
         match &= self._dim == other._dim
+        match &= self._task_id == other._task_id
         match &= self._axes == other._axes
         match &= self._memory_order == other._memory_order
         return not match if invert else match
 
     def __hash__(self):
         h = super().__hash__()
-        return h ^ hash(self._dim) ^ hash(self._axes) ^ hash(self._memory_order)
+        return (
+            h
+            ^ hash(self._dim)
+            ^ hash(self._task_id)
+            ^ hash(self._axes)
+            ^ hash(self._memory_order)
+        )
 
     dim = property(_get_dim)
+    task_id = property(_get_task_id)
     tstate = property(_get_tstate)
     axes = property(_get_axes, _set_axes)
     memory_order = property(_get_memory_order, _set_memory_order)
@@ -431,7 +452,9 @@ class CartesianTopologyView(TopologyView):
 
     def default_state(self):
         """Return the default topology state of this topology."""
-        return CartesianTopologyState(dim=self.domain.dim)
+        return CartesianTopologyState(
+            dim=self.domain.dim, task_id=self.mpi_params.task_id
+        )
 
     def short_description(self):
         """
@@ -722,7 +745,9 @@ class CartesianTopology(CartesianTopologyView, Topology):
 
         npw.set_readonly(proc_shape, is_periodic, is_distributed)
 
-        topology_state = CartesianTopologyState(dim=domain.dim)
+        topology_state = CartesianTopologyState(
+            dim=domain.dim, task_id=mpi_params.task_id
+        )
 
         obj = super().__new__(
             cls,
diff --git a/hysop_examples/examples/tasks/tasks.py b/hysop_examples/examples/tasks/tasks.py
index 35fd14189..c6fa055f3 100755
--- a/hysop_examples/examples/tasks/tasks.py
+++ b/hysop_examples/examples/tasks/tasks.py
@@ -72,7 +72,6 @@ def compute(args):
             mpi_params[tk] = MPIParams(
                 comm=box.get_task_comm(tk), task_id=tk, on_task=box.is_on_task(tk)
             )
-            # print("\n".join([" ** {}: {}".format(a, b) for a, b in mpi_params.items()]))
     else:
         mpi_params = MPIParams(comm=box.task_comm, task_id=HYSOP_DEFAULT_TASK_ID)
     impl = args.impl
@@ -252,7 +251,7 @@ if __name__ == "__main__":
                 container=tuple,
                 append=False,
                 dest="proc_tasks",
-                help="Domain proc_task parameter.",
+                help="Specify the tasks for each proc.",
             )
 
         def _check_main_args(self, args):
-- 
GitLab