Skip to content
Snippets Groups Projects
Commit dd1a480b authored by Keck Jean-Baptiste's avatar Keck Jean-Baptiste
Browse files

copy kernel rect

parent 19cc5533
No related branches found
No related tags found
No related merge requests found
......@@ -3,6 +3,7 @@ from hysop import vprint, dprint
from hysop.deps import np
from hysop.tools.decorators import debug
from hysop.tools.types import check_instance, first_not_None
from hysop.tools.misc import prod
from hysop.tools.numpywrappers import npw
from hysop.backend.device.opencl import cl, clArray
from hysop.backend.device.opencl.opencl_kernel_launcher import OpenClKernelLauncher
......@@ -27,7 +28,7 @@ class OpenClCopyKernelLauncher(OpenClKernelLauncher):
assert 'is_blocking' not in kwds
enqueue_copy_kwds['dest'] = dst
enqueue_copy_kwds['src'] = src
enqueue_copy_kwds['src'] = src
if isinstance(src, np.ndarray) or isinstance(dst, np.ndarray):
enqueue_copy_kwds['is_blocking'] = False
......@@ -200,3 +201,171 @@ class OpenClCopyDevice2DeviceLauncher(OpenClCopyBufferLauncher):
super(OpenClCopyDevice2DeviceLauncher, self).__init__(varname=varname, src=src, dst=dst,
src_device_offset=src_device_offset, dst_device_offset=dst_device_offset,
byte_count=byte_count)
class OpenClCopyBufferRectLauncher(OpenClCopyKernelLauncher):
"""
Non-blocking OpenCL copy kernel between host buffers and/or opencl device
rectangle subregions of buffers (OpenCL 1.1 and newer).
"""
def __init__(self, varname, src, dst,
src_device_offset=None,
dst_device_offset=None,
byte_count=None,
**kwds):
"""
Initialize a (HOST <-> DEVICE) or a (DEVICE <-> DEVICE) rectangle
subregions copy kernel.
Parameters
----------
varname: str
Name of the variable copied for loggin purposes.
src: cl.MemoryObjectHolder or np.ndarray
The source buffer.
dst: cl.MemoryObjectHolder or np.ndarray
The destination buffer.
region: tuple of ints
The 3D region to copy in terms of bytes for the
first dimension and of elements for the two last dimensions.
src_origin: tuple of ints
The 3D offset in number of elements of the region associated with src buffer.
The final src offset in bytes is computed from src_origin and src_pitch.
dst_origin: tuple of ints
The 3D offset in number of elements of the region associated with dst buffer.
The final dst offset in bytes is computed from dst_origin and dst_pitch.
src_pitches: tuple of ints
The 2D pictches used to compute src offsets in bytes for
the second and the third dimension.
dst_pitches: tuple of ints
The 2D pitches used to compute dst offsets in bytes for
the second and the third dimension.
"""
check_instance(src, (cl.MemoryObjectHolder, np.ndarray))
check_instance(dst, (cl.MemoryObjectHolder, np.ndarray))
check_instance(src_offset, tuple, values=(int, np.integer), size=3)
check_instance(dst_offset, tuple, values=(int, np.integer), size=3)
check_instance(src_pitches, tuple, values=(int, np.integer), size=2)
check_instance(dst_pitches, tuple, values=(int, np.integer), size=2)
enqueue_copy_kwds = {}
enqueue_copy_kwds['region'] = region
if isinstance(src, np.ndarray) and \
isinstance(dst, np.ndarray):
msg='Host to host copy is not supported.'
raise RuntimeError(msg)
elif isinstance(src, cl.MemoryObjectHolder) and \
isinstance(dst, cl.MemoryObjectHolder):
enqueue_copy_kwds['src_origin'] = src_origin
enqueue_copy_kwds['src_pitches'] = src_pitches
enqueue_copy_kwds['dst_origin'] = dst_origin
enqueue_copy_kwds['dst_pitches'] = dst_pitches
elif isinstance(src, cl.MemoryObjectHolder) and \
isinstance(dst, np.ndarray):
enqueue_copy_kwds['host_origin'] = dst_origin
enqueue_copy_kwds['host_pitches'] = dst_pitches
enqueue_copy_kwds['buffer_origin'] = src_origin
enqueue_copy_kwds['buffer_pitches'] = src_pitches
elif isinstance(src, np.ndarray) and \
isinstance(dst, cl.MemoryObjectHolder):
enqueue_copy_kwds['host_origin'] = src_origin
enqueue_copy_kwds['host_pitches'] = src_pitches
enqueue_copy_kwds['buffer_origin'] = dst_origin
enqueue_copy_kwds['buffer_pitches'] = dst_pitches
else:
msg='The impossible happened.\n *src={}\n *dst={}'
msg=msg.format(type(src), type(dst))
raise ValueError(msg)
assert 'name' not in kwds
name = 'enqueue_copy_rect_{}__{}_to_{}'.format(varname,
'host' if isinstance(src, np.ndarray) else 'device',
'host' if isinstance(dst, np.ndarray) else 'device')
apply_msg='{}<<<{}>>>'.format(name, region)
super(OpenClCopyBufferLauncher, self).__init__(dst=dst, src=src,
enqueue_copy_kwds=enqueue_copy_kwds,
name=name, apply_msg=apply_msg, **kwds)
@classmethod
def _format_slices(cls, a, slices):
check_instance(a, (np.ndarray, Array))
shape = a.shape
dtype = a.dtype
ndim = a.ndim
if (not slices) or (slices is Ellipsis):
slices = (Ellipsis,)
check_instance(slices, tuple)
# expand ellipsis
if (Ellipsis in slices):
nellipsis = slices.count(Ellipsis):
msg='Only one Ellipsis can be passed.'
assert nellipsis==1, msg
eid = slices.find(Ellipsis)
missing_count = len(slices)-1
missing_slices = tuple(slice(s) for s in xrange(eid, eid+missing_count))
full_slices = slices[:eid]+missing_slices+slices[eid+1:]
slices = full_slices
check_instance(slices, tuple, values=(int,slice), size=ndim)
# compute indices
indices = ()
for slc, si in zip(slices, shape):
if isinstance(slc, slice):
indices += slc.indices(si)
else:
indices += (slc, slc+1, 1)
# compute nelems
nelems = tuple( (idx[1]-idx[0]+idx[2]-1)//idx[2] for idx in indices )
estart = tuple( idx[0] for idx in indices )
offset = 0
region, origin, pitches = ()
for ne, es, si in zip(nelems, estart, shape):
offset *= si
pitches = tuple(p*si for p in pitches)
if (ne<=0) or (ne>=si):
msg='ne={}, si={}'.format(ne, si)
raise ValueError(msg)
elif (not region) and (ne==1):
offset += es
elif (not region) or (ne < si):
origin += (es,)
region += (ne,)
pitches += (1,)
else:
assert ne == si, 'ne={}, si={}'.format(ne, si)
region[-1] *= ne
pitches[-1] //= si
nelems = prod(nelems)
nbytes = nelems * dtype.itemsize
return sdim, slices, dtype, nbytes
@classmethod
def from_slices(src, dst, src_slices=None, dst_slices=None):
src_sdim, src_slices, src_dtype, src_bytes = cls._format_slices(src, src_slices)
dst_sdim, dst_slices, dst_dtype, dst_bytes = cls._format_slices(dst, dst_slices)
if (src_ndim != dst_ndim):
msg='Dimension mismatch between source and destination slices:'
msg+='\n src_slices: {}'
msg+='\n dst_slices: {}'
msg=msg.format(src_slices, dst_slices)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment