From 5a76bfa4ffe74e713fd63195bfbc6d40085f5fc4 Mon Sep 17 00:00:00 2001
From: Jean-Baptiste Keck <Jean-Baptiste.Keck@imag.fr>
Date: Fri, 22 Feb 2019 22:21:14 +0100
Subject: [PATCH] common method for dump

---
 hysop/numerics/fft/opencl_fft.py    |  4 ++++
 hysop/operator/hdf_io.py            |  5 +----
 hysop/operator/mean_field.py        |  5 +----
 hysop/operator/parameter_plotter.py | 21 ++++++---------------
 hysop/simulation.py                 |  8 ++++++++
 5 files changed, 20 insertions(+), 23 deletions(-)

diff --git a/hysop/numerics/fft/opencl_fft.py b/hysop/numerics/fft/opencl_fft.py
index b67149686..a064ea929 100644
--- a/hysop/numerics/fft/opencl_fft.py
+++ b/hysop/numerics/fft/opencl_fft.py
@@ -141,6 +141,10 @@ class OpenClFFTI(FFTI):
             launcher += lnc
         return launcher
     
+    def plan_compute_energy(self, tg, src, dst, transforms, 
+            method='round', target=None):
+        raise NotImplementedError
+    
     @classmethod
     def ensure_buffer(cls, get_buffer):
         if callable(get_buffer):
diff --git a/hysop/operator/hdf_io.py b/hysop/operator/hdf_io.py
index dc6b59266..b65eb3f00 100755
--- a/hysop/operator/hdf_io.py
+++ b/hysop/operator/hdf_io.py
@@ -362,10 +362,7 @@ class HDF_Writer(HDF_IO):
     def apply(self, simulation=None, **kwds):
         if (simulation is None):
             raise ValueError("Missing simulation value for monitoring.")
-        ite = simulation.current_iteration
-        should_dump  = (self.io_params.frequency>0) and (ite % self.io_params.frequency == 0)
-        should_dump |= simulation.is_time_of_interest
-        if should_dump:
+        if simulation.should_dump(frequency=self.io_params.frequency):
             if (self._xmf_file is None):
                 self.createXMFFile()
             self.step(simulation)
diff --git a/hysop/operator/mean_field.py b/hysop/operator/mean_field.py
index ad203012e..f4c2cfc80 100644
--- a/hysop/operator/mean_field.py
+++ b/hysop/operator/mean_field.py
@@ -115,10 +115,7 @@ class ComputeMeanField(ComputationalGraphOperator):
     def apply(self, simulation, **kwds):
         if (simulation is None):
             raise ValueError("Missing simulation value for monitoring.")
-        ite = simulation.current_iteration
-        should_dump  = (self.io_params.frequency>0) and (ite % self.io_params.frequency == 0)
-        should_dump |= simulation.is_time_of_interest
-        if should_dump:
+        if simulation.should_dump(frequency=self.io_params.frequency):
             for (dfield, (view, axes)) in self.averaged_dfields.iteritems():
                 filename = self.filename(dfield, self.write_counter)
                 arrays = {}
diff --git a/hysop/operator/parameter_plotter.py b/hysop/operator/parameter_plotter.py
index 0a4b1ee4d..8a1abda64 100644
--- a/hysop/operator/parameter_plotter.py
+++ b/hysop/operator/parameter_plotter.py
@@ -63,23 +63,14 @@ class PlottingOperator(ComputationalGraphOperator):
         self._save(**kwds)
 
     def _update(self, simulation, **kwds):
-        if (self.update_frequency == 0):
-            return
-        self.update(simulation=simulation, **kwds)
-        if not simulation._next_is_last:
-            if (simulation.current_iteration>1) and \
-                ((simulation.current_iteration % self.update_frequency) != 0):
-                return
-        if self.should_draw:
-            self.draw()
+        if simulation.should_dump(frequency=self.update_frequency, with_last=True):
+            self.update(simulation=simulation, **kwds)
+            if self.should_draw:
+                self.draw()
 
     def _save(self, simulation, **kwds):
-        if not simulation._next_is_last:
-            if (self.save_frequency == 0): 
-                return
-            if ((simulation.current_iteration % self.save_frequency) != 0):
-                return
-        self.save(simulation=simulation, **kwds)
+        if simulation.should_dump(frequency=self.save_frequency, with_last=True):
+            self.save(simulation=simulation, **kwds)
     
     @abstractmethod
     def update(self, **kwds):
diff --git a/hysop/simulation.py b/hysop/simulation.py
index ba3f97397..0a3183ab8 100644
--- a/hysop/simulation.py
+++ b/hysop/simulation.py
@@ -396,4 +396,12 @@ class Simulation(object):
         s += str(self.current_iteration) + ', max number of iterations : '
         s += str(self.max_iter)
         return s
+    
+    def should_dump(self, frequency, with_last=False):
+        dump = (with_last and self._next_is_last)
+        if (frequency >= 0):
+            dump |= self.is_time_of_interest
+        if (frequency > 0):
+            dump |= ((self.current_iteration % frequency) == 0)
+        return dump
 
-- 
GitLab