diff --git a/hysop/numerics/fft/_mkl_fft.py b/hysop/numerics/fft/_mkl_fft.py index e5f604799a595ab927b5af8769eea0739c786670..59b1eab11a37094d9967fa98ac7f0323eebb2872 100644 --- a/hysop/numerics/fft/_mkl_fft.py +++ b/hysop/numerics/fft/_mkl_fft.py @@ -11,7 +11,7 @@ Required version of mkl_fft is: https://gitlab.com/keckj/mkl_fft If MKL_THREADING_LAYER is not set, or is set to INTEL, FFT tests will fail. """ -import functools +import functools, warnings import numpy as np import numba as nb from mkl_fft import (ifft as mkl_ifft, @@ -23,10 +23,15 @@ from hysop.tools.types import first_not_None from hysop.numerics.fft.host_fft import HostFFTPlanI, HostFFTI, HostArray from hysop.numerics.fft.fft import \ complex_to_float_dtype, float_to_complex_dtype, \ - mk_view, mk_shape + mk_view, mk_shape, simd_alignment from hysop import __DEFAULT_NUMBA_TARGET__ from hysop.tools.numba_utils import make_numba_signature, prange +from hysop.numerics.fft.fft import HysopFFTWarning, bytes2str +from hysop.tools.warning import HysopWarning + +class HysopMKLFftWarning(HysopWarning): + pass def setup_transform(x, axis, transform, inverse, type): shape = x.shape @@ -421,12 +426,52 @@ class MklFFTPlan(HostFFTPlanI): self.allocate() self.fn(**self.kwds) self.rescale() - - def allocate(self, input_buf=None, output_buf=None): + + + @property + def required_buffer_size(self): + alignment = simd_alignment + if self._required_input_tmp: + sin, din = self._required_input_tmp['size'], self._required_input_tmp['dtype'] + sin *= din.itemsize + Bin = ((sin+alignment-1)//alignment)*alignment + else: + Bin = 0 + if self._required_output_tmp: + sout, dout = self._required_output_tmp['size'], self._required_output_tmp['dtype'] + sout *= dout.itemsize + Bout = ((sout+alignment-1)//alignment)*alignment + else: + Bout = 0 + return Bin+Bout + + def allocate(self, buf=None): """Allocate plan extra memory, possibly with a custom buffer.""" if self._allocated: msg='Plan was already allocated.' raise RuntimeError(msg) + + if (buf is not None): + alignment = simd_alignment + if self._required_input_tmp: + sin, din = self._required_input_tmp['size'], self._required_input_tmp['dtype'] + sin *= din.itemsize + Bin = ((sin+alignment-1)//alignment)*alignment + else: + Bin = 0 + if self._required_output_tmp: + sout, dout = self._required_output_tmp['size'], self._required_output_tmp['dtype'] + sout *= dout.itemsize + Bout = ((sout+alignment-1)//alignment)*alignment + else: + Bout = 0 + assert buf.dtype.itemsize == 1 + assert buf.size == Bin+Bout + input_buf = buf[:sin].view(dtype=din).reshape(ssin) if Bin else None + output_buf = buf[Bin:Bin+sout].view(dtype=dout).reshape(ssout) if Bout else None + else: + input_buf = None + output_buf = None for (k, buf, required_tmp) in zip(('input', 'output'), (input_buf, output_buf), @@ -445,7 +490,7 @@ class MklFFTPlan(HostFFTPlanI): if self.planner.error_on_allocation: raise RuntimeError(msg) else: - warnings.warn(msg, HysopGpyFftWarning) + warnings.warn(msg, HysopMKLFftWarning) buf = self.planner.backend.empty(shape=shape, dtype=dtype) elif (buf.shape != shape) or (buf.dtype != dtype):