From f54781892447f9844a800c012fabd4cbcf09fff3 Mon Sep 17 00:00:00 2001
From: JM Etancelin <jean-matthieu.etancelin@univ-pau.fr>
Date: Mon, 8 Jun 2020 14:49:43 +0200
Subject: [PATCH] Improve graph building in the case of a field beeing used as
 input with several topologies.

---
 hysop/core/graph/computational_graph.py       |   2 +-
 hysop/core/graph/graph_builder.py             | 100 ++++++++++++------
 .../tests/test_spectral_derivative.py         |  40 +++----
 3 files changed, 91 insertions(+), 51 deletions(-)

diff --git a/hysop/core/graph/computational_graph.py b/hysop/core/graph/computational_graph.py
index 9e522c135..02e2a52fd 100644
--- a/hysop/core/graph/computational_graph.py
+++ b/hysop/core/graph/computational_graph.py
@@ -716,7 +716,7 @@ class ComputationalGraph(ComputationalGraphNode):
         if (self.is_root and __VERBOSE__) or __DEBUG__ or self.__FORCE_REPORTS__:
             print self.node_requirements_report(requirements)
         for node in self.nodes:
-            if node.mpi_params.on_task:
+            if node.mpi_params is None or node.mpi_params.on_task:
                 node_requirements = node.get_and_set_field_requirements()
                 requirements.update(node_requirements)
         if (self.is_root and __VERBOSE__) or __DEBUG__ or self.__FORCE_REPORTS__:
diff --git a/hysop/core/graph/graph_builder.py b/hysop/core/graph/graph_builder.py
index 9d575b16b..904a20782 100644
--- a/hysop/core/graph/graph_builder.py
+++ b/hysop/core/graph/graph_builder.py
@@ -133,6 +133,7 @@ class GraphBuilder(object):
         op_output_topology_states = self.op_output_topology_states
 
         deferred_operators = []
+        double_check_inputs = {}
 
         # check that all target nodes are unique to prevent conflicts
         if len(set(target_node.nodes)) != len(target_node.nodes):
@@ -288,6 +289,10 @@ class GraphBuilder(object):
                             input_fields[ifield] = itopo
                             input_topology_states[ifield] = (ifreqs, dstate)
 
+                        if ifield not in double_check_inputs:
+                            double_check_inputs[ifield] = {}
+                        double_check_inputs[ifield].update({itopo: (ifreqs, dstate)})
+
                 # iterate over subgraph operator output fields
                 output_states = {}
                 if ofields:
@@ -456,36 +461,38 @@ class GraphBuilder(object):
 
         # On level=0 we print a summary (if asked) for input and output fields and
         # their topology.
-        if current_level == 0:
-            msg = '\nComputationalGraph {} inputs:\n'.format(target_node.name)
-            if not (input_fields and input_params):
+        def _print_io_fields_params_summary(comment=''):
+            msg = '\nComputationalGraph {} inputs {}:\n'.format(target_node.name, comment)
+            if not (self.input_fields and self.input_params):
                 msg += '  no inputs\n'
             else:
-                if input_fields:
-                    for ifield in input_fields:
-                        itopo = input_fields[ifield]
-                        _, ireqs = input_topology_states[ifield]
-                        msg += '  *Field {} on topo {}: {}\n'.format(
-                            ifield.name, itopo.id, ireqs)
-                if len(input_params) > 0:
-                    for iparam in input_params:
+                if self.input_fields:
+                    for ifield in sorted(self.input_fields, key=lambda x: x.name):
+                        itopo = self.input_fields[ifield]
+                        _, ireqs = self.input_topology_states[ifield]
+                        msg += '  *Field {} on topo {}{}\n'.format(
+                            ifield.name, itopo.id, ": {}".format(ireqs) if GRAPH_BUILDER_DEBUG_LEVEL == 2 else '')
+                if len(self.input_params) > 0:
+                    for iparam in sorted(self.input_params):
                         msg += '  *Parameter {}\n'.format(iparam)
-            msg += 'ComputationalGraph {} outputs:\n'.format(target_node.name)
-            if not (output_fields and output_params):
+            msg += 'ComputationalGraph {} outputs {}:\n'.format(target_node.name, comment)
+            if not (self.output_fields and self.output_params):
                 msg += '  no outputs\n'
             else:
-                if output_fields:
-                    for ofield in output_fields:
-                        otopo = output_fields[ofield]
-                        _, oreqs = output_topology_states[ofield]
-                        msg += '  *Field {} on topo {}: {}\n'.format(
-                            ofield.name, otopo.id, oreqs)
-                if len(output_params) > 0:
-                    for oparam in output_params:
-                        msg += '  *Parameter {}\n'.format(iparam)
+                if self.output_fields:
+                    for ofield in sorted(self.output_fields, key=lambda x: x.name):
+                        otopo = self.output_fields[ofield]
+                        _, oreqs = self.output_topology_states[ofield]
+                        msg += '  *Field {} on topo {}{}\n'.format(
+                            ofield.name, otopo.id, ": {}".format(oreqs) if GRAPH_BUILDER_DEBUG_LEVEL == 2 else '')
+                if len(self.output_params) > 0:
+                    for oparam in sorted(self.output_params):
+                        msg += '  *Parameter {}\n'.format(oparam)
 
             msg += '\n'
             gprint(msg)
+        if current_level == 0:
+            _print_io_fields_params_summary()
 
         # add output field Writer if necessary
         if (target_node._output_fields_to_dump is not None):
@@ -525,23 +532,52 @@ class GraphBuilder(object):
         # Alter states such that output topology states match input topology states
         # this is only done if required (outputs_are_inputs) and if we are
         # processing the top level (root) graph
+        is_graph_updated = False
         if (current_level == 0) and outputs_are_inputs:
-            # identify variables that needs a closure
-            redistribute_fields = set(input_fields.keys())
-
-            for field in sorted(redistribute_fields, key=lambda x: x.name):
-                assert field in input_topology_states
-                target_topo = input_fields[field]
-                input_dfield_requirements, input_topology_state = \
-                    input_topology_states[field]
+            def _closure(field, itopo, itopostate, cstate):
+                target_topo = itopo
+                input_dfield_requirements, input_topology_state = itopostate
 
                 requirements = input_dfield_requirements.copy()
                 requirements.axes = (input_topology_state.axes,)
                 requirements.memory_order = input_topology_state.memory_order
 
-                cstate = self.topology_states[field]
                 cstate.output_as_input(target_topo, requirements, graph)
 
+                # Update (on closure) to have output as close as possible to inputs
+                if field in output_fields:
+                    orig_topo, orig_state = output_fields[field], cstate.discrete_topology_states[output_fields[field]]
+                    final_topo, final_state = target_topo, cstate.discrete_topology_states[target_topo]
+                    kept_topo_and_state = orig_topo == final_topo and orig_state == final_state
+                    if not kept_topo_and_state:
+                        msg = "      > Update graph outputs from topology {}{} to {}{}".format(
+                            orig_topo.tag, ":{}".format(orig_state) if GRAPH_BUILDER_DEBUG_LEVEL == 2 else '',
+                            final_topo.tag, ":{}".format(final_state) if GRAPH_BUILDER_DEBUG_LEVEL == 2 else '')
+                        gprint(msg)
+                        self.output_fields[field] = target_topo
+                        self.output_topology_states[field] = (None, cstate.discrete_topology_states[target_topo])
+                        return True
+                return False
+
+            # identify variables that needs a closure
+            redistribute_fields = set(input_fields.keys())
+            for field in sorted(redistribute_fields, key=lambda x: x.name):
+                assert field in input_topology_states
+                is_graph_updated = _closure(field, input_fields[field], input_topology_states[field], self.topology_states[field])
+
+            for f in sorted(double_check_inputs.keys(), key=lambda x: x.name):
+                # all field used as input must have been written in each topology or never written
+                written_topos = set(self.topology_states[f].write_nodes.keys())
+                read_topos = set(double_check_inputs[f].keys())
+                diff = read_topos-written_topos
+                if len(written_topos) > 0 and len(diff) >= 0:
+                    for t in diff:
+                        assert f in input_topology_states
+                        is_graph_updated = _closure(f, t, double_check_inputs[f][t], self.topology_states[f])
+
+        if current_level == 0 and is_graph_updated:
+            _print_io_fields_params_summary('After closure and output dumping')
+
         # Check that the generated graph is a directed acyclic graph
         if not is_directed_acyclic_graph(graph):
             msg = '\nGenerated operator graph is not acyclic.'
@@ -607,7 +643,7 @@ class GraphBuilder(object):
         subgraph = None
         from_subgraph = False
 
-        if node.mpi_params.on_task:
+        if node.mpi_params is None or node.mpi_params.on_task:
             if isinstance(node, Problem):
                 assert node.graph_built, "Sub-problem should be already built"
                 assert node.initialized, "Sub-problem should be already initialized"
diff --git a/hysop/operator/tests/test_spectral_derivative.py b/hysop/operator/tests/test_spectral_derivative.py
index a02ae7c9a..b093bf4a2 100644
--- a/hysop/operator/tests/test_spectral_derivative.py
+++ b/hysop/operator/tests/test_spectral_derivative.py
@@ -322,35 +322,39 @@ class TestSpectralDerivative(object):
 
     def test_1d_trigonometric_float32(self, **kwds):
         kwds.update({'max_derivative': 3})
-        self._test(dim=1, dtype=npw.float32, polynomial=False, **kwds)
+        if __ENABLE_LONG_TESTS__:
+            self._test(dim=1, dtype=npw.float32, polynomial=False, **kwds)
 
     def test_2d_trigonometric_float32(self, **kwds):
         kwds.update({'max_derivative': 1, 'max_runs': None})
-        self._test(dim=2, dtype=npw.float32, polynomial=False, **kwds)
+        if __ENABLE_LONG_TESTS__:
+            self._test(dim=2, dtype=npw.float32, polynomial=False, **kwds)
 
     def test_3d_trigonometric_float32(self, **kwds):
-        kwds.update({'max_derivative': 1, 'max_runs': 5})
-        if __ENABLE_LONG_TESTS__:
-            self._test(dim=3, dtype=npw.float32, polynomial=False, **kwds)
+        kwds.update({'max_derivative': 1, 'max_runs': 4})
+        self._test(dim=3, dtype=npw.float32, polynomial=False, **kwds)
 
     def test_1d_trigonometric_float64(self, **kwds):
         kwds.update({'max_derivative': 3})
-        self._test(dim=1, dtype=npw.float64, polynomial=False, **kwds)
+        if __ENABLE_LONG_TESTS__:
+            self._test(dim=1, dtype=npw.float64, polynomial=False, **kwds)
 
     def test_2d_trigonometric_float64(self, **kwds):
         kwds.update({'max_derivative': 1, 'max_runs': None})
-        self._test(dim=2, dtype=npw.float64, polynomial=False, **kwds)
+        if __ENABLE_LONG_TESTS__:
+            self._test(dim=2, dtype=npw.float64, polynomial=False, **kwds)
 
     def test_3d_trigonometric_float64(self, **kwds):
-        kwds.update({'max_derivative': 1, 'max_runs': 5})
-        if __ENABLE_LONG_TESTS__:
-            self._test(dim=3, dtype=npw.float64, polynomial=False, **kwds)
+        kwds.update({'max_derivative': 1, 'max_runs': 4})
+        self._test(dim=3, dtype=npw.float64, polynomial=False, **kwds)
 
     def test_1d_polynomial_float32(self, **kwds):
-        self._test(dim=1, dtype=npw.float32, polynomial=True, **kwds)
+        if __ENABLE_LONG_TESTS__:
+            self._test(dim=1, dtype=npw.float32, polynomial=True, **kwds)
 
     def test_2d_polynomial_float32(self, **kwds):
-        self._test(dim=2, dtype=npw.float32, polynomial=True, **kwds)
+        if __ENABLE_LONG_TESTS__:
+            self._test(dim=2, dtype=npw.float32, polynomial=True, **kwds)
 
     def test_3d_polynomial_float32(self, **kwds):
         kwds.update({'max_derivative': 1})
@@ -358,18 +362,18 @@ class TestSpectralDerivative(object):
             self._test(dim=3, dtype=npw.float32, polynomial=True, **kwds)
 
     def perform_tests(self):
-        self.test_1d_trigonometric_float32(max_derivative=3)
-        self.test_2d_trigonometric_float32(max_derivative=1, max_runs=None)
-        self.test_3d_trigonometric_float32(max_derivative=1, max_runs=5)
-        self.test_1d_trigonometric_float64(max_derivative=3)
-        self.test_2d_trigonometric_float64(max_derivative=1, max_runs=None)
-        self.test_3d_trigonometric_float64(max_derivative=1, max_runs=5)
+        self.test_3d_trigonometric_float32(max_derivative=1, max_runs=4)
+        self.test_3d_trigonometric_float64(max_derivative=1, max_runs=4)
 
         if __ENABLE_LONG_TESTS__:
+            self.test_1d_trigonometric_float64(max_derivative=3)
+            self.test_2d_trigonometric_float64(max_derivative=1, max_runs=None)
             # self.test_1d_trigonometric_float64(max_derivative=3)
             # self.test_2d_trigonometric_float64(max_derivative=2)
             self.test_3d_trigonometric_float64(max_derivative=1)
 
+            self.test_1d_trigonometric_float32(max_derivative=3)
+            self.test_2d_trigonometric_float32(max_derivative=1, max_runs=None)
             # self.test_1d_polynomial_float32(max_derivative=3)
             # self.test_2d_polynomial_float32(max_derivative=2)
             self.test_3d_polynomial_float32(max_derivative=1)
-- 
GitLab