From 9e06cef16d51dd07cbe6b555c624d3d0b91ededc Mon Sep 17 00:00:00 2001
From: Jean-Matthieu Etancelin <jean-matthieu.etancelin@univ-pau.fr>
Date: Thu, 19 Sep 2024 16:49:25 +0200
Subject: [PATCH] fixup tasks

---
 hysop/core/graph/graph_builder.py | 161 ++++++++++++++++--------------
 hysop/domain/domain.py            |   2 +-
 2 files changed, 86 insertions(+), 77 deletions(-)

diff --git a/hysop/core/graph/graph_builder.py b/hysop/core/graph/graph_builder.py
index 8675df701..accc7b19f 100644
--- a/hysop/core/graph/graph_builder.py
+++ b/hysop/core/graph/graph_builder.py
@@ -435,8 +435,7 @@ class GraphBuilder:
                 ]
             )
             comm = domain.parent_comm
-            tcomm = domain.task_comm
-            current_task = domain.current_task()
+            current_tasks = domain.current_task_list()
 
             def _name_to_key(n, d):
                 var = [_ for _ in d.keys() if isinstance(_, str) and _ == n]
@@ -454,94 +453,104 @@ class GraphBuilder:
             } - self._intertasks_exchanged
             mgs = "  >[IT] Current task ({}) {} parameters and fields : {}"
             gprint(
-                mgs.format(current_task, "can communicate", ", ".join(available_names))
+                mgs.format(current_tasks, "can communicate", ", ".join(available_names))
             )
-            gprint(mgs.format(current_task, "needs", ", ".join(needed_names)))
+            gprint(mgs.format(current_tasks, "needs", ", ".join(needed_names)))
 
             # Inter-task matching is performed on root process
             available_names = {_: None for _ in available_names}  # value is dest task
             needed_names = {_: None for _ in needed_names}  # value is src task
-            if domain.task_rank() == 0:
-                msg = ""
-                # loop over other tasks
-                for ot in (_ for _ in domain.all_tasks if _ != current_task):
-                    if domain.task_root_in_parent(ot) == domain.parent_rank:
-                        ot_needs = []
-                        for _n in needed_names.keys():
-                            _ntopo = needed_elems[_name_to_key(_n, needed_elems)]
-                            if (
-                                hasattr(_ntopo, "task_id")
-                                and _ntopo.task_id == current_task
-                            ):
-                                continue
-                            else:
-                                ot_needs.append(_n)
-                        ot_provide = can_provide = [
-                            _ for _ in ot_needs if _ in available_names
-                        ]
-                        for prov in can_provide:
-                            available_names[prov] = needed_elems[
-                                _name_to_key(prov, needed_elems)
-                            ].task_id
-                        for _op in ot_provide:
-                            needed_names[_op] = available_elems[
-                                _name_to_key(_op, available_elems)
-                            ].task_id
-                    else:
-                        comm.isend(
-                            list(needed_names.keys()),
-                            dest=domain.task_root_in_parent(ot),
-                            tag=4321,
-                        )
-                        ot_needs = comm.recv(
-                            source=domain.task_root_in_parent(ot), tag=4321
-                        )
-                        can_provide = [_ for _ in ot_needs if _ in available_names]
-                        for prov in can_provide:
-                            available_names[prov] = ot
-                            assert (
-                                ot
-                                != available_elems[
-                                    _name_to_key(prov, available_elems)
+
+            for current_task in current_tasks:
+                if domain.task_root_in_parent(current_task) == domain.parent_rank:
+                    msg = ""
+                    # loop over other tasks
+                    for ot in (_ for _ in domain.all_tasks if _ != current_task):
+                        if domain.task_root_in_parent(ot) == domain.parent_rank:
+                            ot_needs = []
+                            for _n in needed_names.keys():
+                                _ntopo = needed_elems[_name_to_key(_n, needed_elems)]
+                                if (
+                                    hasattr(_ntopo, "task_id")
+                                    and _ntopo.task_id == current_task
+                                ):
+                                    continue
+                                else:
+                                    ot_needs.append(_n)
+                            ot_provide = can_provide = [
+                                _ for _ in ot_needs if _ in available_names
+                            ]
+                            for prov in can_provide:
+                                available_names[prov] = needed_elems[
+                                    _name_to_key(prov, needed_elems)
+                                ].task_id
+                            for _op in ot_provide:
+                                needed_names[_op] = available_elems[
+                                    _name_to_key(_op, available_elems)
                                 ].task_id
+                        else:
+                            comm.isend(
+                                list(needed_names.keys()),
+                                dest=domain.task_root_in_parent(ot),
+                                tag=4321,
                             )
-                        comm.isend(
-                            can_provide, dest=domain.task_root_in_parent(ot), tag=1234
-                        )
-                        ot_provide = comm.recv(
-                            source=domain.task_root_in_parent(ot), tag=1234
-                        )
-                        for _op in ot_provide:
-                            needed_names[_op] = ot
-                            assert (
-                                ot
-                                != needed_elems[_name_to_key(_op, needed_elems)].task_id
+                            ot_needs = comm.recv(
+                                source=domain.task_root_in_parent(ot), tag=4321
                             )
-                    if len(ot_needs) > 0:
-                        msg += "\n   *Other task {} needs init for {}, we provide {}".format(
-                            ot,
-                            ot_needs,
-                            "nothing" if len(can_provide) == 0 else can_provide,
-                        )
-                if msg != "":
-                    gprint("  >[IT] Inter-tasks matching:" + msg)
+                            can_provide = [_ for _ in ot_needs if _ in available_names]
+                            for prov in can_provide:
+                                available_names[prov] = ot
+                                assert (
+                                    ot
+                                    != available_elems[
+                                        _name_to_key(prov, available_elems)
+                                    ].task_id
+                                )
+                            comm.isend(
+                                can_provide,
+                                dest=domain.task_root_in_parent(ot),
+                                tag=1234,
+                            )
+                            ot_provide = comm.recv(
+                                source=domain.task_root_in_parent(ot), tag=1234
+                            )
+                            for _op in ot_provide:
+                                needed_names[_op] = ot
+                                assert (
+                                    ot
+                                    != needed_elems[
+                                        _name_to_key(_op, needed_elems)
+                                    ].task_id
+                                )
+                        if len(ot_needs) > 0:
+                            msg += "\n   *Other task {} needs init for {}, we provide {}".format(
+                                ot,
+                                ot_needs,
+                                "nothing" if len(can_provide) == 0 else can_provide,
+                            )
+                    if msg != "":
+                        gprint("  >[IT] Inter-tasks matching:" + msg)
             needed_names = {p: t for (p, t) in needed_names.items() if t is not None}
             available_names = {
                 p: t for (p, t) in available_names.items() if t is not None
             }
-            needed_names = tcomm.bcast(needed_names, root=0)
-            available_names = tcomm.bcast(available_names, root=0)
-            if domain.task_rank() != 0:
-                needed_names = dict(
-                    (k, v)
-                    for k, v in needed_names.items()
-                    if not v in domain.current_task_list()
+
+            for current_task in current_tasks:
+                tcomm = domain.get_task_comm(current_task)
+                needed_names = tcomm.bcast(needed_names, root=0)
+                available_names = tcomm.bcast(available_names, root=0)
+
+            final_needed_names, final_available_names = {}, {}
+            for current_task in current_tasks:
+                _tmp_needed = dict(
+                    (k, v) for k, v in needed_names.items() if v != current_task
                 )
-                available_names = dict(
-                    (k, v)
-                    for k, v in available_names.items()
-                    if not v in domain.current_task_list()
+                _tmp_avail = dict(
+                    (k, v) for k, v in available_names.items() if v != current_task
                 )
+                final_needed_names.update(_tmp_needed)
+                final_available_names.update(_tmp_avail)
+            needed_names, available_names = final_needed_names, final_available_names
             gprint(
                 f"  >[IT] Inter-tasks will send:to {available_names} and recieve:from {needed_names}"
             )
diff --git a/hysop/domain/domain.py b/hysop/domain/domain.py
index b088ae323..808d3d865 100644
--- a/hysop/domain/domain.py
+++ b/hysop/domain/domain.py
@@ -147,7 +147,7 @@ class DomainView(TaggedObjectView, metaclass=ABCMeta):
         """Get task number of the current mpi process.
         Return always a tuple ot taks id"""
         t = self.task_on_proc(self._domain._parent_rank)
-        if isinstance(t, list) or isinstance(t, tuple):
+        if isinstance(t, list) or isinstance(t, tuple) or isinstance(t, np.ndarray):
             return t
         else:
             return [
-- 
GitLab