From 713eff6f85fc90f6631614bf17383519459e3e11 Mon Sep 17 00:00:00 2001
From: Jean-Baptiste Keck <Jean-Baptiste.Keck@imag.fr>
Date: Fri, 17 Apr 2020 17:40:23 +0200
Subject: [PATCH] working graph gui

---
 hysop/core/graph/computational_graph.py | 14 +++++----
 hysop/core/graph/graph.py               | 40 ++++++++++++++++++++-----
 2 files changed, 42 insertions(+), 12 deletions(-)

diff --git a/hysop/core/graph/computational_graph.py b/hysop/core/graph/computational_graph.py
index cd822959e..13ec3a711 100644
--- a/hysop/core/graph/computational_graph.py
+++ b/hysop/core/graph/computational_graph.py
@@ -728,26 +728,30 @@ class ComputationalGraph(ComputationalGraphNode):
         net.write_html(path)
 
     @graph_built
-    def to_pyvis(self):
+    def to_pyvis(self, width=None, height=None, with_custom_nodes=True):
         """
         Convert the graph to a pyvis network for vizualization.
         """
         try:
-            import pyvis
+            import pyvis, matplotlib
         except ImportError:
-            msg='\nFATAL ERROR: Graph vizualization requires the pyvis module.\n'
+            msg='\nFATAL ERROR: Graph vizualization requires pyvis and matplotlib.\n'
             print(msg)
             raise
+        
+        width  = first_not_None(width,  1920)
+        height = first_not_None(height, 1080)
 
         graph = self.reduced_graph
-        network = pyvis.network.Network(directed=True, width=800, height=600)
+        network = pyvis.network.Network(directed=True, width=width, height=height)
         known_nodes = set()
 
         def add_node(node):
             node_id = int(node)
             if node_id not in known_nodes:
                 network.add_node(node_id, label=node.label, 
-                        title=node.title, color=node.color)
+                        title=node.title, color=node.color,
+                        shape=node.shape(with_custom_nodes))
                 known_nodes.add(node_id)
         
         def add_edge(from_node, to_node):
diff --git a/hysop/core/graph/graph.py b/hysop/core/graph/graph.py
index 409903422..21466bd2d 100644
--- a/hysop/core/graph/graph.py
+++ b/hysop/core/graph/graph.py
@@ -43,12 +43,15 @@ def new_edge(graph, u, v, *args, **kwds):
     return (u,v)
 
 def generate_vertex_colors():
-    import matplotlib
+    try:
+        import matplotlib
+    except ImportError:
+        return None
     from matplotlib import cm
     c0 = cm.get_cmap('tab20c').colors
     c1 = cm.get_cmap('tab20b').colors
     colors = []
-    for i in range(4):
+    for i in (1,2,3,0):
         colors += c0[i::4] + c1[i::4]
     colors = tuple(map(matplotlib.colors.to_hex, colors))
     return colors
@@ -102,11 +105,30 @@ class VertexAttributes(object):
     # pyvis attributes for display
     @property
     def label(self):
-        return '{}'.format(self.operator.pretty_name)
+        s = '{}'.format(self.operator.pretty_name)
+        if (self.op_ordering is not None):
+            s = '({})\n{}'.format(self.op_ordering, s)
+        return s
+    
     @property
     def title(self):
         return self.node_info().replace('\n','<br>')
     
+    def shape(self, with_custom_nodes=True):
+        from hysop.operator.base.transpose_operator    import TransposeOperatorBase
+        from hysop.operator.base.redistribute_operator import RedistributeOperatorBase
+        from hysop.operator.base.memory_reordering     import MemoryReorderingBase
+        special_shapes = {
+                RedistributeOperatorBase: 'box',
+                TransposeOperatorBase:    'box',
+                MemoryReorderingBase:     'box'
+        }
+        if with_custom_nodes:
+            for (op_type, shape) in special_shapes.iteritems():
+                if isinstance(self.operator, op_type):
+                    return shape
+        return 'circle'
+    
     @property
     def color(self):
         cq = self.command_queue
@@ -198,7 +220,8 @@ class EdgeAttributes(object):
         suffix='</b>&nbsp;&nbsp'
         ss = '<h2>Variable dependencies</h2>{}'.format('\n'.join(
                 '{p}{}:{s}{}'.format(v.pretty_name, 
-                    ', '.join(v[t].short_description() for t in self.variables[v]),
+                    ', '.join(v.pretty_name if (t is None) else 
+                                v[t].short_description() for t in self.variables[v]),
                     p=prefix,s=suffix) for v in self.variables))
         return ss.replace('\n','<br>')
 
@@ -346,16 +369,19 @@ def op_apply(f):
             for dfield in sorted(op.input_discrete_fields.values(), key=lambda x: x.name):
                 tag = 'pre_{}_{}'.format(op.name, dfield.name)
                 kwds['debug_dumper'](it, t, tag,
-                                     tuple(df.sdata.get().handle[df.compute_slices] for df in dfield.dfields), description=description)
+                    tuple(df.sdata.get().handle[df.compute_slices] 
+                        for df in dfield.dfields), description=description)
             ret = f(*args, **kwds)
             for param in sorted(op.output_params.values(), key=lambda x: x.name):
                 tag = 'post_{}_{}'.format(op.name, param.name)
                 kwds['debug_dumper'](it, t, tag,
                                      (param._value,), description=description)
-            for dfield in sorted(op.output_discrete_fields.values(), key=lambda x: x.name):
+            for dfield in sorted(op.output_discrete_fields.values(), 
+                                                    key=lambda x: x.name):
                 tag = 'post_{}_{}'.format(op.name, dfield.name)
                 kwds['debug_dumper'](it, t, tag,
-                                     tuple(df.sdata.get().handle[df.compute_slices] for df in dfield.dfields), description=description)
+                            tuple(df.sdata.get().handle[df.compute_slices] 
+                                for df in dfield.dfields), description=description)
             return ret
         elif dbg:
             msg = inspect.getsourcefile(f)
-- 
GitLab