Skip to content
Snippets Groups Projects
Commit f4981227 authored by Keck Jean-Baptiste's avatar Keck Jean-Baptiste
Browse files

directionnal stretching enhanced

parent 5e1781ce
No related branches found
No related tags found
No related merge requests found
......@@ -62,7 +62,8 @@ class DirectionalStretchingKernel(KernelCodeGenerator):
formulation = StretchingFormulation.value(formulation)
sformulation = StretchingFormulation.svalue(formulation)
local_size_known = ('local_size' in known_vars)
is_conservative = (formulation==StretchingFormulation.CONSERVATIVE)
is_conservative = (formulation==StretchingFormulation.CONSERVATIVE)
is_periodic = (boundary=='periodic')
if cached:
storage = OpenClCodeGenerator.default_keywords['local']
......@@ -100,6 +101,7 @@ class DirectionalStretchingKernel(KernelCodeGenerator):
self.storage = storage
self.local_size_known = local_size_known
self.is_conservative = is_conservative
self.is_periodic = is_periodic
self.gencode()
......@@ -117,7 +119,7 @@ class DirectionalStretchingKernel(KernelCodeGenerator):
reqs = WriteOnceDict()
compute_id = ComputeIndexFunction(typegen=typegen, dim=work_dim, itype='int',
wrap=False)
wrap=(boundary=='periodic'))
reqs['compute_id'] = compute_id
mesh_base_struct = MeshBaseStruct(typegen=typegen, typedef='MeshBase_s')
......@@ -207,7 +209,10 @@ class DirectionalStretchingKernel(KernelCodeGenerator):
formulation = s.formulation
is_conservative = s.is_conservative
is_periodic = s.is_periodic
local_size_known = s.local_size_known
vtype = tg.vtype(ftype,work_dim)
global_id = s.vars['global_id']
local_id = s.vars['local_id']
......@@ -243,8 +248,8 @@ class DirectionalStretchingKernel(KernelCodeGenerator):
const=True,value=self.min_ghosts())
local_work = CodegenVariable('lwork','int',tg,const=True)
cached_vars = ArgDict()
if cached:
cached_vars = ArgDict()
for i in xrange(work_dim):
Vi = self.svelocity+self.xyz[i]
if local_size_known:
......@@ -268,21 +273,42 @@ class DirectionalStretchingKernel(KernelCodeGenerator):
Wic = CodegenVariable(storage=storage,name=Wi+'c',ctype=ftype,typegen=tg,
const=True, restrict=True,ptr=True,init=init)
cached_vars[Wi] = Wic
_U = self.svelocity
size = cache_ghosts.value
Ur = CodegenArray(storage='__local',name=_U+'r',dim=1,ctype=vtype,typegen=tg,
shape=(2*size,))
if is_periodic:
Ul = CodegenArray(storage='__local',name=_U+'l',dim=1,ctype=vtype,typegen=tg,
shape=(size,))
if is_conservative:
_W = self.svorticity
size = cache_ghosts.value
Wr = CodegenArray(storage='__local',name=_W+'r',dim=1,ctype=vtype,typegen=tg,
shape=(2*size,))
if is_periodic:
Wl = CodegenArray(storage='__local',name=_W+'l',dim=1,ctype=vtype,typegen=tg,
shape=(size,))
@contextmanager
def _work_iterate_(i):
try:
fval = local_id[0] if i==0 else global_id.fval(i)
gsize = local_work() if i==0 else global_size[i]
N = grid_size[i]
ghosts = compute_grid_ghosts[i]
if i==0:
N = '{}+2*{}'.format(N,cache_ghosts())
ghosts = '({}-{})'.format(ghosts,cache_ghosts())
with s._for_('int {i}={fval}; {i}<{N}; {i}+={gsize}'.format(
i=' ji'[i],
fval=global_id.fval(i),
gsize=global_size[i],
N=grid_size[i])) as ctx:
s.append('{} = {}+{};'.format(global_id[i], ' ji'[i],
compute_grid_ghosts[i]))
yield ctx
i='kji'[i], fval=fval, gsize=gsize,N=N)) as ctx:
s.append('{} = {}+{};'.format(global_id[i], 'kji'[i], ghosts))
yield ctx
except:
raise
nested_loops = [_work_iterate_(i) for i in xrange(dim-1,0,-1)]
nested_loops = [_work_iterate_(i) for i in xrange(dim-1,-1,-1)]
with s._kernel_():
s.jumpline()
......@@ -310,30 +336,21 @@ class DirectionalStretchingKernel(KernelCodeGenerator):
for varname,var in cached_vars.iteritems():
var.declare(al,align=True)
s.jumpline()
if is_periodic:
Ul.declare(s)
Ur.declare(s)
if is_conservative:
if is_periodic:
Wl.declare(s)
Wr.declare(s)
s.jumpline()
global_id.declare(s,init=False)
global_index.declare(s)
s.jumpline()
s.append('{} = get_group_id(0)*{} + {} + ({} - {});'.format(
global_id[0],local_work(),local_id[0],
compute_grid_ghosts[0],cache_ghosts()))
if boundary is None:
with s._if_('{} >= {}'.format(global_id[0],compute_grid_size[0]),compact=True):
s.append('return;')
elif boundary=='periodic':
with s._if_('{} >= {}+{}'.format(global_id[0],compute_grid_size[0],
cache_ghosts()),compact=True):
s.append('return;')
s.append('{} = ({} + {}) % {};'.format(
global_id[0], global_id[0], compute_grid_size[0], compute_grid_size[0]))
else:
raise NotImplemented()
s.jumpline()
with contextlib.nested(*nested_loops):
s.jumpline()
......@@ -341,8 +358,7 @@ class DirectionalStretchingKernel(KernelCodeGenerator):
init = compute_index(idx=global_id, size=compute_grid_size)
s.append('{} = {};'.format(global_index(), init))
winit = ''
uinit = ''
winit, uinit = '',''
for i in xrange(work_dim):
Wi = self.svorticity+self.xyz[i]
Ui = self.svelocity+self.xyz[i]
......@@ -352,8 +368,27 @@ class DirectionalStretchingKernel(KernelCodeGenerator):
winit='({}{})({})'.format(ftype, work_dim, winit[:-1])
s.jumpline()
U.declare(s,init=uinit)
W.declare(s,init=winit)
s.append('{} {},{};'.format(U.ctype,U(),W()))
with s._if_('k == 0'):
s.append('{} = {};'.format(U(), uinit))
s.append('{} = {};'.format(W(), winit))
if is_periodic:
with s._if_('{} < {}'.format(local_id[0],cache_ghosts())):
s.append('{} = {};'.format(Ul[local_id[0]], U()))
if is_conservative:
s.append('{} = {};'.format(Wl[local_id[0]], W()))
with s._else_():
with s._if_('{} < 2*{}'.format(local_id[0],cache_ghosts())):
s.append('{} = {};'.format(U(), Ur[local_id[0]]))
s.append('{} = {};'.format(W(), Wr[local_id[0]]))
if is_periodic:
with s._elif_('{} >= {}'.format(global_id[0],grid_size[0])):
_id = '{}-{}'.format(global_id[0],grid_size[0])
s.append('{} = {};'.format(U(), Ul[_id]))
s.append('{} = {};'.format(W(), Wl[_id]))
with s._else_():
s.append('{} = {};'.format(U(), uinit))
s.append('{} = {};'.format(W(), winit))
s.jumpline()
......@@ -363,13 +398,26 @@ class DirectionalStretchingKernel(KernelCodeGenerator):
Uic = cached_vars[Ui]
code = '{} = {};'.format(Uic[local_id[0]],U[i])
s.append(code)
s.jumpline()
#Wi = self.svorticity+self.xyz[direction]
#Wic = cached_vars[Wi]
#code = '{} = {};'.format(Wic[local_id[0]],W[direction])
#s.append(code)
s.jumpline()
s.mem_fence(_local=True,read=True)
with s._if_('{} >= {}-2*{}'.format(local_id[0],local_size[0],cache_ghosts())):
_id = '{}-{}+2*{}'.format(local_id[0],local_size[0],cache_ghosts())
s.append('{} = {};'.format(Ur[_id], U()))
s.append('{} = {};'.format(Wr[_id], W()))
s.mem_fence(_local=True,write=True)
s.jumpline()
cond = '({lid}>={ghosts}) && ({lid}<{lwork}+{ghosts})'.format(
cond = '({lid}>={ghosts}) && ({lid}<{L}-{ghosts})'.format(
lid=local_id[0],
ghosts=cache_ghosts(),
lwork=local_work())
L=local_size[0])
with s._if_(cond):
for i in xrange(work_dim):
Wi = self.svorticity+self.xyz[i]
......@@ -395,10 +443,10 @@ if __name__ == '__main__':
rk_scheme=ExplicitRungeKutta('RK2'),
cached=True,
symbolic_mode=True,
boundary=None,#'periodic',
boundary='periodic',
known_vars=dict(
mesh_info=mesh_info,
local_size=local_size[:dim]
#local_size=local_size[:dim]
)
)
dsk.edit()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment