From f4981227e8f19e623666c44ad79578504d215776 Mon Sep 17 00:00:00 2001
From: Keck Jean-Baptiste <jean-baptiste.keck@imag.fr>
Date: Sat, 10 Dec 2016 14:56:08 +0100
Subject: [PATCH] directionnal stretching enhanced

---
 .../codegen/kernels/directional_stretching.py | 124 ++++++++++++------
 1 file changed, 86 insertions(+), 38 deletions(-)

diff --git a/hysop/codegen/kernels/directional_stretching.py b/hysop/codegen/kernels/directional_stretching.py
index 92eb2f36e..cd0119ad6 100644
--- a/hysop/codegen/kernels/directional_stretching.py
+++ b/hysop/codegen/kernels/directional_stretching.py
@@ -62,7 +62,8 @@ class DirectionalStretchingKernel(KernelCodeGenerator):
         formulation  = StretchingFormulation.value(formulation)
         sformulation = StretchingFormulation.svalue(formulation)
         local_size_known = ('local_size' in known_vars)
-        is_conservative = (formulation==StretchingFormulation.CONSERVATIVE)
+        is_conservative  = (formulation==StretchingFormulation.CONSERVATIVE)
+        is_periodic      = (boundary=='periodic')
         
         if cached:
             storage = OpenClCodeGenerator.default_keywords['local']
@@ -100,6 +101,7 @@ class DirectionalStretchingKernel(KernelCodeGenerator):
         self.storage          = storage
         self.local_size_known = local_size_known
         self.is_conservative  = is_conservative
+        self.is_periodic      = is_periodic
 
         self.gencode()
 
@@ -117,7 +119,7 @@ class DirectionalStretchingKernel(KernelCodeGenerator):
         reqs = WriteOnceDict()
         
         compute_id = ComputeIndexFunction(typegen=typegen, dim=work_dim, itype='int', 
-                wrap=False)
+                wrap=(boundary=='periodic'))
         reqs['compute_id'] = compute_id
 
         mesh_base_struct = MeshBaseStruct(typegen=typegen, typedef='MeshBase_s')
@@ -207,7 +209,10 @@ class DirectionalStretchingKernel(KernelCodeGenerator):
 
         formulation      = s.formulation
         is_conservative  = s.is_conservative
+        is_periodic      = s.is_periodic
         local_size_known = s.local_size_known
+
+        vtype = tg.vtype(ftype,work_dim)
        
         global_id     = s.vars['global_id']
         local_id      = s.vars['local_id']
@@ -243,8 +248,8 @@ class DirectionalStretchingKernel(KernelCodeGenerator):
                 const=True,value=self.min_ghosts())
         local_work = CodegenVariable('lwork','int',tg,const=True)
         
+        cached_vars = ArgDict()
         if cached:
-            cached_vars = ArgDict()
             for i in xrange(work_dim):
                 Vi = self.svelocity+self.xyz[i]
                 if local_size_known:
@@ -268,21 +273,42 @@ class DirectionalStretchingKernel(KernelCodeGenerator):
                     Wic = CodegenVariable(storage=storage,name=Wi+'c',ctype=ftype,typegen=tg,
                                     const=True, restrict=True,ptr=True,init=init)
                 cached_vars[Wi] = Wic 
+
+        _U   = self.svelocity
+        size = cache_ghosts.value
+        Ur  = CodegenArray(storage='__local',name=_U+'r',dim=1,ctype=vtype,typegen=tg,
+                            shape=(2*size,))
+        if is_periodic:
+            Ul  = CodegenArray(storage='__local',name=_U+'l',dim=1,ctype=vtype,typegen=tg,
+                                shape=(size,))
+        if is_conservative:
+            _W   = self.svorticity
+            size = cache_ghosts.value
+            Wr  = CodegenArray(storage='__local',name=_W+'r',dim=1,ctype=vtype,typegen=tg,
+                                shape=(2*size,))
+            if is_periodic:
+                Wl  = CodegenArray(storage='__local',name=_W+'l',dim=1,ctype=vtype,typegen=tg,
+                                    shape=(size,))
             
         @contextmanager
         def _work_iterate_(i):
             try:
+                fval  = local_id[0] if i==0 else global_id.fval(i)
+                gsize = local_work() if i==0 else global_size[i]
+                N     = grid_size[i]
+                ghosts = compute_grid_ghosts[i]
+                if i==0:
+                    N = '{}+2*{}'.format(N,cache_ghosts())
+                    ghosts = '({}-{})'.format(ghosts,cache_ghosts())
+
                 with s._for_('int {i}={fval}; {i}<{N}; {i}+={gsize}'.format(
-                        i=' ji'[i],
-                        fval=global_id.fval(i), 
-                        gsize=global_size[i],
-                        N=grid_size[i])) as ctx:
-                    s.append('{} = {}+{};'.format(global_id[i], ' ji'[i], 
-                        compute_grid_ghosts[i]))
-                    yield ctx
+                        i='kji'[i], fval=fval, gsize=gsize,N=N)) as ctx:
+                        
+                        s.append('{} = {}+{};'.format(global_id[i], 'kji'[i], ghosts))
+                        yield ctx
             except:
                 raise
-        nested_loops = [_work_iterate_(i) for i in xrange(dim-1,0,-1)]
+        nested_loops = [_work_iterate_(i) for i in xrange(dim-1,-1,-1)]
         
         with s._kernel_():
             s.jumpline()
@@ -310,30 +336,21 @@ class DirectionalStretchingKernel(KernelCodeGenerator):
                     for varname,var in cached_vars.iteritems():
                         var.declare(al,align=True)
                 s.jumpline()
+        
+            if is_periodic:
+                Ul.declare(s)
+            Ur.declare(s)
+            if is_conservative:
+                if is_periodic:
+                    Wl.declare(s)
+                Wr.declare(s)
+            s.jumpline()
 
             global_id.declare(s,init=False)
             global_index.declare(s)
 
             s.jumpline()
                 
-            s.append('{} = get_group_id(0)*{} + {} + ({} - {});'.format(
-                global_id[0],local_work(),local_id[0],
-                compute_grid_ghosts[0],cache_ghosts()))
-
-            if boundary is None:
-                with s._if_('{} >= {}'.format(global_id[0],compute_grid_size[0]),compact=True):
-                    s.append('return;')
-            elif boundary=='periodic':
-                with s._if_('{} >= {}+{}'.format(global_id[0],compute_grid_size[0],
-                    cache_ghosts()),compact=True):
-                    s.append('return;')
-                s.append('{} = ({} + {}) % {};'.format(
-                    global_id[0], global_id[0], compute_grid_size[0], compute_grid_size[0]))
-            else:
-                raise NotImplemented()
-            
-            s.jumpline()
-
             with contextlib.nested(*nested_loops):
 
                 s.jumpline()
@@ -341,8 +358,7 @@ class DirectionalStretchingKernel(KernelCodeGenerator):
                 init = compute_index(idx=global_id, size=compute_grid_size)
                 s.append('{} = {};'.format(global_index(), init))
                     
-                winit = ''
-                uinit = ''
+                winit, uinit = '',''
                 for i in xrange(work_dim):
                     Wi = self.svorticity+self.xyz[i]
                     Ui = self.svelocity+self.xyz[i]
@@ -352,8 +368,27 @@ class DirectionalStretchingKernel(KernelCodeGenerator):
                 winit='({}{})({})'.format(ftype, work_dim, winit[:-1])
 
                 s.jumpline()
-                U.declare(s,init=uinit)
-                W.declare(s,init=winit)
+                s.append('{} {},{};'.format(U.ctype,U(),W()))
+                with s._if_('k == 0'):
+                    s.append('{} = {};'.format(U(), uinit))
+                    s.append('{} = {};'.format(W(), winit))
+                    if is_periodic:
+                        with s._if_('{} < {}'.format(local_id[0],cache_ghosts())):
+                            s.append('{} = {};'.format(Ul[local_id[0]], U()))
+                            if is_conservative:
+                                s.append('{} = {};'.format(Wl[local_id[0]], W()))
+                with s._else_():
+                    with s._if_('{} < 2*{}'.format(local_id[0],cache_ghosts())):
+                        s.append('{} = {};'.format(U(), Ur[local_id[0]]))
+                        s.append('{} = {};'.format(W(), Wr[local_id[0]]))
+                    if is_periodic:
+                        with s._elif_('{} >= {}'.format(global_id[0],grid_size[0])):
+                            _id = '{}-{}'.format(global_id[0],grid_size[0])
+                            s.append('{} = {};'.format(U(), Ul[_id]))
+                            s.append('{} = {};'.format(W(), Wl[_id]))
+                    with s._else_():
+                        s.append('{} = {};'.format(U(), uinit))
+                        s.append('{} = {};'.format(W(), winit))
 
                 s.jumpline()
 
@@ -363,13 +398,26 @@ class DirectionalStretchingKernel(KernelCodeGenerator):
                         Uic = cached_vars[Ui]
                         code = '{} = {};'.format(Uic[local_id[0]],U[i])
                         s.append(code)
-                    s.jumpline()
+                    #Wi  = self.svorticity+self.xyz[direction]
+                    #Wic = cached_vars[Wi]
+                    #code = '{} = {};'.format(Wic[local_id[0]],W[direction])
+                    #s.append(code)
+                
+                s.jumpline()
+                s.mem_fence(_local=True,read=True)
+                with s._if_('{} >= {}-2*{}'.format(local_id[0],local_size[0],cache_ghosts())):
+                    _id = '{}-{}+2*{}'.format(local_id[0],local_size[0],cache_ghosts())
+                    s.append('{} = {};'.format(Ur[_id], U()))
+                    s.append('{} = {};'.format(Wr[_id], W()))
+                s.mem_fence(_local=True,write=True)
+                s.jumpline()
+            
 
                 
-                cond = '({lid}>={ghosts}) && ({lid}<{lwork}+{ghosts})'.format(
+                cond = '({lid}>={ghosts}) && ({lid}<{L}-{ghosts})'.format(
                         lid=local_id[0],
                         ghosts=cache_ghosts(),
-                        lwork=local_work())
+                        L=local_size[0])
                 with s._if_(cond):
                     for i in xrange(work_dim):
                         Wi = self.svorticity+self.xyz[i]
@@ -395,10 +443,10 @@ if __name__ == '__main__':
         rk_scheme=ExplicitRungeKutta('RK2'),
         cached=True,
         symbolic_mode=True,
-        boundary=None,#'periodic',
+        boundary='periodic',
         known_vars=dict(
             mesh_info=mesh_info,
-            local_size=local_size[:dim]
+            #local_size=local_size[:dim]
         )
     )
     dsk.edit()
-- 
GitLab