Skip to content
Snippets Groups Projects
Commit c42abbc8 authored by EXT Jean-Matthieu Etancelin's avatar EXT Jean-Matthieu Etancelin
Browse files

Fix missing matplotlib import in parameter plotter

parent 69fb568e
No related branches found
No related tags found
1 merge request!16MPI operators
Pipeline #24706 failed
......@@ -17,18 +17,18 @@ class PlottingOperator(ComputationalGraphOperator):
def __init__(self, name=None,
dump_dir=None,
update_frequency=1,
save_frequency=100,
axes_shape=(1,),
update_frequency=1,
save_frequency=100,
axes_shape=(1,),
figsize=(30,18),
visu_rank=0,
fig=None,
fig=None,
axes=None,
**kwds):
import matplotlib
import matplotlib.pyplot as plt
check_instance(name, str)
check_instance(update_frequency, int, minval=0)
check_instance(save_frequency, int, minval=0)
......@@ -39,15 +39,15 @@ class PlottingOperator(ComputationalGraphOperator):
if (fig is None) ^ (axes is None):
msg='figure and axes should be specified at the same time.'
raise RuntimeError(msg)
dump_dir = first_not_None(dump_dir, IO.default_path())
imgpath = '{}/{}.png'.format(dump_dir, name)
if (fig is None):
fig, axes = plt.subplots(*axes_shape, figsize=figsize)
fig.canvas.mpl_connect('key_press_event', self.on_key_press)
fig.canvas.mpl_connect('close_event', self.on_close)
axes = npw.asarray(axes).reshape(axes_shape)
self.fig = fig
......@@ -79,7 +79,7 @@ class PlottingOperator(ComputationalGraphOperator):
def _save(self, simulation, **kwds):
if simulation.should_dump(frequency=self.save_frequency, with_last=True):
self.save(simulation=simulation, **kwds)
@abstractmethod
def update(self, **kwds):
pass
......@@ -87,7 +87,7 @@ class PlottingOperator(ComputationalGraphOperator):
def save(self, **kwds):
self.fig.savefig(self.imgpath, dpi=self.fig.dpi,
bbox_inches='tight')
def on_close(self, event):
self.running = False
......@@ -100,14 +100,15 @@ class PlottingOperator(ComputationalGraphOperator):
class ParameterPlotter(PlottingOperator):
"""
Base operator to plot parameters during runtime.
Base operator to plot parameters during runtime.
"""
def __init__(self, name, parameters, alloc_size=128,
def __init__(self, name, parameters, alloc_size=128,
fig=None, axes=None, shape=None, **kwds):
input_params = {}
if (fig is not None) and (axes is not None):
import matplotlib
custom_axes = True
axes_shape=None
check_instance(parameters, dict, keys=matplotlib.axes.Axes, values=dict)
......@@ -127,7 +128,7 @@ class ParameterPlotter(PlottingOperator):
else:
raise TypeError(type(parameters))
check_instance(_parameters, dict, keys=(int,tuple,list), values=(TensorParameter,list,tuple,dict))
parameters = {}
axes_shape = (1,)*2
for (pos,params) in _parameters.iteritems():
......@@ -158,7 +159,7 @@ class ParameterPlotter(PlottingOperator):
_params[_pname] = _p
parameters[pos] = _params
super(ParameterPlotter, self).__init__(name=name, input_params=input_params,
super(ParameterPlotter, self).__init__(name=name, input_params=input_params,
axes_shape=axes_shape, axes=axes, fig=fig, **kwds)
self.custom_axes = custom_axes
......@@ -198,7 +199,7 @@ class ParameterPlotter(PlottingOperator):
return self.axes[i]
else:
return self.axes.flatten()[i]
def update(self, simulation, **kwds):
# expand memory if required
if (self.counter+1>self.times.size):
......@@ -219,4 +220,3 @@ class ParameterPlotter(PlottingOperator):
lines[pos][p].set_xdata(times[:self.counter])
lines[pos][p].set_ydata(data[pos][p][:self.counter])
self.counter += 1
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment