From c2d1d5fc51d2fa68fa2c5cbfbc1a64383b549efd Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Franck=20P=C3=A9rignon?= <franck.perignon@imag.fr>
Date: Fri, 24 May 2013 12:23:55 +0000
Subject: [PATCH] Add methods to serialize the problem and restart a simulation
 from a file

---
 HySoP/hysop/fields/continuous.py         | 18 +++++++++++---
 HySoP/hysop/fields/discrete.py           | 16 ++++++++++---
 HySoP/hysop/fields/scalar.py             |  3 +++
 HySoP/hysop/fields/vector.py             |  4 ++++
 HySoP/hysop/operator/monitors/printer.py | 12 +++++++---
 HySoP/hysop/problem/problem.py           | 30 ++++++++++++++++++++++++
 HySoP/hysop/problem/simulation.py        | 30 ++++++++++++++++--------
 7 files changed, 94 insertions(+), 19 deletions(-)

diff --git a/HySoP/hysop/fields/continuous.py b/HySoP/hysop/fields/continuous.py
index 0fd1e9c64..427399cf8 100644
--- a/HySoP/hysop/fields/continuous.py
+++ b/HySoP/hysop/fields/continuous.py
@@ -169,20 +169,21 @@ class Field(object):
         """
         self.extraParameters = args
 
-    def dump(self, filename, topo=None):
+    def dump(self, filename, topo=None, mode=None):
         """
         Dump (serialize) the data of the field into filename.
         serialization process.
         @param filename : name of the file in which data are serialized
         @param topo : topology that identify a discrete field to be saved.
+        @param mode : set mode='append' to add data to an existing file.
         if None, the first discrete field in the list will be saved.
         """
         if topo is not None:
             assert topo in self.discreteFields.keys()
-            self.discreteFields[topo].dump(filename)
+            self.discreteFields[topo].dump(filename, mode)
         else:
             # dump all discr or only the first one?
-            self.discreteFields.values()[0].dump(filename)
+            self.discreteFields.values()[0].dump(filename, mode)
 
     def load(self, filename, topo=None):
         """
@@ -199,6 +200,17 @@ class Field(object):
             # dump all discr or only the first one?
             self.discreteFields.values()[0].load(filename)
 
+    def zero(self, topo=None):
+        """
+        reset to 0.0 all components of this field.
+        @param topo : if given, only discreteFields[topo] is
+        set to zero
+         """
+        if topo is not None:
+            self.discreteFields[topo].zero()
+        else:
+            for dfield in self.discreteFields.values():
+                dfield.zero()
 
 if __name__ == "__main__":
     print __doc__
diff --git a/HySoP/hysop/fields/discrete.py b/HySoP/hysop/fields/discrete.py
index 12ff4c1a5..da71fd982 100644
--- a/HySoP/hysop/fields/discrete.py
+++ b/HySoP/hysop/fields/discrete.py
@@ -100,7 +100,6 @@ class DiscreteField(object):
         self.isVector = None
         ## Field data numpy array.
         self.data = None
-        ## An id (optional) used to identify the field for
 
     @abstractmethod
     def __getitem__(self, i):
@@ -118,14 +117,21 @@ class DiscreteField(object):
     def get_data_method(self):
         pass
 
-    def dump(self, filename):
+    def dump(self, filename, mode=None):
         """
         Dump (serialize) the data of the field into filename.
         @param filename : name of the file in which data are serialized
+        @param mode : set mode='append' to add data to an existing file
         """
         filename += '_rk_'
         filename += str(main_rank)
-        db = NumPyDB_cPickle(filename, mode='store')
+        # create a new db
+        if mode is None:
+            db = NumPyDB_cPickle(filename, mode='store')
+        elif mode is 'append':
+            # use an existing db
+            db = NumPyDB_cPickle(filename, mode='load')
+
         #for dim in xrange(self.dimension):
         #    idd = self.name + '_' + str(dim)
         db.dump(self.data, self.name)
@@ -145,6 +151,10 @@ class DiscreteField(object):
         else:
             self.data = db.load(self.name)[0]
 
+    @abstractmethod
+    def zero(self):
+        """ set all components to zero"""
+
 if __name__ == "__main__":
     print __doc__
     print "- Provided class : Domain (abstract)."
diff --git a/HySoP/hysop/fields/scalar.py b/HySoP/hysop/fields/scalar.py
index be92d02e3..6b894baf3 100644
--- a/HySoP/hysop/fields/scalar.py
+++ b/HySoP/hysop/fields/scalar.py
@@ -102,6 +102,9 @@ class ScalarField(DiscreteField):
             s = str(self.data.shape) + "."
         return s + "\n"
 
+    def zero(self):
+        self.data[...] = 0.0
+
 if __name__ == "__main__":
     print __doc__
     print "- Provided class : Scalar"
diff --git a/HySoP/hysop/fields/vector.py b/HySoP/hysop/fields/vector.py
index aa0087f73..70a0bf0bd 100644
--- a/HySoP/hysop/fields/vector.py
+++ b/HySoP/hysop/fields/vector.py
@@ -144,6 +144,10 @@ class VectorField(DiscreteField):
             s = str(self.data[0].shape) + "."
         return s + "\n"
 
+    def zero(self):
+        for dim in xrange(self.dimension):
+            self.data[dim][...] = 0.0
+
 if __name__ == "__main__":
     print __doc__
     print "- Provided class : VectorField"
