diff --git a/hysop/backend/opencl/opencl_types.py b/hysop/backend/opencl/opencl_types.py new file mode 100644 index 0000000000000000000000000000000000000000..93cf3b87bbb645a1bbb58bdd6a662362cab271fc --- /dev/null +++ b/hysop/backend/opencl/opencl_types.py @@ -0,0 +1,333 @@ + +import string + +from hysop import __KERNEL_DEBUG__ +from hysop.constants import np, it +from hysop.backend.opencl import cl, clArray +from hysop.tools.numerics import MPZ, MPQ, MPFR, F2Q + +vsizes = [1,2,3,4,8,16] +base_types = ['float','signed','unsigned'] +float_base_types = ['half','float','double'] +signed_base_types = ['char','short','int','long'] +unsigned_base_types = ['uchar','ushort','uint','ulong'] + +float_types = [] +signed_types = [] +unsigned_types = [] +for b in base_types: + b_base_types = eval(b+'_base_types') + b_types = eval(b+'_types') + for f,c in it.product(b_base_types,vsizes): + if c==1: + if f=='half': continue + else: ftype = f + else: + ftype = f+str(c) + b_types.append(ftype) +integer_types = signed_types + unsigned_types +builtin_types = integer_types + float_types + + +float_base_type_require = { + 'half' : 'cl_khr_fp16', + 'float' : None, + 'double': 'cl_khr_fp64' +} + +FLT_DIG = { + 'half' : 3, # = HALF_DIG + 'float' : 6, # = FLT_DIG + 'double': 15 # = DBL_DIG +} +FLT_MANT_DIG = { + 'half' : 11, # = HALF_MANT_DIG + 'float' : 24, # = FLT_MANT_DIG + 'double': 53 # = DBL_MANT_DIG +} +FLT_LITERAL = { + 'half' : 'h', + 'float' : 'f', + 'double': '' +} +FLT_BYTES = { + 'half' : 2, + 'float' : 4, + 'double': 8 +} + +def basetype(fulltype): + return fulltype.translate(None,string.digits) +def components(fulltype): + comp = fulltype.translate(None,string.ascii_letters) + return 1 if comp == '' else int(comp) +def mangle_vtype(fulltype): + return basetype(fulltype)[0]+str(components(fulltype)) + +def vtype(basetype,N): + return basetype + ('' if N==1 else str(N)) +def itype(fulltype): + N = components(fulltype) + return 'int' + ('' if N==1 else str(N)) +def uitype(fulltype): + N = components(fulltype) + return 'uint' + ('' if N==1 else str(N)) +def np_dtype(fulltype): + return cl.tools.get_or_register_dtype(fulltype) + +def vtype_component_adressing(i,mode='hex'): + if mode=='hex': return '0123456789abcdef'[i] + elif mode=='HEX': return '0123456789ABCDEF'[i] + elif mode=='pos': return 'xyzw'[i] + else: raise ValueError('Bad vtype component adressing mode!') + +def vtype_access(i,N,mode='hex'): + assert i<N + if N==1: return '' + else: return ('s' if mode.lower()=='hex' else '') + vtype_component_adressing(i,mode) + +def float_to_hex_str(f,fbtype): + sf = float(f).hex().split('0x') + [''] + buf = sf[1].split('p') + + mantissa = buf[0] + exponent = buf[1] + + mant_dig = FLT_MANT_DIG[fbtype] + literal = FLT_LITERAL [fbtype] + + nhex = (mant_dig-1+3)//4 + 2 + # +2= leading one or zero and decimal point characters (1.abde... or 0.abcde...) + + sf[0] = ('+' if sf[0] == '' else sf[0])+'0x' + sf[1] = mantissa[:nhex] + sf[2] = 'p'+exponent+literal + return ''.join(sf) + +def float_to_dec_str(f,fbtype): + sf = float(f).__repr__().split('.') + offset = (1 if sf[0][0] in ['-','+'] else 0) + sf[1] = sf[1][:FLT_DIG[fbtype]-len(sf[0])+offset+1] + return ('+' if f>0 else '') + '.'.join(sf)+FLT_LITERAL[fbtype] + + + +#pyopencl specific +vec = clArray.vec + +def npmake(dtype): + return lambda scalar: np.array([scalar], dtype=dtype) + +vtype_int = [np.int32, vec.int2, vec.int3, vec.int4, vec.int8, vec.int16 ] +vtype_uint = [np.uint32, vec.uint2, vec.uint3, vec.uint4, vec.uint8, vec.uint16 ] +vtype_simple = [np.float32, vec.float2, vec.float3, vec.float4, vec.float8, vec.float16 ] +vtype_double = [np.float64, vec.double2, vec.double3, vec.double4, vec.double8, vec.double16 ] + +make_int = [npmake(np.int32), vec.make_int2, vec.make_int3, + vec.make_int4, vec.make_int8, + vec.make_int16 ] +make_uint = [npmake(np.uint32), vec.make_uint2, vec.make_uint3, + vec.make_uint4, vec.make_uint8, + vec.make_uint16 ] +make_simple = [npmake(np.float32), vec.make_float2, vec.make_float3, + vec.make_float4, vec.make_float8, + vec.make_float16 ] +make_double = [npmake(np.float64), vec.make_double2, vec.make_double3, + vec.make_double4, vec.make_double8, + vec.make_double16 ] + +def simplen(n): + if n==1: return np.float32 + i = vsizes.index(n) + return vtype_simple[i] +def doublen(n): + if n==1: return np.float64 + i = vsizes.index(n) + return vtype_double[i] +def intn(n): + if n==1: return np.int32 + i = vsizes.index(n) + return vtype_int[i] +def uintn(n): + if n==1: return np.uint32 + i = vsizes.index(n) + return vtype_uint[i] + +_typen = { + 'float' : simplen, + 'simple': simplen, + 'double': doublen, + 'int' : intn, + 'uint' : uintn +} + + +def typen(btype,n): + return _typen[btype](n) + +def cl_type_to_dtype(cl_type): + btype = basetype(cl_type) + N = components(cl_type) + return typen(btype,N) + +def make_simplen(vals,n,dval=0): + vals = (vals,) if np.isscalar(vals) else tuple(vals) + vals += (dval,)*(n-len(vals)) + i = vsizes.index(n) + return make_simple[i](*vals) +def make_doublen(vals,n,dval=0): + vals = (vals,) if np.isscalar(vals) else tuple(vals) + vals += (dval,)*(n-len(vals)) + i = vsizes.index(n) + return make_double[i](*vals) +def make_intn(vals,n,dval=0): + vals = (vals,) if np.isscalar(vals) else tuple(vals) + vals += (dval,)*(n-len(vals)) + i = vsizes.index(n) + return make_int[i](*vals) +def make_uintn(vals,n,dval=0): + vals = (vals,) if np.isscalar(vals) else tuple(vals) + vals += (dval,)*(n-len(vals)) + i = vsizes.index(n) + return make_uint[i](*vals) + + +class TypeGen(object): + def __init__(self, fbtype='float', float_dump_mode='dec'): + + self.float_base_types = float_base_types + self.FLT_BYTES = FLT_BYTES + self.FLT_DIG = FLT_DIG + self.FLT_MANT_DIG = FLT_MANT_DIG + self.FLT_LITERAL = FLT_LITERAL + + self.np_dtype = np_dtype + + self.float_to_dec_str = float_to_dec_str + self.float_to_hex_str = float_to_hex_str + + self.fbtype = fbtype + + self.float_dump_mode = float_dump_mode + if float_dump_mode in ['hex', 'hexadecimal']: + self.float_to_str = float_to_hex_str + elif float_dump_mode in ['dec','decimal']: + self.float_to_str = float_to_dec_str + else: + raise ValueError('Unknown float_dump_mode \'{}\''.format(float_dump_mode)) + + def dump(self, val): + if isinstance(val, (list,tuple,dict,np.ndarray)): + raise ValueError('Value is not a scalar, got {}.'.format(val)) + if isinstance(val, (float,np.floating,MPFR)): + sval = self.float_to_str(val, self.fbtype) + return '({})'.format(sval) + elif isinstance(val, (np.integer,int,long,MPZ)): + sign = ('' if val==0 else ('+' if val>0 else '-')) + sval = str(val) + if val<0: + sval=sval[1:] + if val!=0: + sval = '({}{})'.format(sign,sval) + else: + sval = '0' + return sval + elif isinstance(val, (bool,np.bool_)): + return 'true' if val else 'false' + elif isinstance(val, MPQ): + if __KERNEL_DEBUG__: + return '({}.0{f}/{}.0{f})'.format(val.numerator,val.denominator, + f=FLT_LITERAL[self.fbtype]) + else: + return self.dump(float(val)) + else: + return val.__str__() + + +# struct type generation (type size and struct field offsets) is different for each device +# depending on architecture and compiler implementation and features. +# /!\ do not use the same opencl typegen instance for two different devices that are +# not equivalent. +class OpenClTypeGen(TypeGen): + @staticmethod + def devicelessTypegen(): + """ + Sometimes we do not need structs and code generation is device independent. + """ + return OpenClTypeGen(device=None,context=None,platform=None); + + def __init__(self, device, context, platform, + fbtype='float', float_dump_mode='dec'): + super(OpenClTypeGen,self).__init__(fbtype,float_dump_mode) + + self.device = device + self.context = context + self.platform = platform + + self.vsizes = vsizes + self.signed_base_types = signed_base_types + self.unsigned_base_types = unsigned_base_types + self.integer_base_types = signed_base_types + unsigned_base_types + + self.float_types = float_types + self.signed_types = signed_types + self.unsigned_types = unsigned_types + self.integer_types = integer_types + self.builtin_types = builtin_types + + self.float_base_type_require = float_base_type_require + + self.basetype = basetype + self.components = components + self.vtype = vtype + self.itype = itype + self.uitype = uitype + self.np_dtype = np_dtype + + self.vtype_component_adressing = vtype_component_adressing + self.vtype_access = vtype_access + self.mangle_vtype = mangle_vtype + self.float_to_dec_str = float_to_dec_str + self.float_to_hex_str = float_to_hex_str + + #pyopencl specifics + self.intn = intn + self.uintn = uintn + self.simplen = simplen + self.doublen = doublen + self.typen = typen + + self.make_intn = make_intn + self.make_uintn = make_uintn + self.make_simplen = make_simplen + self.make_doublen = make_doublen + + if fbtype == 'float': + self.floatn = simplen + self.make_floatn = make_simplen + elif fbtype == 'double': + self.floatn = doublen + self.make_floatn = make_doublen + # elif fbtype == 'half': + # self.floatn = halfn + # self.make_floatn = make_halfn + else: + raise ValueError('Unknown fbtype \'{}\''.format(fbtype)) + + def device_has_ftype(self,device): + dev_exts = device.extensions.split(' ') + req = self.float_base_type_require[self.fbtype] + return (req is None) or (req[0] in dev_exts) + def cl_requirements(self): + return [self.float_base_type_require[self.fbtype]]; + + def dtype_from_str(self,stype): + stype = stype.replace('ftype', self.fbtype).replace('fbtype',self.fbtype) + btype = basetype(stype) + N = components(stype) + return typen(btype,N) + + def __repr__(self): + return '{}_{}_{}_{}'.format(self.platform.name,self.device.name, + self.fbtype,self.float_dump_mode) +