From 713eff6f85fc90f6631614bf17383519459e3e11 Mon Sep 17 00:00:00 2001 From: Jean-Baptiste Keck <Jean-Baptiste.Keck@imag.fr> Date: Fri, 17 Apr 2020 17:40:23 +0200 Subject: [PATCH] working graph gui --- hysop/core/graph/computational_graph.py | 14 +++++---- hysop/core/graph/graph.py | 40 ++++++++++++++++++++----- 2 files changed, 42 insertions(+), 12 deletions(-) diff --git a/hysop/core/graph/computational_graph.py b/hysop/core/graph/computational_graph.py index cd822959e..13ec3a711 100644 --- a/hysop/core/graph/computational_graph.py +++ b/hysop/core/graph/computational_graph.py @@ -728,26 +728,30 @@ class ComputationalGraph(ComputationalGraphNode): net.write_html(path) @graph_built - def to_pyvis(self): + def to_pyvis(self, width=None, height=None, with_custom_nodes=True): """ Convert the graph to a pyvis network for vizualization. """ try: - import pyvis + import pyvis, matplotlib except ImportError: - msg='\nFATAL ERROR: Graph vizualization requires the pyvis module.\n' + msg='\nFATAL ERROR: Graph vizualization requires pyvis and matplotlib.\n' print(msg) raise + + width = first_not_None(width, 1920) + height = first_not_None(height, 1080) graph = self.reduced_graph - network = pyvis.network.Network(directed=True, width=800, height=600) + network = pyvis.network.Network(directed=True, width=width, height=height) known_nodes = set() def add_node(node): node_id = int(node) if node_id not in known_nodes: network.add_node(node_id, label=node.label, - title=node.title, color=node.color) + title=node.title, color=node.color, + shape=node.shape(with_custom_nodes)) known_nodes.add(node_id) def add_edge(from_node, to_node): diff --git a/hysop/core/graph/graph.py b/hysop/core/graph/graph.py index 409903422..21466bd2d 100644 --- a/hysop/core/graph/graph.py +++ b/hysop/core/graph/graph.py @@ -43,12 +43,15 @@ def new_edge(graph, u, v, *args, **kwds): return (u,v) def generate_vertex_colors(): - import matplotlib + try: + import matplotlib + except ImportError: + return None from matplotlib import cm c0 = cm.get_cmap('tab20c').colors c1 = cm.get_cmap('tab20b').colors colors = [] - for i in range(4): + for i in (1,2,3,0): colors += c0[i::4] + c1[i::4] colors = tuple(map(matplotlib.colors.to_hex, colors)) return colors @@ -102,11 +105,30 @@ class VertexAttributes(object): # pyvis attributes for display @property def label(self): - return '{}'.format(self.operator.pretty_name) + s = '{}'.format(self.operator.pretty_name) + if (self.op_ordering is not None): + s = '({})\n{}'.format(self.op_ordering, s) + return s + @property def title(self): return self.node_info().replace('\n','<br>') + def shape(self, with_custom_nodes=True): + from hysop.operator.base.transpose_operator import TransposeOperatorBase + from hysop.operator.base.redistribute_operator import RedistributeOperatorBase + from hysop.operator.base.memory_reordering import MemoryReorderingBase + special_shapes = { + RedistributeOperatorBase: 'box', + TransposeOperatorBase: 'box', + MemoryReorderingBase: 'box' + } + if with_custom_nodes: + for (op_type, shape) in special_shapes.iteritems(): + if isinstance(self.operator, op_type): + return shape + return 'circle' + @property def color(self): cq = self.command_queue @@ -198,7 +220,8 @@ class EdgeAttributes(object): suffix='</b>  ' ss = '<h2>Variable dependencies</h2>{}'.format('\n'.join( '{p}{}:{s}{}'.format(v.pretty_name, - ', '.join(v[t].short_description() for t in self.variables[v]), + ', '.join(v.pretty_name if (t is None) else + v[t].short_description() for t in self.variables[v]), p=prefix,s=suffix) for v in self.variables)) return ss.replace('\n','<br>') @@ -346,16 +369,19 @@ def op_apply(f): for dfield in sorted(op.input_discrete_fields.values(), key=lambda x: x.name): tag = 'pre_{}_{}'.format(op.name, dfield.name) kwds['debug_dumper'](it, t, tag, - tuple(df.sdata.get().handle[df.compute_slices] for df in dfield.dfields), description=description) + tuple(df.sdata.get().handle[df.compute_slices] + for df in dfield.dfields), description=description) ret = f(*args, **kwds) for param in sorted(op.output_params.values(), key=lambda x: x.name): tag = 'post_{}_{}'.format(op.name, param.name) kwds['debug_dumper'](it, t, tag, (param._value,), description=description) - for dfield in sorted(op.output_discrete_fields.values(), key=lambda x: x.name): + for dfield in sorted(op.output_discrete_fields.values(), + key=lambda x: x.name): tag = 'post_{}_{}'.format(op.name, dfield.name) kwds['debug_dumper'](it, t, tag, - tuple(df.sdata.get().handle[df.compute_slices] for df in dfield.dfields), description=description) + tuple(df.sdata.get().handle[df.compute_slices] + for df in dfield.dfields), description=description) return ret elif dbg: msg = inspect.getsourcefile(f) -- GitLab