diff --git a/HySoP/hysop/operator/monitors/printer.py b/HySoP/hysop/operator/monitors/printer.py
index 1ca7e7f75..c0c94b6ce 100644
--- a/HySoP/hysop/operator/monitors/printer.py
+++ b/HySoP/hysop/operator/monitors/printer.py
@@ -16,7 +16,7 @@ class Printer(Monitoring):
     Performs outputs in VTK images.
     """
 
-    def __init__(self, frequency=0, fields=[], prefix='./out_', ext='.vtk'):
+    def __init__(self, frequency=0, fields=[], prefix=None, ext=None):
         """
         Create a results printer for given fields, filename
         prefix (relative path) and an output frequency.
@@ -28,9 +28,15 @@ class Printer(Monitoring):
         """
         Monitoring.__init__(self, fields, frequency)
         ## output file name prefix
-        self.prefix = prefix + str(main_rank)
+        if prefix is None:
+            self.prefix = './out_' + str(main_rank)
+        else:
+            self.prefix = prefix + str(main_rank)
         ## Extension for filename
-        self.ext = ext
+        if ext is None:
+            self.ext = '.vtk'
+        else:
+            self.ext = ext
         ## Method to collect data in case of distributed data
         self.get_data_method = None
         if self.freq != 0:
diff --git a/HySoP/hysop/problem/problem.py b/HySoP/hysop/problem/problem.py
index 758bcbdda..2d7e841ac 100644
--- a/HySoP/hysop/problem/problem.py
+++ b/HySoP/hysop/problem/problem.py
@@ -7,6 +7,8 @@ from parmepy.constants import debug
 from parmepy import __VERBOSE__
 from parmepy.operator.monitors.monitoring import Monitoring
 from parmepy.operator.redistribute import Redistribute
+from parmepy.mpi import main_rank
+from scitools.NumPyDB import NumPyDB_cPickle
 
 
 class Problem(object):
@@ -156,6 +158,34 @@ class Problem(object):
 
         return s
 
+    def dump(self, filename):
+        """
+        Serialize some data of the problem to file
+        (only data required for a proper restart, namely fields in self.input
+        and simulation).
+        @param filename : prefix for output file. Real name = filename_rk_N,
+        N being current process number.
+        """
+        filedump = filename + '_rk_' + str(main_rank)
+        db = NumPyDB_cPickle(filedump, mode='store')
+        db.dump(self.simulation, 'simulation')
+        for v in self.input:
+            v.dump(filename, mode='append')
+
+    def restart(self, filename):
+        """
+        Load serialized data to restart from a previous state.
+        self.input variables and simulation are loaded.
+        @param  filename : prefix for downloaded file.
+        Real name = filename_rk_N, N being current process number.
+        """
+        filedump = filename + '_rk_' + str(main_rank)
+        db = NumPyDB_cPickle(filedump, mode='load')
+        self.simulation = db.load('simulation')[0]
+        self.simulation.reset()
+        for v in self.input:
+            v.load(filename)
+
 if __name__ == "__main__":
     print __doc__
     print "- Provided class : problem"
diff --git a/HySoP/hysop/problem/simulation.py b/HySoP/hysop/problem/simulation.py
index 80b55bb2f..9b3d48483 100644
--- a/HySoP/hysop/problem/simulation.py
+++ b/HySoP/hysop/problem/simulation.py
@@ -11,12 +11,12 @@ class Simulation(object):
     Setup for simulation parameters.
     """
 
-    def __init__(self, tinit=0.0, tend=1.0, nbiter=10):
+    def __init__(self, tinit=0.0, tend=1.0, timeStep=0.01, iterMax=100):
         """
         Creates a Timer.
 
         @param t_end : Simulation final time.
-        @param dt : Time step.
+        @param timeStep : Time step.
         @param t_init : Simulation starting time.
         """
         ## Simulation final time
@@ -29,16 +29,16 @@ class Simulation(object):
         self.isOver = False
         ## Iteration counter
         self.currentIteration = 0
-        ## Number of iteration
-        self.nbiter = nbiter
         ## Simulation time step
-        self.timeStep = (self.end - self.start) / float(self.nbiter)
+        self.timeStep = timeStep
+        ## Maximum number of iterations
+        self.iterMax = iterMax
 
     def advance(self):
         """
         Proceed to next time
         """
-        if self.currentIteration < self.nbiter:
+        if self.currentIteration < self.iterMax and self.time < self.end:
             self.currentIteration += 1
             self.time += self.timeStep
         else:
@@ -52,10 +52,20 @@ class Simulation(object):
             print "==== Iteration : {0:3d}   t={1:6.3f} ====".format(
                 self.currentIteration, self.time)
 
+    def reset(self):
+        """
+        set initial time to current and reset iteration counter.
+        Used to initialize simulation after a restart.
+        """
+        self.start = self.time
+        self.isOver = False
+        self.currentIteration = 0
+
     def __str__(self):
         s = "Simulation parameters : "
-        s += "start from " + str(self.start) + ' to ' + str(self.end)
-        s += ', time step = ' + str(self.timeStep)
-        s += ', current time : ' + str(self.time) + ' iteration number :'
-        s += str(self.currentIteration)
+        s += "from " + str(self.start) + ' to ' + str(self.end)
+        s += ', time step : ' + str(self.timeStep)
+        s += ', current time : ' + str(self.time) + ', iteration number : '
+        s += str(self.currentIteration) + ', max number of iterations : '
+        s += str(self.iterMax)
         return s
-- 
GitLab