diff --git a/hysop/backend/device/opencl/operator/spatial_filtering.py b/hysop/backend/device/opencl/operator/spatial_filtering.py index 7266b3a4b6f76c8cd3cf2cc7af9614e2d3d9dc06..e1c70a607f7b694542c863aed704c58dc6021bc6 100644 --- a/hysop/backend/device/opencl/operator/spatial_filtering.py +++ b/hysop/backend/device/opencl/operator/spatial_filtering.py @@ -40,9 +40,9 @@ class OpenClSpectralLowpassFilter(SpectralLowpassFilterBase, OpenClOperator): fill_ones, _ = kernel_generator.elementwise_kernel('scale', expr) fill_ones(queue=self.cl_env.default_queue) - self.Ft() + self.Ft(simulation=False) kl(queue=self.cl_env.default_queue) # here we apply unscaled filter - self.Bt() + self.Bt(simulation=False) # Here we get the coefficient scaling = 1.0 / self.Bt.output_buffer[(0,)*self.FOUT.ndim].get() @@ -69,9 +69,9 @@ class OpenClSpectralLowpassFilter(SpectralLowpassFilterBase, OpenClOperator): def apply(self, **kwds): """Apply spectral filter (which is just a square window centered on low frequencies).""" super(OpenClSpectralLowpassFilter, self).apply(**kwds) - evt = self.Ft() + evt = self.Ft(**kwds) evt = self.filter() - evt = self.Bt() + evt = self.Bt(**kwds) evt = self.scale() if (self.exchange_ghosts is not None): evt = self.exchange_ghosts() diff --git a/hysop/numerics/fft/fft.py b/hysop/numerics/fft/fft.py index e6ef39eb1e6c97438ca3640fd9aa8c345e9acfb7..f62f4ffd1d46f9ae65ac1bcaccdd05d0b40b6119 100644 --- a/hysop/numerics/fft/fft.py +++ b/hysop/numerics/fft/fft.py @@ -24,6 +24,7 @@ from hysop.tools.types import first_not_None, check_instance from hysop.tools.numerics import float_to_complex_dtype, complex_to_float_dtype from hysop.tools.units import bytes2str from hysop.tools.warning import HysopWarning +from hysop.tools.spectral_utils import SpectralTransformUtils as STU from hysop.core.arrays.array import Array from hysop.core.arrays.array_backend import ArrayBackend @@ -706,4 +707,22 @@ class FFTI(object): def plan_fill_zeros(self, tg, a, slices): """Plan to fill every input slices of input array a with zeroes.""" pass + + #@abstractmethod + def plan_compute_energy(self, tg, src, dst, transforms): + """Plan to compute energy from src to energy.""" + assert src.ndim == len(transforms) + assert dst.ndim == 1 + is_C2C = () + K2 = () + for (Ni, Ti) in zip(src.shape, transforms): + C2C = STU.is_C2C(Ti) + Ki = Ni//2 if C2C else Ni-1 + is_C2C += (C2C,) + K2 += (Ki**2,) + # for C2C we need to check j = (i<(N+1)//2 ? i : N-i) + max_wavenumber = int(round(sum(K2)**0.5, 0)) + msg='Destination buffer should have size {} but has size {}.'.format(max_wavenumber+1, dst.size) + assert (dst.size == max_wavenumber+1), msg + return is_C2C diff --git a/hysop/operator/base/spectral_operator.py b/hysop/operator/base/spectral_operator.py index cb0a2c821831a75beb1574f59efb3853b9502dc6..1e30197b25785c7896eb4c5f056ec6e20f5e66b1 100644 --- a/hysop/operator/base/spectral_operator.py +++ b/hysop/operator/base/spectral_operator.py @@ -626,7 +626,8 @@ class SpectralTransformGroup(object): symbolic_transform=transforms[idx], custom_output_buffer=custom_output_buffer, action=action, - compute_energy=compute_energy, plot_energy=plot_energy) + compute_energy=compute_energy, + plot_energy=plot_energy) self._forward_transforms[(f,axes,transform_tag)] = planned_transforms[idx] else: assert (field,axes,transform_tag) not in self._forward_transforms, msg.format(field.name, axes, transform_tag) @@ -638,7 +639,8 @@ class SpectralTransformGroup(object): symbolic_transform=transforms, custom_output_buffer=custom_output_buffer, action=action, - compute_energy=compute_energy, plot_energy=plot_energy) + compute_energy=compute_energy, + plot_energy=plot_energy) self._forward_transforms[(field,axes,transform_tag)] = planned_transforms return planned_transforms @@ -733,7 +735,8 @@ class SpectralTransformGroup(object): custom_input_buffer=custom_input_buffer, matching_forward_transform=matching_forward_transform, action=action, - compute_energy=compute_energy, plot_energy=plot_energy) + compute_energy=compute_energy, + plot_energy=plot_energy) self._backward_transforms[(f,axes,transform_tag)] = planned_transforms[idx] else: assert (field,axes,transform_tag) not in self._backward_transforms, msg.format(field.name, axes, transform_tag) @@ -745,7 +748,8 @@ class SpectralTransformGroup(object): custom_input_buffer=custom_input_buffer, matching_forward_transform=matching_forward_transform, action=action, - compute_energy=compute_energy, plot_energy=plot_energy) + compute_energy=compute_energy, + plot_energy=plot_energy) self._backward_transforms[(field,axes,transform_tag)] = planned_transforms return planned_transforms @@ -819,7 +823,7 @@ class PlannedSpectralTransform(object): check_instance(symbolic_transform, AppliedSpectralTransform) check_instance(action, SpectralTransformAction) check_instance(compute_energy, int, allow_none=True) - check_instance(plot_energy, IOParams, allow_none=True) + check_instance(plot_energy, IOParams, allow_none=True) msg='Can only specify one parameter between compute_energy and plot_energy.' assert (compute_energy is None) or (plot_energy is None), msg assert custom_input_buffer in (None, 'B0', 'B1', 'auto'), custom_input_buffer @@ -836,8 +840,15 @@ class PlannedSpectralTransform(object): self._do_compute_energy = (compute_energy is not None) or (plot_energy is not None) self._do_plot_energy = (plot_energy is not None) - self._energy_ioparams = first_not_None(compute_energy, plot_energy) + if (plot_energy is not None): + self._compute_energy_frequency = plot_energy.frequency + elif (compute_energy is not None): + self._compute_energy_frequency = compute_energy + else: + self._compute_energy_frequency = None + self._plot_energy_ioparams = plot_energy del plot_energy + del compute_energy field = self.s.field is_forward = self.s.is_forward @@ -1540,13 +1551,14 @@ class PlannedSpectralTransform(object): assert len(shape) == len(transforms) K2 = () for (tr, Ni) in zip(transforms, shape): - Ki = Ni//2 if STU.is_C2C(tr) else Ni + Ki = Ni//2 if STU.is_C2C(tr) else Ni-1 K2 += (Ki*Ki,) max_wavenumber = int(round(sum(K2)**0.5, 0)) energy_nbytes = compute_nbytes(max_wavenumber+1, self.dfield.dtype) requests[ENERGY_tag] = energy_nbytes self.max_wavenumber = max_wavenumber self.energy_nbytes = energy_nbytes + self.permuted_transforms = transforms return requests @@ -1864,21 +1876,68 @@ SPECTRAL TRANSFORM SETUP # allocate fft plans FFTI.allocate_plans(op, fft_plans, tmp_buffer=TMP) + + # build kernels to compute energy if required + if self._do_compute_energy: + compute_energy_queue = FFTI.new_queue(tg=self, name='compute_energy') + spectral_buffer = self.output_buffer if self.is_forward else self.input_buffer + compute_energy_queue += FFTI.plan_compute_energy(tg=tg, + src=spectral_buffer, dst=energy_buffer, + transforms=self.permuted_transforms) + else: + compute_energy_queue = None self._queue = queue + self._compute_energy_queue = compute_energy_queue self._ready = True - def __call__(self, simu=None): + def __call__(self, **kwds): assert (self._ready) assert (self._queue is not None) + evt = self._pre_transform_actions(**kwds) + evt = self._queue.execute() + evt = self._post_transform_actions(**kwds) + return evt + + def _pre_transform_actions(self, simulation, **kwds): + evt = None if self.is_backward and self._do_compute_energy: - evt = self.compute_energy(simu=simu) + evt = self.compute_energy(simulation=simulation) if self._do_plot_energy: - evt = self.plot_energy(simu=simu) - evt = self._queue.execute() + evt = self.plot_energy(simulation=simulation) + return evt + + def _post_transform_actions(self, simulation, **kwds): + evt = None if self.is_forward and self._do_compute_energy: - evt = self.compute_energy(simu=simu) + evt = self.compute_energy(simulation=simulation) if self._do_plot_energy: - evt = self.plot_energy(simu=simu) - return evt + evt = self.plot_energy(simulation=simulation) + + def compute_energy(self, simulation): + if (simulation is False): + return + msg='No simulation was passed in {}.__call__().'.format(type(self)) + assert (simulation is not None), msg + frequency = self._compute_energy_frequency + ite = simulation.current_iteration + should_compute = (frequency>0) and (ite % frequency == 0) + should_compute |= simulation.is_time_of_interest + if should_compute: + return self._compute_energy_queue() + + def plot_energy(self, simulation): + if (simulation is False): + return + msg='No simulation was passed in {}.__call__().'.format(type(self)) + assert (simulation is not None), msg + frequency = self._plot_energy_ioparams.frequency + ite = simulation.current_iteration + should_plot = (frequency>0) and (ite % frequency == 0) + should_plot |= simulation.is_time_of_interest + if should_plot: + return self._plot_energy() + + def _plot_energy(self): + pass