From b49037faf8c5862d4a9c1d0522914f580a16eb28 Mon Sep 17 00:00:00 2001 From: Jean-Baptiste Keck <Jean-Baptiste.Keck@imag.fr> Date: Wed, 12 Dec 2018 00:15:26 +0100 Subject: [PATCH] fixed symbolic expr --- .../device/opencl/operator/external_force.py | 25 ++++++++++--------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/hysop/backend/device/opencl/operator/external_force.py b/hysop/backend/device/opencl/operator/external_force.py index b76699dde..fd482b0a6 100644 --- a/hysop/backend/device/opencl/operator/external_force.py +++ b/hysop/backend/device/opencl/operator/external_force.py @@ -86,12 +86,10 @@ class SymbolicExternalForce(ExternalForce): backward_transforms = {} for Si in fft_fields: Fi = tg.require_forward_transform(Si) + forward_transforms[Si] = Fi if (Si in self.diffusion): Bi = tg.require_backward_transform(Si) - else: - Bi = None - forward_transforms[Si] = Fi - backward_transforms[Si] = Bi + backward_transforms[Si] = Bi force_backward_transforms = {} for Fi in op.force.fields: @@ -145,12 +143,11 @@ class SymbolicExternalForce(ExternalForce): force_kernels = () vorticity_kernels = () - names = ('dWx', 'dWy', 'dWz') - assert len(op.vorticity.fields)==len(op.force.fields)==len(self.fft_expressions)<=len(names) - for (Fi,Wi,name,e) in zip( + assert len(op.vorticity.fields)==len(op.force.fields)==len(self.fft_expressions) + for (Fi,Wi,e) in zip( op.force.fields, op.vorticity.fields, - names, self.fft_expressions): + self.fft_expressions): if (e==0): force_kernels += (None,) vorticity_kernels += (None,) @@ -158,7 +155,7 @@ class SymbolicExternalForce(ExternalForce): Fis = Fi.s() Fi_hat = self.force_backward_transforms[Fi] - Fi_buf = Fi_hat.input_symbolic_array('{}_hat'.format(name)) + Fi_buf = Fi_hat.input_symbolic_array('{}_hat'.format(Fi.name)) Wn = self.tg.push_expressions(Assignment(Fi_hat, e)) msg='Could not extract transforms.' @@ -168,8 +165,10 @@ class SymbolicExternalForce(ExternalForce): raise RuntimeError(msg) assert len(transforms)>=1, msg - fft_buffers = {Ft: Ft.output_symbolic_array('{}_hat'.format(Ft.field.name))} - wavenumbers = {Wi: self.tg._indexed_wave_numbers[Wi] for Wi in Wn} + fft_buffers = { Ft: Ft.output_symbolic_array('{}_hat'.format(Ft.field.name)) + for Ft in self.forward_transforms.values() } + wavenumbers = { Wi: self.tg._indexed_wave_numbers[Wi] + for Wi in Wn } replace = {} replace.update(fft_buffers) @@ -177,15 +176,17 @@ class SymbolicExternalForce(ExternalForce): expr = e.xreplace(replace) expr = Assignment(Fi_buf, expr) - kname = 'compute_{}'.format(name) + kname = 'compute_{}'.format(Fi.var_name) op.require_symbolic_kernel(kname, expr) force_kernels += (kname,) + print expr Wis = Wi.s() expr = Assignment(Wis, Wis + dts*Fis) kname = 'update_{}'.format(Wi.var_name) op.require_symbolic_kernel(kname, expr) vorticity_kernels += (kname,) + print expr assert len(diffusion_kernels) == len(self.diffusion) assert len(force_kernels) == op.vorticity.nb_components == len(vorticity_kernels) -- GitLab