From 2aa26b8f9bcac9dee243524ec111ef1ab5c4c6df Mon Sep 17 00:00:00 2001 From: Jean-Baptiste Keck <Jean-Baptiste.Keck@imag.fr> Date: Mon, 27 Jul 2020 01:16:59 +0200 Subject: [PATCH] refactored checkpoints --- hysop/core/checkpoints.py | 780 ++++++++++++++++++ hysop/core/tests/test_checkpoint.sh | 6 +- hysop/problem.py | 728 +--------------- hysop/simulation.py | 6 +- hysop_examples/example_utils.py | 60 +- hysop_examples/examples/analytic/analytic.py | 14 +- .../examples/bubble/periodic_bubble.py | 17 +- .../bubble/periodic_bubble_levelset.py | 17 +- .../periodic_bubble_levelset_penalization.py | 17 +- .../examples/bubble/periodic_jet_levelset.py | 17 +- .../examples/cylinder/oscillating_cylinder.py | 17 +- .../examples/fixed_point/heat_equation.py | 3 +- .../flow_around_sphere/flow_around_sphere.py | 15 +- .../multiresolution/scalar_advection.py | 19 +- .../particles_above_salt_bc.py | 18 +- .../particles_above_salt_bc_3d.py | 18 +- .../particles_above_salt_periodic.py | 18 +- .../particles_above_salt_symmetrized.py | 18 +- .../examples/scalar_advection/levelset.py | 20 +- .../scalar_advection/scalar_advection.py | 21 +- .../scalar_diffusion/scalar_diffusion.py | 4 +- .../sediment_deposit/sediment_deposit.py | 18 +- .../sediment_deposit_levelset.py | 18 +- .../examples/shear_layer/shear_layer.py | 6 +- .../examples/taylor_green/taylor_green.py | 4 +- .../taylor_green/taylor_green_cpuFortran.py | 5 +- 26 files changed, 1103 insertions(+), 781 deletions(-) create mode 100644 hysop/core/checkpoints.py diff --git a/hysop/core/checkpoints.py b/hysop/core/checkpoints.py new file mode 100644 index 000000000..518365599 --- /dev/null +++ b/hysop/core/checkpoints.py @@ -0,0 +1,780 @@ +import functools, shutil, operator, os, warnings, shutil, tarfile, uuid +import numpy as np +from hysop.tools.types import check_instance, first_not_None, to_tuple, to_list +from hysop.tools.units import bytes2str, time2str +from hysop.tools.io_utils import IOParams +from hysop.tools.numerics import default_invalid_value +from hysop.tools.string_utils import vprint_banner, vprint +from hysop.core.mpi import Wtime +from hysop.domain.box import Box +from hysop.parameters import ScalarParameter, TensorParameter, BufferParameter +from hysop.fields.cartesian_discrete_field import CartesianDiscreteScalarField + +class CheckpointHandler(object): + def __init__(self, load_checkpoint_path, save_checkpoint_path, + io_params, relax_constraints): + check_instance(load_checkpoint_path, str, allow_none=True) + check_instance(save_checkpoint_path, str, allow_none=True) + check_instance(io_params, IOParams, allow_none=True) + check_instance(relax_constraints, bool) + + self._load_checkpoint_path = load_checkpoint_path + self._save_checkpoint_path = save_checkpoint_path + self._io_params = io_params + self._relax_constraints = relax_constraints + + self._checkpoint_template = None + self._checkpoint_compressor = None + + @property + def load_checkpoint_path(self): + return self._load_checkpoint_path + @property + def save_checkpoint_path(self): + return self._save_checkpoint_path + @property + def io_params(self): + return self._io_params + @property + def relax_constraints(self): + return self._relax_constraints + + def get_mpio_parameters(self, mpi_params): + io_params = self.io_params + comm = mpi_params.comm + io_leader = io_params.io_leader + is_io_leader = (io_leader == mpi_params.rank) + return (io_params, mpi_params, comm, io_leader, is_io_leader) + + def is_io_leader(self, mpi_params): + return (self.io_params.io_leader == mpi_params.rank) + + def finalize(self, mpi_params): + if ((self._checkpoint_template is not None) + and os.path.exists(self._checkpoint_template) + and self.is_io_leader(mpi_params)): + try: + shutil.rmtree(self._checkpoint_template) + except OSError: + pass + self._checkpoint_template = None + self._checkpoint_compressor = None + + def load_checkpoint(self, problem, simulation): + from hysop.problem import Problem + from hysop.simulation import Simulation + check_instance(problem, Problem) + check_instance(simulation, Simulation) + + load_checkpoint_path = self.load_checkpoint_path + if (load_checkpoint_path is None): + return + + vprint('\n>Loading {}problem checkpoint from \'{}\'...'.format( + 'relaxed' if self.relax_constraints else '', load_checkpoint_path)) + if not os.path.exists(load_checkpoint_path): + msg='Failed to load checkpoint \'{}\' because the file does not exist.' + raise RuntimeError(msg.format(load_checkpoint)) + if (self.io_params is None): + msg='Load checkpoint has been set to \'{}\' but checkpoint_io_params has not been specified.' + raise RuntimeError(msg.format(load_checkpoint_path)) + + (io_params, mpi_params, comm, io_leader, is_io_leader) = self.get_mpio_parameters(problem.mpi_params) + start = Wtime() + + # extract checkpoint to directory if required + if os.path.isfile(load_checkpoint_path): + if load_checkpoint_path.endswith('.tar'): + if is_io_leader: + load_checkpoint_dir = os.path.join(os.path.dirname(load_checkpoint_path), + os.path.basename(load_checkpoint_path).replace('.tar', '')) + while os.path.exists(load_checkpoint_dir): + # ok, use another directory name to avoid dataloss... + load_checkpoint_dir = os.path.join(os.path.dirname(load_checkpoint_path), + '{}'.format(uuid.uuid4().hex)) + tf = tarfile.open(load_checkpoint_path, mode='r') + tf.extractall(path=load_checkpoint_dir) + else: + load_checkpoint_dir = None + load_checkpoint_dir = comm.bcast(load_checkpoint_dir, root=io_leader) + should_remove_dir = True + else: + msg='Can only load checkpoint with tar extension, got {}.' + raise NotImplementedError(msg.format(load_checkpoint_path)) + elif os.path.isdir(load_checkpoint_path): + load_checkpoint_dir = load_checkpoint_path + should_remove_dir = False + else: + raise RuntimeError + + # import checkpoint data + self._import_checkpoint(problem, simulation, load_checkpoint_dir) + + if (is_io_leader and should_remove_dir): + shutil.rmtree(load_checkpoint_dir) + + ellapsed = Wtime() - start + msg=' > Successfully imported checkpoint in {}.' + vprint(msg.format(time2str(ellapsed))) + + + # Checkpoint is first exported as a directory containing a hierarchy of arrays (field and parameters data + metadata) + # This folder is than tarred (without any form of compression) so that a checkpoint consists in a single movable file. + # Data is already compressed during data export by the zarr module, using the blosc compressor (snappy, clevel=3). + def save_checkpoint(self, problem, simulation): + save_checkpoint_path = self.save_checkpoint_path + if (save_checkpoint_path is None): + return + + if (self.io_params is None): + msg='Load checkpoint has been set to \'{}\' but checkpoint io_params has not been specified.' + raise RuntimeError(msg.format(load_checkpoint_path)) + + if not self.io_params.should_dump(simulation): + return + + vprint('>Exporting problem checkpoint to \'{}\':'.format(save_checkpoint_path)) + if not save_checkpoint_path.endswith('.tar'): + msg='Can only export checkpoint with tar extension, got {}.' + raise NotImplementedError(msg.format(save_checkpoint_path)) + save_checkpoint_tar = save_checkpoint_path + + (io_params, mpi_params, comm, io_leader, is_io_leader) = self.get_mpio_parameters(problem.mpi_params) + start = Wtime() + + # create a backup of last checkpoint just in case things go wrong + if is_io_leader and os.path.exists(save_checkpoint_tar): + backup_checkpoint_tar = save_checkpoint_tar + '.bak' + if os.path.exists(backup_checkpoint_tar): + os.remove(backup_checkpoint_tar) + os.rename(save_checkpoint_tar, backup_checkpoint_tar) + else: + backup_checkpoint_tar = None + + # determine checkpoint dump directory + if is_io_leader: + save_checkpoint_dir = os.path.join(os.path.dirname(save_checkpoint_tar), + os.path.basename(save_checkpoint_tar).replace('.tar', '')) + while os.path.exists(save_checkpoint_dir): + # ok, use another directory name to avoid dataloss... + save_checkpoint_dir = os.path.join(os.path.dirname(save_checkpoint_tar), + '{}'.format(uuid.uuid4().hex)) + else: + save_checkpoint_dir = None + save_checkpoint_dir = mpi_params.comm.bcast(save_checkpoint_dir, root=io_leader) + + # try to create the checkpoint directory, this is a collective MPI operation + try: + success, reason, nbytes = self._export_checkpoint(problem, simulation, save_checkpoint_dir) + except Exception as e: + raise + success = False + reason = str(e) + success = comm.allreduce(int(success)) == comm.size + + # Compress checkpoint directory to tar (easier to copy/move between clusters) + # Note that there is no effective compression here, zarr already compressed field/param data + if success and is_io_leader and os.path.isdir(save_checkpoint_dir): + try: + with tarfile.open(save_checkpoint_tar, 'w') as tf: + for (root, dirs, files) in os.walk(save_checkpoint_dir): + for f in files: + fpath = os.path.join(root, f) + tf.add(fpath, arcname=fpath.replace(save_checkpoint_dir+os.path.sep,'')) + + if os.path.isfile(save_checkpoint_tar): + shutil.rmtree(save_checkpoint_dir) + else: + raise RuntimeError('Could not tar checkpoint datadir.') + + ellapsed = Wtime() - start + effective_nbytes = os.path.getsize(save_checkpoint_tar) + compression_ratio = max(1.0, float(nbytes)/effective_nbytes) + + msg=' > Successfully exported checkpoint in {} with a compression ratio of {:.1f} ({}).' + vprint(msg.format(time2str(ellapsed), compression_ratio, bytes2str(effective_nbytes))) + except Exception as e: + success = False + reason = str(e) + success = comm.allreduce(int(success)) == comm.size + + if success: + if (backup_checkpoint_tar is not None) and os.path.isfile(backup_checkpoint_tar) and is_io_leader: + os.remove(backup_checkpoint_tar) + return + + from hysop.tools.warning import HysopDumpWarning + msg='Failed to export checkpoint because: {}.'.format(reason) + warnings.warn(msg, HysopDumpWarning) + + # Something went wrong (I/O error or other) so we rollback to previous checkpoint (if there is one) + vprint(' | An error occured during checkpoint creation, rolling back to previous checkpoint...') + if is_io_leader: + if os.path.exists(save_checkpoint_dir): + shutil.rmtree(save_checkpoint_dir) + if os.path.exists(save_checkpoint_tar): + os.remove(save_checkpoint_tar) + if (backup_checkpoint_tar is not None) and os.path.exists(backup_checkpoint_tar): + os.rename(backup_checkpoint_tar, save_checkpoint_tar) + + + def create_checkpoint_template(self, problem, simulation): + # Create groups of arrays on disk (only hierarchy and array metadata is stored in the template) + # /!\ ZipStores are not safe from multiple processes so we use a DirectoryStore + # that can then be tarred manually by io_leader. + + save_checkpoint_path = self.save_checkpoint_path + if (save_checkpoint_path is None): + return + + if not save_checkpoint_path.endswith('.tar'): + msg='Can only export checkpoint with tar extension, got {}.' + raise NotImplementedError(msg.format(save_checkpoint_path)) + + (io_params, mpi_params, comm, io_leader, is_io_leader) = self.get_mpio_parameters(problem.mpi_params) + + # determine an empty directory for the template + if is_io_leader: + checkpoint_template = os.path.join(os.path.dirname(save_checkpoint_path), + os.path.basename(save_checkpoint_path).replace('.tar', '.template')) + while os.path.exists(checkpoint_template): + # ok, use another directory name to avoid dataloss... + checkpoint_template = os.path.join(os.path.dirname(save_checkpoint_path), + '{}'.format(uuid.uuid4().hex)) + else: + checkpoint_template = None + checkpoint_template = comm.bcast(checkpoint_template, root=io_leader) + self._checkpoint_template = checkpoint_template + + vprint('\n>Creating checkpoint template as \'{}\'...'.format(checkpoint_template)) + import zarr + from numcodecs import blosc, Blosc + + # array data compressor + blosc.use_threads = (mpi_params.size == 1) # disable threads for multiple processes (can deadlock) + compressor = Blosc(cname='snappy', clevel=3, shuffle=Blosc.BITSHUFFLE) + self._checkpoint_compressor = compressor + + # io_leader creates a directory layout on (hopefully) shared filesystem + if is_io_leader: + if os.path.exists(checkpoint_template): + shutil.rmtree(checkpoint_template) + store = zarr.DirectoryStore(path=checkpoint_template) + root = zarr.open_group(store=store, mode='w', path='data') + params_group = root.create_group('params') + fields_group = root.create_group('fields') + simu_group = root.create_group('simulation') + else: + store = None + root = None + params_group = None + fields_group = None + simu_group = None + + # count number of total data bytes without compression + nbytes = 0 + fmt_key = self._format_zarr_key + + # Generate parameter arrays + # Here we expect that each process store parameters that are in sync + # For each parameter we assume that the same values are broadcast to all processes + # even if is not enforced by the library (should cover most current use cases...) + for param in sorted(problem.parameters, key=operator.attrgetter('name')): + if not is_io_leader: + continue + if isinstance(param, (ScalarParameter, TensorParameter, BufferParameter)): + # all those parameters store their data in a numpy ndarray so we're good + assert isinstance(param._value, np.ndarray), type(param._value) + value = param._value + array = params_group.create_dataset(name=fmt_key(param.name), + overwrite=False, data=None, synchronizer=None, + compressor=compressor, shape=value.shape, chunks=None, + dtype=value.dtype, fill_value=default_invalid_value(value.dtype)) + array.attrs['kind'] = param.__class__.__name__ + nbytes += value.nbytes + else: + msg = 'Cannot export parameter of type {}.'.format(param.__class__.__name__) + raise NotImplementedError(msg) + + # Generate discrete field arrays + # Here we assume that each process has a non-empty chunk of data + for field in sorted(problem.fields, key=operator.attrgetter('name')): + + # we do not care about fields discretized only on temporary fields + if all(df.is_tmp for df in field.discrete_fields.values()): + continue + + if is_io_leader: + field_group = fields_group.create_group(fmt_key(field.name)) + else: + field_group = None + + dim = field.dim + domain = field.domain._domain + + if isinstance(domain, Box): + if (field_group is not None): + field_group.attrs['domain'] = 'Box' + field_group.attrs['dim'] = domain.dim + field_group.attrs['origin'] = to_tuple(domain.origin) + field_group.attrs['end'] = to_tuple(domain.end) + field_group.attrs['length'] = to_tuple(domain.length) + else: + # for now we just handle Boxed domains + raise NotImplementedError + + for (k, topo) in enumerate(sorted(field.discrete_fields, key=operator.attrgetter('full_tag'))): + dfield = field.discrete_fields[topo] + mesh = topo.mesh._mesh + + # we do not care about temporary fields + if dfield.is_tmp: + continue + + if not isinstance(dfield, CartesianDiscreteScalarField): + # for now we just handle CartesianDiscreteScalarFields. + raise NotImplementedError + + global_resolution = topo.global_resolution # logical grid size + grid_resolution = topo.grid_resolution # effective grid size + ghosts = topo.ghosts + + # get local resolutions exluding ghosts + compute_resolutions = comm.gather(to_tuple(mesh.compute_resolution), root=io_leader) + + # is the current process handling a right boundary data block on a distributed axe ? + is_at_right_boundary = (mesh.is_at_right_boundary*(mesh.proc_shape>1)).any() + is_at_right_boundary = np.asarray(comm.gather(is_at_right_boundary, root=io_leader)) + + if not is_io_leader: + continue + + # io_leader can now determine wether the cartesian discretization is uniformly distributed + # between processes or not + inner_compute_resolutions = tuple(compute_resolutions[i] for i in range(len(compute_resolutions)) + if not is_at_right_boundary[i]) + grid_is_uniformly_distributed = all(res == inner_compute_resolutions[0] + for res in inner_compute_resolutions) + grid_is_uniformly_distributed |= (topo.mpi_params.size == 1) + + if grid_is_uniformly_distributed: + # We divide the array in 'compute_resolution' chunks, no sychronization is required. + # Here there is no need to use the process locker to write this array data. + # Each process writes its own independent block of data of size 'compute_resolution'. + should_sync = False + chunks = inner_compute_resolutions[0] + else: + # We divide the array in >=1MB chunks (chunks are given in terms of elements) + # Array chunks may overlap different processes so we need interprocess sychronization (slow) + should_sync = True + if dim == 1: + chunks = 1024*1024 # at least 1MB / chunk + elif dim == 2: + chunks = (1024,1024) # at least 1MB / chunk + elif dim == 3: + chunks = (64,128,128) # at least 1MB / chunk + else: + raise NotImplementedError(dim) + + # Create array (no memory is allocated here, even on disk because data blocks are empty) + dtype = dfield.dtype + shape = grid_resolution + + # We scale the keys up to 100 topologies, which seams to be a pretty decent upper limit + # on a per field basis. + array = field_group.create_dataset(name='topo_{:02d}'.format(k), + overwrite=False, data=None, synchronizer=None, + compressor=compressor, shape=shape, chunks=chunks, + dtype=dtype, fill_value=default_invalid_value(dtype)) + array.attrs['should_sync'] = should_sync + + # We cannot rely on discrete mesh name because of topology names + # so we save some field metadata to be able to differentiate between + # discrete fields with the exact same grid resolution. + # proc_shape and name are used in last resort to differentiate discrete fields. + array.attrs['lboundaries'] = to_tuple(map(str, mesh.global_lboundaries)) + array.attrs['rboundaries'] = to_tuple(map(str, mesh.global_rboundaries)) + array.attrs['ghosts'] = to_tuple(mesh.ghosts) + array.attrs['proc_shape'] = to_tuple(mesh.proc_shape) + array.attrs['name'] = dfield.name + + nbytes += np.prod(shape, dtype=np.int64) * dtype.itemsize + + if (root is not None): + root.attrs['nbytes'] = nbytes + msg=' => Maximum checkpoint size will be {}, without compression and metadata.' + vprint(root.tree()) + vprint(msg.format(bytes2str(nbytes))) + + # some zarr store formats require a final close to flush data + try: + if (root is not None): + root.close() + except AttributeError: + pass + + + def _export_checkpoint(self, problem, simulation, save_checkpoint_dir): + # Given a template, fill field and parameters data from all processes. + # returns (bool, msg) where bool is True on success + (io_params, mpi_params, comm, io_leader, is_io_leader) = self.get_mpio_parameters(problem.mpi_params) + + # checkpoint template may have been deleted by user during simulation + if (self._checkpoint_template is None) or (not os.path.isdir(self._checkpoint_template)): + self.create_checkpoint_template(problem) + checkpoint_template = self._checkpoint_template + checkpoint_compressor = self._checkpoint_compressor + + if is_io_leader: + if os.path.exists(save_checkpoint_dir): + shutil.rmtree(save_checkpoint_dir) + shutil.copytree(checkpoint_template, save_checkpoint_dir) + comm.Barrier() + + if not os.path.isdir(save_checkpoint_dir): + msg='Could not find checkpoint directory \'{}\'. Are you using a network file system ?'.format(save_checkpoint_dir) + raise RuntimeError(msg) + + #Every process now loads the same dataset template + import zarr + try: + store = zarr.DirectoryStore(save_checkpoint_dir) + root = zarr.open_group(store=store, mode='r+', synchronizer=None, path='data') + fields_group = root['fields'] + params_group = root['params'] + simu_group = root['simulation'] + nbytes = root.attrs['nbytes'] + except: + msg='A fatal error occured during checkpoint export, checkpoint template may be illformed.' + vprint(msg) + vprint() + raise + + fmt_key = self._format_zarr_key + + # Export simulation data + if is_io_leader: + simulation.save_checkpoint(simu_group, mpi_params, io_params, checkpoint_compressor) + + # Currently there is no distributed parameter capabilities so io_leader has to dump all parameters + if is_io_leader: + msg = ' | dumping parameters...' + vprint(msg) + for param in sorted(problem.parameters, key=operator.attrgetter('name')): + if isinstance(param, (ScalarParameter, TensorParameter, BufferParameter)): + array = params_group[fmt_key(param.name)] + assert array.attrs['kind'] == param.__class__.__name__ + assert array.dtype == param._value.dtype + assert array.shape == param._value.shape + array[...] = param._value + else: + msg = 'Cannot dump parameter of type {}.'.format(param.__class__.__name__) + raise NotImplementedError(msg) + + # Unlike parameter all processes participate for fields + for field in sorted(problem.fields, key=operator.attrgetter('name')): + + # we do not care about fields discretized only on temporary fields + if all(df.is_tmp for df in field.discrete_fields.values()): + continue + + msg = ' | dumping field {}...'.format(field.pretty_name) + vprint(msg) + + field_group = fields_group[fmt_key(field.name)] + for (k, topo) in enumerate(sorted(field.discrete_fields, key=operator.attrgetter('full_tag'))): + dfield = field.discrete_fields[topo] + mesh = topo.mesh._mesh + + # we do not care about temporary fields + if dfield.is_tmp: + continue + + dataset = 'topo_{:02d}'.format(k) # key has to match template + array = field_group[dataset] + should_sync = array.attrs['should_sync'] + + assert dfield.nb_components == 1 + assert (array.shape == mesh.grid_resolution).all(), (array.shape, mesh.grid_resolution) + assert array.dtype == dfield.dtype, (array.dtype, dfield.dtype) + + if should_sync: + # Should not be required untill we allow non-uniform discretizations + global_start = mesh.global_start + global_stop = mesh.global_stop + raise NotImplementedError('Synchronized multiprocess write has not been implemented yet.') + else: + assert ((mesh.compute_resolution == array.chunks).all() + or (mesh.is_at_right_boundary*(mesh.proc_shape>1)).any()) + local_data = dfield.compute_data[0].get() + global_slices = mesh.global_compute_slices + array[global_slices] = local_data # ok, every process writes to an independent data blocks + + # Some zarr store formats require a final close to flush data + try: + root.close() + except AttributeError: + pass + + return True, None, nbytes + + + # On data import, there is no need to synchronize read-only arrays + # so we are good with multiple processes reading overlapping data blocks + def _import_checkpoint(self, problem, simulation, load_checkpoint_dir): + + (io_params, mpi_params, comm, io_leader, is_io_leader) = self.get_mpio_parameters(problem.mpi_params) + mpi_params.comm.Barrier() + + if not os.path.isdir(load_checkpoint_dir): + msg='Could not find checkpoint directory \'{}\'. Are you using a network file system ?'.format(load_checkpoint_dir) + raise RuntimeError(msg) + + import zarr + store = zarr.DirectoryStore(load_checkpoint_dir) + try: + root = zarr.open_group(store=store, mode='r', synchronizer=None, path='data') + params_group = root['params'] + fields_group = root['fields'] + simu_group = root['simulation'] + except: + msg='A fatal error occured during checkpoint import, checkpoint data may be illformed.' + vprint(msg) + vprint() + raise + + # Define helper functions + relax_constraints = self.relax_constraints + raise_error = self._raise_error + if relax_constraints: + raise_warning = self._raise_warning + else: + raise_warning = self._raise_error + load_array_data = functools.partial(self._load_array_data, on_mismatch=raise_warning) + fmt_key = self._format_zarr_key + + # Import simulation data after parameters are up to date + msg = ' | importing simulation...' + vprint(msg) + simulation.load_checkpoint(simu_group, mpi_params, io_params, relax_constraints) + + # Import parameters, hopefully parameter names match the ones in the checkpoint + msg = ' | importing parameters...' + vprint(msg) + for param in sorted(problem.parameters, key=operator.attrgetter('name')): + key = fmt_key(param.name) + + if (key not in params_group): + msg='Checkpoint directory \'{}\' does not contain any data regarding to parameter {}' + msg=msg.format(load_checkpoint_dir, param.name) + raise_error(msg) + + array = params_group[key] + + if array.attrs['kind'] != param.__class__.__name__: + msg='Parameter kind do not match with checkpointed parameter {}, loaded kind {} but expected {}.' + msg=msg.format(param.name, array.attrs['kind'], param.__class__.__name__) + raise_error(msg) + + if isinstance(param, (ScalarParameter, TensorParameter, BufferParameter)): + value = param._value + + if (array.shape != value.shape): + msg='Parameter shape does not match with checkpointed parameter {}, loaded shape {} but expected {}.' + msg=msg.format(param.name, array.shape, value.shape) + raise_error(msg) + + if (array.dtype != value.dtype): + msg='Parameter datatype does not match with checkpointed parameter {}, loaded dtype {} but expected {}.' + msg=msg.format(param.name, array.dtype, value.dtype) + raise_warning(msg) + + value[...] = array[...] + else: + msg = 'Cannot import parameter of type {}.'.format(param.__class__.__name__) + raise NotImplementedError(msg) + + # Import discrete fields, this is a bit more tricky because topologies or simply topology + # names can change. Moreover there is currently no waranty that the same operator graph is + # generated for the exact same problem configuration each time. We just emit user warnings + # if we find a way to match topologies that do not match exactly checkpointed ones. + for field in sorted(problem.fields, key=operator.attrgetter('name')): + domain = field.domain._domain + + # we do not care about fields discretized only on temporary fields + if all(df.is_tmp for df in field.discrete_fields.values()): + continue + + msg = ' | importing field {}...'.format(field.pretty_name) + vprint(msg) + + field_key = fmt_key(field.name) + if (field_key not in fields_group): + msg='Checkpoint directory \'{}\' does not contain any data regarding to field {}' + msg=msg.format(load_checkpoint_dir, field.name) + raise_error(msg) + + field_group = fields_group[field_key] + + # check that domain matches + if field_group.attrs['domain'] != domain.__class__.__name__: + msg='Domain kind does not match with checkpointed field {}, loaded kind {} but expected {}.' + msg=msg.format(field.name, field_group.attrs['domain'], domain.__class__.__name__) + raise_error(msg) + if field_group.attrs['dim'] != domain.dim: + msg='Domain dim does not match with checkpointed field {}, loaded dim {} but expected {}.' + msg=msg.format(field.name, field_group.attrs['dim'], domain.dim) + raise_error(msg) + if field_group.attrs['origin'] != to_list(domain.origin): + msg='Domain origin does not match with checkpointed field {}, loaded origin {} but expected {}.' + msg=msg.format(field.name, field_group.attrs['origin'], domain.origin) + raise_error(msg) + if field_group.attrs['end'] != to_list(domain.end): + msg='Domain end does not match with checkpointed field {}, loaded end {} but expected {}.' + msg=msg.format(field.name, field_group.attrs['end'], domain.end) + raise_error(msg) + if field_group.attrs['length'] != to_list(domain.length): + msg='Domain length does not match with checkpointed field {}, loaded length {} but expected {}.' + msg=msg.format(field.name, field_group.attrs['length'], domain.length) + raise_error(msg) + + for (k, topo) in enumerate(sorted(field.discrete_fields, key=operator.attrgetter('full_tag'))): + dfield = field.discrete_fields[topo] + mesh = topo.mesh._mesh + + # we do not care about temporary fields + if dfield.is_tmp: + continue + + # for now we just handle CartesianDiscreteScalarFields. + if not isinstance(dfield, CartesianDiscreteScalarField): + raise NotImplementedError + + # first we need to exactly match global grid resolution + candidates = tuple(filter(lambda d: np.equal(d.shape, mesh.grid_resolution).all(), field_group.values())) + if len(candidates)==0: + msg='Could not find any topology with shape {} for field {}, available discretizations are: {}.' + msg=msg.format(to_tuple(mesh.grid_resolution), field.name, + ', '.join(set(str(d.shape) for d in field_group.values()))) + raise_error(msg) + elif len(candidates)==1: + load_array_data(candidates[0], dfield) + continue + + # Here multiple topologies have the extact same grid resolution so we try to match boundary conditions + old_candidates = candidates + candidates = tuple(filter(lambda d: d.attrs['lboundaries'] == to_tuple(map(str, mesh.global_lboundaries)), candidates)) + candidates = tuple(filter(lambda d: d.attrs['rboundaries'] == to_tuple(map(str, mesh.global_rboundaries)), candidates)) + if len(candidates)==0: + # ok, the user changed the boundary conditions, we ignore boundary condition information + candidates = old_candidates + elif len(candidates)==1: + load_array_data(candidates[0], dfield) + continue + + # From now on multiple topologies have the same grid resolution and boundary conditions + # We try to match exact ghost count, user did likely not change the order of the methods. + old_candidates = candidates + candidates = tuple(filter(lambda d: d.attrs['ghosts'] == to_tuple(mesh.ghosts), candidates)) + if len(candidates)==0: + # ok, the user made a change that affected ghosts, we ignore the ghost condition + candidates = old_candidates + elif len(candidates)==1: + load_array_data(candidates[0], dfield) + continue + + # Now we try to differentiate by using zero ghost info (ghosts may change with method order, but zero-ghost is very specific) + # Topology containing zero ghost layer usually target Fortran topologies for FFT operators or method that do not require any ghosts. + old_candidates = candidates + candidates = tuple(filter(lambda d: (np.equal(d.attrs['ghosts'],0) == (mesh.ghosts==0)).all(), candidates)) + if len(candidates)==0: + # ok, we ignore the zero-ghost condition + candidates = old_candidates + elif len(candidates)==1: + load_array_data(candidates[0], dfield) + continue + + # Now we try to match exact topology shape (the MPICart grid of processes) + # We try this late because use may run the simulation again with a different number of processes. + old_candidates = candidates + candidates = tuple(filter(lambda d: d.attrs['proc_shape'] == to_tuple(mesh.proc_shape), candidates)) + if len(candidates)==0: + # ok, we ignore the proc shape + candidates = old_candidates + elif len(candidates)==1: + load_array_data(candidates[0], dfield) + continue + + # Now we try to differentiate by using topo splitting info (axes on which data is distributed) + # This again is very specific and can differentiate topologies used for spectral transforms. + old_candidates = candidates + candidates = tuple(filter(lambda d: (np.greater(d.attrs['proc_shape'],1) == (mesh.proc_shape>1)).all(), candidates)) + if len(candidates)==0: + # ok, we ignore the MPI data splitting condition + candidates = old_candidates + elif len(candidates)==1: + load_array_data(candidates[0], dfield) + continue + + # Ok now, our last hope is to match the discrete field name + old_candidates = candidates + candidates = tuple(filter(lambda d: d.attrs['name'] == dfield.name, candidates)) + if len(candidates)==0: + # ok, we ignore the name + candidates = old_candidates + elif len(candidates)==1: + load_array_data(candidates[0], dfield) + continue + + assert len(candidates) > 1, 'Something went wrong.' + + msg='Could not discriminate checkpointed topologies for field {}, got {} candidates remaining.' + msg=msg.format(field.name, len(candidates)) + raise_error(msg) + + + @staticmethod + def _load_array_data(array, dfield, on_mismatch): + mesh = dfield.mesh._mesh + assert np.equal(array.shape, mesh.grid_resolution).all() + + # compare attributes but ignore name because this can be annoying + attr_names = ('left boundaries', 'right boundaries', 'ghost layers', 'process shape', 'datatype') + array_attributes = (array.attrs['lboundaries'], array.attrs['rboundaries'], array.attrs['ghosts'], + array.attrs['proc_shape'], array.dtype) + dfield_attributes = (list(map(str, mesh.global_lboundaries)), list(map(str, mesh.global_rboundaries)), + list(mesh.ghosts), list(mesh.proc_shape)) + + for (name,lhs,rhs) in zip(attr_names, array_attributes, dfield_attributes): + if lhs==rhs: + continue + msg='{} do not match with checkpointed field {}, loaded {} {} but expected {}.' + msg=msg.format(name, dfield.field.name, name, lhs, rhs) + on_mismatch(msg) + + global_slices = mesh.global_compute_slices + data = np.asarray(array[global_slices], dtype=dfield.dtype) + dfield.compute_data[0][...] = data + dfield.exchange_ghosts() + + @staticmethod + def _raise_error(msg): + vprint(' | error: {}\n'.format(msg)) + vprint() + err = 'FATAL ERROR: Failed to import checkpoint, because the following error occured: {}.' + raise RuntimeError(err.format(msg)) + + @staticmethod + def _raise_warning(msg): + msg = ' | warning: {}'.format(msg) + vprint(msg) + + @staticmethod + def _format_zarr_key(k): + # note keys that contains the special characters '/' and '\' do not work well with zarr + # so we need to replace it by another character such as '_'. + # We cannot use utf8 characters such as u+2215 (division slash). + if (k is None): + return None + return k.replace('/', '_').replace('\\', '_') + diff --git a/hysop/core/tests/test_checkpoint.sh b/hysop/core/tests/test_checkpoint.sh index 059d2c1f4..b53952996 100755 --- a/hysop/core/tests/test_checkpoint.sh +++ b/hysop/core/tests/test_checkpoint.sh @@ -48,7 +48,7 @@ echo ' Running simulations...' echo ' Comparing solutions...' echo " >debug dumps match" compare_files "${TEST_DIR}/run0/dump/run.txt" "${TEST_DIR}/run1/dump/run.txt" -for f0 in $(find "${TEST_DIR}/run0" -name '*.h5'); do +for f0 in $(find "${TEST_DIR}/run0" -name '*.h5' | sort -n); do f1=$(echo "${f0}" | sed 's/run0/run1/') compare_files "${f0}" "${f1}" echo " >$(basename ${f0}) match" @@ -95,7 +95,7 @@ mpirun -np 4 "${PYTHON_EXECUTABLE}" "${EXAMPLE_FILE}" ${COMMON_OPTIONS} -S "${TE echo ' Comparing solutions...' echo " >debug dumps match" compare_files "${TEST_DIR}/run0/dump/run.txt" "${TEST_DIR}/run1/dump/run.txt" -for f0 in $(find "${TEST_DIR}/run0" -name '*.h5'); do +for f0 in $(find "${TEST_DIR}/run0" -name '*.h5' | sort -n); do f1=$(echo "${f0}" | sed 's/run0/run1/') compare_files "${f0}" "${f1}" echo " >$(basename ${f0}) match" @@ -152,6 +152,7 @@ for f0 in $(find "${TEST_DIR}/run0" -name '*.h5' | sort -n); do done echo ' Running simulations from checkpoints using different MPI topologies...' +COMMON_OPTIONS="-NC -d24 --tend 0.3 --dump-tstart 0.15 --dump-freq 1 --hdf5-disable-slicing --hdf5-disable-compression --checkpoint-relax-constraints" mpirun -np 3 "${PYTHON_EXECUTABLE}" "${EXAMPLE_FILE}" ${COMMON_OPTIONS} -impl fortran -cp fp64 -L "${TEST_DIR}/checkpoint0.tar" --dump-dir "${TEST_DIR}/run3" mpirun -np 2 "${PYTHON_EXECUTABLE}" "${EXAMPLE_FILE}" ${COMMON_OPTIONS} -impl fortran -cp fp64 -L "${TEST_DIR}/checkpoint1.tar" --dump-dir "${TEST_DIR}/run4" mpirun -np 1 "${PYTHON_EXECUTABLE}" "${EXAMPLE_FILE}" ${COMMON_OPTIONS} -impl fortran -cp fp64 -L "${TEST_DIR}/checkpoint2.tar" --dump-dir "${TEST_DIR}/run5" @@ -176,3 +177,4 @@ for f0 in $(find "${TEST_DIR}/run0" -name '*.h5' | sort -n); do h5diff -d '5e-5' "${f6}" "${f7}" echo " >$(basename ${f0}) match" done + diff --git a/hysop/problem.py b/hysop/problem.py index e11679584..65faf8aa4 100644 --- a/hysop/problem.py +++ b/hysop/problem.py @@ -1,21 +1,15 @@ from __future__ import absolute_import -import operator, os, sys, datetime, warnings, shutil, tarfile, uuid -import numpy as np +import sys, datetime from hysop.constants import Backend, MemoryOrdering -from hysop.tools.types import first_not_None, to_tuple, to_list +from hysop.tools.types import check_instance, first_not_None, to_tuple, to_list from hysop.tools.string_utils import vprint_banner from hysop.tools.contexts import Timer from hysop.tools.decorators import debug from hysop.tools.parameters import MPIParams -from hysop.tools.numerics import default_invalid_value -from hysop.tools.units import bytes2str, time2str +from hysop.core.checkpoints import CheckpointHandler from hysop.core.graph.computational_graph import ComputationalGraph from hysop.tools.string_utils import vprint_banner, vprint -from hysop.core.mpi import main_rank, main_size, main_comm, Wtime -from hysop.fields.cartesian_discrete_field import CartesianDiscreteScalarField -from hysop.parameters import ScalarParameter, TensorParameter, BufferParameter -from hysop.domain.box import Box class Problem(ComputationalGraph): @@ -25,8 +19,6 @@ class Problem(ComputationalGraph): mpi_params = first_not_None(mpi_params, MPIParams()) # enforce mpi params for problems super(Problem, self).__init__(name=name, method=method, mpi_params=mpi_params, **kwds) self._do_check_unique_clenv = check_unique_clenv - self._checkpoint_template = None - self._checkpoint_compressor = None @debug def insert(self, *ops): @@ -43,10 +35,11 @@ class Problem(ComputationalGraph): msg = ' Problem {} achieved, exiting ! '.format(msg) vprint_banner(msg, at_border=2) sys.exit(0) - avg_time = main_comm.allreduce(tm.interval) / main_size + size = self.mpi_params.size + avg_time = self.mpi_params.comm.allreduce(tm.interval) / size msg = ' Problem building took {} ({}s)' - if main_size > 1: - msg += ', averaged over {} ranks. '.format(main_size) + if size > 1: + msg += ', averaged over {} ranks. '.format(size) msg = msg.format(datetime.timedelta(seconds=round(avg_time)), avg_time) vprint_banner(msg, spacing=True, at_border=2) @@ -135,8 +128,8 @@ class Problem(ComputationalGraph): @debug def solve(self, simu, dry_run=False, dbg=None, - report_freq=10, plot_freq=10, checkpoint_io_params=None, - load_checkpoint=None, save_checkpoint=None, **kwds): + report_freq=10, plot_freq=10, + checkpoint_handler=None, **kwds): if dry_run: vprint() @@ -144,9 +137,10 @@ class Problem(ComputationalGraph): return simu.initialize() - - self.create_checkpoint_template(save_checkpoint, checkpoint_io_params) - self.load_checkpoint(load_checkpoint, checkpoint_io_params, simu) + + check_instance(checkpoint_handler, CheckpointHandler, allow_none=True) + checkpoint_handler.create_checkpoint_template(self, simu) + checkpoint_handler.load_checkpoint(self, simu) vprint('\nSolving problem...') with Timer() as tm: @@ -155,14 +149,15 @@ class Problem(ComputationalGraph): simu.print_state() self.apply(simulation=simu, dbg=dbg, **kwds) simu.advance(dbg=dbg, plot_freq=plot_freq) - self.save_checkpoint(save_checkpoint, checkpoint_io_params, simu) + checkpoint_handler.save_checkpoint(self, simu) if report_freq and (simu.current_iteration % report_freq) == 0: self.profiler_report() - - avg_time = main_comm.allreduce(tm.interval) / main_size + + size = self.mpi_params.size + avg_time = self.mpi_params.comm.allreduce(tm.interval) / size msg = ' Simulation took {} ({}s)' - if main_size > 1: - msg += ', averaged over {} ranks. '.format(main_size) + if size > 1: + msg += ', averaged over {} ranks. '.format(size) msg += '\n for {} iterations ({}s per iteration) ' msg = msg.format(datetime.timedelta(seconds=round(avg_time)), avg_time, max(simu.current_iteration+1, 1), @@ -170,6 +165,7 @@ class Problem(ComputationalGraph): vprint_banner(msg, spacing=True, at_border=2) simu.finalize() + checkpoint_handler.finalize(self.mpi_params) self.final_report() if (dbg is not None): @@ -182,688 +178,4 @@ class Problem(ComputationalGraph): def finalize(self): vprint('Finalizing problem...') super(Problem, self).finalize() - if ((self._checkpoint_template is not None) - and os.path.exists(self._checkpoint_template) - and self.mpi_params.rank==0): - try: - shutil.rmtree(self._checkpoint_template) - except OSError: - pass - - - @debug - def load_checkpoint(self, load_checkpoint, checkpoint_io_params, simu): - if (load_checkpoint is None): - return - - vprint('\n>Loading problem checkpoint from \'{}\'...'.format(load_checkpoint)) - if not os.path.exists(load_checkpoint): - msg='Failed to load checkpoint \'{}\' because the file does not exist.' - raise RuntimeError(msg.format(load_checkpoint)) - - mpi_params = self.mpi_params - comm = mpi_params.comm - io_leader = checkpoint_io_params.io_leader - is_io_leader = (io_leader == self.mpi_params.rank) - start = Wtime() - - if os.path.isfile(load_checkpoint): - if load_checkpoint.endswith('.tar'): - if is_io_leader: - load_checkpoint_dir = os.path.join(os.path.dirname(load_checkpoint), - os.path.basename(load_checkpoint).replace('.tar', '')) - while os.path.exists(load_checkpoint_dir): - # ok, use another directory name to avoid dataloss... - load_checkpoint_dir = os.path.join(os.path.dirname(load_checkpoint), - '{}'.format(uuid.uuid4().hex)) - tf = tarfile.open(load_checkpoint, mode='r') - tf.extractall(path=load_checkpoint_dir) - else: - load_checkpoint_dir = None - load_checkpoint_dir = comm.bcast(load_checkpoint_dir, root=io_leader) - should_remove_dir = True - else: - msg='Can only load checkpoint with tar extension.' - raise NotImplementedError(msg) - else: - load_checkpoint_dir = load_checkpoint - should_remove_dir = False - - # here we want hysop to crash on unsuccessfull import - self._import_checkpoint(load_checkpoint_dir,checkpoint_io_params, simu) - - if is_io_leader and should_remove_dir: - shutil.rmtree(load_checkpoint_dir) - - ellapsed = Wtime() - start - msg=' > Successfully imported checkpoint in {}.' - vprint(msg.format(time2str(ellapsed))) - - - @debug - def save_checkpoint(self, save_checkpoint, checkpoint_io_params, simu): - if (save_checkpoint is None): - return - - if (checkpoint_io_params is None): - msg='Save checkpoint has been set to \'{}\' but checkpoint_io_params has not been specified.' - raise RuntimeError(msg.format(save_checkpoint)) - - if not checkpoint_io_params.should_dump(simu): - return - - io_leader = checkpoint_io_params.io_leader - is_io_leader = (io_leader == self.mpi_params.rank) - start = Wtime() - - # Checkpoint is first exported as a directory containing a hierarchy of arrays (field and parameters data + metadata) - # This folder is than tarred (without any form of compression) so that a checkpoint consists in a single movable file. - # Data is already compressed during data export by the zarr module, using the blosc compressor (snappy, clevel=3). - assert save_checkpoint.endswith('.tar') - save_checkpoint_tar = save_checkpoint - if is_io_leader: - save_checkpoint_dir = os.path.join(os.path.dirname(save_checkpoint), - os.path.basename(save_checkpoint).replace('.tar', '')) - while os.path.exists(save_checkpoint_dir): - # ok, use another directory name to avoid dataloss... - save_checkpoint_dir = os.path.join(os.path.dirname(save_checkpoint), - '{}'.format(uuid.uuid4().hex)) - else: - save_checkpoint_dir = None - save_checkpoint_dir = self.mpi_params.comm.bcast(save_checkpoint_dir, root=io_leader) - del save_checkpoint - - vprint('>Exporting problem checkpoint to \'{}\':'.format(save_checkpoint_tar)) - - # create a backup of last checkpoint in the case things go wrong - if is_io_leader and os.path.exists(save_checkpoint_tar): - backup_checkpoint_tar = save_checkpoint_tar + '.bak' - if os.path.exists(backup_checkpoint_tar): - os.remove(backup_checkpoint_tar) - os.rename(save_checkpoint_tar, backup_checkpoint_tar) - else: - backup_checkpoint_tar = None - - # try to create the checkpoint directory, this is a collective MPI operation - try: - success, reason, nbytes = self._export_checkpoint(save_checkpoint_dir, checkpoint_io_params, simu) - except Exception as e: - raise - success = False - reason = str(e) - success = main_comm.allreduce(int(success)) == main_comm.size - - # Compress checkpoint directory to tar (easier to copy/move between clusters) - # Note that there is no effective compression here, zarr already compressed field/param data - if success and is_io_leader and os.path.isdir(save_checkpoint_dir): - try: - with tarfile.open(save_checkpoint_tar, 'w') as tf: - for (root, dirs, files) in os.walk(save_checkpoint_dir): - for f in files: - fpath = os.path.join(root, f) - tf.add(fpath, arcname=fpath.replace(save_checkpoint_dir+os.path.sep,'')) - if os.path.isfile(save_checkpoint_tar): - shutil.rmtree(save_checkpoint_dir) - else: - raise RuntimeError('Could not tar checkpoint datadir.') - ellapsed = Wtime() - start - effective_nbytes = os.path.getsize(save_checkpoint_tar) - compression_ratio = max(1.0, float(nbytes)/effective_nbytes) - msg=' > Successfully exported checkpoint in {} with a compression ratio of {:.1f} ({}).' - vprint(msg.format(time2str(ellapsed), compression_ratio, bytes2str(effective_nbytes))) - except Exception as e: - success = False - reason = str(e) - success = main_comm.allreduce(int(success)) == main_comm.size - - if success: - if (backup_checkpoint_tar is not None) and os.path.isfile(backup_checkpoint_tar) and is_io_leader: - os.remove(backup_checkpoint_tar) - return - - from hysop.tools.warning import HysopDumpWarning - msg='Failed to export checkpoint because: {}.'.format(reason) - warnings.warn(msg, HysopDumpWarning) - - # Something went wrong (I/O error or other) so we rollback to previous checkpoint (if there is one) - vprint(' | An error occured during checkpoint creation, rolling back to previous checkpoint...') - if is_io_leader: - if os.path.exists(save_checkpoint_dir): - shutil.rmtree(save_checkpoint_dir) - if os.path.exists(save_checkpoint_tar): - os.remove(save_checkpoint_tar) - if (backup_checkpoint_tar is not None) and os.path.exists(backup_checkpoint_tar): - os.rename(backup_checkpoint_tar, save_checkpoint_tar) - - def create_checkpoint_template(self, save_checkpoint, checkpoint_io_params): - # Create groups of arrays on disk (only hierarchy and array metadata is stored in the template) - # /!\ ZipStores are not safe from multiple processes so we use a DirectoryStore - # that can then be tarred manually by io_leader. - - if (save_checkpoint is None): - return - - mpi_params = self.mpi_params - comm = mpi_params.comm - io_leader = checkpoint_io_params.io_leader - is_io_leader = (io_leader == mpi_params.rank) - - assert save_checkpoint.endswith('.tar'), save_checkpoint - if is_io_leader: - checkpoint_template = os.path.join(os.path.dirname(save_checkpoint), - os.path.basename(save_checkpoint).replace('.tar', '.template')) - while os.path.exists(checkpoint_template): - # ok, use another directory name to avoid dataloss... - checkpoint_template = os.path.join(os.path.dirname(save_checkpoint), - '{}'.format(uuid.uuid4().hex)) - else: - checkpoint_template = None - checkpoint_template = comm.bcast(checkpoint_template, root=io_leader) - self._checkpoint_template = checkpoint_template - del save_checkpoint - - vprint('\n>Creating checkpoint template as \'{}\'...'.format(checkpoint_template)) - # Array block data compressor - from numcodecs import blosc, Blosc - blosc.use_threads = (self.mpi_params.size == 1) # disable threads for multiple processes (can deadlock) - compressor = Blosc(cname='snappy', clevel=3, shuffle=Blosc.BITSHUFFLE) - self._checkpoint_compressor = compressor - - # Create a directory layout as a file on shared filesystem - import zarr - - nbytes = 0 # count number of total data bytes without compression - - if is_io_leader: - if os.path.exists(checkpoint_template): - shutil.rmtree(checkpoint_template) - store = zarr.DirectoryStore(path=checkpoint_template) - root = zarr.open_group(store=store, mode='w', path='data') - params = root.create_group('params') - fields = root.create_group('fields') - simulation = root.create_group('simulation') - else: - store = None - root = None - params = None - fields = None - simulation = None - - # Generate parameter arrays - # Here we expect that each process store parameters that are in sync - # For each parameter we assume that the same values are broadcast to all processes - # even if is not enforced by the library (should cover most current use cases...) - for param in sorted(self.parameters, key=operator.attrgetter('name')): - if not is_io_leader: - continue - if isinstance(param, (ScalarParameter, TensorParameter, BufferParameter)): - # all those parameters store their data in a numpy ndarray so we're good - assert isinstance(param._value, np.ndarray), type(param._value) - value = param._value - array = params.create_dataset(name=self.format_zarr_key(param.name), - overwrite=False, data=None, synchronizer=None, - compressor=compressor, shape=value.shape, chunks=None, - dtype=value.dtype, fill_value=default_invalid_value(value.dtype)) - array.attrs['kind'] = param.__class__.__name__ - nbytes += value.nbytes - else: - msg = 'Cannot export parameter of type {}.'.format(param.__class__.__name__) - raise NotImplementedError(msg) - - # Generate discrete field arrays - # Here we assume that each process has a non-empty chunk of data - for field in sorted(self.fields, key=operator.attrgetter('name')): - # we do not care about fields discretized only on temporary fields - if all(df.is_tmp for df in field.discrete_fields.values()): - continue - - if is_io_leader: - discrete_fields = fields.create_group(self.format_zarr_key(field.name)) - else: - discrete_fields = None - - dim = field.dim - domain = field.domain._domain - - if isinstance(domain, Box): - if (discrete_fields is not None): - discrete_fields.attrs['domain'] = 'Box' - discrete_fields.attrs['dim'] = domain.dim - discrete_fields.attrs['origin'] = to_tuple(domain.origin) - discrete_fields.attrs['end'] = to_tuple(domain.end) - discrete_fields.attrs['length'] = to_tuple(domain.length) - else: - # for now we just handle Boxed domains - raise NotImplementedError - - for (k, topo) in enumerate(sorted(field.discrete_fields, key=operator.attrgetter('full_tag'))): - dfield = field.discrete_fields[topo] - mesh = topo.mesh._mesh - - # we do not care about temporary fields - if dfield.is_tmp: - continue - - if not isinstance(dfield, CartesianDiscreteScalarField): - # for now we just handle CartesianDiscreteScalarFields. - raise NotImplementedError - - global_resolution = topo.global_resolution # logical grid size - grid_resolution = topo.grid_resolution # effective grid size - ghosts = topo.ghosts - - # get local resolutions exluding ghosts - compute_resolutions = comm.gather(to_tuple(mesh.compute_resolution), root=io_leader) - - # is the current process handling a right boundary data block on a distributed axe ? - is_at_right_boundary = (mesh.is_at_right_boundary*(mesh.proc_shape>1)).any() - is_at_right_boundary = np.asarray(comm.gather(is_at_right_boundary, root=io_leader)) - - if not is_io_leader: - continue - - # io_leader can now determine wether the cartesian discretization is uniformly distributed - # between processes or not - inner_compute_resolutions = tuple(compute_resolutions[i] for i in range(len(compute_resolutions)) - if not is_at_right_boundary[i]) - grid_is_uniformly_distributed = all(res == inner_compute_resolutions[0] - for res in inner_compute_resolutions) - grid_is_uniformly_distributed |= (topo.mpi_params.size == 1) - - if grid_is_uniformly_distributed: - # We divide the array in 'compute_resolution' chunks, no sychronization is required. - # Here there is no need to use the process locker to write this array data. - # Each process writes its own independent block of data of size 'compute_resolution'. - should_sync = False - chunks = inner_compute_resolutions[0] - else: - # We divide the array in >=1MB chunks (chunks are given in terms of elements) - # Array chunks may overlap different processes so we need interprocess sychronization (slow) - should_sync = True - if dim == 1: - chunks = 1024*1024 # at least 1MB / chunk - elif dim == 2: - chunks = (1024,1024) # at least 1MB / chunk - elif dim == 3: - chunks = (64,128,128) # at least 1MB / chunk - else: - raise NotImplementedError(dim) - - # Create array (no memory is allocated here, even on disk because data blocks are empty) - dtype = dfield.dtype - shape = grid_resolution - - # We scale the keys up to 100 topologies, which seams to be a pretty decent upper limit - # on a per field basis. - array = discrete_fields.create_dataset(name='topo_{:02d}'.format(k), - overwrite=False, data=None, synchronizer=None, - compressor=compressor, shape=shape, chunks=chunks, - dtype=dtype, fill_value=default_invalid_value(dtype)) - array.attrs['should_sync'] = should_sync - - # We cannot rely on discrete mesh name because of topology names - # so we save some field metadata to be able to differentiate between - # discrete fields with the exact same grid resolution. - # proc_shape and name are used in last resort to differentiate discrete fields. - array.attrs['lboundaries'] = to_tuple(map(str, mesh.global_lboundaries)) - array.attrs['rboundaries'] = to_tuple(map(str, mesh.global_rboundaries)) - array.attrs['ghosts'] = to_tuple(mesh.ghosts) - array.attrs['proc_shape'] = to_tuple(mesh.proc_shape) - array.attrs['name'] = dfield.name - - nbytes += np.prod(shape, dtype=np.int64) * dtype.itemsize - - if (root is not None): - root.attrs['nbytes'] = nbytes - msg=' => Maximum checkpoint size will be {}, without compression and metadata.' - vprint(root.tree()) - vprint(msg.format(bytes2str(nbytes))) - - # some zarr store formats require a final close to flush data - try: - if (root is not None): - root.close() - except AttributeError: - pass - - - def _export_checkpoint(self, save_checkpoint_dir, checkpoint_io_params, simu): - # Given a template, fill field and parameters data from all processes. - # returns (bool, msg) where bool is True on success - - mpi_params = self.mpi_params - comm = mpi_params.comm - io_leader = checkpoint_io_params.io_leader - is_io_leader = (io_leader == mpi_params.rank) - - if not os.path.exists(self._checkpoint_template): - # checkpoint template may have been deleted by user during simulation - self.create_checkpoint_template(save_checkpoint, checkpoint_io_params) - compressor = self._checkpoint_compressor - - if is_io_leader: - if os.path.exists(save_checkpoint_dir): - shutil.rmtree(save_checkpoint_dir) - shutil.copytree(self._checkpoint_template, save_checkpoint_dir) - comm.Barrier() - - #Every process now loads the same dataset template - import zarr - store = zarr.DirectoryStore(save_checkpoint_dir) - root = zarr.open_group(store=store, mode='r+', synchronizer=None, path='data') - - nbytes = root.attrs['nbytes'] - fields_group = root['fields'] - params_group = root['params'] - simu_group = root['simulation'] - - # Export simulation data - if is_io_leader: - simu.save_checkpoint(simu_group, mpi_params, checkpoint_io_params, compressor=compressor) - - # Currently there is no distributed parameter capabilities so io_leader has to dump all parameters - if is_io_leader: - msg = ' | dumping parameters...' - vprint(msg) - for param in sorted(self.parameters, key=operator.attrgetter('name')): - if isinstance(param, (ScalarParameter, TensorParameter, BufferParameter)): - array = params_group[self.format_zarr_key(param.name)] - assert array.attrs['kind'] == param.__class__.__name__ - assert array.dtype == param._value.dtype - assert array.shape == param._value.shape - array[...] = param._value - else: - msg = 'Cannot dump parameter of type {}.'.format(param.__class__.__name__) - raise NotImplementedError(msg) - - # Unlike parameter all processes participate for fields - for field in sorted(self.fields, key=operator.attrgetter('name')): - # we do not care about fields discretized only on temporary fields - if all(df.is_tmp for df in field.discrete_fields.values()): - continue - - msg = ' | dumping field {}...'.format(field.pretty_name) - vprint(msg) - - field_group = fields_group[self.format_zarr_key(field.name)] - for (k, topo) in enumerate(sorted(field.discrete_fields, key=operator.attrgetter('full_tag'))): - dfield = field.discrete_fields[topo] - mesh = topo.mesh._mesh - - if dfield.is_tmp: - # we do not care about temporary fields - continue - - dataset = 'topo_{:02d}'.format(k) # key has to match template - array = field_group[dataset] - should_sync = array.attrs['should_sync'] - - assert dfield.nb_components == 1 - assert (array.shape == mesh.grid_resolution).all(), (array.shape, mesh.grid_resolution) - assert array.dtype == dfield.dtype, (array.dtype, dfield.dtype) - - if should_sync: - # Should not be required untill we allow non-uniform discretizations - global_start = mesh.global_start - global_stop = mesh.global_stop - raise NotImplementedError('Synchronized multiprocess write has not been implemented yet.') - else: - assert ((mesh.compute_resolution == array.chunks).all() - or (mesh.is_at_right_boundary*(mesh.proc_shape>1)).any()) - local_data = dfield.compute_data[0].get() - global_slices = mesh.global_compute_slices - array[global_slices] = local_data # ok, every process writes to an independent data blocks - - # Some zarr store formats require a final close to flush data - try: - root.close() - except AttributeError: - pass - - return True, None, nbytes - - - def _import_checkpoint(self, load_checkpoint_dir, checkpoint_io_params, simu, strict_check=False): - if not os.path.isdir(load_checkpoint_dir): - msg='Cannot find directory \'{}\'.'.format(load_checkpoint_dir) - raise RuntimeError(msg) - - # On data import, there is no need to synchronize read-only arrays - # so we are good with multiple processes reading overlapping data blocks - - import zarr - store = zarr.DirectoryStore(load_checkpoint_dir) - self.mpi_params.comm.Barrier() - - try: - root = zarr.open_group(store=store, mode='r', synchronizer=None, path='data') - params = root['params'] - fields = root['fields'] - simulation = root['simulation'] - except: - msg='An error occured during checkpoint import.' - vprint(msg) - vprint() - raise - - def raise_error(msg): - msg = ' | error: {}\n'.format(msg) - vprint(msg) - msg = 'FATAL ERROR: Failed to import checkpoint, check logs for more information.'.format(msg) - raise RuntimeError(msg) - def raise_warning(msg): - msg = ' | warning: {}'.format(msg) - vprint(msg) - - if strict_check: - raise_warning = raise_error - - def load_array_data(array, dfield): - mesh = dfield.mesh._mesh - assert np.equal(array.shape, mesh.grid_resolution).all() - - # compare attributes but ignore name because this can be annoying - attr_names = ('left boundaries', 'right boundaries', 'ghost layers', 'process shape', 'datatype') - array_attributes = (array.attrs['lboundaries'], array.attrs['rboundaries'], array.attrs['ghosts'], - array.attrs['proc_shape'], array.dtype) - dfield_attributes = (list(map(str, mesh.global_lboundaries)), list(map(str, mesh.global_rboundaries)), - list(mesh.ghosts), list(mesh.proc_shape)) - for (name,lhs,rhs) in zip(attr_names, array_attributes, dfield_attributes): - if lhs==rhs: - continue - msg='{} do not match with checkpointed field {}, loaded {} {} but expected {}.' - msg=msg.format(name, dfield.field.name, name, lhs, rhs) - raise_warning(msg) - - global_slices = mesh.global_compute_slices - data = np.asarray(array[global_slices], dtype=dfield.dtype) - dfield.compute_data[0][...] = data - dfield.exchange_ghosts() - - # Import parameters, hopefully parameter names match the ones in the checkpoint - msg = ' | importing parameters...' - vprint(msg) - for param in sorted(self.parameters, key=operator.attrgetter('name')): - key = self.format_zarr_key(param.name) - - if (key not in params): - msg='Checkpoint directory \'{}\' does not contain any data regarding to parameter {}' - msg=msg.format(load_checkpoint_dir, param.name) - raise_error(msg) - - array = params[key] - - if array.attrs['kind'] != param.__class__.__name__: - msg='Parameter kind do not match with checkpointed parameter {}, loaded kind {} but expected {}.' - msg=msg.format(param.name, array.attrs['kind'], param.__class__.__name__) - raise_error(msg) - - if isinstance(param, (ScalarParameter, TensorParameter, BufferParameter)): - value = param._value - - if (array.shape != value.shape): - msg='Parameter shape does not match with checkpointed parameter {}, loaded shape {} but expected {}.' - msg=msg.format(param.name, array.shape, value.shape) - raise_error(msg) - - if (array.dtype != value.dtype): - msg='Parameter datatype does not match with checkpointed parameter {}, loaded dtype {} but expected {}.' - msg=msg.format(param.name, array.dtype, value.dtype) - raise_warning(msg) - - value[...] = array[...] - else: - msg = 'Cannot import parameter of type {}.'.format(param.__class__.__name__) - raise NotImplementedError(msg) - - # Import simulation data after parameters are up to date - simu.load_checkpoint(simulation, self.mpi_params, checkpoint_io_params, strict_check) - - # Import discrete fields, this is a bit more tricky because topologies or simply topology - # names can change. Moreover there is currently no waranty that the same operator graph is - # generated for the exact same problem configuration each time. We just emit user warnings - # if we find a way to match topologies that do not match exactly checkpointed ones. - for field in sorted(self.fields, key=operator.attrgetter('name')): - domain = field.domain._domain - - # we do not care about fields discretized only on temporary fields - if all(df.is_tmp for df in field.discrete_fields.values()): - continue - msg = ' | importing field {}...'.format(field.pretty_name) - vprint(msg) - - field_key = self.format_zarr_key(field.name) - if (field_key not in fields): - msg='Checkpoint directory \'{}\' does not contain any data regarding to field {}' - msg=msg.format(load_checkpoint_dir, field.name) - raise_error(msg) - - dfields = fields[field_key] - - # check that domain matches - if dfields.attrs['domain'] != domain.__class__.__name__: - msg='Domain kind does not match with checkpointed field {}, loaded kind {} but expected {}.' - msg=msg.format(field.name, dfields.attrs['domain'], domain.__class__.__name__) - raise_error(msg) - if dfields.attrs['dim'] != domain.dim: - msg='Domain dim does not match with checkpointed field {}, loaded dim {} but expected {}.' - msg=msg.format(field.name, dfields.attrs['dim'], domain.dim) - raise_error(msg) - if dfields.attrs['origin'] != to_list(domain.origin): - msg='Domain origin does not match with checkpointed field {}, loaded origin {} but expected {}.' - msg=msg.format(field.name, dfields.attrs['origin'], domain.origin) - raise_error(msg) - if dfields.attrs['end'] != to_list(domain.end): - msg='Domain end does not match with checkpointed field {}, loaded end {} but expected {}.' - msg=msg.format(field.name, dfields.attrs['end'], domain.end) - raise_error(msg) - if dfields.attrs['length'] != to_list(domain.length): - msg='Domain length does not match with checkpointed field {}, loaded length {} but expected {}.' - msg=msg.format(field.name, dfields.attrs['length'], domain.length) - raise_error(msg) - - for (k, topo) in enumerate(sorted(field.discrete_fields, key=operator.attrgetter('full_tag'))): - dfield = field.discrete_fields[topo] - mesh = topo.mesh._mesh - - # we do not care about temporary fields - if dfield.is_tmp: - continue - - # for now we just handle CartesianDiscreteScalarFields. - if not isinstance(dfield, CartesianDiscreteScalarField): - raise NotImplementedError - - # first we need to exactly match global grid resolution - candidates = tuple(filter(lambda d: np.equal(d.shape, mesh.grid_resolution).all(), dfields.values())) - if len(candidates)==0: - msg='Could not find any topology with shape {} for field {}, available discretizations are: {}.' - msg=msg.format(to_tuple(mesh.grid_resolution), field.name, - ', '.join(set(str(d.shape) for d in dfields.values()))) - raise_error(msg) - elif len(candidates)==1: - load_array_data(candidates[0], dfield) - continue - - # Here multiple topologies have the extact same grid resolution so we try to match boundary conditions - old_candidates = candidates - candidates = tuple(filter(lambda d: d.attrs['lboundaries'] == to_tuple(map(str, mesh.global_lboundaries)), candidates)) - candidates = tuple(filter(lambda d: d.attrs['rboundaries'] == to_tuple(map(str, mesh.global_rboundaries)), candidates)) - if len(candidates)==0: - # ok, the user changed the boundary conditions, we ignore boundary condition information - candidates = old_candidates - elif len(candidates)==1: - load_array_data(candidates[0], dfield) - continue - - # From now on multiple topologies have the same grid resolution and boundary conditions - # We try to match exact ghost count, user did likely not change the order of the methods. - old_candidates = candidates - candidates = tuple(filter(lambda d: d.attrs['ghosts'] == to_tuple(mesh.ghosts), candidates)) - if len(candidates)==0: - # ok, the user made a change that affected ghosts, we ignore the ghost condition - candidates = old_candidates - elif len(candidates)==1: - load_array_data(candidates[0], dfield) - continue - - # Now we try to differentiate by using zero ghost info (ghosts may change with method order, but zero-ghost is very specific) - # Topology containing zero ghost layer usually target Fortran topologies for FFT operators or method that do not require any ghosts. - old_candidates = candidates - candidates = tuple(filter(lambda d: (np.equal(d.attrs['ghosts'],0) == (mesh.ghosts==0)).all(), candidates)) - if len(candidates)==0: - # ok, we ignore the zero-ghost condition - candidates = old_candidates - elif len(candidates)==1: - load_array_data(candidates[0], dfield) - continue - - # Now we try to match exact topology shape (the MPICart grid of processes) - # We try this late because use may run the simulation again with a different number of processes. - old_candidates = candidates - candidates = tuple(filter(lambda d: d.attrs['proc_shape'] == to_tuple(mesh.proc_shape), candidates)) - if len(candidates)==0: - # ok, we ignore the proc shape - candidates = old_candidates - elif len(candidates)==1: - load_array_data(candidates[0], dfield) - continue - - # Now we try to differentiate by using topo splitting info (axes on which data is distributed) - # This again is very specific and can differentiate topologies used for spectral transforms. - old_candidates = candidates - candidates = tuple(filter(lambda d: (np.greater(d.attrs['proc_shape'],1) == (mesh.proc_shape>1)).all(), candidates)) - if len(candidates)==0: - # ok, we ignore the MPI data splitting condition - candidates = old_candidates - elif len(candidates)==1: - load_array_data(candidates[0], dfield) - continue - - # Ok now, our last hope is to match the discrete field name - old_candidates = candidates - candidates = tuple(filter(lambda d: d.attrs['name'] == dfield.name, candidates)) - if len(candidates)==0: - # ok, we ignore the name - candidates = old_candidates - elif len(candidates)==1: - load_array_data(candidates[0], dfield) - continue - - assert len(candidates) > 1, 'Something went wrong.' - - msg='Could not discriminate checkpointed topologies for field {}, got {} candidates remaining.' - msg=msg.format(field.name, len(candidates)) - raise_error(msg) - - - @staticmethod - def format_zarr_key(k): - # note keys that contains the special characters '/' and '\' do not work well with zarr - # so we need to replace it by another character such as '_'. - # We cannot use utf8 characters such as u+2215 (division slash). - if (k is None): - return None - return k.replace('/', '_').replace('\\', '_') - diff --git a/hysop/simulation.py b/hysop/simulation.py index 917d3f2e4..07397e356 100644 --- a/hysop/simulation.py +++ b/hysop/simulation.py @@ -426,10 +426,10 @@ class Simulation(object): params = _params self._parameters_to_write.append((io_params, params, kwds)) - def save_checkpoint(self, datagroup, checkpoint_mpi_params, checkpoint_io_params, compressor): + def save_checkpoint(self, datagroup, mpi_params, io_params, compressor): import zarr check_instance(datagroup, zarr.hierarchy.Group) - is_io_leader = (checkpoint_mpi_params.rank == checkpoint_io_params.io_leader) + is_io_leader = (mpi_params.rank == io_params.io_leader) if is_io_leader: # we need to export simulation parameter values because they # may not be part of global problem parameters @@ -443,7 +443,7 @@ class Simulation(object): pass datagroup.attrs[attrname] = data - def load_checkpoint(self, datagroup, checkpoint_mpi_params, checkpoint_io_params, strict_check): + def load_checkpoint(self, datagroup, mpi_params, io_params, relax_constraints): import zarr check_instance(datagroup, zarr.hierarchy.Group) self.times_of_interest = tuple(sorted(filter(lambda t: t>=datagroup.attrs['time'], self.times_of_interest))) diff --git a/hysop_examples/example_utils.py b/hysop_examples/example_utils.py index 085302f76..c3501d82b 100644 --- a/hysop_examples/example_utils.py +++ b/hysop_examples/example_utils.py @@ -1011,13 +1011,13 @@ class HysopArgParser(argparse.ArgumentParser): description = ('Configure problem checkpoints I/O parameters, dumped checkpoints represent simulation states ' 'that can be loaded back to continue the simulation later on.') pargs = self.add_argument_group('{} I/O'.format(pname.upper()), description=description) - pargs.add_argument('-L', '--load-checkpoint', default=None, const='checkpoint.tar', nargs='?', type=str, dest='load_checkpoint', + pargs.add_argument('-L', '--load-checkpoint', default=None, const='checkpoint.tar', nargs='?', type=str, dest='load_checkpoint_path', help=('Begin simulation from this checkpoint. Can be given as fullpath or as a filename relative to --checkpoint-dump-dir. ' 'The given checkpoint has to be compatible with the problem it will be loaded to. ' 'This will only work if parameter names, variable names, operator names, discretization and global topology information remain unchanged. ' 'Operator ordering, boundary conditions, data ordering, data permutation and MPI layouts may be however be changed. ' 'Defaults to {checkpoint_output_dir}/checkpoint.tar if no filename is specified.')) - pargs.add_argument('-S', '--save-checkpoint', default=None, const='checkpoint.tar', nargs='?', type=str, dest='save_checkpoint', + pargs.add_argument('-S', '--save-checkpoint', default=None, const='checkpoint.tar', nargs='?', type=str, dest='save_checkpoint_path', help=('Enable simulation checkpoints to be able to restart simulations from a specific point later on. ' 'Can be given as fullpath or as a filename relative to --checkpoint-dump-dir. ' 'Frequency or time of interests for checkpoints can be configured by using global FILE I/O parameters or ' @@ -1025,6 +1025,11 @@ class HysopArgParser(argparse.ArgumentParser): 'Should not be to frequent for efficiency reasons. May be used in conjunction with --load-checkpoint, ' 'in which case the starting checkpoint may be overwritten in the case the same path are given. ' 'Defaults to {checkpoint_output_dir}/checkpoint.tar if no filename is specified.')) + pargs.add_argument('--checkpoint-relax-constraints', action='store_true', dest='checkpoint_relax_constraints', + help=('Relax field/parameter checks when loading a checkpoint. This allows for a change in datatype, ' + 'boundary conditions, ghost count and topology shape when reloading a checkpoint. ' + 'Useful to continue a simulation with a different precision, different compute backend, ' + 'different boundary conditions or with a different number of processes.')) else: pargs = self.add_argument_group('{} I/O'.format(pname.upper())) @@ -1261,9 +1266,10 @@ class HysopArgParser(argparse.ArgumentParser): args.times_of_interest = times_of_interest - # checkpoints - self._check_default(args, 'load_checkpoint', str, allow_none=True) - self._check_default(args, 'save_checkpoint', str, allow_none=True) + # extra checkpoints arguments + self._check_default(args, 'load_checkpoint_path', str, allow_none=True) + self._check_default(args, 'save_checkpoint_path', str, allow_none=True) + self._check_default(args, 'checkpoint_relax_constraints', bool, allow_none=False) def _add_graphical_io_args(self): graphical_io = self.add_argument_group('Graphical I/O') @@ -1550,6 +1556,7 @@ class HysopArgParser(argparse.ArgumentParser): def _setup_parameters(self, args): from hysop import IO, IOParams + from hysop.core.checkpoints import CheckpointHandler args.io_params = IOParams(filename=None, filepath=args.dump_dir, frequency=args.dump_freq, dump_times=args.dump_times, @@ -1576,30 +1583,33 @@ class HysopArgParser(argparse.ArgumentParser): hdf5_disable_slicing = getattr(args, '{}_hdf5_disable_slicing'.format(pname))) setattr(args, '{}_io_params'.format(pname), iop) - load_checkpoint = args.load_checkpoint - if (load_checkpoint is not None): - if not load_checkpoint.endswith('.tar'): + load_checkpoint_path = args.load_checkpoint_path + if (load_checkpoint_path is not None): + if not load_checkpoint_path.endswith('.tar'): msg='Load checkpoint filename has to end with .tar, got \'{}\'.' - self.error(msg.format(load_checkpoint)) - if (os.path.sep not in load_checkpoint): - load_checkpoint = os.path.join(args.checkpoint_dump_dir, load_checkpoint) - if not os.path.isfile(load_checkpoint): + self.error(msg.format(load_checkpoint_path)) + if (os.path.sep not in load_checkpoint_path): + load_checkpoint_path = os.path.join(args.checkpoint_dump_dir, load_checkpoint_path) + if not os.path.isfile(load_checkpoint_path): msg = 'Cannot load checkpoint \'{}\' because the file does not exist.' - self.error(msg.format(load_checkpoint)) - load_checkpoint = os.path.abspath(load_checkpoint) - args.load_checkpoint = load_checkpoint + self.error(msg.format(load_checkpoint_path)) + load_checkpoint_path = os.path.abspath(load_checkpoint_path) + args.load_checkpoint_path = load_checkpoint_path - save_checkpoint = args.save_checkpoint - if (save_checkpoint is not None): - if not save_checkpoint.endswith('.tar'): + save_checkpoint_path = args.save_checkpoint_path + if (save_checkpoint_path is not None): + if not save_checkpoint_path.endswith('.tar'): msg='Save checkpoint filename has to end with .tar, got \'{}\'.' - self.error(msg.format(save_checkpoint)) - if (os.path.sep not in save_checkpoint): - save_checkpoint = os.path.join(args.checkpoint_dump_dir, save_checkpoint) - save_checkpoint = os.path.abspath(save_checkpoint) - args.checkpoint_dump_dir = os.path.dirname(save_checkpoint) - args.save_checkpoint = save_checkpoint - + self.error(msg.format(save_checkpoint_path)) + if (os.path.sep not in save_checkpoint_path): + save_checkpoint_path = os.path.join(args.checkpoint_dump_dir, save_checkpoint_path) + save_checkpoint_path = os.path.abspath(save_checkpoint_path) + args.checkpoint_dump_dir = os.path.dirname(save_checkpoint_path) + args.save_checkpoint_path = save_checkpoint_path + + args.checkpoint_handler = CheckpointHandler(args.load_checkpoint_path, args.save_checkpoint_path, + args.checkpoint_io_params, args.checkpoint_relax_constraints) + # debug dumps if (args.debug_dump_dir is None): args.debug_dump_dir = args.dump_dir diff --git a/hysop_examples/examples/analytic/analytic.py b/hysop_examples/examples/analytic/analytic.py index 548165fac..43d175cf9 100755 --- a/hysop_examples/examples/analytic/analytic.py +++ b/hysop_examples/examples/analytic/analytic.py @@ -110,9 +110,21 @@ def compute(args): max_iter=args.max_iter, times_of_interest=args.times_of_interest, t=t) + + # Attach a field debug dumper if requested + from hysop.tools.debug_dumper import DebugDumper + if args.debug_dump_target: + debug_dumper = DebugDumper( + path=args.debug_dump_dir, + name=args.debug_dump_target, + force_overwrite=True, enable_on_op_apply=True) + else: + debug_dumper = None # Finally solve the problem - problem.solve(simu, dry_run=args.dry_run) + problem.solve(simu, dry_run=args.dry_run, + debug_dumper=debug_dumper, + checkpoint_handler=args.checkpoint_handler) # Finalize problem.finalize() diff --git a/hysop_examples/examples/bubble/periodic_bubble.py b/hysop_examples/examples/bubble/periodic_bubble.py index 76c67733a..21f4dd8ee 100644 --- a/hysop_examples/examples/bubble/periodic_bubble.py +++ b/hysop_examples/examples/bubble/periodic_bubble.py @@ -285,6 +285,16 @@ def compute(args): adapt_dt.equivalent_CFL, filename='parameters.txt', precision=8) + # Attach a field debug dumper if requested + from hysop.tools.debug_dumper import DebugDumper + if args.debug_dump_target: + debug_dumper = DebugDumper( + path=args.debug_dump_dir, + name=args.debug_dump_target, + force_overwrite=True, enable_on_op_apply=True) + else: + debug_dumper = None + # Initialize vorticity, velocity, viscosity and density on all topologies Bc, Br = args.Bc, args.Br dx = np.max(np.divide(box.length, np.asarray(args.npts)-1)) @@ -294,8 +304,10 @@ def compute(args): problem.initialize_field(field=rho, formula=init_rho, rho1=args.rho1, rho2=args.rho2, Bc=Bc, Br=Br, reorder='Bc', eps=eps) problem.initialize_field(field=mu, formula=init_mu, mu1=args.mu1, mu2=args.mu2, Bc=Bc, Br=Br, reorder='Bc', eps=eps) - # Finally solve the problem - problem.solve(simu, dry_run=args.dry_run) + # Finally solve the problem + problem.solve(simu, dry_run=args.dry_run, + debug_dumper=debug_dumper, + checkpoint_handler=args.checkpoint_handler) # Finalize problem.finalize() @@ -385,6 +397,7 @@ if __name__=='__main__': self._check_positive(args, 'plot_freq', strict=True, allow_none=False) def _setup_parameters(self, args): + super(PeriodicBubbleArgParser, self)._setup_parameters(args) dim = args.ndim if (dim not in (2,3)): msg='Domain should be 2D or 3D.' diff --git a/hysop_examples/examples/bubble/periodic_bubble_levelset.py b/hysop_examples/examples/bubble/periodic_bubble_levelset.py index df11b891a..967c0b119 100644 --- a/hysop_examples/examples/bubble/periodic_bubble_levelset.py +++ b/hysop_examples/examples/bubble/periodic_bubble_levelset.py @@ -284,6 +284,16 @@ def compute(args): adapt_dt.equivalent_CFL, filename='parameters.txt', precision=8) + # Attach a field debug dumper if requested + from hysop.tools.debug_dumper import DebugDumper + if args.debug_dump_target: + debug_dumper = DebugDumper( + path=args.debug_dump_dir, + name=args.debug_dump_target, + force_overwrite=True, enable_on_op_apply=True) + else: + debug_dumper = None + # Initialize vorticity, velocity, viscosity and density on all topologies Bc, Br = args.Bc, args.Br problem.initialize_field(field=velo, formula=init_velocity) @@ -292,8 +302,10 @@ def compute(args): problem.initialize_field(field=mu, formula=init_mu) problem.initialize_field(field=phi, formula=init_phi, Bc=Bc, Br=Br, reorder='Bc') - # Finally solve the problem - problem.solve(simu, dry_run=args.dry_run) + # Finally solve the problem + problem.solve(simu, dry_run=args.dry_run, + debug_dumper=debug_dumper, + checkpoint_handler=args.checkpoint_handler) # Finalize problem.finalize() @@ -383,6 +395,7 @@ if __name__=='__main__': self._check_positive(args, 'plot_freq', strict=True, allow_none=False) def _setup_parameters(self, args): + super(PeriodicBubbleArgParser, self)._setup_parameters(args) dim = args.ndim if (dim not in (2,3)): msg='Domain should be 2D or 3D.' diff --git a/hysop_examples/examples/bubble/periodic_bubble_levelset_penalization.py b/hysop_examples/examples/bubble/periodic_bubble_levelset_penalization.py index 848886963..6023b607c 100644 --- a/hysop_examples/examples/bubble/periodic_bubble_levelset_penalization.py +++ b/hysop_examples/examples/bubble/periodic_bubble_levelset_penalization.py @@ -325,6 +325,16 @@ def compute(args): adapt_dt.equivalent_CFL, filename='parameters.txt', precision=8) + # Attach a field debug dumper if requested + from hysop.tools.debug_dumper import DebugDumper + if args.debug_dump_target: + debug_dumper = DebugDumper( + path=args.debug_dump_dir, + name=args.debug_dump_target, + force_overwrite=True, enable_on_op_apply=True) + else: + debug_dumper = None + # Initialize vorticity, velocity, viscosity and density on all topologies Bc, Br = args.Bc, args.Br problem.initialize_field(field=velo, formula=init_velocity) @@ -334,8 +344,10 @@ def compute(args): problem.initialize_field(field=phi, formula=init_phi, Bc=Bc, Br=Br, reorder='Bc') problem.initialize_field(field=_lambda, formula=init_lambda) - # Finally solve the problem - problem.solve(simu, dry_run=args.dry_run) + # Finally solve the problem + problem.solve(simu, dry_run=args.dry_run, + debug_dumper=debug_dumper, + checkpoint_handler=args.checkpoint_handler) # Finalize problem.finalize() @@ -425,6 +437,7 @@ if __name__=='__main__': self._check_positive(args, 'plot_freq', strict=True, allow_none=False) def _setup_parameters(self, args): + super(PeriodicBubbleArgParser, self)._setup_parameters(args) dim = args.ndim if (dim not in (2,3)): msg='Domain should be 2D or 3D.' diff --git a/hysop_examples/examples/bubble/periodic_jet_levelset.py b/hysop_examples/examples/bubble/periodic_jet_levelset.py index 479c6469e..bbf7223d0 100644 --- a/hysop_examples/examples/bubble/periodic_jet_levelset.py +++ b/hysop_examples/examples/bubble/periodic_jet_levelset.py @@ -273,14 +273,26 @@ def compute(args): adapt_dt.equivalent_CFL, filename='parameters.txt', precision=8) + # Attach a field debug dumper if requested + from hysop.tools.debug_dumper import DebugDumper + if args.debug_dump_target: + debug_dumper = DebugDumper( + path=args.debug_dump_dir, + name=args.debug_dump_target, + force_overwrite=True, enable_on_op_apply=True) + else: + debug_dumper = None + # Initialize vorticity, velocity, viscosity and density on all topologies problem.initialize_field(field=velo, formula=init_velocity) problem.initialize_field(field=vorti, formula=init_vorticity) problem.initialize_field(field=rho, formula=init_rho) problem.initialize_field(field=phi, formula=init_phi, L=box.length) - # Finally solve the problem - problem.solve(simu, dry_run=args.dry_run) + # Finally solve the problem + problem.solve(simu, dry_run=args.dry_run, + debug_dumper=debug_dumper, + checkpoint_handler=args.checkpoint_handler) # Finalize problem.finalize() @@ -329,6 +341,7 @@ if __name__=='__main__': self._check_positive(args, vars_, strict=False, allow_none=False) def _setup_parameters(self, args): + super(PeriodicJetArgParser, self)._setup_parameters(args) dim = args.ndim if (dim not in (2,3)): msg='Domain should be 2D or 3D.' diff --git a/hysop_examples/examples/cylinder/oscillating_cylinder.py b/hysop_examples/examples/cylinder/oscillating_cylinder.py index a19d4f94d..287159233 100644 --- a/hysop_examples/examples/cylinder/oscillating_cylinder.py +++ b/hysop_examples/examples/cylinder/oscillating_cylinder.py @@ -247,13 +247,25 @@ def compute(args): adapt_dt.equivalent_CFL, filename='parameters.txt', precision=8) + # Attach a field debug dumper if requested + from hysop.tools.debug_dumper import DebugDumper + if args.debug_dump_target: + debug_dumper = DebugDumper( + path=args.debug_dump_dir, + name=args.debug_dump_target, + force_overwrite=True, enable_on_op_apply=True) + else: + debug_dumper = None + # Initialize vorticity, velocity, viscosity and density on all topologies problem.initialize_field(field=velo, formula=init_velocity) problem.initialize_field(field=vorti, formula=init_vorticity) problem.initialize_field(field=_lambda, formula=init_lambda) - # Finally solve the problem - problem.solve(simu, dry_run=args.dry_run) + # Finally solve the problem + problem.solve(simu, dry_run=args.dry_run, + debug_dumper=debug_dumper, + checkpoint_handler=args.checkpoint_handler) # Finalize problem.finalize() @@ -279,6 +291,7 @@ if __name__=='__main__': default_dump_dir=default_dump_dir) def _setup_parameters(self, args): + super(OscillatingCylinderArgParser, self)._setup_parameters(args) dim = args.ndim if (dim not in (2,3)): msg='Domain should be 2D or 3D.' diff --git a/hysop_examples/examples/fixed_point/heat_equation.py b/hysop_examples/examples/fixed_point/heat_equation.py index 7804d6231..452b69d26 100644 --- a/hysop_examples/examples/fixed_point/heat_equation.py +++ b/hysop_examples/examples/fixed_point/heat_equation.py @@ -192,7 +192,8 @@ def compute(args): simu.write_parameters(t, fixedPoint.it_num, filename='parameters.txt', precision=8) problem.initialize_field(u, formula=init_u) - problem.solve(simu, dry_run=args.dry_run) + problem.solve(simu, dry_run=args.dry_run, + checkpoint_handler=args.checkpoint_handler) problem.finalize() diff --git a/hysop_examples/examples/flow_around_sphere/flow_around_sphere.py b/hysop_examples/examples/flow_around_sphere/flow_around_sphere.py index 16a239382..ce5c6ef71 100644 --- a/hysop_examples/examples/flow_around_sphere/flow_around_sphere.py +++ b/hysop_examples/examples/flow_around_sphere/flow_around_sphere.py @@ -314,13 +314,26 @@ def compute(args): simu.write_parameters(t, dt_cfl, dt_advec, dt, enstrophy, flowrate, min_max_U.Finf, min_max_W.Finf, adapt_dt.equivalent_CFL, filename='parameters.txt', precision=8) + + # Attach a field debug dumper if requested + from hysop.tools.debug_dumper import DebugDumper + if args.debug_dump_target: + debug_dumper = DebugDumper( + path=args.debug_dump_dir, + name=args.debug_dump_target, + force_overwrite=True, enable_on_op_apply=True) + else: + debug_dumper = None problem.initialize_field(vorti, formula=computeVort) problem.initialize_field(velo, formula=computeVel) problem.initialize_field(sphere, formula=computeSphere) # Finally solve the problem - problem.solve(simu) + problem.solve(simu, dry_run=args.dry_run, + debug_dumper=debug_dumper, + checkpoint_handler=args.checkpoint_handler) + # Finalize problem.finalize() diff --git a/hysop_examples/examples/multiresolution/scalar_advection.py b/hysop_examples/examples/multiresolution/scalar_advection.py index f41cf0a00..58a9daf25 100644 --- a/hysop_examples/examples/multiresolution/scalar_advection.py +++ b/hysop_examples/examples/multiresolution/scalar_advection.py @@ -181,9 +181,21 @@ def compute(args): times_of_interest=args.times_of_interest, dt=dt, dt0=dt0) - # Finally solve the problem - problem.solve(simu, dry_run=args.dry_run) - + # Attach a field debug dumper if requested + from hysop.tools.debug_dumper import DebugDumper + if args.debug_dump_target: + debug_dumper = DebugDumper( + path=args.debug_dump_dir, + name=args.debug_dump_target, + force_overwrite=True, enable_on_op_apply=True) + else: + debug_dumper = None + + # Finally solve the problem + problem.solve(simu, dry_run=args.dry_run, + debug_dumper=debug_dumper, + checkpoint_handler=args.checkpoint_handler) + # Finalize problem.finalize() @@ -219,6 +231,7 @@ if __name__=='__main__': self._check_default(args, 'velocity', tuple, allow_none=False) def _setup_parameters(self, args): + super(MultiResolutionScalarAdvectionArgParser, self)._setup_parameters(args) if len(args.velocity) == 1: args.velocity *= args.ndim diff --git a/hysop_examples/examples/particles_above_salt/particles_above_salt_bc.py b/hysop_examples/examples/particles_above_salt/particles_above_salt_bc.py index 77c7ace66..2c74ae5e8 100644 --- a/hysop_examples/examples/particles_above_salt/particles_above_salt_bc.py +++ b/hysop_examples/examples/particles_above_salt/particles_above_salt_bc.py @@ -308,14 +308,26 @@ def compute(args): min_max_U.Finf, min_max_W.Finf, adapt_dt.equivalent_CFL, filename='parameters.txt', precision=8) + # Attach a field debug dumper if requested + from hysop.tools.debug_dumper import DebugDumper + if args.debug_dump_target: + debug_dumper = DebugDumper( + path=args.debug_dump_dir, + name=args.debug_dump_target, + force_overwrite=True, enable_on_op_apply=True) + else: + debug_dumper = None + # Initialize vorticity, velocity, S and C on all topologies problem.initialize_field(field=velo, formula=init_velocity) problem.initialize_field(field=vorti, formula=init_vorticity) problem.initialize_field(field=C, formula=init_concentration, l0=l0) problem.initialize_field(field=S, formula=init_salinity, l0=l0) - # Finally solve the problem - problem.solve(simu, dry_run=args.dry_run) + # Finally solve the problem + problem.solve(simu, dry_run=args.dry_run, + debug_dumper=debug_dumper, + checkpoint_handler=args.checkpoint_handler) # Finalize problem.finalize() @@ -345,11 +357,11 @@ if __name__=='__main__': default_dump_dir=default_dump_dir) def _setup_parameters(self, args): + super(ParticleAboveSaltArgParser, self)._setup_parameters(args) dim = args.ndim if (dim not in (2,3)): msg='Domain should be 2D or 3D.' self.error(msg) - super(ParticleAboveSaltArgParser, self)._setup_parameters(args) parser = ParticleAboveSaltArgParser() diff --git a/hysop_examples/examples/particles_above_salt/particles_above_salt_bc_3d.py b/hysop_examples/examples/particles_above_salt/particles_above_salt_bc_3d.py index 0e4abfd2d..08adb0dc2 100644 --- a/hysop_examples/examples/particles_above_salt/particles_above_salt_bc_3d.py +++ b/hysop_examples/examples/particles_above_salt/particles_above_salt_bc_3d.py @@ -324,14 +324,27 @@ def compute(args): min_max_U.Finf, min_max_W.Finf, adapt_dt.equivalent_CFL, filename='parameters.txt', precision=8) + # Attach a field debug dumper if requested + from hysop.tools.debug_dumper import DebugDumper + if args.debug_dump_target: + debug_dumper = DebugDumper( + path=args.debug_dump_dir, + name=args.debug_dump_target, + force_overwrite=True, enable_on_op_apply=True) + else: + debug_dumper = None + # Initialize vorticity, velocity, S and C on all topologies problem.initialize_field(field=velo, formula=init_velocity) problem.initialize_field(field=vorti, formula=init_vorticity) problem.initialize_field(field=C, formula=init_concentration, l0=l0) problem.initialize_field(field=S, formula=init_salinity, l0=l0) - # Finally solve the problem - problem.solve(simu, dry_run=args.dry_run) + # Finally solve the problem + problem.solve(simu, dry_run=args.dry_run, + debug_dumper=debug_dumper, + checkpoint_handler=args.checkpoint_handler) + # Finalize problem.finalize() @@ -382,6 +395,7 @@ if __name__=='__main__': self._check_positive(args, ('schmidt', 'tau', 'Vp', 'Rs')) def _setup_parameters(self, args): + super(ParticleAboveSaltArgParser, self)._setup_parameters(args) dim = args.ndim if (dim not in (2,3)): msg='Domain should be 2D or 3D.' diff --git a/hysop_examples/examples/particles_above_salt/particles_above_salt_periodic.py b/hysop_examples/examples/particles_above_salt/particles_above_salt_periodic.py index e028fa051..10de6dde5 100644 --- a/hysop_examples/examples/particles_above_salt/particles_above_salt_periodic.py +++ b/hysop_examples/examples/particles_above_salt/particles_above_salt_periodic.py @@ -318,6 +318,16 @@ def compute(args): min_max_U.Finf, min_max_W.Finf, adapt_dt.equivalent_CFL, filename='parameters.txt', precision=8) + # Attach a field debug dumper if requested + from hysop.tools.debug_dumper import DebugDumper + if args.debug_dump_target: + debug_dumper = DebugDumper( + path=args.debug_dump_dir, + name=args.debug_dump_target, + force_overwrite=True, enable_on_op_apply=True) + else: + debug_dumper = None + # Initialize vorticity, velocity, S and C on all topologies problem.initialize_field(field=velo, formula=init_velocity) problem.initialize_field(field=vorti, formula=init_vorticity) @@ -325,8 +335,11 @@ def compute(args): problem.initialize_field(field=S, formula=init_salinity, l0=l0) problem.initialize_field(field=_lambda, formula=init_lambda) - # Finally solve the problem - problem.solve(simu, dry_run=args.dry_run) + # Finally solve the problem + problem.solve(simu, dry_run=args.dry_run, + debug_dumper=debug_dumper, + checkpoint_handler=args.checkpoint_handler) + # Finalize problem.finalize() @@ -356,6 +369,7 @@ if __name__=='__main__': default_dump_dir=default_dump_dir) def _setup_parameters(self, args): + super(ParticleAboveSaltArgParser, self)._setup_parameters(args) dim = args.ndim if (dim not in (2,3)): msg='Domain should be 2D or 3D.' diff --git a/hysop_examples/examples/particles_above_salt/particles_above_salt_symmetrized.py b/hysop_examples/examples/particles_above_salt/particles_above_salt_symmetrized.py index a261d3c4f..c77fb54bc 100644 --- a/hysop_examples/examples/particles_above_salt/particles_above_salt_symmetrized.py +++ b/hysop_examples/examples/particles_above_salt/particles_above_salt_symmetrized.py @@ -306,14 +306,27 @@ def compute(args): min_max_U.Finf, min_max_W.Finf, adapt_dt.equivalent_CFL, filename='parameters.txt', precision=8) + # Attach a field debug dumper if requested + from hysop.tools.debug_dumper import DebugDumper + if args.debug_dump_target: + debug_dumper = DebugDumper( + path=args.debug_dump_dir, + name=args.debug_dump_target, + force_overwrite=True, enable_on_op_apply=True) + else: + debug_dumper = None + # Initialize vorticity, velocity, S and C on all topologies problem.initialize_field(field=velo, formula=init_velocity) problem.initialize_field(field=vorti, formula=init_vorticity) problem.initialize_field(field=C, formula=init_concentration, l0=l0) problem.initialize_field(field=S, formula=init_salinity, l0=l0) - # Finally solve the problem - problem.solve(simu, dry_run=args.dry_run) + # Finally solve the problem + problem.solve(simu, dry_run=args.dry_run, + debug_dumper=debug_dumper, + checkpoint_handler=args.checkpoint_handler) + # Finalize problem.finalize() @@ -343,6 +356,7 @@ if __name__=='__main__': default_dump_dir=default_dump_dir) def _setup_parameters(self, args): + super(ParticleAboveSaltArgParser, self)._setup_parameters(args) dim = args.ndim if (dim not in (2,3)): msg='Domain should be 2D or 3D.' diff --git a/hysop_examples/examples/scalar_advection/levelset.py b/hysop_examples/examples/scalar_advection/levelset.py index 76a26d1e5..0b153896c 100644 --- a/hysop_examples/examples/scalar_advection/levelset.py +++ b/hysop_examples/examples/scalar_advection/levelset.py @@ -207,10 +207,24 @@ def compute(args): if args.display_graph: problem.display(args.visu_rank) - dfields = problem.input_discrete_fields - dfields[scalar].initialize(formula=init_scalar) + # Attach a field debug dumper if requested + from hysop.tools.debug_dumper import DebugDumper + if args.debug_dump_target: + debug_dumper = DebugDumper( + path=args.debug_dump_dir, + name=args.debug_dump_target, + force_overwrite=True, enable_on_op_apply=True) + else: + debug_dumper = None + + # Initialize scalar + problem.initialize_field(scalar, formula=init_scalar) + + # Finally solve the problem + problem.solve(simu, dry_run=args.dry_run, + debug_dumper=debug_dumper, + checkpoint_handler=args.checkpoint_handler) - problem.solve(simu, dry_run=args.dry_run) problem.finalize() diff --git a/hysop_examples/examples/scalar_advection/scalar_advection.py b/hysop_examples/examples/scalar_advection/scalar_advection.py index 008db85b2..17c7275a0 100644 --- a/hysop_examples/examples/scalar_advection/scalar_advection.py +++ b/hysop_examples/examples/scalar_advection/scalar_advection.py @@ -105,7 +105,8 @@ def compute(args): problem.insert(splitting) # Add a writer of input field at given frequency. - problem.dump_inputs(fields=scalar, filename='S0', frequency=args.dump_freq, **extra_op_kwds) + problem.dump_inputs(fields=scalar, + io_params=args.io_params.clone(filename='S0'), **extra_op_kwds) problem.build(args) # If a visu_rank was provided, and show_graph was set, @@ -134,8 +135,21 @@ def compute(args): times_of_interest=args.times_of_interest, dt=dt, dt0=dt0) - # Finally solve the problem - problem.solve(simu, dry_run=args.dry_run) + # Attach a field debug dumper if requested + from hysop.tools.debug_dumper import DebugDumper + if args.debug_dump_target: + debug_dumper = DebugDumper( + path=args.debug_dump_dir, + name=args.debug_dump_target, + force_overwrite=True, enable_on_op_apply=True) + else: + debug_dumper = None + + # Finally solve the problem + problem.solve(simu, dry_run=args.dry_run, + debug_dumper=debug_dumper, + checkpoint_handler=args.checkpoint_handler) + # Finalize problem.finalize() @@ -172,6 +186,7 @@ if __name__=='__main__': self._check_default(args, 'velocity', tuple, allow_none=False) def _setup_parameters(self, args): + super(ScalarAdvectionArgParser, self)._setup_parameters(args) if len(args.velocity) == 1: args.velocity *= args.ndim diff --git a/hysop_examples/examples/scalar_diffusion/scalar_diffusion.py b/hysop_examples/examples/scalar_diffusion/scalar_diffusion.py index cddacb6a5..dcf1f88bb 100755 --- a/hysop_examples/examples/scalar_diffusion/scalar_diffusion.py +++ b/hysop_examples/examples/scalar_diffusion/scalar_diffusion.py @@ -124,9 +124,7 @@ def compute(args): # Finally solve the problem problem.solve(simu, dry_run=args.dry_run, debug_dumper=debug_dumper, - load_checkpoint=args.load_checkpoint, - save_checkpoint=args.save_checkpoint, - checkpoint_io_params=args.checkpoint_io_params) + checkpoint_handler=args.checkpoint_handler) # Finalize problem.finalize() diff --git a/hysop_examples/examples/sediment_deposit/sediment_deposit.py b/hysop_examples/examples/sediment_deposit/sediment_deposit.py index 9b67663e7..9b2107b43 100644 --- a/hysop_examples/examples/sediment_deposit/sediment_deposit.py +++ b/hysop_examples/examples/sediment_deposit/sediment_deposit.py @@ -320,14 +320,27 @@ def compute(args): min_max_U.Finf, min_max_W.Finf, adapt_dt.equivalent_CFL, filename='parameters.txt', precision=8) + # Attach a field debug dumper if requested + from hysop.tools.debug_dumper import DebugDumper + if args.debug_dump_target: + debug_dumper = DebugDumper( + path=args.debug_dump_dir, + name=args.debug_dump_target, + force_overwrite=True, enable_on_op_apply=True) + else: + debug_dumper = None + # Initialize vorticity, velocity, S on all topologies problem.initialize_field(field=velo, formula=init_velocity) problem.initialize_field(field=vorti, formula=init_vorticity) problem.initialize_field(field=S, formula=init_sediment, nblobs=nblobs, rblob=rblob, without_ghosts=True) - # Finally solve the problem - problem.solve(simu, dry_run=args.dry_run) + # Finally solve the problem + problem.solve(simu, dry_run=args.dry_run, + debug_dumper=debug_dumper, + checkpoint_handler=args.checkpoint_handler) + # Finalize problem.finalize() @@ -354,6 +367,7 @@ if __name__=='__main__': default_dump_dir=default_dump_dir) def _setup_parameters(self, args): + super(ParticleAboveSaltArgParser, self)._setup_parameters(args) dim = args.ndim if (dim not in (2,3)): msg='Domain should be 2D or 3D.' diff --git a/hysop_examples/examples/sediment_deposit/sediment_deposit_levelset.py b/hysop_examples/examples/sediment_deposit/sediment_deposit_levelset.py index fb1848c67..476852fa0 100644 --- a/hysop_examples/examples/sediment_deposit/sediment_deposit_levelset.py +++ b/hysop_examples/examples/sediment_deposit/sediment_deposit_levelset.py @@ -381,14 +381,27 @@ def compute(args): min_max_U.Finf, min_max_W.Finf, adapt_dt.equivalent_CFL, filename='parameters.txt', precision=8) + # Attach a field debug dumper if requested + from hysop.tools.debug_dumper import DebugDumper + if args.debug_dump_target: + debug_dumper = DebugDumper( + path=args.debug_dump_dir, + name=args.debug_dump_target, + force_overwrite=True, enable_on_op_apply=True) + else: + debug_dumper = None + # Initialize vorticity, velocity, S on all topologies problem.initialize_field(field=velo, formula=init_velocity) problem.initialize_field(field=vorti, formula=init_vorticity) problem.initialize_field(field=phi, formula=init_phi, nblobs=nblobs, rblob=rblob, without_ghosts=BLOB_INIT) - # Finally solve the problem - problem.solve(simu, dry_run=args.dry_run) + # Finally solve the problem + problem.solve(simu, dry_run=args.dry_run, + debug_dumper=debug_dumper, + checkpoint_handler=args.checkpoint_handler) + # Finalize problem.finalize() @@ -416,6 +429,7 @@ if __name__=='__main__': default_dump_dir=default_dump_dir) def _setup_parameters(self, args): + super(ParticleAboveSaltArgParser, self)._setup_parameters(args) dim = args.ndim if (dim not in (2,3)): msg='Domain should be 2D or 3D.' diff --git a/hysop_examples/examples/shear_layer/shear_layer.py b/hysop_examples/examples/shear_layer/shear_layer.py index 63dfebfb0..7220b5adc 100644 --- a/hysop_examples/examples/shear_layer/shear_layer.py +++ b/hysop_examples/examples/shear_layer/shear_layer.py @@ -198,7 +198,11 @@ def compute(args): # Finally solve the problem problem.solve(simu, dry_run=args.dry_run, - debug_dumper=debug_dumper, plot_freq=args.plot_freq) + debug_dumper=debug_dumper, + load_checkpoint=args.load_checkpoint, + save_checkpoint=args.save_checkpoint, + checkpoint_io_params=args.checkpoint_io_params, + plot_freq=args.plot_freq) # Finalize problem.finalize() diff --git a/hysop_examples/examples/taylor_green/taylor_green.py b/hysop_examples/examples/taylor_green/taylor_green.py index b0c9d728e..942bfccac 100644 --- a/hysop_examples/examples/taylor_green/taylor_green.py +++ b/hysop_examples/examples/taylor_green/taylor_green.py @@ -309,9 +309,7 @@ def compute(args): # Finally solve the problem problem.solve(simu, dry_run=args.dry_run, debug_dumper=debug_dumper, - load_checkpoint=args.load_checkpoint, - save_checkpoint=args.save_checkpoint, - checkpoint_io_params=args.checkpoint_io_params) + checkpoint_handler=args.checkpoint_handler) # Finalize problem.finalize() diff --git a/hysop_examples/examples/taylor_green/taylor_green_cpuFortran.py b/hysop_examples/examples/taylor_green/taylor_green_cpuFortran.py index b6a65931b..fdb613695 100644 --- a/hysop_examples/examples/taylor_green/taylor_green_cpuFortran.py +++ b/hysop_examples/examples/taylor_green/taylor_green_cpuFortran.py @@ -225,7 +225,9 @@ def compute(args): problem.initialize_field(vorti, formula=init_vorticity) # Finally solve the problem - problem.solve(simu, debug_dumper=debug_dumper) + problem.solve(simu, dry_run=args.dry_run, + debug_dumper=debug_dumper, + checkpoint_handler=args.checkpoint_handler) # Finalize problem.finalize() @@ -288,6 +290,7 @@ if __name__=='__main__': self._check_positive(args, 'plot_freq', strict=True, allow_none=False) def _setup_parameters(self, args): + super(TaylorGreenArgParser, self)._setup_parameters(args) if (args.ndim != 3): msg='This example only works for 3D domains.' self.error(msg) -- GitLab