From b49037faf8c5862d4a9c1d0522914f580a16eb28 Mon Sep 17 00:00:00 2001
From: Jean-Baptiste Keck <Jean-Baptiste.Keck@imag.fr>
Date: Wed, 12 Dec 2018 00:15:26 +0100
Subject: [PATCH] fixed symbolic expr

---
 .../device/opencl/operator/external_force.py  | 25 ++++++++++---------
 1 file changed, 13 insertions(+), 12 deletions(-)

diff --git a/hysop/backend/device/opencl/operator/external_force.py b/hysop/backend/device/opencl/operator/external_force.py
index b76699dde..fd482b0a6 100644
--- a/hysop/backend/device/opencl/operator/external_force.py
+++ b/hysop/backend/device/opencl/operator/external_force.py
@@ -86,12 +86,10 @@ class SymbolicExternalForce(ExternalForce):
         backward_transforms = {}
         for Si in fft_fields:
             Fi = tg.require_forward_transform(Si)
+            forward_transforms[Si]  = Fi
             if (Si in self.diffusion):
                 Bi = tg.require_backward_transform(Si)
-            else:
-                Bi = None
-            forward_transforms[Si]  = Fi
-            backward_transforms[Si] = Bi
+                backward_transforms[Si] = Bi
 
         force_backward_transforms = {}
         for Fi in op.force.fields:
@@ -145,12 +143,11 @@ class SymbolicExternalForce(ExternalForce):
         
         force_kernels = ()
         vorticity_kernels = ()
-        names = ('dWx', 'dWy', 'dWz')
-        assert len(op.vorticity.fields)==len(op.force.fields)==len(self.fft_expressions)<=len(names)
-        for (Fi,Wi,name,e) in zip(
+        assert len(op.vorticity.fields)==len(op.force.fields)==len(self.fft_expressions)
+        for (Fi,Wi,e) in zip(
                 op.force.fields,
                 op.vorticity.fields, 
-                names, self.fft_expressions):
+                self.fft_expressions):
             if (e==0):
                 force_kernels     += (None,)
                 vorticity_kernels += (None,)
@@ -158,7 +155,7 @@ class SymbolicExternalForce(ExternalForce):
             
             Fis    = Fi.s()
             Fi_hat = self.force_backward_transforms[Fi]
-            Fi_buf = Fi_hat.input_symbolic_array('{}_hat'.format(name))
+            Fi_buf = Fi_hat.input_symbolic_array('{}_hat'.format(Fi.name))
             Wn     = self.tg.push_expressions(Assignment(Fi_hat, e))
             
             msg='Could not extract transforms.'
@@ -168,8 +165,10 @@ class SymbolicExternalForce(ExternalForce):
                 raise RuntimeError(msg)
             assert len(transforms)>=1, msg
 
-            fft_buffers = {Ft: Ft.output_symbolic_array('{}_hat'.format(Ft.field.name))}
-            wavenumbers = {Wi: self.tg._indexed_wave_numbers[Wi] for Wi in Wn}
+            fft_buffers = { Ft: Ft.output_symbolic_array('{}_hat'.format(Ft.field.name)) 
+                                for Ft in self.forward_transforms.values() }
+            wavenumbers = { Wi: self.tg._indexed_wave_numbers[Wi] 
+                                for Wi in Wn }
 
             replace = {}
             replace.update(fft_buffers)
@@ -177,15 +176,17 @@ class SymbolicExternalForce(ExternalForce):
             expr = e.xreplace(replace)
             expr = Assignment(Fi_buf, expr)
             
-            kname = 'compute_{}'.format(name)
+            kname = 'compute_{}'.format(Fi.var_name)
             op.require_symbolic_kernel(kname, expr)
             force_kernels += (kname,)
+            print expr
 
             Wis = Wi.s()
             expr = Assignment(Wis, Wis + dts*Fis)
             kname = 'update_{}'.format(Wi.var_name)
             op.require_symbolic_kernel(kname, expr)
             vorticity_kernels += (kname,)
+            print expr
 
         assert len(diffusion_kernels) == len(self.diffusion)
         assert len(force_kernels) == op.vorticity.nb_components == len(vorticity_kernels)
-- 
GitLab