From b58eee87cc6beea951b735f36d84b2269db23716 Mon Sep 17 00:00:00 2001
From: Jean-Baptiste Keck <Jean-Baptiste.Keck@imag.fr>
Date: Thu, 4 Jul 2019 12:31:30 +0200
Subject: [PATCH] better support for isolation files

---
 .../opencl/opencl_autotunable_kernel.py       | 62 +++++++++++++------
 hysop/backend/device/opencl/opencl_types.py   | 11 ++++
 2 files changed, 53 insertions(+), 20 deletions(-)

diff --git a/hysop/backend/device/opencl/opencl_autotunable_kernel.py b/hysop/backend/device/opencl/opencl_autotunable_kernel.py
index a6e003b18..faedece39 100644
--- a/hysop/backend/device/opencl/opencl_autotunable_kernel.py
+++ b/hysop/backend/device/opencl/opencl_autotunable_kernel.py
@@ -7,6 +7,7 @@ from hysop.tools.numpywrappers import npw
 from hysop.tools.types import check_instance, first_not_None, to_tuple, to_list
 from hysop.tools.misc import upper_pow2_or_3
 from hysop.tools.units import bytes2str
+from hysop.tools.numerics import get_dtype
 from hysop.core.mpi import main_rank
 
 from hysop.backend.device.kernel_autotuner import KernelGenerationError
@@ -244,7 +245,7 @@ class OpenClAutotunableKernel(AutotunableKernel):
                     arg_isol = isolation_params[arg_name]
                 elif isinstance(arg_value, npw.ndarray):
                     assert arg_value.dtype == arg_types
-                    arg_isol = dict(count=arg_value.size*arg_value.dtype.itemsize, dtype=npw.uint8)
+                    arg_isol = dict(count=arg_value.size, dtype=arg_value.dtype)
                 elif isinstance(arg_value, npw.number):
                     arg_value = npw.asarray([arg_value], dtype=arg_types)
                     arg_isol = dict(count=1, dtype=arg_value.dtype)
@@ -258,29 +259,46 @@ class OpenClAutotunableKernel(AutotunableKernel):
                 except:
                     type_str = type(arg_types).__name__
                 msg+='\n# argument {} with type {}\n'.format(arg_name, type_str)
-                msg+=self.format_oclgrind_isolation_argument(arg_isol, arg_value)
+                msg+=self.format_oclgrind_isolation_argument(arg_name, arg_isol, arg_value)
                 msg+='\n'
             if self.autotuner_config.verbose:
                 print '  >Saving oclgrind kernel isolation file to \'{}\'.'.format(dump_file)
             f.write(msg)
         return dump_file
 
-    def format_oclgrind_isolation_argument(self, arg_isol, arg_value):
+    def format_oclgrind_isolation_argument(self, arg_name, arg_isol, arg_value):
+        from pyopencl.cltypes import vec_types, vec_type_to_scalar_and_count
+        from hysop.backend.device.opencl.opencl_types import cl_vec_types, cl_vec_type_to_scalar_and_count
         check_instance(arg_isol, dict)
         assert 'count'  in arg_isol
         assert 'dtype' in arg_isol
+        dtype = get_dtype(arg_isol['dtype'])
+        if (dtype == npw.void):
+            dtype = arg_value.dtype
+        if dtype in vec_types:
+            dtype, vect = vec_type_to_scalar_and_count(dtype)
+            dtype = npw.dtype(get_dtype(dtype))
+            if (vect==3):
+                vect=4
+        elif dtype in cl_vec_types:
+            dtype, vect = cl_vec_type_to_scalar_and_count(dtype)
+            dtype = npw.dtype(get_dtype(dtype))
+            if (vect==3):
+                vect=4
+        elif dtype is npw.complex64:
+            dtype, vect = npw.float32, 2
+        elif dtype is npw.complex128:
+            dtype, vect = npw.float64, 2
+        else:
+            dtype, vect = dtype, 1
+        dtype = npw.dtype(get_dtype(dtype))
+        itemsize = dtype.itemsize
+        dtype    = dtype.type
         count  = arg_isol['count']
-        dtype  = arg_isol['dtype']
         assert count >= 1
-        if isinstance(dtype, npw.dtype):
-            itemsize = dtype.itemsize
-            dtype    = dtype.type
-        else:
-            itemsize = dtype(0).itemsize
+        assert vect >= 1
+        count *= vect
         size = count * itemsize
