Skip to content
Snippets Groups Projects
Commit 7da501ce authored by Keck Jean-Baptiste's avatar Keck Jean-Baptiste
Browse files

codegen refactoring

parent 0f0a9354
No related branches found
No related tags found
No related merge requests found
...@@ -120,8 +120,6 @@ class CodeGenerator(object): ...@@ -120,8 +120,6 @@ class CodeGenerator(object):
pass pass
return self return self
def to_file(self,folder,filename): def to_file(self,folder,filename):
dst_file = folder + '/' + filename dst_file = folder + '/' + filename
if not os.path.exists(folder): if not os.path.exists(folder):
...@@ -186,7 +184,6 @@ class CodeGenerator(object): ...@@ -186,7 +184,6 @@ class CodeGenerator(object):
if k>0: if k>0:
self.code = self.code[:k] self.code = self.code[:k]
return self return self
def _noop(self): def _noop(self):
...@@ -206,9 +203,9 @@ class CodeGenerator(object): ...@@ -206,9 +203,9 @@ class CodeGenerator(object):
def define(self,what,prepend=True): def define(self,what,prepend=True):
code = '#define {}'.format(what) code = '#define {}'.format(what)
self.append(code,simple=True) self.append(code,simple=True)
def include(self,*kargs): def include(self,*args):
code = [] code = []
for k in kargs: for k in args:
code.append('#include {}'.format(k)) code.append('#include {}'.format(k))
return self.append(code) return self.append(code)
...@@ -252,24 +249,13 @@ class CodeGenerator(object): ...@@ -252,24 +249,13 @@ class CodeGenerator(object):
else: else:
self.append(code) self.append(code)
def decl_vars(self,_type,_varnames,_inits=None,comment=None,cv_qualifier=None): def decl_vars(self, *variables):
if not isinstance(_varnames,list): assert len(set(var.base_ctype() for var in variables))==1
_varnames = [_varnames] base= variables[0].base_ctype()
if _inits is not None and not isinstance(_inits,list): svars=[]
_inits = [_inits] for var in variables:
svars.append(var.declare(multidecl=True))
declvar = '' return '{} {};'.format(base, ', '.join(svars))
if cv_qualifier:
declvar+=cv_qualifier+' '
declvar += '{} '.format(_type)
if _inits:
declvar += ','.join(['{}={}'.format(v,i) for v,i in zip(_varnames,_inits)])
else:
declvar += ','.join(_varnames)
declvar += ';'
if comment is not None:
declvar += ' /* {} */'.format(comment)
self.append(declvar)
class VarBlock(object): class VarBlock(object):
...@@ -354,7 +340,11 @@ class CodeGenerator(object): ...@@ -354,7 +340,11 @@ class CodeGenerator(object):
maxlen = lambda i: max([len(line[i]) for line in self._lines if len(line)>1]) maxlen = lambda i: max([len(line[i]) for line in self._lines if len(line)>1])
line_str = '' line_str = ''
for i in xrange(self._parts_count): for i in xrange(self._parts_count):
line_str+='{:'+str(maxlen(i))+'}' ml = maxlen(i)
if ml==0:
line_str+='{}'
else:
line_str+='{:'+str(ml)+'}'
code = [] code = []
for line in self._lines: for line in self._lines:
if len(line)>1: if len(line)>1:
......
...@@ -59,7 +59,7 @@ class StructCodeGenerator(OpenClCodeGenerator): ...@@ -59,7 +59,7 @@ class StructCodeGenerator(OpenClCodeGenerator):
i+=1 i+=1
def build_codegen_variable(self,name,**kargs): def build_codegen_variable(self,name,**kargs):
return CodegenStruct(varname=name, struct=self, **kargs) return CodegenStruct(name=name, struct=self, **kargs)
if __name__ == '__main__': if __name__ == '__main__':
......
This diff is collapsed.
...@@ -64,8 +64,7 @@ class DirectionalRemeshKernel(KernelCodeGenerator): ...@@ -64,8 +64,7 @@ class DirectionalRemeshKernel(KernelCodeGenerator):
return int(1+math.ceil(scalar_cfl)+remesh_kernel.n/2) return int(1+math.ceil(scalar_cfl)+remesh_kernel.n/2)
def __init__(self, typegen, work_dim, direction, ftype, def __init__(self, typegen, work_dim, direction, ftype,
nparticles, nscalars, sboundary, nparticles, nscalars, sboundary, is_inplace,
is_inplace,
scalar_cfl, remesh_kernel, scalar_cfl, remesh_kernel,
remesh_criteria_eps=None, remesh_criteria_eps=None,
use_atomics = False, use_atomics = False,
...@@ -80,19 +79,19 @@ class DirectionalRemeshKernel(KernelCodeGenerator): ...@@ -80,19 +79,19 @@ class DirectionalRemeshKernel(KernelCodeGenerator):
check_instance(sboundary[0],BoundaryCondition) check_instance(sboundary[0],BoundaryCondition)
check_instance(sboundary[1],BoundaryCondition) check_instance(sboundary[1],BoundaryCondition)
check_instance(remesh_kernel, RemeshKernel) check_instance(remesh_kernel, RemeshKernel)
known_vars = known_vars or dict()
itype = 'int'
vftype = tg.vtype(ftype, nparticles)
vitype = tg.vtype(itype, nparticles)
assert sboundary[0] in [BoundaryCondition.PERIODIC, BoundaryCondition.NONE] assert sboundary[0] in [BoundaryCondition.PERIODIC, BoundaryCondition.NONE]
assert sboundary[1] in [BoundaryCondition.PERIODIC, BoundaryCondition.NONE] assert sboundary[1] in [BoundaryCondition.PERIODIC, BoundaryCondition.NONE]
is_periodic = (sboundary[0]==BoundaryCondition.PERIODIC \ is_periodic = (sboundary[0]==BoundaryCondition.PERIODIC \
and sboundary[1]==BoundaryCondition.PERIODIC) and sboundary[1]==BoundaryCondition.PERIODIC)
known_vars = known_vars or dict()
local_size_known = ('local_size' in known_vars) local_size_known = ('local_size' in known_vars)
itype = 'int'
vftype = tg.vtype(ftype, nparticles)
vitype = tg.vtype(itype, nparticles)
name = DirectionalRemeshKernel.codegen_name(work_dim, direction, name = DirectionalRemeshKernel.codegen_name(work_dim, direction,
remesh_kernel, ftype, remesh_kernel, ftype,
nparticles,nscalars, remesh_criteria_eps, nparticles,nscalars, remesh_criteria_eps,
...@@ -155,28 +154,28 @@ class DirectionalRemeshKernel(KernelCodeGenerator): ...@@ -155,28 +154,28 @@ class DirectionalRemeshKernel(KernelCodeGenerator):
kargs = ArgDict() kargs = ArgDict()
self.position = OpenClArrayBackend.build_codegen_argument(kargs, name='position', self.position = OpenClArrayBackend.build_codegen_argument(kargs, name='position',
storage=self._global, ctype=ftype, typegen=typegen, storage=self._global, ctype=ftype, typegen=typegen,
restrict=True, const=True) ptr_restrict=True, ptr_const=True)
if is_inplace: if is_inplace:
self.scalars_in = tuple( self.scalars_in = tuple(
OpenClArrayBackend.build_codegen_argument(kargs, name=' s{}_in'.format(i), OpenClArrayBackend.build_codegen_argument(kargs, name=' s{}_in'.format(i),
storage=self._global, ctype=ftype, typegen=typegen, storage=self._global, ctype=ftype, typegen=typegen,
restrict=True, const=False) for i in xrange(nscalars)) ptr_restrict=True, ptr_const=False) for i in xrange(nscalars))
self.scalars_out = self.scalars_in self.scalars_out = self.scalars_in
else: else:
self.scalars_in = tuple( self.scalars_in = tuple(
OpenClArrayBackend.build_codegen_argument(kargs, name='s{}_in'.format(i), OpenClArrayBackend.build_codegen_argument(kargs, name='s{}_in'.format(i),
storage=self._global, ctype=ftype, typegen=typegen, storage=self._global, ctype=ftype, typegen=typegen,
restrict=True, const=True) for i in xrange(nscalars)) ptr_restrict=True, ptr_const=True) for i in xrange(nscalars))
self.scalars_out = tuple( self.scalars_out = tuple(
OpenClArrayBackend.build_codegen_argument(kargs, name='s{}_out'.format(i), OpenClArrayBackend.build_codegen_argument(kargs, name='s{}_out'.format(i),
storage=self._global, ctype=ftype, typegen=typegen, storage=self._global, ctype=ftype, typegen=typegen,
restrict=True, const=False) for i in xrange(nscalars)) ptr_restrict=True, ptr_const=False) for i in xrange(nscalars))
if debug_mode: if debug_mode:
kargs['dbg0'] = CodegenVariable(storage=self._global,name='dbg0',ctype=itype, kargs['dbg0'] = CodegenVariable(storage=self._global,name='dbg0',ctype=itype,
typegen=typegen, restrict=True,ptr=True,const=False,add_impl_const=True) typegen=typegen, ptr_restrict=True,ptr=True,const=False,add_impl_const=True)
kargs['dbg1'] = CodegenVariable(storage=self._global,name='dbg1',ctype=itype, kargs['dbg1'] = CodegenVariable(storage=self._global,name='dbg1',ctype=itype,
typegen=typegen, restrict=True,ptr=True,const=False,add_impl_const=True) typegen=typegen, ptr_restrict=True,ptr=True,const=False,add_impl_const=True)
kargs['position_mesh_info'] = kernel_reqs['MeshInfoStruct'].build_codegen_variable( kargs['position_mesh_info'] = kernel_reqs['MeshInfoStruct'].build_codegen_variable(
const=True, name='position_mesh_info') const=True, name='position_mesh_info')
...@@ -187,7 +186,7 @@ class DirectionalRemeshKernel(KernelCodeGenerator): ...@@ -187,7 +186,7 @@ class DirectionalRemeshKernel(KernelCodeGenerator):
if not local_size_known: if not local_size_known:
kargs['buffer'] = CodegenVariable(storage=self._local, ctype=ftype, kargs['buffer'] = CodegenVariable(storage=self._local, ctype=ftype,
add_impl_const=True, name='buffer', ptr=True, restrict=True, add_impl_const=True, name='buffer', ptr=True, ptr_restrict=True,
typegen=typegen, nl=False) typegen=typegen, nl=False)
return kargs return kargs
...@@ -256,7 +255,7 @@ class DirectionalRemeshKernel(KernelCodeGenerator): ...@@ -256,7 +255,7 @@ class DirectionalRemeshKernel(KernelCodeGenerator):
line_index = CodegenVariable(name='line_index', ctype=itype, typegen=tg) line_index = CodegenVariable(name='line_index', ctype=itype, typegen=tg)
line_offset = CodegenVariable(name='line_offset', ctype=itype, typegen=tg,const=True) line_offset = CodegenVariable(name='line_offset', ctype=itype, typegen=tg,const=True)
line_velocity = CodegenVariable(name='line_velocity', ctype=ftype, ptr=True, line_velocity = CodegenVariable(name='line_velocity', ctype=ftype, ptr=True,
storage='__global', restrict=True, const=True, add_impl_const=True, typegen=tg) storage='__global', ptr_restrict=True, ptr_const=True, const=True, typegen=tg)
position_global_id = CodegenVectorClBuiltin('pos_gid', itype, work_dim, typegen=tg) position_global_id = CodegenVectorClBuiltin('pos_gid', itype, work_dim, typegen=tg)
scalars_global_id = tuple(CodegenVectorClBuiltin('S{}_gid'.format(i), scalars_global_id = tuple(CodegenVectorClBuiltin('S{}_gid'.format(i),
...@@ -283,7 +282,7 @@ class DirectionalRemeshKernel(KernelCodeGenerator): ...@@ -283,7 +282,7 @@ class DirectionalRemeshKernel(KernelCodeGenerator):
buf = self.vars['buffer'] buf = self.vars['buffer']
for i in xrange(nscalars): for i in xrange(nscalars):
Si = CodegenVariable(name='S{}'.format(i),ctype=ftype,typegen=tg, Si = CodegenVariable(name='S{}'.format(i),ctype=ftype,typegen=tg,
restrict=True, ptr=True, storage=self._local, ptr_restrict=True, ptr=True, storage=self._local,
const=True, const=True,
init='{} + {}*{}'.format(buf,i,cache_width)) init='{} + {}*{}'.format(buf,i,cache_width))
cached_scalars.append(Si) cached_scalars.append(Si)
...@@ -380,9 +379,10 @@ class DirectionalRemeshKernel(KernelCodeGenerator): ...@@ -380,9 +379,10 @@ class DirectionalRemeshKernel(KernelCodeGenerator):
sgid.declare(al,align=True) sgid.declare(al,align=True)
s.jumpline() s.jumpline()
with s._align_() as al: #with s._align_() as al:
for var in cached_scalars: #for var in cached_scalars:
var.declare(al,align=True); #var.declare(al,align=True);
s.decl_vars(*cached_scalars)
s.jumpline() s.jumpline()
......
...@@ -2985,20 +2985,20 @@ class OpenClArrayBackend(ArrayBackend): ...@@ -2985,20 +2985,20 @@ class OpenClArrayBackend(ArrayBackend):
assert 'add_impl_const' not in kargs assert 'add_impl_const' not in kargs
assert 'init' not in kargs assert 'init' not in kargs
args[base] = CodegenVariable(name=base, args[base] = CodegenVariable(name=base, typegen=typegen,
typegen=typegen, ctype=ctype, ptr=ptr, const=const, ctype=ctype, ptr=ptr, const=const,
add_impl_const=True, nl=False, **kargs) add_impl_const=True, nl=False, **kargs)
args[offset] = CodegenVariable(name=offset, args[offset] = CodegenVariable(name=offset,
typegen=typegen, ctype=itype, typegen=typegen, ctype=itype,
add_impl_const=True,nl=True) add_impl_const=True, nl=True)
char_alias = args[base].alias(None, ctype='char', restrict=False, volatile=False).full_ctype() char_alias = args[base].full_ctype(ctype='char', cast=True)
ctype_alias = args[base].alias(None, ctype=ctype, restrict=False, volatile=False).full_ctype() ctype_alias = args[base].full_ctype(cast=True)
init = '({})(({})({})+{})'.format(ctype_alias, char_alias, base, offset) init = '({})(({})({})+{})'.format(ctype_alias, char_alias, base, offset)
var = CodegenVariable(name=name, var = CodegenVariable(name=name, typegen=typegen,
typegen=typegen, ctype=ctype, ptr=ptr, const=const, ctype=ctype, ptr=ptr, const=const,
add_impl_const=True, nl=False, add_impl_const=True, nl=False,
init=init, **kargs) init=init, **kargs)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment