Skip to content
Snippets Groups Projects
Commit 15065d8b authored by EXT Jean-Matthieu Etancelin's avatar EXT Jean-Matthieu Etancelin
Browse files

improve convergence for iterative methods

parent e0845138
No related branches found
No related tags found
1 merge request!16MPI operators
......@@ -5,7 +5,7 @@ Convergence python backend.
from hysop.constants import HYSOP_REAL
from hysop.backend.host.host_operator import HostOperator
from hysop.operator.base.convergence import ConvergenceBase
from hysop.tools.decorators import debug
from hysop.tools.decorators import debug
from hysop.core.graph.graph import op_apply
from hysop.tools.numpywrappers import npw
from hysop.constants import ResidualError
......@@ -25,17 +25,20 @@ class PythonConvergence(ConvergenceBase, HostOperator):
super(PythonConvergence, self).setup(**kwds)
self.field_buffers = self.dField.compute_buffers
self._tmp_convergence = npw.zeros((self.field.nb_components),
self._tmp_convergence = npw.zeros((1+self.field.nb_components),
dtype=self.convergence.dtype)
self._tmp_reduce = npw.zeros((self.field.nb_components),
self._tmp_reduce = npw.zeros((1+self.field.nb_components),
dtype=self.convergence.dtype)
old = [npw.zeros(_.shape) for _ in self.field_buffers]
self.dField_old = tuple(old)
self.__compute_error_absolute = lambda ui, ui_old: npw.max(npw.abs(ui - ui_old))
self.__compute_error_relative = lambda ui, ui_old, max_ui: npw.max(
npw.abs(ui - ui_old))/max_ui
if self._residual_computation == ResidualError.ABSOLUTE:
self.__compute_error = lambda ui, ui_old, max_ui: npw.max(npw.abs(ui - ui_old))
self.__compute_error = self.__compute_error_absolute
elif self._residual_computation == ResidualError.RELATIVE:
self.__compute_error = lambda ui, ui_old, max_ui: npw.max(npw.abs(ui - ui_old))/max_ui
self.__compute_error = self.__compute_error_relative
else:
raise RuntimeError('Unknown residual computation method.')
......@@ -44,22 +47,21 @@ class PythonConvergence(ConvergenceBase, HostOperator):
@op_apply
def apply(self, **kwds):
u = self.field_buffers
u = self.field_buffers
u_old = self.dField_old
self._tmp_convergence[...] = 0.
for (i, (ui, ui_old)) in enumerate(zip(u, u_old)):
max_ui = npw.max(npw.abs(ui))
if max_ui < self._large_zero:
self._tmp_convergence[i] = self._eps
else:
self._tmp_convergence[i] = self.__compute_error(ui, ui_old, max_ui)
self._tmp_convergence[i] = self.__compute_error_absolute(ui, ui_old)
ui_old[...] = ui
self._tmp_convergence[-1] = npw.sum(self._tmp_convergence)
if self._residual_computation == ResidualError.RELATIVE:
max_u = npw.max([npw.max(npw.abs(_)) for _ in u])
self._tmp_convergence /= max_u
self.mpi_params.comm.Allreduce(sendbuf=self._tmp_convergence,
recvbuf=self._tmp_reduce,
op=MPI.MAX)
self.convergence.value = self._tmp_reduce
self.convergence.value = self._tmp_reduce[-1]
@classmethod
def supports_mpi(cls):
......
......@@ -7,7 +7,6 @@ from hysop import dprint, vprint
from hysop.tools.decorators import debug
from hysop.core.graph.graph import ready
from hysop.constants import HYSOP_REAL, HYSOP_INTEGER
from hysop.simulation import eps
from hysop.tools.numpywrappers import npw
from hysop.core.mpi import main_rank, main_size, main_comm
import numpy as np
......@@ -89,10 +88,12 @@ class IterativeMethod(Problem):
# create a pseudo-time step parameter if not given.
if (dt is None):
dt = ScalarParameter(name='pseudo-dt', dtype=HYSOP_REAL, min_value=eps,
initial_value=eps, quiet=True)
dt = ScalarParameter(name='pseudo-dt', dtype=HYSOP_REAL,
min_value=np.finfo(HYSOP_REAL).eps,
initial_value=np.finfo(HYSOP_REAL).eps,
quiet=True)
else:
dt.value = eps
dt.value = np.finfo(HYSOP_REAL).eps
self.dt0, self.dt = dt0, dt
self.state_print = state_print
self.max_iter = max_iter
......@@ -115,7 +116,7 @@ class IterativeMethod(Problem):
@debug
@ready
def apply(self, simulation, report_freq=0, **kwds):
def apply(self, simulation, report_freq=0, dbg=None, **kwds):
vprint('=== Entering iterative method...')
self.stop_criteria.value = self._stop_criteria_reset
......@@ -137,9 +138,9 @@ class IterativeMethod(Problem):
while not loop.is_over:
if loop.current_iteration % self.state_print == 0:
loop.print_state()
super(IterativeMethod, self).apply(simulation=loop, **kwds)
loop.advance(dbg=kwds['dbg'])
if (loop.current_iteration % report_freq) == 0:
super(IterativeMethod, self).apply(simulation=loop, dbg=dbg, **kwds)
loop.advance(dbg=dbg)
if report_freq > 0 and (loop.current_iteration % report_freq) == 0:
self.profiler_report()
avg_time = main_comm.allreduce(tm.interval) / main_size
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment