From 10fe1631b6a44bf52dcd9f03e74e12d38e58ce16 Mon Sep 17 00:00:00 2001
From: Jean-Matthieu Etancelin <jean-matthieu.etancelin@univ-pau.fr>
Date: Tue, 15 Mar 2022 18:42:09 +0100
Subject: [PATCH] fix opencl custom

---
 .../backend/device/opencl/operator/custom.py  | 27 ++++++++++++++++---
 1 file changed, 24 insertions(+), 3 deletions(-)

diff --git a/hysop/backend/device/opencl/operator/custom.py b/hysop/backend/device/opencl/operator/custom.py
index cde464bb8..094bcaa92 100644
--- a/hysop/backend/device/opencl/operator/custom.py
+++ b/hysop/backend/device/opencl/operator/custom.py
@@ -6,12 +6,14 @@ from hysop.backend.device.opencl.autotunable_kernels.custom_symbolic import Open
 from hysop.backend.device.opencl.opencl_kernel_launcher import OpenClKernelListLauncher
 from hysop.backend.device.opencl.opencl_copy_kernel_launchers import OpenClCopyBufferRectLauncher
 from hysop.operator.base.custom import CustomOperatorBase
+from hysop.parameters.scalar_parameter import ScalarParameter
 
 from hysop.backend.device.codegen.structs.mesh_info import MeshBaseStruct, MeshInfoStruct
 from hysop.backend.device.codegen.base.variables import dtype_to_ctype
 from hysop.backend.device.codegen.base.opencl_codegen import OpenClCodeGenerator
 
 from pyopencl.elementwise import ElementwiseKernel
+import pyopencl as cl
 
 
 class OpenClCustomOperator(CustomOperatorBase, OpenClOperator):
@@ -34,7 +36,6 @@ class OpenClCustomOperator(CustomOperatorBase, OpenClOperator):
         cg.require('mis', MeshInfoStruct(self.typegen,
                                          typedef='{}MeshInfo{}D_s'.format(self.typegen.fbtype[0], dim),
                                          mbs_typedef=mbs.typedef, vsize=dim))
-        kernel_args = []
         for f in self.discrete_fields:
             fn = f.continuous_fields()[0].name
             mesh_info = MeshInfoStruct.create_from_mesh(fn+"_mesh", self.typegen, f.mesh,
@@ -44,7 +45,22 @@ class OpenClCustomOperator(CustomOperatorBase, OpenClOperator):
               int iz = i/({fn}_mesh.local_mesh.resolution.x*{fn}_mesh.local_mesh.resolution.y);
               int iy = (i-({fn}_mesh.local_mesh.resolution.x*{fn}_mesh.local_mesh.resolution.y)*iz)/({fn}_mesh.local_mesh.resolution.x);
               return (int3)(i % {fn}_mesh.local_mesh.resolution.x,iy,iz);}}""")
-            kernel_args.append(f"{self.typegen.fbtype} *{fn}")
+        kernel_args = []
+        for f in self.dinvar:
+            if f not in self.doutvar:
+                fn = f.continuous_fields()[0].name
+                kernel_args.append(f"const {self.typegen.fbtype} * {fn}")
+        for p in self.dinparam:
+            if isinstance(p, ScalarParameter):
+                kernel_args.append(f"const {self.typegen.fbtype} {p.name}")
+                self._cl_dinparam.append(p)
+            else:
+                kernel_args.append(f"const {self.typegen.fbtype} *{p.name}")
+        for f in self.doutvar:
+            fn = f.continuous_fields()[0].name
+            kernel_args.append(f"{self.typegen.fbtype} * {fn}")
+        for p in self.doutparam:
+            kernel_args.append(f"{self.typegen.fbtype} *{p.name}")
         self.__elementwise = ElementwiseKernel(
             self.cl_env.context,
             ",".join(kernel_args),
@@ -54,6 +70,11 @@ class OpenClCustomOperator(CustomOperatorBase, OpenClOperator):
     @op_apply
     def apply(self, **kwds):
         super().apply(**kwds)
-        self.__elementwise(*(_.sbuffer for _ in self.dinvar + self.dinparam + self.doutvar + self.doutparam + self.extra_args))
+        args = (tuple(_.sbuffer for _ in self.dinvar)
+                + tuple(cl.array.to_device(self.cl_env.default_queue, _._value) for _ in self.dinparam)
+                + tuple(_.sbuffer for _ in self.doutvar)
+                + tuple(cl.array.to_device(self.cl_env.default_queue, _._value) for _ in self.doutparam)
+                + tuple(cl.array.to_device(self.cl_env.default_queue, _._value) for _ in self.extra_args))
+        self.__elementwise(*args)
         for gh_exch in self.ghost_exchanger:
             gh_exch.exchange_ghosts()
-- 
GitLab