From 39048ab56acd3aeccf38f6032c07a753f6c1dfa3 Mon Sep 17 00:00:00 2001
From: Jean-Matthieu Etancelin <jean-matthieu.etancelin@univ-pau.fr>
Date: Fri, 19 Jun 2020 11:09:39 +0200
Subject: [PATCH] Automatic disjoint tasks redistribute inserts.

---
 hysop/core/graph/computational_graph.py       |   6 +-
 .../core/graph/computational_node_frontend.py |   4 +-
 hysop/core/graph/graph_builder.py             | 726 +++++++++++-------
 hysop/core/graph/node_generator.py            |  33 +
 hysop/core/mpi/redistribute.py                |  16 +-
 5 files changed, 481 insertions(+), 304 deletions(-)

diff --git a/hysop/core/graph/computational_graph.py b/hysop/core/graph/computational_graph.py
index 759118f99..01b6e53e1 100644
--- a/hysop/core/graph/computational_graph.py
+++ b/hysop/core/graph/computational_graph.py
@@ -10,7 +10,7 @@ from hysop.core.graph.graph import not_implemented, initialized, discretized, \
 from hysop.core.graph.graph import ComputationalGraphNodeData
 from hysop.core.graph.computational_node import ComputationalGraphNode
 from hysop.core.graph.computational_operator import ComputationalGraphOperator
-from hysop.core.graph.node_generator import ComputationalGraphNodeGenerator
+from hysop.core.graph.node_generator import ComputationalGraphNodeGenerator, HiddenOperator
 from hysop.core.graph.node_requirements import NodeRequirements, OperatorRequirements
 from hysop.core.memory.memory_request import MultipleOperatorMemoryRequests
 from hysop.fields.field_requirements import MultiFieldRequirements
@@ -622,6 +622,10 @@ class ComputationalGraph(ComputationalGraphNode):
                 nodes = node.generate()
                 assert (nodes is not None), node
                 self.push_nodes(*nodes)
+            elif isinstance(node, HiddenOperator):
+                # Fuse al HiddenOperators
+                if len(self.nodes) == 0 or not isinstance(self.nodes[-1], HiddenOperator):
+                    self.nodes.append(node)
             else:
                 msg = 'Given node is not an instance of ComputationalGraphNode (got a {}).'
                 raise ValueError(msg.format(node.__class__))
diff --git a/hysop/core/graph/computational_node_frontend.py b/hysop/core/graph/computational_node_frontend.py
index cfa0fea67..55bc7d4f0 100644
--- a/hysop/core/graph/computational_node_frontend.py
+++ b/hysop/core/graph/computational_node_frontend.py
@@ -1,7 +1,7 @@
 from hysop.constants import Implementation, Backend, implementation_to_backend
 from hysop.tools.decorators import debug
 from hysop.tools.types import check_instance, first_not_None
-from hysop.core.graph.node_generator import ComputationalGraphNodeGenerator
+from hysop.core.graph.node_generator import ComputationalGraphNodeGenerator, HiddenOperator
 from hysop.core.graph.graph import not_implemented
 
 from hysop.fields.continuous_field import Field
@@ -102,7 +102,7 @@ class ComputationalGraphNodeFrontend(ComputationalGraphNodeGenerator):
             # Skip not on-task operators very early
             if 'mpi_params' in self.impl_kwds.keys():
                 if not self.impl_kwds['mpi_params'].on_task:
-                    return tuple()
+                    return (HiddenOperator(mpi_params=self.impl_kwds['mpi_params']), )
             op = self.impl(**self.impl_kwds)
         except:
             sargs = ['*{} = {}'.format(k, v.__class__)
diff --git a/hysop/core/graph/graph_builder.py b/hysop/core/graph/graph_builder.py
index cd797aed5..baa9113d2 100644
--- a/hysop/core/graph/graph_builder.py
+++ b/hysop/core/graph/graph_builder.py
@@ -1,6 +1,6 @@
 from hysop import vprint, dprint, Problem
 from hysop.deps import np, __builtin__, print_function
-from hysop.tools.types import check_instance
+from hysop.tools.types import check_instance, first_not_None
 from hysop.tools.io_utils import IOParams
 
 from hysop.tools.transposition_states import TranspositionState
@@ -15,6 +15,7 @@ from hysop.core.graph.graph import (new_directed_graph, new_vertex, new_edge,
 from hysop.core.graph.computational_graph import ComputationalGraph
 from hysop.core.graph.computational_node import ComputationalGraphNode
 from hysop.core.graph.computational_operator import ComputationalGraphOperator
+from hysop.core.graph.node_generator import HiddenOperator
 
 from hysop.fields.field_requirements import (DiscreteFieldRequirements,
                                              MultiFieldRequirements)
@@ -151,7 +152,175 @@ class GraphBuilder(object):
             msg = msg.format(target_node.name)
             raise RuntimeError(msg)
 
-        def __handle_node(node_id, node):
+        def __handle_node(node_id, node, subgraph, node_ops, node_vertices, from_subgraph, opvertex, op, opnode):
+            gprint('  *{} ({})'.format(op.name, type(op)))
+            opname = op.name
+            oppname = op.pretty_name
+            iparams = op.input_params
+            oparams = op.output_params
+            ifields = op.input_fields
+            ofields = op.output_fields
+            field_requirements = op._field_requirements
+            if field_requirements is None:
+                op.get_and_set_field_requirements()
+                field_requirements = op._field_requirements
+
+            if not isinstance(op, Problem) and not isinstance(op, RedistributeInter):
+                # try to fill in undertermined topologies (experimental feature)
+                backends = op.supported_backends()
+                for (ifield, itopo) in sorted(ifields.iteritems(),
+                                              key=lambda x: x[0].name):
+                    if (itopo is not None):
+                        continue
+                    # look for ifield usage untill now
+                    if ((ifield in ofields) and (ofields[ifield] is not None)
+                            and (ofields[ifield].backend.kind in backends)):
+                        ifields[ifield] = ofields[ifield]
+                    elif (ifield not in self.topology_states):
+                        if outputs_are_inputs:
+                            # we can try to push this operator after we're done
+                            deferred_operators.append((op, opnode))
+                        else:
+                            msg = ('\nGraphBuilder {} could not automatically '
+                                   'determine the topology of input field {} in '
+                                   'operator {}.\nTry to set a non empty '
+                                   'TopologyDescriptor when passing the variable '
+                                   'parameters, when creating the operator.'
+                                   '\nAutomatic topology detection is an '
+                                   'experimental feature.')
+                            msg = msg.format(target_node.name, ifield.name, op.name)
+                            raise RuntimeError(msg)
+                    else:
+                        cstate = self.topology_states[ifield]
+                        (itopo, dstate, node, ireqs) = cstate.first_topology_and_dstate
+                        field_requirements.update_inputs({ifield: ireqs})
+                        if (itopo.backend.kind not in backends):
+                            backend = itopo.backend.any_backend_from_kind(*backends)
+                            itopo = itopo.topology_like(backend=backend)
+                        ifields[ifield] = itopo
+                for (ofield, otopo) in sorted(ofields.iteritems(),
+                                              key=lambda x: x[0].name):
+                    if (otopo is not None):
+                        continue
+                    if (ofield in ifields) and (ifields[ofield] is not None):
+                        ofields[ofield] = ifields[ofield]
+                    elif (ofield not in self.topology_states):
+                        msg = ('\nGraphBuilder {} could not automatically determine '
+                               'the topology of input field {} in operator {}.'
+                               '\nTry to set a non empty TopologyDescriptor when '
+                               'passing the variable parameters, when creating the '
+                               'operator.\nAutomatic topology detection is an '
+                               'experimental feature.')
+                        msg = msg.format(target_node.name, ofield.name, op.name)
+                        raise RuntimeError(msg)
+                    else:
+                        cstate = self.topology_states[ofield]
+                        (otopo, dstate, node, oreqs) = cstate.first_topology_and_dstate
+                        field_requirements.update_outputs({ofield: oreqs})
+                        ofields[ofield] = otopo
+
+            # iterate over subgraph operator input parameters
+            if iparams:
+                gprint('   >Input parameters')
+                for iparam in sorted(iparams.values(), key=lambda x: x.name):
+                    gprint('     *{}'.format(iparam.short_description()))
+                    parameter_handler.handle_input_parameter(iparam, opnode)
+                    if (iparam.name not in output_params):
+                        input_params[iparam.name] = iparam
+
+            # iterate over subgraph operator output parameters
+            if oparams:
+                gprint('   >Output parameters')
+                for oparam in sorted(oparams.values(), key=lambda x: x.name):
+                    gprint('     *{}'.format(oparam.short_description()))
+                    parameter_handler.handle_output_parameter(oparam, opnode)
+                    output_params[oparam.name] = oparam
+
+            # iterate over subgraph operator input fields
+            input_states = {}
+            if ifields:
+                gprint('   >Input fields')
+                for (ifield, itopo) in sorted(ifields.iteritems(),
+                                              key=lambda x: x[0].name, reverse=True):
+                    gprint('     *{}{}'.format(ifield.name, ' on an unknown topology'
+                                               if (itopo is None) else '.{}'.format(itopo.pretty_tag)))
+                    if (itopo is None):
+                        assert isinstance(op, RedistributeInter)
+                        continue
+                    if isinstance(op, Problem):
+                        if ifield in op.initial_input_topology_states.keys():
+                            ifreqs = op.initial_input_topology_states[ifield][0]
+                        else:
+                            ifreqs = None
+                    else:
+                        if (current_level != 0 or isinstance(op, Problem)):
+                            ifreqs = None
+                        else:
+                            ifreqs = \
+                                field_requirements.get_input_requirement(ifield)[1]
+                    if (ifield not in self.topology_states):
+                        cstate = self.new_topology_state(ifield)
+                        self.topology_states[ifield] = cstate
+                        is_new = True
+                    else:
+                        cstate = self.topology_states[ifield]
+                        is_new = False
+                    dstate = cstate.handle_input(opnode, itopo, ifreqs,
+                                                 graph, is_new)
+                    input_states[ifield] = dstate
+                    if is_new:
+                        input_fields[ifield] = itopo
+                        input_topology_states[ifield] = (ifreqs, dstate)
+
+                    if ifield not in double_check_inputs:
+                        double_check_inputs[ifield] = {}
+                    double_check_inputs[ifield].update({itopo: (ifreqs, dstate)})
+
+            # iterate over subgraph operator output fields
+            output_states = {}
+            if ofields:
+                gprint('   >Output fields')
+                for (ofield, otopo) in sorted(ofields.iteritems(),
+                                              key=lambda x: x[0].name, reverse=True):
+                    gprint('     *{}{}'.format(ofield.name, ' on an unknown topology'
+                                               if (otopo is None) else '.{}'.format(otopo.pretty_tag)))
+                    if (otopo is None):
+                        assert isinstance(op, RedistributeInter)
+                        continue
+                    if isinstance(op, Problem):
+                        if ofield in op.final_output_topology_states.keys():
+                            ofreqs = op.final_output_topology_states[ofield][0]
+                        else:
+                            ofreqs = None
+                    else:
+                        ofreqs = None if (current_level != 0) \
+                            else field_requirements.get_output_requirement(ofield)[1]
+                    istates = None if (current_level != 0) else input_states
+                    cstate = self.topology_states.setdefault(ofield,
+                                                             self.new_topology_state(ofield))
+                    invalidate_field = (ofield not in
+                                        op.get_preserved_input_fields())
+                    dstate = cstate.handle_output(opnode, otopo, ofreqs,
+                                                  op, istates, invalidate_field, graph,
+                                                  node_list=target_node.nodes)
+                    output_fields[ofield] = otopo
+                    output_states[ofield] = dstate
+                    output_topology_states[ofield] = (None, dstate)
+
+            if (current_level == 0) and ((op, opnode) not in deferred_operators):
+                opnode.set_op_info(op, input_states, output_states)
+
+            op_input_topology_states[op] = input_states
+            op_output_topology_states[op] = output_states
+
+        for (node_id, node) in enumerate(target_node.nodes):
+            if isinstance(node, HiddenOperator):
+                target_node.nodes[node_id] = RedistributeInter()
+
+        redistribute_inter = []
+        intertasks_exchanged = set()
+        # iterate over ComputationalNodes
+        for (node_id, node) in enumerate(target_node.nodes):
             gprint(' >Handling node {}: {} {}'.format(
                 node_id, node.name, node.__class__))
 
@@ -162,175 +331,230 @@ class GraphBuilder(object):
             # current node operator.
             subgraph, node_ops, node_vertices, from_subgraph = \
                 self.build_subgraph(node, current_level)
-
             # iterate over subgraph operators
             for (opvertex, op) in zip(node_vertices, node_ops):
-                gprint('  *{} ({})'.format(op.name, type(op)))
-                opname = op.name
-                oppname = op.pretty_name
-                iparams = op.input_params
-                oparams = op.output_params
-                ifields = op.input_fields
-                ofields = op.output_fields
-                field_requirements = op._field_requirements
-                if field_requirements is None:
-                    op.get_and_set_field_requirements()
-                    field_requirements = op._field_requirements
-
                 # add operator node and fill vertex properties
-                opnode = self.new_node(op, subgraph,
-                                       current_level, node, node_id, opvertex)
-
-                if not isinstance(op, Problem) and not isinstance(op, RedistributeInter):
-                    # try to fill in undertermined topologies (experimental feature)
-                    backends = op.supported_backends()
-                    for (ifield, itopo) in sorted(ifields.iteritems(),
-                                                  key=lambda x: x[0].name):
-                        if (itopo is not None):
-                            continue
-                        # look for ifield usage untill now
-                        if ((ifield in ofields) and (ofields[ifield] is not None)
-                                and (ofields[ifield].backend.kind in backends)):
-                            ifields[ifield] = ofields[ifield]
-                        elif (ifield not in self.topology_states):
-                            if outputs_are_inputs:
-                                # we can try to push this operator after we're done
-                                deferred_operators.append((op, opnode))
-                            else:
-                                msg = ('\nGraphBuilder {} could not automatically '
-                                       'determine the topology of input field {} in '
-                                       'operator {}.\nTry to set a non empty '
-                                       'TopologyDescriptor when passing the variable '
-                                       'parameters, when creating the operator.'
-                                       '\nAutomatic topology detection is an '
-                                       'experimental feature.')
-                                msg = msg.format(target_node.name, ifield.name, op.name)
-                                raise RuntimeError(msg)
-                        else:
-                            cstate = self.topology_states[ifield]
-                            (itopo, dstate, node, ireqs) = cstate.first_topology_and_dstate
-                            field_requirements.update_inputs({ifield: ireqs})
-                            if (itopo.backend.kind not in backends):
-                                backend = itopo.backend.any_backend_from_kind(*backends)
-                                itopo = itopo.topology_like(backend=backend)
-                            ifields[ifield] = itopo
-                    for (ofield, otopo) in sorted(ofields.iteritems(),
-                                                  key=lambda x: x[0].name):
-                        if (otopo is not None):
-                            continue
-                        if (ofield in ifields) and (ifields[ofield] is not None):
-                            ofields[ofield] = ifields[ofield]
-                        elif (ofield not in self.topology_states):
-                            msg = ('\nGraphBuilder {} could not automatically determine '
-                                   'the topology of input field {} in operator {}.'
-                                   '\nTry to set a non empty TopologyDescriptor when '
-                                   'passing the variable parameters, when creating the '
-                                   'operator.\nAutomatic topology detection is an '
-                                   'experimental feature.')
-                            msg = msg.format(target_node.name, ofield.name, op.name)
-                            raise RuntimeError(msg)
-                        else:
-                            cstate = self.topology_states[ofield]
-                            (otopo, dstate, node, oreqs) = cstate.first_topology_and_dstate
-                            field_requirements.update_outputs({ofield: oreqs})
-                            ofields[ofield] = otopo
-
-                # iterate over subgraph operator input parameters
-                if iparams:
-                    gprint('   >Input parameters')
-                    for iparam in sorted(iparams.values(), key=lambda x: x.name):
-                        gprint('     *{}'.format(iparam.short_description()))
-                        parameter_handler.handle_input_parameter(iparam, opnode)
-                        if (iparam.name not in output_params):
-                            input_params[iparam.name] = iparam
-
-                # iterate over subgraph operator output parameters
-                if oparams:
-                    gprint('   >Output parameters')
-                    for oparam in sorted(oparams.values(), key=lambda x: x.name):
-                        gprint('     *{}'.format(oparam.short_description()))
-                        parameter_handler.handle_output_parameter(oparam, opnode)
-                        output_params[oparam.name] = oparam
-
-                # iterate over subgraph operator input fields
-                input_states = {}
-                if ifields:
-                    gprint('   >Input fields')
-                    for (ifield, itopo) in sorted(ifields.iteritems(),
-                                                  key=lambda x: x[0].name, reverse=True):
-                        gprint('     *{}{}'.format(ifield.name, ' on an unknown topology'
-                                                   if (itopo is None) else '.{}'.format(itopo.pretty_tag)))
-                        if (itopo is None):
-                            assert isinstance(op, RedistributeInter)
-                            continue
-                        if isinstance(op, Problem):
-                            if ifield in op.initial_input_topology_states.keys():
-                                ifreqs = op.initial_input_topology_states[ifield][0]
-                            else:
-                                ifreqs = None
-                        else:
-                            if (current_level != 0 or isinstance(op, Problem)):
-                                ifreqs = None
-                            else:
-                                ifreqs = \
-                                    field_requirements.get_input_requirement(ifield)[1]
-                        if (ifield not in self.topology_states):
-                            cstate = self.new_topology_state(ifield)
-                            self.topology_states[ifield] = cstate
-                            is_new = True
-                        else:
-                            cstate = self.topology_states[ifield]
-                            is_new = False
-                        dstate = cstate.handle_input(opnode, itopo, ifreqs,
-                                                     graph, is_new)
-                        input_states[ifield] = dstate
-                        if is_new:
-                            input_fields[ifield] = itopo
-                            input_topology_states[ifield] = (ifreqs, dstate)
-
-                        if ifield not in double_check_inputs:
-                            double_check_inputs[ifield] = {}
-                        double_check_inputs[ifield].update({itopo: (ifreqs, dstate)})
-
-                # iterate over subgraph operator output fields
-                output_states = {}
-                if ofields:
-                    gprint('   >Output fields')
-                    for (ofield, otopo) in sorted(ofields.iteritems(),
-                                                  key=lambda x: x[0].name, reverse=True):
-                        gprint('     *{}.{}'.format(ofield.name, ' on an unknown topology'
-                                                    if (otopo is None) else '.{}'.format(otopo.pretty_tag)))
-                        if (otopo is None):
-                            assert isinstance(op, RedistributeInter)
-                            continue
-                        if isinstance(op, Problem):
-                            if ofield in op.final_output_topology_states.keys():
-                                ofreqs = op.final_output_topology_states[ofield][0]
-                            else:
-                                ofreqs = None
-                        else:
-                            ofreqs = None if (current_level != 0) \
-                                else field_requirements.get_output_requirement(ofield)[1]
-                        istates = None if (current_level != 0) else input_states
-                        cstate = self.topology_states.setdefault(ofield,
-                                                                 self.new_topology_state(ofield))
-                        invalidate_field = (ofield not in
-                                            op.get_preserved_input_fields())
-                        dstate = cstate.handle_output(opnode, otopo, ofreqs,
-                                                      op, istates, invalidate_field, graph)
-                        output_fields[ofield] = otopo
-                        output_states[ofield] = dstate
-                        output_topology_states[ofield] = (None, dstate)
-
-                if (current_level == 0) and ((op, opnode) not in deferred_operators):
-                    opnode.set_op_info(op, input_states, output_states)
-
-                op_input_topology_states[op] = input_states
-                op_output_topology_states[op] = output_states
+                opnode = self.new_node(op, subgraph, current_level, node, node_id, opvertex)
+                if isinstance(node, RedistributeInter):
+                    # Save graph building state for filling next nodes topologies after graph completion
+                    redistribute_inter.append((node_id, node, subgraph, node_ops, node_vertices, from_subgraph, opvertex, op, opnode))
+                    gprint("  *Will be handled later")
+                else:
+                    __handle_node(node_id, node, subgraph, node_ops, node_vertices, from_subgraph, opvertex, op, opnode)
+
+        # def __handle_intertask(input_elems, output_elems):
+            # If tasks are used, we also search for and insert  needed RedistributeInter operators.
+            # InterComm redistribute (for field or parameter) is inserted if an input
+            # of a task is found in other task output.
+            # Note: RedistributeInter are added at the end of graph (after fields being used)
+            # search_it_ops = self.search_intertasks_ops
+            #     search_it_ops = search_it_ops and (len(input_elems)+len(output_elems) > 0)
+            # # do        main = None
+            # if search_it_ops and input_fields:
+            #     domain = [_.domain for _ in input_fields.values()][0]
+            #     search_it_ops = search_it_ops and domain.has_tasks
+            # if search_it_ops and output_fields:
+            #     domain = [_.domain for _ in output_fields.values()][0]
+            #     search_it_ops = search_it_ops and domain.has_tasks
+            # if search_it_ops:
+            #     from hysop.operators import InterTaskParamComm
+            #     comm = domain.parent_comm
+            #     tcomm = domain.task_comm()
+            #     current_task = domain.current_task()
+            #     gprint('>[IT] Handling for inter-tasks redistributes (current task={}).'.format(current_task))
+
+            #     # Uninitialized fields or parameters
+            #     inputs_names = set([_.name for _ in input_elems.keys()])
+            #     outputs_names = set([_.name for _ in output_elems.keys()])
+            #     gprint(" >[IT] Output parameters and fields : " + ", ".join(outputs_names))
+            #     gprint(" >[IT] Input parameters and fields : " + ", ".join(inputs_names))
+
+                #     # Inter-task matching is performed on root process
+                #     to_provide = {}
+                #     inputs_names = dict((_, None) for _ in inputs_names)
+                #     if domain.task_rank() == 0:
+                #         msg = ''
+                #         # loop over other tasks
+                #         for ot in (_ for _ in domain.all_tasks if _ != current_task):
+                #             comm.isend(inputs_names.keys(),
+                #                        dest=domain.task_root_in_parent(ot), tag=ot)
+                #             ot_needs = comm.recv(source=domain.task_root_in_parent(ot), tag=current_task)
+                #             can_provide = [_ for _ in ot_needs if _ in outputs_names]
+                #             for prov in can_provide:
+                #                 to_provide[prov] = ot
+                #             if len(ot_needs) > 0:
+                #                 msg += "\n  *Other task {} needs init for {}, we provide {}".format(
+                #                     ot, ot_needs, "nothing" if len(can_provide) == 0 else can_provide)
+                #             comm.isend(can_provide, dest=domain.task_root_in_parent(ot), tag=1234)
+                #             ot_provide = comm.recv(source=domain.task_root_in_parent(ot), tag=1234)
+                #             for op in ot_provide:
+                #                 inputs_names[op] = ot
+                #         if msg != '':
+                #             gprint(" >[IT] Inter-tasks matching:" + msg)
+                #     provided = {p: t for (p, t) in inputs_names.items() if t is not None}
+                #     provided = tcomm.bcast(provided, root=0)
+                #     to_provide = tcomm.bcast(to_provide, root=0)
+                #     gprint(" >[IT] Inter-tasks providing {} provided {}".format(to_provide, provided))
+
+                #     def __add_param_redistribute(l, p, src, dest):
+                #         gprint("  >[IT] Added InterTaskParamComm for parameter {} {}".format(
+                #             p.name, "provided for task "+str(dest) if src == current_task
+                #             else " initialized from task "+str(src)))
+                #         l.append(InterTaskParamComm(parameter=(p, ), domain=domain,
+                #                                     source_task=src, dest_task=dest,
+                #                                     mpi_params=target_node.mpi_params))
+
+                #     def __add_field_redistribute(l, f, topo, src, dest):
+                #         msg = "  >[IT] Added RedistributeInter for field {} ".format(f.name)
+                #         if src == current_task:
+                #             msg += "provided for topology {} on task {}".format(topo.full_pretty_tag, t)
+                #             comm_dir = "src"
+                #         elif dest == current_task:
+                #             msg += "initialized from topology {} on task {}".format(topo.full_pretty_tag, t)
+                #             comm_dir = "dst"
+                #         else:
+                #             raise RuntimeError()
+                #         gprint(msg)
+                #         l.append(Redistribute(variables=f, other_task_id=t, mpi_params=topo.mpi_params,
+                #                               name='RI{}_{}_{}'.format(comm_dir, topo.id, f.name),
+                #                               pretty_name=u'RI{}_{}{}{}_{}'.format(
+                #                                   comm_dir,
+                #                                   subscript(topo.id) if src == current_task else '',
+                #                                   u'\u2192',
+                #                                   subscript(topo.id) if dest == current_task else '',
+                #                                   f.pretty_name.decode('utf-8')),
+                #                               source_topos=topo if src == current_task else None,
+                #                               target_topo=topo if dest == current_task else None))
+
+                #     # Build new Inter-tasks operators
+                #     new_ops = []
+                #     for p, t in sorted(to_provide.items() + provided.items(), key=lambda _: _[0]):
+                #         of_key = [_ for _ in self.output_fields.keys() if _.name == p]
+                #         if_key = [_ for _ in self.input_fields.keys() if _.name == p]
+                #         if p in provided:
+                #             if p in input_params.keys():
+                #                 __add_param_redistribute(new_ops, input_params[p], current_task, t)
+                #             else:
+                #                 __add_field_redistribute(
+                #                     new_ops, if_key[0], self.input_fields[if_key[0]], current_task, t)
+                #         elif p in to_provide:
+                #             if p in output_params.keys():
+                #                 __add_param_redistribute(new_ops, output_params[p], t, current_task)
+                #             else:
+                #                 __add_field_redistribute(
+                #                     new_ops, of_key[0], self.output_fields[of_key[0]], t, current_task)
+                #         else:
+                #             raise RuntimeError("Field {} not found either in inputs {} or outputs {}".format(
+                #                 p, [_.name for _ in self.input_fields.keys()], [_.name for _ in self.output_fields.keys()]))
+
+                #     # Handle new nodes
+                #     shift = len(target_node.nodes)
+                #     target_node.push_nodes(*tuple(new_ops))
+                #     for (node_id, node) in enumerate(_ for __ in new_ops for _ in __.nodes):
+                #         node.initialize()
+                #         __handle_node(shift + node_id, node)
+        def __find_elements_to_redistribute(available_elems, needed_elems, intertasks_exchanged):
+            # The algorithm is to extract level0 input fields and topologies as needs
+            # and meet with output fields and topologies as provided. The key feature is that
+            # these informations are distributed across distinct tasks (sub-communicators).
+            # Same algorithm is also used for parameters.
+            domain = first_not_None([_.domain for _ in available_elems.values() + needed_elems.values()])
+            comm = domain.parent_comm
+            tcomm = domain.task_comm()
+            current_task = domain.current_task()
+            gprint('>[IT] Handling for inter-tasks redistributes (current task={}).'.format(current_task))
 
-        # iterate over ComputationalNodes
-        for (node_id, node) in enumerate(target_node.nodes):
-            __handle_node(node_id, node)
+            # Find redistribute candidates
+            available_names = set([_.name for _ in available_elems.keys()]) - intertasks_exchanged
+            needed_names = set([_.name for _ in needed_elems.keys()]) - intertasks_exchanged
+            mgs = " >[IT] Current task {} parameters and fields : {}"
+            gprint(mgs.format("can communicate", ", ".join(available_names)))
+            gprint(mgs.format("needs", ", ".join(needed_names)))
+
+            # Inter-task matching is performed on root process
+            available_names = dict((_, None) for _ in available_names)  # value is dest task
+            needed_names = dict((_, None) for _ in needed_names)  # value is src task
+            if domain.task_rank() == 0:
+                msg = ''
+                # loop over other tasks
+                for ot in (_ for _ in domain.all_tasks if _ != current_task):
+                    comm.isend(needed_names.keys(),
+                               dest=domain.task_root_in_parent(ot), tag=ot)
+                    ot_needs = comm.recv(source=domain.task_root_in_parent(ot), tag=current_task)
+                    can_provide = [_ for _ in ot_needs if _ in available_names]
+                    for prov in can_provide:
+                        available_names[prov] = ot
+                    if len(ot_needs) > 0:
+                        msg += "\n  *Other task {} needs init for {}, we provide {}".format(
+                            ot, ot_needs, "nothing" if len(can_provide) == 0 else can_provide)
+                    comm.isend(can_provide, dest=domain.task_root_in_parent(ot), tag=1234)
+                    ot_provide = comm.recv(source=domain.task_root_in_parent(ot), tag=1234)
+                    for _op in ot_provide:
+                        needed_names[_op] = ot
+                if msg != '':
+                    gprint(" >[IT] Inter-tasks matching:" + msg)
+            needed_names = {p: t for (p, t) in needed_names.items() if t is not None}
+            available_names = {p: t for (p, t) in available_names.items() if t is not None}
+            needed_names = tcomm.bcast(needed_names, root=0)
+            available_names = tcomm.bcast(available_names, root=0)
+            gprint(" >[IT] Inter-tasks will send {} and recieve {}".format(available_names, needed_names))
+            intertasks_exchanged = intertasks_exchanged.union(set(available_names.keys() + needed_names.keys()))
+            assert len(available_names.items() + needed_names.items()) == 1, \
+                "Redistributes work only for single variables for the moment"
+            # Get back the actual field or parameter
+            kwargs = {}
+            for p, t in sorted(available_names.items() + needed_names.items(), key=lambda _: _[0]):
+                kwargs.update({'other_task_id': t})
+                s_topo, r_topo, comm_dir = (None, )*3
+                if p in available_names.keys():
+                    var = [_ for _ in available_elems.keys() if _.name == p][0]
+                    topo = available_elems[var]
+                    comm_dir = 'src'
+                    s_topo = topo
+                if p in needed_names.keys():
+                    var = [_ for _ in needed_elems.keys() if _.name == p][0]
+                    topo = needed_elems[var]
+                    comm_dir = 'dest'
+                    r_topo = topo
+                assert not comm_dir is None
+                # Finalize init call
+                kwargs.update({'variable': var,
+                               'mpi_params': topo.mpi_params,
+                               'name': 'RI{}_{}_{}'.format(comm_dir, topo.id, var.name),
+                               'pretty_name': u'RI{}_{}{}{}_{}'.format(
+                                   comm_dir,
+                                   '' if s_topo is None else subscript(s_topo.id),
+                                   u'\u2192',
+                                   '' if r_topo is None else subscript(r_topo.id),
+                                   var.pretty_name.decode('utf-8')),
+                               'source_topo': s_topo,
+                               'target_topo': r_topo,
+                               })
+            return kwargs, intertasks_exchanged
+
+        # Iterate redistribute inter-tasks (Resume for these nodes)
+        for _handle_node_args in redistribute_inter:
+            node_id, node, subgraph, node_ops, node_vertices, from_subgraph, opvertex, op, opnode = _handle_node_args
+            # Fix the redistribute base initialization
+            available_elems, needed_elems = {}, {}
+            for _node in target_node.nodes:
+                if _node is node:
+                    break
+                available_elems.update(_node.output_fields)
+                available_elems.update(_node.output_params)
+            for _node in target_node.nodes[::-1]:
+                if _node is node:
+                    break
+                needed_elems.update(_node.input_fields)
+                needed_elems.update(_node.input_params)
+            it_redistribute_kwargs, intertasks_exchanged = __find_elements_to_redistribute(
+                available_elems, needed_elems, intertasks_exchanged)
+            assert RedistributeInter.can_redistribute(
+                *tuple([it_redistribute_kwargs[_]
+                        for _ in ('source_topo', 'target_topo', 'other_task_id')]))
+            op.__init__(**it_redistribute_kwargs)
+            op.initialize(topgraph_method=self.target_node.method)
+            __handle_node(*_handle_node_args)
 
         # iterate deferred nodes
         for (op, opnode) in deferred_operators:
@@ -361,123 +585,6 @@ class GraphBuilder(object):
             if current_level == 0:
                 opnode.set_op_info(op, input_states, output_states)
 
-        # If tasks are used, we also search for and insert  needed RedistributeInter operators.
-        # The algorithm is to extract level0 input fields and topologies as needs
-        # and meet with output fields and topologies as provided. The key feature is that
-        # these informations are distributed across distinct tasks (sub-communicators).
-        # Same algorithm is also used for parameters.
-        # InterComm redistribute (for field or parameter) is inserted if an input
-        # of a task is found in other task output.
-        # Note: RedistributeInter are added at the end of graph (after fields being used)
-        search_it_ops = self.search_intertasks_ops
-        search_it_ops = search_it_ops and (len(input_fields)+len(output_fields) > 0)
-        domain = None
-        if search_it_ops and input_fields:
-            domain = [_.domain for _ in input_fields.values()][0]
-            search_it_ops = search_it_ops and domain.has_tasks
-        if search_it_ops and output_fields:
-            domain = [_.domain for _ in output_fields.values()][0]
-            search_it_ops = search_it_ops and domain.has_tasks
-        if search_it_ops:
-            from hysop.operators import InterTaskParamComm
-            comm = domain.parent_comm
-            tcomm = domain.task_comm()
-            current_task = domain.current_task()
-            gprint('>[IT] Handling for inter-tasks redistributes (current task={}).'.format(current_task))
-
-            # Uninitialized fields or parameters
-            inputs_names = set([_.name for _ in input_fields.keys()] + input_params.keys())
-            # inputs_names -= set(output_params.keys())
-            # inputs_names -= set(_.name for _ in output_fields.keys())
-            outputs_names = output_params.keys() + [_.name for _ in output_fields.keys()]
-            gprint(" >[IT] Output parameters and fields : " + ", ".join(outputs_names))
-            gprint(" >[IT] Input parameters and fields : " + ", ".join(inputs_names))
-
-            # Inter-task matching is performed on root process
-            to_provide = {}
-            inputs_names = dict((_, None) for _ in inputs_names)
-            if domain.task_rank() == 0:
-                msg = ''
-                # loop over other tasks
-                for ot in (_ for _ in domain.all_tasks if _ != current_task):
-                    comm.isend(inputs_names.keys(),
-                               dest=domain.task_root_in_parent(ot), tag=ot)
-                    ot_needs = comm.recv(source=domain.task_root_in_parent(ot), tag=current_task)
-                    can_provide = [_ for _ in ot_needs if _ in outputs_names]
-                    for prov in can_provide:
-                        to_provide[prov] = ot
-                    if len(ot_needs) > 0:
-                        msg += "\n  *Other task {} needs init for {}, we provide {}".format(
-                            ot, ot_needs, "nothing" if len(can_provide) == 0 else can_provide)
-                    comm.isend(can_provide, dest=domain.task_root_in_parent(ot), tag=1234)
-                    ot_provide = comm.recv(source=domain.task_root_in_parent(ot), tag=1234)
-                    for op in ot_provide:
-                        inputs_names[op] = ot
-                if msg != '':
-                    gprint(" >[IT] Inter-tasks matching:" + msg)
-            provided = {p: t for (p, t) in inputs_names.items() if t is not None}
-            provided = tcomm.bcast(provided, root=0)
-            to_provide = tcomm.bcast(to_provide, root=0)
-            gprint(" >[IT] Inter-tasks providing {} provided {}".format(to_provide, provided))
-
-            def __add_param_redistribute(l, p, src, dest):
-                gprint("  >[IT] Added InterTaskParamComm for parameter {} {}".format(
-                    p.name, "provided for task "+str(dest) if src == current_task
-                    else " initialized from task "+str(src)))
-                l.append(InterTaskParamComm(parameter=(p, ), domain=domain,
-                                            source_task=src, dest_task=dest,
-                                            mpi_params=target_node.mpi_params))
-
-            def __add_field_redistribute(l, f, topo, src, dest):
-                msg = "  >[IT] Added RedistributeInter for field {} ".format(f.name)
-                if src == current_task:
-                    msg += "provided for topology {} on task {}".format(topo.full_pretty_tag, t)
-                    comm_dir = "src"
-                elif dest == current_task:
-                    msg += "initialized from topology {} on task {}".format(topo.full_pretty_tag, t)
-                    comm_dir = "dst"
-                else:
-                    raise RuntimeError()
-                gprint(msg)
-                l.append(Redistribute(variables=f, other_task_id=t, mpi_params=topo.mpi_params,
-                                      name='RI{}_{}_{}'.format(comm_dir, topo.id, f.name),
-                                      pretty_name=u'RI{}_{}{}{}_{}'.format(
-                                          comm_dir,
-                                          subscript(topo.id) if src == current_task else '',
-                                          u'\u2192',
-                                          subscript(topo.id) if dest == current_task else '',
-                                          f.pretty_name.decode('utf-8')),
-                                      source_topos=topo if src == current_task else None,
-                                      target_topo=topo if dest == current_task else None))
-
-            # Build new Inter-tasks operators
-            new_ops = []
-            for p, t in sorted(to_provide.items() + provided.items(), key=lambda _: _[0]):
-                of_key = [_ for _ in self.output_fields.keys() if _.name == p]
-                if_key = [_ for _ in self.input_fields.keys() if _.name == p]
-                if p in provided:
-                    if p in input_params.keys():
-                        __add_param_redistribute(new_ops, input_params[p], current_task, t)
-                    else:
-                        __add_field_redistribute(
-                            new_ops, if_key[0], self.input_fields[if_key[0]], current_task, t)
-                elif p in to_provide:
-                    if p in output_params.keys():
-                        __add_param_redistribute(new_ops, output_params[p], t, current_task)
-                    else:
-                        __add_field_redistribute(
-                            new_ops, of_key[0], self.output_fields[of_key[0]], t, current_task)
-                else:
-                    raise RuntimeError("Field {} not found either in inputs {} or outputs {}".format(
-                        p, [_.name for _ in self.input_fields.keys()], [_.name for _ in self.output_fields.keys()]))
-
-            # Handle new nodes
-            shift = len(target_node.nodes)
-            target_node.push_nodes(*tuple(new_ops))
-            for (node_id, node) in enumerate(_ for __ in new_ops for _ in __.nodes):
-                node.initialize()
-                __handle_node(shift + node_id, node)
-
         # On level=0 we print a summary (if asked) for input and output fields and
         # their topology.
         def _print_io_fields_params_summary(comment=''):
@@ -512,6 +619,7 @@ class GraphBuilder(object):
             gprint(msg)
         if current_level == 0:
             _print_io_fields_params_summary()
+        is_graph_updated = False
 
         # add output field Writer if necessary
         if (target_node._output_fields_to_dump is not None):
@@ -547,11 +655,11 @@ class GraphBuilder(object):
                     self.op_output_topology_states[op] = output_states
                     if current_level == 0:
                         opnode.set_op_info(op, input_states, output_states)
+                    is_graph_updated = True
 
         # Alter states such that output topology states match input topology states
         # this is only done if required (outputs_are_inputs) and if we are
         # processing the top level (root) graph
-        is_graph_updated = False
         if (current_level == 0) and outputs_are_inputs:
             def _closure(field, itopo, itopostate, cstate):
                 target_topo = itopo
@@ -594,6 +702,28 @@ class GraphBuilder(object):
                         assert f in input_topology_states
                         is_graph_updated = _closure(f, t, double_check_inputs[f][t], self.topology_states[f])
 
+        # Final intertask redistributes as closure
+        if (current_level == 0) and outputs_are_inputs:
+            available_elems, needed_elems = {}, {}
+            needed_elems.update(self.input_fields)
+            available_elems.update(self.output_fields)
+            needed_elems.update(self.input_params)
+            available_elems.update(self.output_params)
+            it_redistribute_kwargs, intertasks_exchanged = __find_elements_to_redistribute(
+                available_elems, needed_elems, intertasks_exchanged)
+            node = RedistributeInter(**it_redistribute_kwargs)
+            node_id = len(target_node.nodes)
+            target_node.push_nodes(node)
+            node.initialize()
+            gprint(' >Handling node {}: {} {}'.format(
+                node_id, node.name, node.__class__))
+            subgraph, node_ops, node_vertices, from_subgraph = \
+                self.build_subgraph(node, current_level)
+            for (opvertex, op) in zip(node_vertices, node_ops):
+                opnode = self.new_node(op, subgraph, current_level, node, node_id, opvertex)
+                __handle_node(node_id, node, subgraph, node_ops, node_vertices, from_subgraph, opvertex, op, opnode)
+                is_graph_updated = True
+
         if current_level == 0 and is_graph_updated:
             _print_io_fields_params_summary('After closure and output dumping')
 
@@ -773,10 +903,13 @@ class GraphBuilder(object):
         def add_vertex(self, graph, operator):
             return new_vertex(graph, operator)
 
-        def add_edge(self, graph, src_node, dst_node, field, topology):
+        def add_edge(self, graph, src_node, dst_node, field, topology, reverse=False):
             if (src_node is not None) and (dst_node is not None) \
                     and (src_node != dst_node):
-                return new_edge(graph, src_node, dst_node, field, topology)
+                if reverse:
+                    return new_edge(graph, dst_node, src_node, field, topology)
+                else:
+                    return new_edge(graph, src_node, dst_node, field, topology)
             else:
                 return None
 
@@ -1029,10 +1162,8 @@ class GraphBuilder(object):
                             else:
                                 istate.axes = allowed_axes[0]
 
-                            allowed_memory_order = \
-                                target_dfield_requirements.memory_order
-                            default_memory_order = \
-                                self.discrete_topology_states[target_topo].memory_order
+                            allowed_memory_order = target_dfield_requirements.memory_order
+                            default_memory_order = self.discrete_topology_states[target_topo].memory_order
                             assert (default_memory_order is not MemoryOrdering.ANY)
                             if (allowed_memory_order is MemoryOrdering.ANY):
                                 istate.memory_order = default_memory_order
@@ -1140,7 +1271,7 @@ class GraphBuilder(object):
             return istate
 
         def handle_output(self, opnode, output_topo, oreqs, operator,
-                          input_topology_states, invalidate_field, graph):
+                          input_topology_states, invalidate_field, graph, node_list=[]):
 
             ofield = self.field
             write_nodes = self.write_nodes
@@ -1152,8 +1283,8 @@ class GraphBuilder(object):
             # add dependency to last node written to prevent
             # concurent write-writes.
             if (output_topo in write_nodes):
-                self.add_edge(graph, write_nodes[output_topo],
-                              opnode, ofield, output_topo)
+                self.add_edge(graph, write_nodes[output_topo], opnode, ofield, output_topo,
+                              reverse=node_list.index(write_nodes[output_topo].operator) > node_list.index(operator))
 
             if invalidate_field:
                 msg = '   >Invalidating output field {} on all topologies but {} '
@@ -1164,8 +1295,11 @@ class GraphBuilder(object):
                 # to prevent concurent read-writes.
                 if output_topo in read_nodes:
                     for ro_node in read_nodes[output_topo]:
-                        self.add_edge(graph, ro_node, opnode, ofield, output_topo)
-
+                        if not ro_node is None:
+                            self.add_edge(graph, ro_node, opnode, ofield, output_topo,
+                                          reverse=node_list.index(ro_node.operator) > node_list.index(operator))
+                        else:
+                            self.add_edge(graph, ro_node, opnode, ofield, output_topo)
                 # remove read/write dependencies and states
                 write_nodes.clear()
                 dtopology_states.clear()
diff --git a/hysop/core/graph/node_generator.py b/hysop/core/graph/node_generator.py
index fa2c5bc90..d195f9295 100644
--- a/hysop/core/graph/node_generator.py
+++ b/hysop/core/graph/node_generator.py
@@ -5,6 +5,37 @@ from hysop.tools.types import first_not_None
 from hysop.core.graph.computational_node import ComputationalGraphNode
 
 
+class HiddenOperator(object):
+    """Object that is inserted in node list in replacement of out-of-taks operators.
+    This object should pass the intialization through graph building where it is definitely removed.
+    This object is an helper to build a graph with appropriates Inter-Task redistributes."""
+
+    def __init__(self, mpi_params):
+        self.mpi_params = mpi_params
+        self.name = ''
+
+    def available_methods(self, *args, **kwargs):
+        return {}
+
+    def get_node_requirements(self, *args, **kwargs):
+        return tuple()
+
+    def available_methods(self, *args, **kwargs):
+        return {}
+
+    def get_domains(self, *args, **kwargs):
+        return {}
+
+    def pre_initialize(self, *args, **kwargs):
+        pass
+
+    def initialize(self, *args, **kwargs):
+        pass
+
+    def post_initialize(self, *args, **kwargs):
+        pass
+
+
 class ComputationalGraphNodeGenerator(object):
     """
     A class that can generate multiple hysop.core.graph.ComputationalGraphNode.
@@ -73,6 +104,8 @@ class ComputationalGraphNodeGenerator(object):
                             raise RuntimeError(msg)
                         self.candidate_input_tensors.update(op.candidate_input_tensors)
                         self.candidate_output_tensors.update(op.candidate_output_tensors)
+                    elif isinstance(op, HiddenOperator):
+                        nodes = (op, )
                     else:
                         msg = 'Unknown node type {}.'.format(op.__class__)
                         raise TypeError(msg)
diff --git a/hysop/core/mpi/redistribute.py b/hysop/core/mpi/redistribute.py
index 92b091eec..08baa4b29 100644
--- a/hysop/core/mpi/redistribute.py
+++ b/hysop/core/mpi/redistribute.py
@@ -323,11 +323,17 @@ class RedistributeInter(RedistributeOperatorBase):
           *be CartesianTopology topologies with the same global resolution
           *be defined on different communicators
         """
-
-        # Base class initialisation
-        super(RedistributeInter, self).__init__(**kwds)
-        self._other_task_id = kwds['other_task_id']
-        self._synchronize(kwds['source_topo'], kwds['target_topo'])
+        if kwds:
+            # Base class initialisation
+            super(RedistributeInter, self).__init__(**kwds)
+            self._other_task_id = kwds['other_task_id']
+            self._synchronize(kwds['source_topo'], kwds['target_topo'])
+        else:
+            # Fake init. Should be called again later
+            self.name = "TempName"
+            self.mpi_params = None
+            self._input_fields_to_dump = []
+            self._output_fields_to_dump = []
 
     def _synchronize(self, tin, tout):
         """Ensure that the two redistributes are operating on the same variable"""
-- 
GitLab