Skip to content
Snippets Groups Projects
Commit 56336419 authored by Jean-Baptiste Keck's avatar Jean-Baptiste Keck
Browse files

numba plans

parent b52abca4
No related branches found
No related tags found
1 merge request!16MPI operators
...@@ -13,7 +13,8 @@ import numba as nb ...@@ -13,7 +13,8 @@ import numba as nb
from hysop import __FFTW_NUM_THREADS__, __FFTW_PLANNER_EFFORT__, __FFTW_PLANNER_TIMELIMIT__, \ from hysop import __FFTW_NUM_THREADS__, __FFTW_PLANNER_EFFORT__, __FFTW_PLANNER_TIMELIMIT__, \
__DEFAULT_NUMBA_TARGET__ __DEFAULT_NUMBA_TARGET__
from hysop.tools.types import first_not_None, check_instance from hysop.tools.types import first_not_None, check_instance
from hysop.tools.numba_utils import make_numba_signature from hysop.tools.numba_utils import HAS_NUMBA, bake_numba_copy, bake_numba_accumulate, bake_numba_transpose
from hysop.tools.decorators import static_vars
from hysop.backend.host.host_array_backend import HostArrayBackend from hysop.backend.host.host_array_backend import HostArrayBackend
from hysop.backend.host.host_array import HostArray from hysop.backend.host.host_array import HostArray
from hysop.numerics.fft.fft import FFTQueueI, FFTPlanI, FFTI from hysop.numerics.fft.fft import FFTQueueI, FFTPlanI, FFTI
...@@ -35,7 +36,7 @@ def can_exec_hptt(src, dst): ...@@ -35,7 +36,7 @@ def can_exec_hptt(src, dst):
try: try:
import hptt import hptt
HAS_HPTT=True HAS_HPTT=False
# required version is: https://gitlab.com/keckj/hptt # required version is: https://gitlab.com/keckj/hptt
except ImportError: except ImportError:
HAS_HPTT=False HAS_HPTT=False
...@@ -44,12 +45,6 @@ except ImportError: ...@@ -44,12 +45,6 @@ except ImportError:
msg='Failed to import HPTT module, falling back to slow numpy transpose. Required version is available at https://gitlab.com/keckj/hptt.' msg='Failed to import HPTT module, falling back to slow numpy transpose. Required version is available at https://gitlab.com/keckj/hptt.'
warnings.warn(msg, HysopPerformanceWarning) warnings.warn(msg, HysopPerformanceWarning)
try:
import numba as nb
HAS_NUMBA=True
except ImportError:
HAS_NUMBA=False
class DummyEvent(object): class DummyEvent(object):
@classmethod @classmethod
def wait(cls): def wait(cls):
...@@ -138,11 +133,17 @@ class HostFFTI(FFTI): ...@@ -138,11 +133,17 @@ class HostFFTI(FFTI):
def plan_copy(self, tg, src, dst): def plan_copy(self, tg, src, dst):
src = self.ensure_callable(src) src = self.ensure_callable(src)
dst = self.ensure_callable(dst) dst = self.ensure_callable(dst)
@static_vars(numba_copy=None)
def exec_copy(src=src, dst=dst): def exec_copy(src=src, dst=dst):
src, dst = src(), dst() src, dst = src(), dst()
if HAS_HPTT and can_exec_hptt(src, dst): if HAS_HPTT and can_exec_hptt(src, dst):
hptt.tensorTransposeAndUpdate(perm=range(src.ndim), hptt.tensorTransposeAndUpdate(perm=range(src.ndim),
alpha=1.0, A=src, beta=0.0, B=dst) alpha=1.0, A=src, beta=0.0, B=dst)
elif HAS_NUMBA:
if (exec_copy.numba_copy is None):
exec_copy.numba_copy = bake_numba_copy(src=src, dst=dst)
exec_copy.numba_copy()
else: else:
dst[...] = src dst[...] = src
return exec_copy return exec_copy
...@@ -150,11 +151,17 @@ class HostFFTI(FFTI): ...@@ -150,11 +151,17 @@ class HostFFTI(FFTI):
def plan_accumulate(self, tg, src, dst): def plan_accumulate(self, tg, src, dst):
src = self.ensure_callable(src) src = self.ensure_callable(src)
dst = self.ensure_callable(dst) dst = self.ensure_callable(dst)
@static_vars(numba_accumulate=None)
def exec_accumulate(src=src, dst=dst): def exec_accumulate(src=src, dst=dst):
src, dst = src(), dst() src, dst = src(), dst()
if HAS_HPTT and can_exec_hptt(src, dst): if HAS_HPTT and can_exec_hptt(src, dst):
hptt.tensorTransposeAndUpdate(perm=range(src.ndim), hptt.tensorTransposeAndUpdate(perm=range(src.ndim),
alpha=1.0, A=src, beta=1.0, B=dst) alpha=1.0, A=src, beta=1.0, B=dst)
elif HAS_NUMBA:
if (exec_accumulate.numba_accumulate is None):
exec_accumulate.numba_accumulate = bake_numba_accumulate(src=src, dst=dst)
exec_accumulate.numba_accumulate()
else: else:
dst[...] += src dst[...] += src
return exec_copy return exec_copy
...@@ -162,11 +169,17 @@ class HostFFTI(FFTI): ...@@ -162,11 +169,17 @@ class HostFFTI(FFTI):
def plan_transpose(self, tg, src, dst, axes): def plan_transpose(self, tg, src, dst, axes):
src = self.ensure_callable(src) src = self.ensure_callable(src)
dst = self.ensure_callable(dst) dst = self.ensure_callable(dst)
@static_vars(numba_transpose=None)
def exec_transpose(src=src, dst=dst, axes=axes): def exec_transpose(src=src, dst=dst, axes=axes):
src, dst = src(), dst() src, dst = src(), dst()
if HAS_HPTT and can_exec_hptt(src, dst): if HAS_HPTT and can_exec_hptt(src, dst):
hptt.tensorTransposeAndUpdate(perm=axes, hptt.tensorTransposeAndUpdate(perm=axes,
alpha=1.0, A=src, beta=0.0, B=dst) alpha=1.0, A=src, beta=0.0, B=dst)
elif HAS_NUMBA:
if (exec_transpose.numba_transpose is None):
exec_transpose.numba_transpose = bake_numba_transpose(src=src, dst=dst, axes=axes)
exec_transpose.numba_transpose()
else: else:
dst[...] = np.transpose(a=src, axes=axes) dst[...] = np.transpose(a=src, axes=axes)
return exec_transpose return exec_transpose
......
import numba as nb
import numpy as np import numpy as np
from numba import prange from hysop import __DEFAULT_NUMBA_TARGET__
from hysop.core.arrays.array import Array from hysop.core.arrays.array import Array
try:
import numba as nb
from numba import prange
HAS_NUMBA=True
except ImportError:
HAS_NUMBA=False
def make_numba_signature(*args, **kwds): def make_numba_signature(*args, **kwds):
raise_on_cl_array = kwds.pop('raise_on_cl_array', True) raise_on_cl_array = kwds.pop('raise_on_cl_array', True)
if kwds: if kwds:
...@@ -94,3 +100,169 @@ def make_numba_signature(*args, **kwds): ...@@ -94,3 +100,169 @@ def make_numba_signature(*args, **kwds):
numba_args += (na,) numba_args += (na,)
return nb.void(*numba_args), ','.join(numba_layout) return nb.void(*numba_args), ','.join(numba_layout)
def bake_numba_copy(dst, src, target=None):
if (target is None):
target = __DEFAULT_NUMBA_TARGET__
signature, layout = make_numba_signature(dst, src)
if (dst.ndim == 1):
@nb.guvectorize([signature], layout,
target=target, nopython=True, cache=True)
def copy(dst, src):
for i in xrange(0, dst.shape[0]):
dst[i] = src[i]
elif (dst.ndim == 2):
@nb.guvectorize([signature], layout,
target=target, nopython=True, cache=True)
def copy(dst, src):
for i in prange(0, dst.shape[0]):
for j in xrange(0, dst.shape[1]):
dst[i,j] = src[i,j]
elif (dst.ndim == 3):
@nb.guvectorize([signature], layout,
target=target, nopython=True, cache=True)
def copy(dst, src):
for i in prange(0, dst.shape[0]):
for j in prange(0, dst.shape[1]):
for k in xrange(0, dst.shape[2]):
dst[i,j,k] = src[i,j,k]
elif (dst.ndim == 4):
@nb.guvectorize([signature], layout,
target=target, nopython=True, cache=True)
def copy(dst, src):
for i in prange(0, dst.shape[0]):
for j in prange(0, dst.shape[1]):
for k in prange(0, dst.shape[2]):
for l in xrange(0, dst.shape[3]):
dst[i,j,k,l] = src[i,j,k,l]
else:
raise NotImplementedError(dst.ndim)
def _exec_copy(copy=copy, dst=dst, src=src):
copy(dst,src)
return _exec_copy
def bake_numba_accumulate(dst, src, target=None):
if (target is None):
target = __DEFAULT_NUMBA_TARGET__
signature, layout = make_numba_signature(dst, src)
if (dst.ndim == 1):
@nb.guvectorize([signature], layout,
target=target, nopython=True, cache=True)
def accumulate(dst, src):
for i in xrange(0, dst.shape[0]):
dst[i] += src[i]
elif (dst.ndim == 2):
@nb.guvectorize([signature], layout,
target=target, nopython=True, cache=True)
def accumulate(dst, src):
for i in prange(0, dst.shape[0]):
for j in xrange(0, dst.shape[1]):
dst[i,j] += src[i,j]
elif (dst.ndim == 3):
@nb.guvectorize([signature], layout,
target=target, nopython=True, cache=True)
def accumulate(dst, src):
for i in prange(0, dst.shape[0]):
for j in prange(0, dst.shape[1]):
for k in xrange(0, dst.shape[2]):
dst[i,j,k] += src[i,j,k]
elif (dst.ndim == 4):
@nb.guvectorize([signature], layout,
target=target, nopython=True, cache=True)
def accumulate(dst, src):
for i in prange(0, dst.shape[0]):
for j in prange(0, dst.shape[1]):
for k in prange(0, dst.shape[2]):
for l in xrange(0, dst.shape[3]):
dst[i,j,k,l] += src[i,j,k,l]
else:
raise NotImplementedError(dst.ndim)
def _exec_accumulate(accumulate=accumulate, dst=dst, src=src):
accumulate(dst,src)
return _exec_accumulate
def bake_numba_transpose(src, dst, axes, target=None):
# inefficient permutations
if (target is None):
target = __DEFAULT_NUMBA_TARGET__
signature, layout = make_numba_signature(dst, src)
assert src.ndim == dst.ndim
assert dst.shape == tuple(src.shape[i] for i in axes)
assert dst.dtype == src.dtype
ndim = src.ndim
def noop(dst, src):
pass
if (ndim == 1):
transpose = noop
elif (ndim == 2):
if axes == (0,1):
transpose == noop
elif axes == (1,0):
@nb.guvectorize([signature], layout,
target=target, nopython=True, cache=True)
def transpose(dst, src):
for i in prange(0, src.shape[0]):
for j in xrange(0, src.shape[1]):
dst[j,i] = src[i,j]
else:
raise NotImplementedError
elif (ndim == 3):
if axes == (0,1,2):
transpose == noop
elif axes == (0,2,1):
@nb.guvectorize([signature], layout,
target=target, nopython=True, cache=True)
def transpose(dst, src):
for i in prange(0, src.shape[0]):
for j in prange(0, src.shape[1]):
for k in xrange(0, src.shape[2]):
dst[i,k,j] = src[i,j,k]
elif axes == (1,0,2):
@nb.guvectorize([signature], layout,
target=target, nopython=True, cache=True)
def transpose(dst, src):
for i in prange(0, src.shape[0]):
for j in prange(0, src.shape[1]):
for k in xrange(0, src.shape[2]):
dst[j,i,k] = src[i,j,k]
elif axes == (1,2,0):
@nb.guvectorize([signature], layout,
target=target, nopython=True, cache=True)
def transpose(dst, src):
for i in prange(0, src.shape[0]):
for j in prange(0, src.shape[1]):
for k in xrange(0, src.shape[2]):
dst[j,k,i] = src[i,j,k]
elif axes == (2,1,0):
@nb.guvectorize([signature], layout,
target=target, nopython=True, cache=True)
def transpose(dst, src):
for i in prange(0, src.shape[0]):
for j in prange(0, src.shape[1]):
for k in xrange(0, src.shape[2]):
dst[k,j,i] = src[i,j,k]
elif axes == (2,0,1):
@nb.guvectorize([signature], layout,
target=target, nopython=True, cache=True)
def transpose(dst, src):
for i in prange(0, src.shape[0]):
for j in prange(0, src.shape[1]):
for k in xrange(0, src.shape[2]):
dst[k,i,j] = src[i,j,k]
else:
raise NotImplementedError(axes)
else:
raise NotImplementedError(ndim)
def _exec_transpose(transpose=transpose, dst=dst, src=src):
transpose(dst,src)
return _exec_transpose
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment