From 10b41b00df3d457cc02ddcdb709b10378cad3363 Mon Sep 17 00:00:00 2001
From: Keck Jean-Baptiste <jean-baptiste.keck@imag.fr>
Date: Sat, 8 Apr 2017 01:21:09 +0200
Subject: [PATCH] fixed opencl kernel codegen

---
 hysop/__init__.py                             |   5 +-
 hysop/__init__.py.in                          |   1 +
 hysop/backend/codegen/base/variables.py       |   2 +-
 .../codegen/functions/apply_stencil.py        |  17 +-
 .../codegen/functions/stretching_rhs.py       |   4 +-
 .../tests/test_directional_stretching.py      |   2 +-
 hysop/backend/opencl/opencl_printer.py        | 151 ++++++++++++++++++
 hysop/backend/opencl/opencl_types.py          |  33 +++-
 hysop/numerics/stencil/stencil.py             |  10 +-
 hysop/operator/hdf_io.py                      |   2 +
 hysop/tools/sympy_utils.py                    |  10 --
 setup.py.in                                   |   4 +-
 12 files changed, 201 insertions(+), 40 deletions(-)
 create mode 100644 hysop/backend/opencl/opencl_printer.py

diff --git a/hysop/__init__.py b/hysop/__init__.py
index 9fa51f813..0d65de499 100644
--- a/hysop/__init__.py
+++ b/hysop/__init__.py
@@ -16,7 +16,7 @@ __FFTW_ENABLED__   = "ON"   is "ON"
 __SCALES_ENABLED__ = "ON" is "ON"
 __OPTIMIZE__       = "OFF"       is "ON"
 
-__VERBOSE__        = True
+__VERBOSE__        = "ON"   in ["1", "3"]
 __DEBUG__          = "ON"   in ["2", "3"]
 __KERNEL_DEBUG__   = "ON"   in ["4", "3"]
 __PROFILE__        = "OFF" in ["0", "1"]
@@ -25,7 +25,7 @@ __ENABLE_LONG_TESTS__ = "OFF" is "ON"
 
 # OpenCL
 __DEFAULT_PLATFORM_ID__ = 1
-__DEFAULT_DEVICE_ID__   = 0
+__DEFAULT_DEVICE_ID__   = 1
 
 
 
@@ -93,4 +93,5 @@ default_path = IO.default_path()
 cache_path   = IO.default_cache_path()
 msg_io =  '\n*Default path for all i/o is \'{}\'.'.format(default_path)
 msg_io += '\n*Default path for caching is \'{}\'.'.format(cache_path)
+msg_io += '\n'
 vprint(msg_io)
diff --git a/hysop/__init__.py.in b/hysop/__init__.py.in
index e46d03399..c7e73ba7c 100644
--- a/hysop/__init__.py.in
+++ b/hysop/__init__.py.in
@@ -92,4 +92,5 @@ default_path = IO.default_path()
 cache_path   = IO.default_cache_path()
 msg_io =  '\n*Default path for all i/o is \'{}\'.'.format(default_path)
 msg_io += '\n*Default path for caching is \'{}\'.'.format(cache_path)
+msg_io += '\n'
 vprint(msg_io)
diff --git a/hysop/backend/codegen/base/variables.py b/hysop/backend/codegen/base/variables.py
index 9681eba4a..166d4014d 100644
--- a/hysop/backend/codegen/base/variables.py
+++ b/hysop/backend/codegen/base/variables.py
@@ -449,7 +449,7 @@ class CodegenVector(CodegenVariable):
     def __getitem__(self,i):
         return self.sval(i)
     
-    def __str__(self):
+    def __repr__(self):
         if self.is_symbolic():
             return '{}({})'.format(self.name,self.ctype)
         else:
diff --git a/hysop/backend/codegen/functions/apply_stencil.py b/hysop/backend/codegen/functions/apply_stencil.py
index cd65cb631..ab145228f 100644
--- a/hysop/backend/codegen/functions/apply_stencil.py
+++ b/hysop/backend/codegen/functions/apply_stencil.py
@@ -13,7 +13,7 @@ from hysop.numerics.stencil.stencil import Stencil
 class ApplyStencilFunction(OpenClFunctionCodeGenerator):
     
     def __init__(self,typegen,stencil,ftype,
-            symbol2var=None,
+            symbol2vars=None,
             components=1, vectorize=True,
             extra_inputs=[],
             scalar_inputs=[],
@@ -27,10 +27,10 @@ class ApplyStencilFunction(OpenClFunctionCodeGenerator):
             known_args=None):
         
         check_instance(stencil,Stencil)
-        check_instance(symbol2var, dict, keys=sm.Symbol, values=CodegenVariable, allow_none=True)
-        assert set(symbol2var.keys())==stencil.variables()
+        check_instance(symbol2vars, dict, keys=sm.Symbol, values=CodegenVariable, allow_none=True)
+        assert set(symbol2vars.keys())==stencil.variables()
 
-        extra_inputs  = set(extra_inputs + symbol2var.values())
+        extra_inputs  = set(extra_inputs + symbol2vars.values())
         scalar_inputs = set(scalar_inputs)
         vector_inputs = set(vector_inputs)
 
@@ -96,7 +96,7 @@ class ApplyStencilFunction(OpenClFunctionCodeGenerator):
         self.has_custom_id = has_custom_id
         self.custom_id = custom_id
         self.op = op
-        self.symbol2var = symbol2var
+        self.symbol2vars = symbol2vars
 
         self.gencode()
 
@@ -174,8 +174,7 @@ class ApplyStencilFunction(OpenClFunctionCodeGenerator):
                             operands['vinput{}'.format(j)] = s.vars[vn]()
                         for j,vn in enumerate(s.scalar_inputs):
                             operands['sinput{}'.format(j)] = s.vars[vn]()
-                        for (off,coeff) in stencil.iteritems(include_factor=False,
-                                                                  svars=s.symbol2var):
+                        for (off,coeff) in stencil.iteritems(include_factor=False):
                             if coeff=='0':
                                 continue
                             if not has_custom_id:
@@ -190,7 +189,7 @@ class ApplyStencilFunction(OpenClFunctionCodeGenerator):
                                 operands['id'] = '{}+{}'.format(offset(),strided)
                             else:
                                 operands['id'] = s.custom_id.format(offset=tg.dump(off[0]))
-                            code = '{} += {} $* {};'.format(_res,tg.dump(coeff),
+                            code = '{} += {} $* {};'.format(_res,tg.dump_expr(coeff,symbol2vars=s.symbol2vars),
                                     s.op.format(**operands))
                             al.append(code)
                         if vectorized:
@@ -200,7 +199,7 @@ class ApplyStencilFunction(OpenClFunctionCodeGenerator):
             for mult in s.multipliers:
                 mul+='{}*'.format(s.vars[mult]())
             if stencil.has_factor():
-                mul+='{}*'.format(stencil.format_factor(s.symbol2var))
+                mul+='{}*'.format(tg.dump_expr(stencil.factor, symbol2vars=s.symbol2vars))
 
             ret = 'return {}{};'.format(mul,res())
             s.append(ret)
diff --git a/hysop/backend/codegen/functions/stretching_rhs.py b/hysop/backend/codegen/functions/stretching_rhs.py
index 47b040e00..2152bb279 100644
--- a/hysop/backend/codegen/functions/stretching_rhs.py
+++ b/hysop/backend/codegen/functions/stretching_rhs.py
@@ -160,11 +160,11 @@ class DirectionalStretchingRhsFunction(OpenClFunctionCodeGenerator):
         
         stencil = self.build_stencil(order)
         stencil.replace_symbols({stencil.dx:1/inv_dx_s})
-        symbol2var = {inv_dx_s:inv_dx_var}
+        symbol2vars = {inv_dx_s:inv_dx_var}
 
         apply_stencil = ApplyStencilFunction(typegen=typegen,
                 stencil=stencil,
-                symbol2var=symbol2var,
+                symbol2vars=symbol2vars,
                 ftype=ftype, itype=itype, 
                 data_storage=storage, 
                 vectorize=vectorize_u,
diff --git a/hysop/backend/codegen/kernels/tests/test_directional_stretching.py b/hysop/backend/codegen/kernels/tests/test_directional_stretching.py
index 235e864cc..1d4373d3e 100644
--- a/hysop/backend/codegen/kernels/tests/test_directional_stretching.py
+++ b/hysop/backend/codegen/kernels/tests/test_directional_stretching.py
@@ -557,7 +557,7 @@ class TestDirectionalStretching(object):
 
 
 if __name__ == '__main__':
-    TestDirectionalStretching.setup_class(do_extra_tests=True, enable_error_plots=False)
+    TestDirectionalStretching.setup_class(do_extra_tests=False, enable_error_plots=False)
     test = TestDirectionalStretching()
     
     test.test_stretching_gradUW_Euler()
diff --git a/hysop/backend/opencl/opencl_printer.py b/hysop/backend/opencl/opencl_printer.py
new file mode 100644
index 000000000..eda115dc1
--- /dev/null
+++ b/hysop/backend/opencl/opencl_printer.py
@@ -0,0 +1,151 @@
+
+
+import sympy as sm
+from sympy.printing.ccode import CCodePrinter
+from hysop.tools.types import check_instance
+from hysop.backend.opencl.opencl_types import OpenClTypeGen
+
+# /!\ TODO complete known_functions list with OpenCL builtins
+# - keys are sympy function names (beware to capital letters)
+# - values are either strings or list of tuples (predicate(inputs),string)
+#   that corresponds to OpenCL function builtins.
+# Here are some attributes that can be checked in predicates:
+#  is_zero
+#  is_finite 	 is_integer 	 
+#  is_negative 	 is_positive 	 
+#  is_rational 	 is_real 	 
+known_functions = {
+    'Abs': [(lambda x: x.is_integer, 'abs'),'fabs'],
+    'min': [(lambda x,y: x.is_integer and y.is_integer, 'min'),'fmin'],
+    'max': [(lambda x,y: x.is_integer and y.is_integer, 'max'),'fmax'],
+    'sqrt': 'sqrt',
+    'gamma': 'tgamma',
+    
+    'sin': 'sin',
+    'cos': 'cos',
+    'tan': 'tan',
+    'asin': 'asin',
+    'acos': 'acos',
+    'atan': 'atan',
+    'atan2': 'atan2',
+    
+    'sinh': 'sinh',
+    'cosh': 'cosh',
+    'tanh': 'tanh',
+    'asinh': 'asinh',
+    'acosh': 'acosh',
+    'atanh': 'atanh',
+    
+    'exp': 'exp',
+    'log': 'log',
+    'erf': 'erf',
+    'floor': 'floor',
+    'ceiling': 'ceil',
+}
+
+# OpenCl 2.2 reserved keywords (see opencl documentation)
+reserved_words = [
+    
+    # C++14 keywords
+    'alignas', 'continue', 'friend', 'register', 'true', 
+    'alignof', 'decltype', 'goto', 'reinterpret_cast', 'try', 
+    'asm', 'default', 'if', 'return', 'typedef', 
+    'auto', 'delete', 'inline', 'short', 'typeid', 
+    'bool', 'do', 'int', 'signed', 'typename', 
+    'break', 'double', 'long', 'sizeof', 'union', 
+    'case', 'dynamic_cast', 'mutable', 'static', 'unsigned', 
+    'catch', 'else', 'namespace', 'static_assert', 'using', 
+    'char', 'enum', 'new', 'static_cast', 'virtual', 
+    'char16_t', 'explicit', 'noexcept', 'struct', 'void', 
+    'char32_t', 'export', 'nullptr', 'switch', 'volatile', 
+    'class', 'extern', 'operator', 'template', 'wchar_t', 
+    'const', 'false', 'private', 'this', 'while', 
+    'constexpr', 'float', 'protected', 'thread_local', 
+    'const_cast', 'for', 'public', 'throw'
+    'override', 'final',
+    
+    # OpenCl data types
+    'uchar', 'ushort', 'uint', 'ulong', 'half',
+    'bool2', 'char2', 'uchar2', 'short2', 'ushort2', 'int2', 'uint2', 'long2', 'ulong2', 'half2', 'float2', 'double2',
+    'bool3', 'char3', 'uchar3', 'short3', 'ushort3', 'int3', 'uint3', 'long3', 'ulong3', 'half3', 'float3', 'double3',
+    'bool4', 'char4', 'uchar4', 'short4', 'ushort4', 'int4', 'uint4', 'long4', 'ulong4', 'half4', 'float4', 'double4',
+    'bool8', 'char8', 'uchar8', 'short8', 'ushort8', 'int8', 'uint8', 'long8', 'ulong8', 'half8', 'float8', 'double8',
+    'bool16', 'char16', 'uchar16', 'short16', 'ushort16', 'int16', 'uint16', 'long16', 'ulong16', 'half16', 'float16', 'double16',
+    
+    # function qualifiers
+    'kernel', '__kernel',
+    
+    # access qualifiers
+    'read_only', 'write_only', 'read_write',
+    '__read_only', '__write_only', '__read_write',
+]
+
+
+class OpenClPrinter(CCodePrinter):
+    """
+    A printer to convert sympy expressions to strings of opencl code
+    """
+    printmethod = '_clcode'
+    language = 'OpenCL'
+
+    _default_settings = {
+        'order': None,
+        'full_prec': 'auto',
+        'precision': None,
+        'user_functions': {},
+        'human': True,
+        'contract': True,
+        'dereference': set(),
+        'error_on_reserved': True,
+        'reserved_word_suffix': None,
+    }
+    
+    def __init__(self, typegen, symbol2vars=None, **settings):
+        check_instance(typegen, OpenClTypeGen)
+        check_instance(symbol2vars, dict, keys=sm.Symbol, allow_none=True)
+        
+        super(OpenClPrinter,self).__init__(settings=settings)
+
+        self.known_functions = dict(known_functions)
+        self.reserved_words  = set(reserved_words)
+        self.typegen     = typegen
+        self.symbol2vars = symbol2vars
+
+    def dump_symbol(self, expr):
+        symbol2vars = self.symbol2vars
+        if expr in symbol2vars:
+            return self._print(symbol2vars[expr])
+        else:
+            return super(OpenClPrinter,self)._print_Symbol(expr)
+    def dump_rational(self, expr):
+        return self.typegen.dump(expr)
+    def dump_float(self, expr):
+        return self.typegen.dump(expr)
+    
+
+    def _print_Symbol(self, expr):
+        return self.dump_symbol(expr)
+    def _print_Rational(self, expr):
+        return self.dump_rational(expr)
+    def _print_PythonRational(self, expr):
+        return self.dump_rational(expr)
+    def _print_Fraction(self, expr):
+        return self.dump_rational(expr)
+    def _print_mpq(self, expr):
+        return self.dump_rational(expr)
+    def _print_Float(self, expr):
+        return self.dump_float(expr)
+
+    # last resort printer (if _print_CLASS is not found)
+    def emptyPrinter(self,expr):
+        return self.typegen.dump(expr)
+
+def dump_clcode(expr, typegen, **kargs):
+    """Return OpenCL representation of the given expression as a string."""
+    p = OpenClPrinter(typegen=typegen, **kargs)
+    s = p.doprint(expr)
+    return s
+
+def print_clcode(expr, typegen, **kargs):
+    """Prints OpenCL representation of the given expression."""
+    print dump_clcode(expr,typegen=typegen,**kargs)
diff --git a/hysop/backend/opencl/opencl_types.py b/hysop/backend/opencl/opencl_types.py
index 93cf3b87b..04e992842 100644
--- a/hysop/backend/opencl/opencl_types.py
+++ b/hysop/backend/opencl/opencl_types.py
@@ -2,7 +2,7 @@
 import string
 
 from hysop import __KERNEL_DEBUG__
-from hysop.constants import np, it
+from hysop.deps import sm, np, it
 from hysop.backend.opencl import cl, clArray
 from hysop.tools.numerics import MPZ, MPQ, MPFR, F2Q
 
@@ -234,12 +234,21 @@ class TypeGen(object):
             return sval
         elif isinstance(val, (bool,np.bool_)):
             return 'true' if val else 'false'
-        elif isinstance(val, MPQ):
-            if __KERNEL_DEBUG__:
-                return '({}.0{f}/{}.0{f})'.format(val.numerator,val.denominator,
-                                                  f=FLT_LITERAL[self.fbtype])
-            else:
+        elif isinstance(val, (MPQ, sm.Rational)):
+            if not __KERNEL_DEBUG__:
                 return self.dump(float(val))
+            if isinstance(val, MPQ):
+                if val.denominator==1:
+                    return str(val.numerator)
+                else:
+                    return '({}.0{f}/{}.0{f})'.format(val.numerator,val.denominator,
+                                                  f=FLT_LITERAL[self.fbtype])
+            elif isinstance(val, sm.Rational):
+                if val.q == 1:
+                    return str(val.p)
+                else:
+                    return '({}.0{f}/{}.0{f})'.format(val.p,val.q,
+                                                      f=FLT_LITERAL[self.fbtype])
         else:
             return val.__str__()
 
@@ -313,6 +322,7 @@ class OpenClTypeGen(TypeGen):
             # self.make_floatn = make_halfn
         else:
             raise ValueError('Unknown fbtype \'{}\''.format(fbtype))
+        
 
     def device_has_ftype(self,device):
         dev_exts = device.extensions.split(' ')
@@ -326,6 +336,17 @@ class OpenClTypeGen(TypeGen):
         btype = basetype(stype) 
         N     = components(stype)
         return typen(btype,N)
+    
+    def dump_expr(self, expr, symbol2vars=None, **printer_settings):
+        """
+        Print sympy expression expr as OpenCL code.
+        Sympy symbols may be replaced using symbol2vars dictionnary.
+        This dumper uses OpenClTypeGen.dump for floats and quotients.
+        See hysop.backend.opencl.opencl_printer.OpenClPrinter
+        """
+        from hysop.backend.opencl.opencl_printer import OpenClPrinter
+        printer = OpenClPrinter(typegen=self,symbol2vars=symbol2vars,**printer_settings)
+        return printer.doprint(expr)
 
     def __repr__(self):
         return '{}_{}_{}_{}'.format(self.platform.name,self.device.name, 
diff --git a/hysop/numerics/stencil/stencil.py b/hysop/numerics/stencil/stencil.py
index 95af6627d..11ea9b256 100644
--- a/hysop/numerics/stencil/stencil.py
+++ b/hysop/numerics/stencil/stencil.py
@@ -9,7 +9,7 @@
 """
 
 from hysop.deps import sm, sp, it, np, hashlib
-from hysop.tools.sympy_utils import recurse_expression_tree, expr2str
+from hysop.tools.sympy_utils import recurse_expression_tree
 
 class Stencil(object):
     """
@@ -91,9 +91,6 @@ class Stencil(object):
 
         self._update_attributes()
 
-    def format_factor(self, svars):
-        return expr2str(self.factor,svars)
-
     def has_factor(self):
         return (self.factor!=1)
     
@@ -201,7 +198,7 @@ class Stencil(object):
             raise RuntimeError('Stencil is not 2d !')
         return sp.sparse.coo_matrix(self.coeffs,shape=self.shape,dtype=self.dtype)
 
-    def iteritems(self,svars={},include_factor=True):
+    def iteritems(self,include_factor=True):
         """
         Return an (offset,coefficient) iterator iterating on all **non zero** coefficients.
         Offset is taken from origin.
@@ -212,10 +209,9 @@ class Stencil(object):
             Zipped offset and coefficient iterator.
         """
         factor = self.factor if include_factor else 1
-        svars = dict(zip(svars.keys(),[str(v) for v in svars.values()]))
         def mapfun(x):
             offset = x-self.origin
-            value = expr2str(factor*self.coeffs[x],svars)
+            value = factor*self.coeffs[x]
             return (offset,value)
         iterator = np.ndindex(self.shape)
         iterator = it.imap(mapfun, iterator)
diff --git a/hysop/operator/hdf_io.py b/hysop/operator/hdf_io.py
index d388f8f5c..5da72f92d 100755
--- a/hysop/operator/hdf_io.py
+++ b/hysop/operator/hdf_io.py
@@ -108,6 +108,8 @@ class HDF_IO(ComputationalGraphOperator):
         super(HDF_IO, self).initialize()
     
     def discretize(self):
+        if not self.initialized:
+            self.initialize()
         super(HDF_IO, self).discretize()
         self.topology = self.variables.values()[0]
 
diff --git a/hysop/tools/sympy_utils.py b/hysop/tools/sympy_utils.py
index f005e9a52..02b57d8c2 100644
--- a/hysop/tools/sympy_utils.py
+++ b/hysop/tools/sympy_utils.py
@@ -180,13 +180,3 @@ def recurse_expression_tree(op, expr):
         for arg in expr.args:
             recurse_expression_tree(op, arg)
 
-def expr2str(expr, svars, dumper=str):
-    svars = dict(zip(svars.keys(), [dumper(v) for v in svars.values()]))
-    expr = copy.deepcopy(expr)
-    def op(expr):
-        if isinstance(expr,sm.Symbol) and (expr in svars):
-            expr.name = svars[expr]
-        print expr.__class__
-    recurse_expression_tree(op,expr)
-    expr = remove_pows(expr)
-    return str(expr)
diff --git a/setup.py.in b/setup.py.in
index 8d9b9f87a..fc161eabb 100644
--- a/setup.py.in
+++ b/setup.py.in
@@ -248,9 +248,9 @@ else:
                              where="@CMAKE_SOURCE_DIR@")
 
 if "@WITH_GPU@" is "ON":
-    packages.append('hysop.opencl')
+    packages.append('hysop.backend.opencl')
     if with_test:
-        packages.append('hysop.opencl.tests')
+        packages.append('hysop.backend.opencl.tests')
 
 # Enable this to get debug info
 DISTUTILS_DEBUG = 1
-- 
GitLab