From c79617fbf49c06b35de8a6aa95187f91b20e5c30 Mon Sep 17 00:00:00 2001
From: Jean-Baptiste Keck <jean-baptiste.keck@imag.fr>
Date: Tue, 13 Mar 2018 23:17:22 +0100
Subject: [PATCH] dump times implementation

---
 examples/example_utils.py           |  12 ++-
 examples/shear_layer/shear_layer.py |  10 ++-
 hysop/operator/hdf_io.py            |  20 +++--
 hysop/simulation.py                 | 112 +++++++++++++++++++++-------
 4 files changed, 111 insertions(+), 43 deletions(-)

diff --git a/examples/example_utils.py b/examples/example_utils.py
index 9075a5d95..6467ce32a 100644
--- a/examples/example_utils.py
+++ b/examples/example_utils.py
@@ -521,8 +521,8 @@ class HysopArgParser(argparse.ArgumentParser):
                 dest='dump_freq',
                 help=('HDF5 output frequency in terms of iterations (default=10).' 
                      +' Use 0 to disable frequency based dumping.'))
-        file_io.add_argument('--dump-times', type=str, default=None, 
-                action=self.split, container=tuple, append=True,
+        file_io.add_argument('--dump-times', type=str, default=None, convert=float,
+                action=self.split, container=tuple, append=False,
                 dest='dump_times',
                 help='Comma delimited list of additional HDF5 output times of interest.')
         file_io.add_argument('--cache-dir', type=str, default=None, 
@@ -552,13 +552,19 @@ class HysopArgParser(argparse.ArgumentParser):
             msg=msg.format(args.tstart, args.tend)
             self.error(msg)
 
+        args.dump_times = set(args.dump_times)
+
         msg='args.dump_times = {}\n'.format(args.dump_times)
         for dp in args.dump_times:
             if (dp<args.tstart):
                 msg+='Dump time of interest t={} happens before tstart={}.'
                 msg=msg.format(dp, args.tstart)
                 self.error(msg)
-            elif (dp>args.tend):
+            elif (dp==args.tend):
+                msg+='Dump time of interest t={} happens exactly at tend={}.'
+                msg=msg.format(dp, args.tend)
+                self.error(msg)
+            elif (dp>=args.tend):
                 msg+='Dump time of interest t={} happens after tend={}.'
                 msg=msg.format(dp, args.tend)
                 self.error(msg)
diff --git a/examples/shear_layer/shear_layer.py b/examples/shear_layer/shear_layer.py
index 8139c3d16..374fcb2ff 100644
--- a/examples/shear_layer/shear_layer.py
+++ b/examples/shear_layer/shear_layer.py
@@ -147,8 +147,8 @@ def compute(args):
     simu = Simulation(start=args.tstart, end=args.tend, 
                       nb_iter=args.nb_iter,
                       max_iter=args.max_iter,
-                      dt0=args.dt0, 
-                        t=t, dt=dt)
+                      dt0=args.dt0, times_of_interest=args.dump_times,
+                      t=t, dt=dt)
     simu.write_parameters(t, dt, filename='parameters.txt', precision=4)
     
     # Initialize only the vorticity
@@ -250,9 +250,11 @@ if __name__=='__main__':
 
     parser.set_defaults(impl='cl', ndim=2, npts=(257,),
                         box_origin=(0.0,), box_length=(1.0,), 
-                        tstart=0.0, tend=1.20, 
+                        tstart=0.0, tend=1.25, 
                         nb_iter=None, dt=None, 
                         dt0=1e-4, cfl=0.5, lcfl=0.125,
-                        case=1, dump_times=(0.8, 1.20))
+                        case=1, 
+                        dump_freq=0,
+                        dump_times=(0.8, 1.20))
 
     parser.run(compute)
diff --git a/hysop/operator/hdf_io.py b/hysop/operator/hdf_io.py
index b2534545a..992cd733b 100755
--- a/hysop/operator/hdf_io.py
+++ b/hysop/operator/hdf_io.py
@@ -258,14 +258,12 @@ class HDF_Writer(HDF_IO):
     def apply(self, simulation=None, **kwds):
         if simulation is None:
             raise ValueError("Missing simulation value for monitoring.")