-        assert issubclass(dtype, npw.generic)
-
-        arg = '<size={}'.format(size)
 
         typemap = {
             npw.int8:    'char',
@@ -293,9 +311,8 @@ class OpenClAutotunableKernel(AutotunableKernel):
             npw.uint64:  'ulong',
             npw.float32: 'float',
             npw.float64: 'double',
-            npw.complex64: 'float2',
-            npw.complex128: 'double2'
         }
+        arg = '<size={}'.format(size)
 
         dump_data = False
         dump_hex_data = False
@@ -309,26 +326,31 @@ class OpenClAutotunableKernel(AutotunableKernel):
         elif 'range' in arg_isol:
             slices = arg_isol['range']
             assert isinstance(slices, slice)
-            assert dtype in typemap.keys()
-            ranges = slices.indices(count)
+            assert dtype in typemap.keys(), dtype
+            ranges = list(slices.indices(count))
+            assert (ranges[1]-ranges[0])//ranges[2] in (count, count//vect)
+            if ((ranges[1]-ranges[0])//ranges[2] == count//vect): 
+                ranges[0]*=vect
+                ranges[1]*=vect
+            assert (ranges[1]-ranges[0])//ranges[2] == count, '{} != {}'.format((ranges[1]-ranges[0])//ranges[2], count)
             arg+=' range={}:{}:{}'.format(ranges[0], ranges[2], ranges[1]-1)
             arg+=' {}'.format(typemap[dtype])
         else:
-            if 'arg_value' in arg_isol:
+            if ('arg_value' in arg_isol):
                 arg_value = arg_isol['arg_value']
             assert isinstance(arg_value, npw.ndarray), type(arg_value)
-            if dtype in typemap:
+            if (dtype in typemap):
                 arg+=' {}'.format(typemap[dtype])
                 dump_data=True
             else:
                 arg+=' uint8 hex'
                 dump_hex_data=True
-        if ('dump' in arg_isol) and (arg_isol['dump'] is True):
+        if (('dump' in arg_isol) and (arg_isol['dump'] is True)):
             arg+= ' dump'
         arg+='>'
 
         if dump_data:
-            arg+= '\n' + ' '.join(str(x) for x in arg_value.flatten())
+            arg+= '\n' + ' '.join(str(x).replace(',','').replace('(','').replace(')','') for x in arg_value.flatten())
         if dump_hex_data:
             view = arg_value.ravel().view(dtype=npw.uint8)
             arg+= '\n' + ' '.join('{:02x}'.format(ord(x)) for x in arg_value.tobytes())
diff --git a/hysop/backend/device/opencl/opencl_types.py b/hysop/backend/device/opencl/opencl_types.py
index d474928bd..203d6e5df 100644
--- a/hysop/backend/device/opencl/opencl_types.py
+++ b/hysop/backend/device/opencl/opencl_types.py
@@ -140,6 +140,7 @@ vtype_int     = [np.int32,   vec.int2, vec.int3, vec.int4, vec.int8, vec.int16 ]
 vtype_uint    = [np.uint32,  vec.uint2, vec.uint3, vec.uint4, vec.uint8, vec.uint16 ]
 vtype_simple  = [np.float32, vec.float2, vec.float3, vec.float4, vec.float8, vec.float16 ]
 vtype_double  = [np.float64, vec.double2, vec.double3, vec.double4, vec.double8, vec.double16 ]
+cl_vec_types = vtype_int + vtype_uint + vtype_simple + vtype_double
 
 make_int     = [npmake(np.int32),   vec.make_int2, vec.make_int3,
                                     vec.make_int4, vec.make_int8,
@@ -218,6 +219,16 @@ def cl_type_to_dtype(cl_type):
     N = components(cl_type)
     return typen(btype,N)
 
+def cl_vec_type_to_scalar_and_count(cl_vec_type):
+    assert cl_vec_type in cl_vec_types
+    cvt = cl_vec_type
+    for vtypes in (vtype_int, vtype_uint, vtype_simple, vtype_double):
+        if cvt in vtypes:
+            btype = vtypes[0]
+            count = vsizes[vtypes.index(cvt)]
+            return (btype, count)
+    msg='cl_vec_types != U(vtype_*)'
+    raise RuntimeError(msg)
 
 class TypeGen(object):
     def __init__(self, fbtype='float', float_dump_mode='dec'):
-- 
GitLab