From a5b0d6c7b57a7068ef4cf654c2c190026290f538 Mon Sep 17 00:00:00 2001
From: Jean-Matthieu Etancelin <jean-matthieu.etancelin@univ-pau.fr>
Date: Wed, 24 Mar 2021 13:01:27 +0100
Subject: [PATCH] fixup

---
 hysop/core/graph/computational_graph.py | 78 ++++++++++++-------------
 1 file changed, 39 insertions(+), 39 deletions(-)

diff --git a/hysop/core/graph/computational_graph.py b/hysop/core/graph/computational_graph.py
index e20bc9402..e2c942aeb 100644
--- a/hysop/core/graph/computational_graph.py
+++ b/hysop/core/graph/computational_graph.py
@@ -29,10 +29,10 @@ class ComputationalGraph(ComputationalGraphNode, metaclass=ABCMeta):
 
     @debug
     def __new__(cls, candidate_input_tensors=None,
-                 candidate_output_tensors=None,
-                 **kwds):
+                candidate_output_tensors=None,
+                **kwds):
         return super(ComputationalGraph, cls).__new__(cls,
-                input_fields=None, output_fields=None, **kwds)
+                                                      input_fields=None, output_fields=None, **kwds)
 
     @debug
     def __init__(self, candidate_input_tensors=None,
@@ -183,19 +183,19 @@ class ComputationalGraph(ComputationalGraphNode, metaclass=ABCMeta):
         titles = [[('OPERATOR', 'FIELD', 'DISCRETIZATION', 'GHOSTS',
                     'MEMORY ORDER', 'CAN_SPLIT', 'TSTATES')]]
         vals = tuple(sinputs.values()) + tuple(soutputs.values()) + tuple(titles)
-        name_size     = max(len(s[0]) for ss in vals for s in ss)
-        field_size    = max(len(s[1]) for ss in vals for s in ss)
-        discr_size    = max(len(s[2]) for ss in vals for s in ss)
-        ghosts_size   = max(len(s[3]) for ss in vals for s in ss)
-        order_size    = max(len(s[4]) for ss in vals for s in ss)
+        name_size = max(len(s[0]) for ss in vals for s in ss)
+        field_size = max(len(s[1]) for ss in vals for s in ss)
+        discr_size = max(len(s[2]) for ss in vals for s in ss)
+        ghosts_size = max(len(s[3]) for ss in vals for s in ss)
+        order_size = max(len(s[4]) for ss in vals for s in ss)
         cansplit_size = max(len(s[5]) for ss in vals for s in ss)
-        tstates_size  = max(len(s[6]) for ss in vals for s in ss)
+        tstates_size = max(len(s[6]) for ss in vals for s in ss)
 
         template = '\n   {:<{name_size}}   {:^{field_size}}     {:^{discr_size}}      {:^{ghosts_size}}      {:^{order_size}}      {:^{cansplit_size}}      {:^{tstates_size}}'
 
         ss = '>INPUTS:'
         if sinputs:
-            for (td, sreqs) in sorted(sinputs.items()):
+            for (td, sreqs) in sorted(sinputs.items(), key=lambda _: _[1][0]):
                 if isinstance(td, Topology):
                     ss += '\n {}'.format(td.short_description())
                 else:
@@ -216,7 +216,7 @@ class ComputationalGraph(ComputationalGraphNode, metaclass=ABCMeta):
             ss += ' None'
         ss += '\n>OUTPUTS:'
         if soutputs:
-            for (td, sreqs) in sorted(soutputs.items()):
+            for (td, sreqs) in sorted(soutputs.items(), key=lambda _: _[1][0]):
                 if isinstance(td, Topology):
                     ss += '\n {}'.format(td.short_description())
                 else:
@@ -254,13 +254,13 @@ class ComputationalGraph(ComputationalGraphNode, metaclass=ABCMeta):
                 continue
             for op in sorted(operators, key=lambda x: x.pretty_name):
                 finputs = ','.join(sorted([f.pretty_name
-                                            for f in op.iter_input_fields() if f.domain is domain]))
+                                           for f in op.iter_input_fields() if f.domain is domain]))
                 foutputs = ','.join(sorted([f.pretty_name
-                                             for f in op.iter_output_fields() if f.domain is domain]))
+                                            for f in op.iter_output_fields() if f.domain is domain]))
                 pinputs = ','.join(sorted([p.pretty_name
-                                            for p in op.input_params.values()]))
+                                           for p in op.input_params.values()]))
                 poutputs = ','.join(sorted([p.pretty_name
-                                             for p in op.output_params.values()]))
+                                            for p in op.output_params.values()]))
                 infields = '[{}]'.format(finputs) if finputs else ''
                 outfields = '[{}]'.format(foutputs) if foutputs else ''
                 inparams = '[{}]'.format(pinputs) if pinputs else ''
@@ -287,9 +287,9 @@ class ComputationalGraph(ComputationalGraphNode, metaclass=ABCMeta):
             operators = domains[None]
             for op in sorted(operators, key=lambda x: x.pretty_name):
                 pinputs = ','.join(sorted([p.pretty_name
-                                            for p in op.input_params.values()]))
+                                           for p in op.input_params.values()]))
                 poutputs = ','.join(sorted([p.pretty_name
-                                             for p in op.output_params.values()]))
+                                            for p in op.output_params.values()]))
                 inparams = '[{}]'.format(pinputs) if pinputs else ''
                 outparams = '[{}]'.format(poutputs) if poutputs else ''
 
@@ -306,11 +306,11 @@ class ComputationalGraph(ComputationalGraphNode, metaclass=ABCMeta):
                 op_data = ops.setdefault(None, [])
                 op_data += multiline_split(strdata, maxlen, split_sep, replace, newline_prefix)
 
-        name_size  = max(strlen(s[0]) for ss in ops.values() for s in ss)
-        in_size    = max(strlen(s[1]) for ss in ops.values() for s in ss)
+        name_size = max(strlen(s[0]) for ss in ops.values() for s in ss)
+        in_size = max(strlen(s[1]) for ss in ops.values() for s in ss)
         arrow_size = max(strlen(s[2]) for ss in ops.values() for s in ss)
-        out_size   = max(strlen(s[3]) for ss in ops.values() for s in ss)
-        type_size  = max(strlen(s[4]) for ss in ops.values() for s in ss)
+        out_size = max(strlen(s[3]) for ss in ops.values() for s in ss)
+        type_size = max(strlen(s[4]) for ss in ops.values() for s in ss)
 
         ss = ''
         for (domain, dops) in ops.items():
@@ -350,7 +350,7 @@ class ComputationalGraph(ComputationalGraphNode, metaclass=ABCMeta):
         for (backend, topologies) in self.get_topologies().items():
             ss += '\n {}:'.format(backend.short_description())
             ss += '\n  *'+'\n  *'.join(t.short_description()
-                                        for t in sorted(topologies, key=lambda x: x.id))
+                                       for t in sorted(topologies, key=lambda x: x.id))
         title = 'ComputationalGraph {} topology report '.format(self.pretty_name)
         return '\n{}\n'.format(framed_str(title=title, msg=ss[1:]))
 
@@ -394,7 +394,7 @@ class ComputationalGraph(ComputationalGraphNode, metaclass=ABCMeta):
         titles = [[('BACKEND', 'TOPOLOGY', 'OPERATORS')]]
         vals = tuple(topologies.values()) + tuple(titles)
         backend_size = max(len(s[0]) for ss in vals for s in ss)
-        topo_size    = max(len(s[1]) for ss in vals for s in ss)
+        topo_size = max(len(s[1]) for ss in vals for s in ss)
         template = '\n   {:<{backend_size}}   {:<{topo_size}}   {}'
         sizes = {'backend_size': backend_size,
                  'topo_size': topo_size}
@@ -428,30 +428,30 @@ class ComputationalGraph(ComputationalGraphNode, metaclass=ABCMeta):
                 t0 = node.input_fields[f0]
                 if all((node.input_fields[fi] is t0) for fi in f.fields):
                     finputs.append('{}.{}'.format(f.pretty_name,
-                                                   t0.pretty_tag))
+                                                  t0.pretty_tag))
                     handled_inputs += f.fields
             for f in node.output_tensor_fields:
                 f0 = f.fields[0]
                 t0 = node.output_fields[f0]
                 if all((node.output_fields[fi] is t0) for fi in f.fields):
                     foutputs.append('{}.{}'.format(f.pretty_name,
-                                                    t0.pretty_tag))
+                                                   t0.pretty_tag))
                     handled_outputs += f.fields
             finputs += ['{}.{}'.format(f.pretty_name,
-                                        t.pretty_tag)
+                                       t.pretty_tag)
                         for (f, t) in node.input_fields.items()
                         if f not in handled_inputs]
             foutputs += ['{}.{}'.format(f.pretty_name,
-                                         t.pretty_tag)
+                                        t.pretty_tag)
                          for (f, t) in node.output_fields.items()
                          if f not in handled_outputs]
             finputs = ','.join(sorted(finputs))
             foutputs = ','.join(sorted(foutputs))
 
             pinputs = ','.join(sorted([p.pretty_name
-                                        for p in node.input_params.values()]))
+                                       for p in node.input_params.values()]))
             poutputs = ','.join(sorted([p.pretty_name
-                                         for p in node.output_params.values()]))
+                                        for p in node.output_params.values()]))
 
             infields = '[{}]'.format(finputs) if finputs else ''
             outfields = '[{}]'.format(foutputs) if foutputs else ''
@@ -738,13 +738,14 @@ class ComputationalGraph(ComputationalGraphNode, metaclass=ABCMeta):
         Convert the graph to a pyvis network for vizualization.
         """
         try:
-            import pyvis, matplotlib
+            import pyvis
+            import matplotlib
         except ImportError:
-            msg='\nGraph vizualization requires pyvis and matplotlib, which are not present on your system.\n'
+            msg = '\nGraph vizualization requires pyvis and matplotlib, which are not present on your system.\n'
             print(msg)
             return
 
-        width  = first_not_None(width,  1920)
+        width = first_not_None(width,  1920)
         height = first_not_None(height, 1080)
 
         graph = self.reduced_graph
@@ -755,16 +756,16 @@ class ComputationalGraph(ComputationalGraphNode, metaclass=ABCMeta):
             node_id = int(node)
             if node_id not in known_nodes:
                 network.add_node(node_id, label=node.label,
-                        title=node.title, color=node.color,
-                        shape=node.shape(with_custom_nodes))
+                                 title=node.title, color=node.color,
+                                 shape=node.shape(with_custom_nodes))
                 known_nodes.add(node_id)
 
         def add_edge(from_node, to_node):
             from_node_id = int(from_node)
-            to_node_id   = int(to_node)
+            to_node_id = int(to_node)
             edge = graph[from_node][to_node]
             network.add_edge(from_node_id, to_node_id,
-                    title=str(edge.get('data', 'no edge data')))
+                             title=str(edge.get('data', 'no edge data')))
 
         for node in graph:
             add_node(node)
@@ -774,7 +775,6 @@ class ComputationalGraph(ComputationalGraphNode, metaclass=ABCMeta):
 
         return network
 
-
     @debug
     @graph_built
     def discretize(self):
@@ -820,10 +820,10 @@ class ComputationalGraph(ComputationalGraphNode, metaclass=ABCMeta):
                 output_discrete_tensor_fields[tfield] = tdfield
 
             discrete_fields = tuple(set(input_discrete_fields.values()).union(
-                                        output_discrete_fields.values()))
+                output_discrete_fields.values()))
 
             discrete_tensor_fields = tuple(set(input_discrete_tensor_fields.values()).union(
-                                               output_discrete_tensor_fields.values()))
+                output_discrete_tensor_fields.values()))
 
         else:
             input_discrete_fields = None
-- 
GitLab