From 10fe1631b6a44bf52dcd9f03e74e12d38e58ce16 Mon Sep 17 00:00:00 2001 From: Jean-Matthieu Etancelin <jean-matthieu.etancelin@univ-pau.fr> Date: Tue, 15 Mar 2022 18:42:09 +0100 Subject: [PATCH] fix opencl custom --- .../backend/device/opencl/operator/custom.py | 27 ++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/hysop/backend/device/opencl/operator/custom.py b/hysop/backend/device/opencl/operator/custom.py index cde464bb8..094bcaa92 100644 --- a/hysop/backend/device/opencl/operator/custom.py +++ b/hysop/backend/device/opencl/operator/custom.py @@ -6,12 +6,14 @@ from hysop.backend.device.opencl.autotunable_kernels.custom_symbolic import Open from hysop.backend.device.opencl.opencl_kernel_launcher import OpenClKernelListLauncher from hysop.backend.device.opencl.opencl_copy_kernel_launchers import OpenClCopyBufferRectLauncher from hysop.operator.base.custom import CustomOperatorBase +from hysop.parameters.scalar_parameter import ScalarParameter from hysop.backend.device.codegen.structs.mesh_info import MeshBaseStruct, MeshInfoStruct from hysop.backend.device.codegen.base.variables import dtype_to_ctype from hysop.backend.device.codegen.base.opencl_codegen import OpenClCodeGenerator from pyopencl.elementwise import ElementwiseKernel +import pyopencl as cl class OpenClCustomOperator(CustomOperatorBase, OpenClOperator): @@ -34,7 +36,6 @@ class OpenClCustomOperator(CustomOperatorBase, OpenClOperator): cg.require('mis', MeshInfoStruct(self.typegen, typedef='{}MeshInfo{}D_s'.format(self.typegen.fbtype[0], dim), mbs_typedef=mbs.typedef, vsize=dim)) - kernel_args = [] for f in self.discrete_fields: fn = f.continuous_fields()[0].name mesh_info = MeshInfoStruct.create_from_mesh(fn+"_mesh", self.typegen, f.mesh, @@ -44,7 +45,22 @@ class OpenClCustomOperator(CustomOperatorBase, OpenClOperator): int iz = i/({fn}_mesh.local_mesh.resolution.x*{fn}_mesh.local_mesh.resolution.y); int iy = (i-({fn}_mesh.local_mesh.resolution.x*{fn}_mesh.local_mesh.resolution.y)*iz)/({fn}_mesh.local_mesh.resolution.x); return (int3)(i % {fn}_mesh.local_mesh.resolution.x,iy,iz);}}""") - kernel_args.append(f"{self.typegen.fbtype} *{fn}") + kernel_args = [] + for f in self.dinvar: + if f not in self.doutvar: + fn = f.continuous_fields()[0].name + kernel_args.append(f"const {self.typegen.fbtype} * {fn}") + for p in self.dinparam: + if isinstance(p, ScalarParameter): + kernel_args.append(f"const {self.typegen.fbtype} {p.name}") + self._cl_dinparam.append(p) + else: + kernel_args.append(f"const {self.typegen.fbtype} *{p.name}") + for f in self.doutvar: + fn = f.continuous_fields()[0].name + kernel_args.append(f"{self.typegen.fbtype} * {fn}") + for p in self.doutparam: + kernel_args.append(f"{self.typegen.fbtype} *{p.name}") self.__elementwise = ElementwiseKernel( self.cl_env.context, ",".join(kernel_args), @@ -54,6 +70,11 @@ class OpenClCustomOperator(CustomOperatorBase, OpenClOperator): @op_apply def apply(self, **kwds): super().apply(**kwds) - self.__elementwise(*(_.sbuffer for _ in self.dinvar + self.dinparam + self.doutvar + self.doutparam + self.extra_args)) + args = (tuple(_.sbuffer for _ in self.dinvar) + + tuple(cl.array.to_device(self.cl_env.default_queue, _._value) for _ in self.dinparam) + + tuple(_.sbuffer for _ in self.doutvar) + + tuple(cl.array.to_device(self.cl_env.default_queue, _._value) for _ in self.doutparam) + + tuple(cl.array.to_device(self.cl_env.default_queue, _._value) for _ in self.extra_args)) + self.__elementwise(*args) for gh_exch in self.ghost_exchanger: gh_exch.exchange_ghosts() -- GitLab