-        if (self.io_params.frequency>0): 
-            ite = simulation.current_iteration
-            should_dump = (ite == -1)
-            should_dump |= (ite % self.io_params.frequency == 0)
-            should_dump |= simulation._next_is_last
-            if should_dump:
-                self.step(simulation)
-                self._count += 1
+        ite = simulation.current_iteration
+        should_dump  = (self.io_params.frequency>0) and (ite % self.io_params.frequency == 0)
+        should_dump |= simulation.is_time_of_interest
+        if should_dump:
+            self.step(simulation)
+            self._count += 1
 
     def _setup_grid_template(self):
         topo = self.topology
@@ -360,14 +358,14 @@ class HDF_Writer(HDF_IO):
         
         # Collect datas required to write the xdmf file
         # --> add tuples (counter, time).
-        if (simu.time == self._last_written_time):
+        if (simu.t() == self._last_written_time):
             msg = 'You cannot write two hdf files for the same '
             msg += '(time, var) set. '
             msg += 'If you want to save a field two times for '
             msg += 'a single time value, please use two hdf_writer operators.'
             raise RuntimeError(msg)
-        self._xdmf_data_files.append((self._count, simu.time))
-        self._last_written_time = simu.time
+        self._xdmf_data_files.append((self._count, simu.t()))
+        self._last_written_time = simu.t()
 
         self._hdf_file.close()
 
diff --git a/hysop/simulation.py b/hysop/simulation.py
index d163cc22b..4f62c4d06 100644
--- a/hysop/simulation.py
+++ b/hysop/simulation.py
@@ -33,7 +33,7 @@ from hysop import dprint, vprint
 from hysop.deps import sys, os
 from hysop.constants import HYSOP_REAL
 from hysop.parameters.scalar_parameter import ScalarParameter
-from hysop.tools.types import first_not_None
+from hysop.tools.types import first_not_None, to_set
 from hysop.tools.numpywrappers import npw
 from hysop.tools.io_utils import IO, IOParams
 
@@ -46,7 +46,8 @@ class Simulation(object):
     """
 
     def __init__(self, name=None, start=0.0, end=1.0, nb_iter=None, dt0=None,
-                 max_iter=None, t=None, dt=None, **kwds):
+                 max_iter=None, t=None, dt=None, times_of_interest=None,
+                 **kwds):
         """
         Parameters
         ----------
@@ -67,6 +68,13 @@ class Simulation(object):
         max_iter : int, optional
             Maximum number of iterations allowed.
             Defaults to 1e9.
+        times_of_interest: array-like of float
+            List of times ti where the simulation may
+            modify current timestep to get t=ti.
+            Mainly used by HDF_Writers for precise 
+            time dependent dumping.
+            tstart < ti <= tend
+            Defaults to empty set.
 
         Attributes
         ----------
@@ -82,6 +90,8 @@ class Simulation(object):
             The scalar parameter that represents time.
         time_step: double
             Value of the dt parameter.
+        time_of_interest: float
+            Current simulation time target.
 
         Notes
         -----
@@ -96,17 +106,12 @@ class Simulation(object):
           with self.tkp1 == self.time
 
         """
-        # Simulation final time
         self.end = end
-        # Starting time
         self.start = start
-        # Simulation current time
         self.time = start
-        # Is simulation is terminated
         self.is_over = False
-        # Iteration counter
         self.current_iteration = -1
-        # Number of iterations
+
         if (nb_iter is not None):
             self.nb_iter = nb_iter
             msg = '------------------------------------------------\n'
@@ -150,19 +155,29 @@ class Simulation(object):
             assert isinstance(t, ScalarParameter), type(t)
             assert not t.const, 't cannot be a constant parameter.'
             t.value = start
-        # t + dt
+        
+        # tk+1 = t + dt
         self.tkp1 = start + self.time_step
-        assert self.end > self.start, \
+
+        assert (self.end > self.start), \
             'Final time must be greater than initial time'
         assert (self.start + self.time_step) <= self.end,\
             'start + step is bigger than end.'
 
+        # times of interest
+        times_of_interest = to_set(first_not_None(times_of_interest, []))
+        times_of_interest = tuple(sorted(times_of_interest))
+        for toi in times_of_interest:
+            assert self.start <= toi < self.end, toi
+        self.times_of_interest = times_of_interest
+
         # Internal tolerance for timer
         self.tol = eps
         # True if initialize has been called.
-        self._is_ready = False
-        self._next_is_last = False
-        self._parameters_to_write = []
+        self._is_ready                 = False
+        self._next_is_last             = False
+        self._next_is_time_of_interest = False
+        self._parameters_to_write      = []
 
     def _get_time_step(self):
         """Get current timestep."""
@@ -201,26 +216,57 @@ class Simulation(object):
             return
 
         self.t.set_value(self.tkp1)
+        
+        self.is_time_of_interest = False
+        if (self.target_time_of_interest is not None):
+            if (abs(self.tkp1 - self.target_time_of_interest) <= self.tol):
+                self.next_time_of_interest()
+                self.is_time_of_interest = True
+        
         self.tkp1 = self.t() + self.time_step
         if abs(self.tkp1 - self.end) <= self.tol:
             self._next_is_last = True
-        elif self.tkp1 > self.end:
-            # resize ...
-            dprint('> Next iteration is last iteration, clamping dt.')
+        elif (self.tkp1 > self.end):
+            msg='** Next iteration is last iteration, clamping dt to achieve t={}. **'
+            msg=msg.format(self.end)
+            vprint()
+            self._print_banner(msg)
             self._next_is_last = True
             self.tkp1 = self.end
             self.update_time_step(self.end - self.t())
+        elif (self.target_time_of_interest is not None) and \
+                (self.tkp1 > self.target_time_of_interest):
+            msg='** Next iteration is a time of interest, clamping dt to achieve t={}. **'
+            msg=msg.format(self.target_time_of_interest)
+            vprint()
+            self._print_banner(msg)
+            self.tkp1 = self.target_time_of_interest
+            self.update_time_step(self.target_time_of_interest - self.t())
+
         self.current_iteration += 1
         self.time = self.tkp1
 
-        if self.current_iteration + 2 > self.max_iter:
+        if (self.current_iteration + 2 > self.max_iter):
             msg = '** Next iteration will be the last because max_iter={} will be achieved. **'
             msg=msg.format(self.max_iter)
             vprint()
-            vprint('*'*len(msg))
-            vprint(msg)
-            vprint('*'*len(msg))
+            self._print_banner(msg)
             self._next_is_last = True
+            self.is_time_of_interest = True
+
+    def _print_banner(self, msg):
+        vprint('*'*len(msg))
+        vprint(msg)
+        vprint('*'*len(msg))
+
+    def next_time_of_interest(self):
+        toi_counter       = self.toi_counter
+        times_of_interest = self.times_of_interest
+        if (toi_counter<len(times_of_interest)):
+            self.target_time_of_interest = times_of_interest[toi_counter]
+            self.toi_counter += 1
+        else:
+            self.target_time_of_interest = None
 
     def update_time_step(self, dt):
         """Update time step for the next iteration
@@ -243,11 +289,27 @@ class Simulation(object):
         """(Re)set simulation to initial values
         --> back to iteration 0 and ready to run.
         """
-        self.t.set_value(self.start)
-        self.update_time_step(self._dt0)
-        self.tkp1 = self.start + self.time_step
-        assert self.tkp1 <= self.end
-        if abs(self._dt0 - self.end) <= self.tol:
+        tstart, tend = self.start, self.end
+        times_of_interest = self.times_of_interest
+
+        self.toi_counter = 0
+        self.next_time_of_interest()
+        self.is_time_of_interest = False
+
+        assert ((tend - tstart) >= self.tol)
+        assert (tend >= self.target_time_of_interest >= tstart)
+        if abs(self.target_time_of_interest - tstart) <= self.tol:
+            self.next_time_of_interest()
+            self.is_time_of_interest = True
+        assert (tend >= self.target_time_of_interest > tstart)
+
+        dt0 = min(self._dt0, self.target_time_of_interest-tstart)
+        self.t.set_value(tstart)
+        self.update_time_step(dt0)
+        self.tkp1 = tstart + self.time_step
+        assert self.tkp1 < self.target_time_of_interest <= tend
+
+        if abs(self.tkp1 - self.end) <= self.tol:
             self._next_is_last = True
         else:
             self._next_is_last = False
-- 
GitLab