diff --git a/hysop/numerics/fft/host_fft.py b/hysop/numerics/fft/host_fft.py index 20ebd2b834ceaeb2c9880e8b075c3f4749dbfce4..3bedc0d87c8d2404a4f46ad13ef6a38c13ee5565 100644 --- a/hysop/numerics/fft/host_fft.py +++ b/hysop/numerics/fft/host_fft.py @@ -13,7 +13,8 @@ import numba as nb from hysop import __FFTW_NUM_THREADS__, __FFTW_PLANNER_EFFORT__, __FFTW_PLANNER_TIMELIMIT__, \ __DEFAULT_NUMBA_TARGET__ from hysop.tools.types import first_not_None, check_instance -from hysop.tools.numba_utils import make_numba_signature +from hysop.tools.numba_utils import HAS_NUMBA, bake_numba_copy, bake_numba_accumulate, bake_numba_transpose +from hysop.tools.decorators import static_vars from hysop.backend.host.host_array_backend import HostArrayBackend from hysop.backend.host.host_array import HostArray from hysop.numerics.fft.fft import FFTQueueI, FFTPlanI, FFTI @@ -35,7 +36,7 @@ def can_exec_hptt(src, dst): try: import hptt - HAS_HPTT=True + HAS_HPTT=False # required version is: https://gitlab.com/keckj/hptt except ImportError: HAS_HPTT=False @@ -44,12 +45,6 @@ except ImportError: msg='Failed to import HPTT module, falling back to slow numpy transpose. Required version is available at https://gitlab.com/keckj/hptt.' warnings.warn(msg, HysopPerformanceWarning) -try: - import numba as nb - HAS_NUMBA=True -except ImportError: - HAS_NUMBA=False - class DummyEvent(object): @classmethod def wait(cls): @@ -138,11 +133,17 @@ class HostFFTI(FFTI): def plan_copy(self, tg, src, dst): src = self.ensure_callable(src) dst = self.ensure_callable(dst) + + @static_vars(numba_copy=None) def exec_copy(src=src, dst=dst): src, dst = src(), dst() if HAS_HPTT and can_exec_hptt(src, dst): hptt.tensorTransposeAndUpdate(perm=range(src.ndim), alpha=1.0, A=src, beta=0.0, B=dst) + elif HAS_NUMBA: + if (exec_copy.numba_copy is None): + exec_copy.numba_copy = bake_numba_copy(src=src, dst=dst) + exec_copy.numba_copy() else: dst[...] = src return exec_copy @@ -150,11 +151,17 @@ class HostFFTI(FFTI): def plan_accumulate(self, tg, src, dst): src = self.ensure_callable(src) dst = self.ensure_callable(dst) + + @static_vars(numba_accumulate=None) def exec_accumulate(src=src, dst=dst): src, dst = src(), dst() if HAS_HPTT and can_exec_hptt(src, dst): hptt.tensorTransposeAndUpdate(perm=range(src.ndim), alpha=1.0, A=src, beta=1.0, B=dst) + elif HAS_NUMBA: + if (exec_accumulate.numba_accumulate is None): + exec_accumulate.numba_accumulate = bake_numba_accumulate(src=src, dst=dst) + exec_accumulate.numba_accumulate() else: dst[...] += src return exec_copy @@ -162,11 +169,17 @@ class HostFFTI(FFTI): def plan_transpose(self, tg, src, dst, axes): src = self.ensure_callable(src) dst = self.ensure_callable(dst) + + @static_vars(numba_transpose=None) def exec_transpose(src=src, dst=dst, axes=axes): src, dst = src(), dst() if HAS_HPTT and can_exec_hptt(src, dst): hptt.tensorTransposeAndUpdate(perm=axes, alpha=1.0, A=src, beta=0.0, B=dst) + elif HAS_NUMBA: + if (exec_transpose.numba_transpose is None): + exec_transpose.numba_transpose = bake_numba_transpose(src=src, dst=dst, axes=axes) + exec_transpose.numba_transpose() else: dst[...] = np.transpose(a=src, axes=axes) return exec_transpose diff --git a/hysop/tools/numba_utils.py b/hysop/tools/numba_utils.py index e5095b8435a7fae2ae10959c06816f890ed78759..58a73f603acb00a0e61227e4781e82a6fdbbdae2 100644 --- a/hysop/tools/numba_utils.py +++ b/hysop/tools/numba_utils.py @@ -1,9 +1,15 @@ -import numba as nb import numpy as np -from numba import prange +from hysop import __DEFAULT_NUMBA_TARGET__ from hysop.core.arrays.array import Array +try: + import numba as nb + from numba import prange + HAS_NUMBA=True +except ImportError: + HAS_NUMBA=False + def make_numba_signature(*args, **kwds): raise_on_cl_array = kwds.pop('raise_on_cl_array', True) if kwds: @@ -94,3 +100,169 @@ def make_numba_signature(*args, **kwds): numba_args += (na,) return nb.void(*numba_args), ','.join(numba_layout) + + +def bake_numba_copy(dst, src, target=None): + if (target is None): + target = __DEFAULT_NUMBA_TARGET__ + signature, layout = make_numba_signature(dst, src) + if (dst.ndim == 1): + @nb.guvectorize([signature], layout, + target=target, nopython=True, cache=True) + def copy(dst, src): + for i in xrange(0, dst.shape[0]): + dst[i] = src[i] + elif (dst.ndim == 2): + @nb.guvectorize([signature], layout, + target=target, nopython=True, cache=True) + def copy(dst, src): + for i in prange(0, dst.shape[0]): + for j in xrange(0, dst.shape[1]): + dst[i,j] = src[i,j] + elif (dst.ndim == 3): + @nb.guvectorize([signature], layout, + target=target, nopython=True, cache=True) + def copy(dst, src): + for i in prange(0, dst.shape[0]): + for j in prange(0, dst.shape[1]): + for k in xrange(0, dst.shape[2]): + dst[i,j,k] = src[i,j,k] + elif (dst.ndim == 4): + @nb.guvectorize([signature], layout, + target=target, nopython=True, cache=True) + def copy(dst, src): + for i in prange(0, dst.shape[0]): + for j in prange(0, dst.shape[1]): + for k in prange(0, dst.shape[2]): + for l in xrange(0, dst.shape[3]): + dst[i,j,k,l] = src[i,j,k,l] + else: + raise NotImplementedError(dst.ndim) + def _exec_copy(copy=copy, dst=dst, src=src): + copy(dst,src) + return _exec_copy + + +def bake_numba_accumulate(dst, src, target=None): + if (target is None): + target = __DEFAULT_NUMBA_TARGET__ + signature, layout = make_numba_signature(dst, src) + if (dst.ndim == 1): + @nb.guvectorize([signature], layout, + target=target, nopython=True, cache=True) + def accumulate(dst, src): + for i in xrange(0, dst.shape[0]): + dst[i] += src[i] + elif (dst.ndim == 2): + @nb.guvectorize([signature], layout, + target=target, nopython=True, cache=True) + def accumulate(dst, src): + for i in prange(0, dst.shape[0]): + for j in xrange(0, dst.shape[1]): + dst[i,j] += src[i,j] + elif (dst.ndim == 3): + @nb.guvectorize([signature], layout, + target=target, nopython=True, cache=True) + def accumulate(dst, src): + for i in prange(0, dst.shape[0]): + for j in prange(0, dst.shape[1]): + for k in xrange(0, dst.shape[2]): + dst[i,j,k] += src[i,j,k] + elif (dst.ndim == 4): + @nb.guvectorize([signature], layout, + target=target, nopython=True, cache=True) + def accumulate(dst, src): + for i in prange(0, dst.shape[0]): + for j in prange(0, dst.shape[1]): + for k in prange(0, dst.shape[2]): + for l in xrange(0, dst.shape[3]): + dst[i,j,k,l] += src[i,j,k,l] + else: + raise NotImplementedError(dst.ndim) + def _exec_accumulate(accumulate=accumulate, dst=dst, src=src): + accumulate(dst,src) + return _exec_accumulate + + +def bake_numba_transpose(src, dst, axes, target=None): + # inefficient permutations + + if (target is None): + target = __DEFAULT_NUMBA_TARGET__ + signature, layout = make_numba_signature(dst, src) + + assert src.ndim == dst.ndim + assert dst.shape == tuple(src.shape[i] for i in axes) + assert dst.dtype == src.dtype + ndim = src.ndim + + def noop(dst, src): + pass + + if (ndim == 1): + transpose = noop + elif (ndim == 2): + if axes == (0,1): + transpose == noop + elif axes == (1,0): + @nb.guvectorize([signature], layout, + target=target, nopython=True, cache=True) + def transpose(dst, src): + for i in prange(0, src.shape[0]): + for j in xrange(0, src.shape[1]): + dst[j,i] = src[i,j] + else: + raise NotImplementedError + elif (ndim == 3): + if axes == (0,1,2): + transpose == noop + elif axes == (0,2,1): + @nb.guvectorize([signature], layout, + target=target, nopython=True, cache=True) + def transpose(dst, src): + for i in prange(0, src.shape[0]): + for j in prange(0, src.shape[1]): + for k in xrange(0, src.shape[2]): + dst[i,k,j] = src[i,j,k] + elif axes == (1,0,2): + @nb.guvectorize([signature], layout, + target=target, nopython=True, cache=True) + def transpose(dst, src): + for i in prange(0, src.shape[0]): + for j in prange(0, src.shape[1]): + for k in xrange(0, src.shape[2]): + dst[j,i,k] = src[i,j,k] + elif axes == (1,2,0): + @nb.guvectorize([signature], layout, + target=target, nopython=True, cache=True) + def transpose(dst, src): + for i in prange(0, src.shape[0]): + for j in prange(0, src.shape[1]): + for k in xrange(0, src.shape[2]): + dst[j,k,i] = src[i,j,k] + elif axes == (2,1,0): + @nb.guvectorize([signature], layout, + target=target, nopython=True, cache=True) + def transpose(dst, src): + for i in prange(0, src.shape[0]): + for j in prange(0, src.shape[1]): + for k in xrange(0, src.shape[2]): + dst[k,j,i] = src[i,j,k] + elif axes == (2,0,1): + @nb.guvectorize([signature], layout, + target=target, nopython=True, cache=True) + def transpose(dst, src): + for i in prange(0, src.shape[0]): + for j in prange(0, src.shape[1]): + for k in xrange(0, src.shape[2]): + dst[k,i,j] = src[i,j,k] + else: + raise NotImplementedError(axes) + else: + raise NotImplementedError(ndim) + + def _exec_transpose(transpose=transpose, dst=dst, src=src): + transpose(dst,src) + return _exec_transpose + +