Skip to content
Snippets Groups Projects
Commit 48970ccd authored by Jean-Baptiste Keck's avatar Jean-Baptiste Keck
Browse files

tmp support

parent f6c7f71d
No related branches found
No related tags found
1 merge request!16MPI operators
...@@ -11,7 +11,7 @@ Required version of mkl_fft is: https://gitlab.com/keckj/mkl_fft ...@@ -11,7 +11,7 @@ Required version of mkl_fft is: https://gitlab.com/keckj/mkl_fft
If MKL_THREADING_LAYER is not set, or is set to INTEL, FFT tests will fail. If MKL_THREADING_LAYER is not set, or is set to INTEL, FFT tests will fail.
""" """
import functools import functools, warnings
import numpy as np import numpy as np
import numba as nb import numba as nb
from mkl_fft import (ifft as mkl_ifft, from mkl_fft import (ifft as mkl_ifft,
...@@ -23,10 +23,15 @@ from hysop.tools.types import first_not_None ...@@ -23,10 +23,15 @@ from hysop.tools.types import first_not_None
from hysop.numerics.fft.host_fft import HostFFTPlanI, HostFFTI, HostArray from hysop.numerics.fft.host_fft import HostFFTPlanI, HostFFTI, HostArray
from hysop.numerics.fft.fft import \ from hysop.numerics.fft.fft import \
complex_to_float_dtype, float_to_complex_dtype, \ complex_to_float_dtype, float_to_complex_dtype, \
mk_view, mk_shape mk_view, mk_shape, simd_alignment
from hysop import __DEFAULT_NUMBA_TARGET__ from hysop import __DEFAULT_NUMBA_TARGET__
from hysop.tools.numba_utils import make_numba_signature, prange from hysop.tools.numba_utils import make_numba_signature, prange
from hysop.numerics.fft.fft import HysopFFTWarning, bytes2str
from hysop.tools.warning import HysopWarning
class HysopMKLFftWarning(HysopWarning):
pass
def setup_transform(x, axis, transform, inverse, type): def setup_transform(x, axis, transform, inverse, type):
shape = x.shape shape = x.shape
...@@ -421,12 +426,52 @@ class MklFFTPlan(HostFFTPlanI): ...@@ -421,12 +426,52 @@ class MklFFTPlan(HostFFTPlanI):
self.allocate() self.allocate()
self.fn(**self.kwds) self.fn(**self.kwds)
self.rescale() self.rescale()
def allocate(self, input_buf=None, output_buf=None):
@property
def required_buffer_size(self):
alignment = simd_alignment
if self._required_input_tmp:
sin, din = self._required_input_tmp['size'], self._required_input_tmp['dtype']
sin *= din.itemsize
Bin = ((sin+alignment-1)//alignment)*alignment
else:
Bin = 0
if self._required_output_tmp:
sout, dout = self._required_output_tmp['size'], self._required_output_tmp['dtype']
sout *= dout.itemsize
Bout = ((sout+alignment-1)//alignment)*alignment
else:
Bout = 0
return Bin+Bout
def allocate(self, buf=None):
"""Allocate plan extra memory, possibly with a custom buffer.""" """Allocate plan extra memory, possibly with a custom buffer."""
if self._allocated: if self._allocated:
msg='Plan was already allocated.' msg='Plan was already allocated.'
raise RuntimeError(msg) raise RuntimeError(msg)
if (buf is not None):
alignment = simd_alignment
if self._required_input_tmp:
sin, din = self._required_input_tmp['size'], self._required_input_tmp['dtype']
sin *= din.itemsize
Bin = ((sin+alignment-1)//alignment)*alignment
else:
Bin = 0
if self._required_output_tmp:
sout, dout = self._required_output_tmp['size'], self._required_output_tmp['dtype']
sout *= dout.itemsize
Bout = ((sout+alignment-1)//alignment)*alignment
else:
Bout = 0
assert buf.dtype.itemsize == 1
assert buf.size == Bin+Bout
input_buf = buf[:sin].view(dtype=din).reshape(ssin) if Bin else None
output_buf = buf[Bin:Bin+sout].view(dtype=dout).reshape(ssout) if Bout else None
else:
input_buf = None
output_buf = None
for (k, buf, required_tmp) in zip(('input', 'output'), for (k, buf, required_tmp) in zip(('input', 'output'),
(input_buf, output_buf), (input_buf, output_buf),
...@@ -445,7 +490,7 @@ class MklFFTPlan(HostFFTPlanI): ...@@ -445,7 +490,7 @@ class MklFFTPlan(HostFFTPlanI):
if self.planner.error_on_allocation: if self.planner.error_on_allocation:
raise RuntimeError(msg) raise RuntimeError(msg)
else: else:
warnings.warn(msg, HysopGpyFftWarning) warnings.warn(msg, HysopMKLFftWarning)
buf = self.planner.backend.empty(shape=shape, buf = self.planner.backend.empty(shape=shape,
dtype=dtype) dtype=dtype)
elif (buf.shape != shape) or (buf.dtype != dtype): elif (buf.shape != shape) or (buf.dtype != dtype):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment