From 723fcc805ad3a98a9fe6610e2b7e72e5f1b535aa Mon Sep 17 00:00:00 2001
From: Jean-Baptiste Keck <Jean-Baptiste.Keck@imag.fr>
Date: Wed, 12 Dec 2018 16:56:56 +0100
Subject: [PATCH] external force

---
 examples/sediment_deposit/sediment_deposit.py | 38 ++++-----
 .../device/opencl/opencl_array_backend.py     |  2 +-
 .../device/opencl/operator/external_force.py  | 83 +++++++++++++------
 hysop/core/graph/computational_operator.py    | 19 ++++-
 hysop/fields/cartesian_discrete_field.py      | 13 ++-
 hysop/fields/continuous_field.py              |  9 +-
 hysop/operator/adapt_timestep.py              |  2 +-
 hysop/operator/base/external_force.py         | 15 ++--
 8 files changed, 122 insertions(+), 59 deletions(-)

diff --git a/examples/sediment_deposit/sediment_deposit.py b/examples/sediment_deposit/sediment_deposit.py
index 364b3e56e..6284d4ba3 100644
--- a/examples/sediment_deposit/sediment_deposit.py
+++ b/examples/sediment_deposit/sediment_deposit.py
@@ -3,10 +3,10 @@ import scipy as sp
 import sympy as sm
 import numba as nb
 
-TANK_RATIO      = 10
-SEDIMENT_COUNT  = 250*TANK_RATIO
-SEDIMENT_RADIUS = 1e-2
-DISCRETIZATION  = 128
+TANK_RATIO      = 3
+SEDIMENT_COUNT  = 2048*TANK_RATIO
+SEDIMENT_RADIUS = 0.5e-2
+DISCRETIZATION  = 512
 
 # initialize vorticity
 def init_vorticity(data, coords, component=None):
@@ -60,7 +60,7 @@ def init_sediment(data, coords, nblobs, rblob):
         assert (rblob>=dy), 'Sediment radius < dy.'
         
         Bx = 1*np.random.rand(nblobs)
-        By = TANK_RATIO*np.random.rand(nblobs)
+        By = 1*np.random.rand(nblobs)
         Ix = np.floor(Bx/dx).astype(np.int32)
         Iy = np.floor(By/dy).astype(np.int32)
         Px = Bx - Ix*dx
@@ -96,7 +96,6 @@ def init_sediment(data, coords, nblobs, rblob):
         vprint('  *Initializing sediments of radius {} with {} random blobs.'.format(rblob, nblobs))
         data[...]  = 0.0
         iter_blobs(*args)
-        #data[:int(0.05*Ny),:] = 1.0
 
         # we cache initialization
         np.savez_compressed(file=cache_file, data=data)
@@ -149,7 +148,7 @@ def compute(args):
         raise NotImplementedError(msg)
 
     nu_S = ScalarParameter(name='nu_S', dtype=args.dtype, const=True, initial_value=1e-10)
-    nu_W = ScalarParameter(name='nu_W', dtype=args.dtype, const=True, initial_value=1.00)
+    nu_W = ScalarParameter(name='nu_W', dtype=args.dtype, const=True, initial_value=1e-2)
 
     lboundaries = (BoxBoundaryCondition.SYMMETRIC, BoxBoundaryCondition.SYMMETRIC)
     rboundaries = (BoxBoundaryCondition.SYMMETRIC, BoxBoundaryCondition.SYMMETRIC)
@@ -194,8 +193,8 @@ def compute(args):
     t, dt = TimeParameters(dtype=args.dtype)
     velo  = VelocityField(domain=box, dtype=args.dtype)
     vorti = VorticityField(velocity=velo)
-    S = Field(domain=box, name='S', dtype=args.dtype,
-            lboundaries=S_lboundaries, rboundaries=S_rboundaries)
+    S = Field(domain=box, name='S', dtype=args.dtype)
+            #lboundaries=S_lboundaries, rboundaries=S_rboundaries)
     
     # Symbolic fields
     frame = velo.domain.frame
@@ -246,7 +245,7 @@ def compute(args):
                             **extra_op_kwds)
 
     #> External force rot(-rho*g) = rot(-(1+S)) = rot(-S)
-    g = 1.0
+    g = 9.81
     Fext = SymbolicExternalForce(name='S', Fext=(0,-g*Ss),
                                    diffusion = {S: nu_S})
     external_force = SpectralExternalForce(name='Fext', 
@@ -286,8 +285,7 @@ def compute(args):
 
     ### Adaptive timestep operator
     adapt_dt = AdaptiveTimeStep(dt, equivalent_CFL=True,
-                                    name='merge_dt', pretty_name='dt',
-                                    max_dt=1e0)
+                                    name='merge_dt', pretty_name='dt')
     dt_cfl = adapt_dt.push_cfl_criteria(cfl=args.cfl, 
                                         Finf=min_max_U.Finf,
                                         equivalent_CFL=True, 
@@ -295,7 +293,7 @@ def compute(args):
     dt_advec = adapt_dt.push_advection_criteria(lcfl=args.lcfl, Finf=min_max_W.Finf,
                                                  criteria=AdvectionCriteria.W_INF,
                                                  name='dt_lcfl', pretty_name='LCFL')
-    dt_force = adapt_dt.push_cst_criteria(cst=1,
+    dt_force = adapt_dt.push_cst_criteria(cst=10000,
                                         Finf=external_force.Finf,
                                         name='dt_force', pretty_name='FEXT')
 
@@ -316,11 +314,11 @@ def compute(args):
 
     problem = Problem(method=method)
     problem.insert(poisson, 
+                   dump_fields,
                    min_max_U, min_max_W, adapt_dt,
                    splitting, 
-                   external_force,
-                   dump_fields,
-                   compute_mean_fields)
+                   compute_mean_fields,
+                   external_force)
     problem.build(args)
     
     # If a visu_rank was provided, and show_graph was set,
@@ -384,10 +382,10 @@ if __name__=='__main__':
     parser.set_defaults(impl='cl', ndim=2, 
                         npts=(TANK_RATIO*DISCRETIZATION+1,DISCRETIZATION+1),
                         box_origin=(0.0,), box_length=(1.0,), 
-                        tstart=0.0, tend=100000.0, 
-                        dt=1.0, cfl=8.0, lcfl=0.99,
-                        dump_times=tuple(float(x) for x in range(0,100000,1000)),
-                        dump_freq=0)
+                        tstart=0.0, tend=20.0,
+                        dt=1e-6, cfl=32.0, lcfl=0.90,
+                        #dump_times=tuple(float(x) for x in range(0,100000,1000)),
+                        dump_freq=10)
 
     parser.run(compute)
 
diff --git a/hysop/backend/device/opencl/opencl_array_backend.py b/hysop/backend/device/opencl/opencl_array_backend.py
index 28fc56823..b9f70a2ff 100644
--- a/hysop/backend/device/opencl/opencl_array_backend.py
+++ b/hysop/backend/device/opencl/opencl_array_backend.py
@@ -3163,7 +3163,7 @@ class OpenClArrayBackend(ArrayBackend):
         """
         Compute the arithmetic mean along the specified axis.
         """
-        return a.sum(a=a, axis=axis, dtype=dtype, out=out, queue=queue) / float(a.size)
+        return a.sum(axis=axis, dtype=dtype, out=out, queue=queue) / float(a.size)
 
     def std(self, a, axis=None, dtype=None, out=None, ddof=0, queue=None):
         """
diff --git a/hysop/backend/device/opencl/operator/external_force.py b/hysop/backend/device/opencl/operator/external_force.py
index fd482b0a6..7fe49b285 100644
--- a/hysop/backend/device/opencl/operator/external_force.py
+++ b/hysop/backend/device/opencl/operator/external_force.py
@@ -63,6 +63,7 @@ class SymbolicExternalForce(ExternalForce):
         super(SymbolicExternalForce, self).__init__(name=name, dim=dim, Fext=Fext, **kwds)
         
         diffusion = first_not_None(diffusion, {})
+        diffusion = {k:v for (k,v) in diffusion.iteritems() if (v is not None)}
         for (k,v) in diffusion.iteritems():
             assert k in self.input_fields(), k.short_description()
             assert isinstance(v, ScalarParameter)
@@ -153,7 +154,6 @@ class SymbolicExternalForce(ExternalForce):
                 vorticity_kernels += (None,)
                 continue
             
-            Fis    = Fi.s()
             Fi_hat = self.force_backward_transforms[Fi]
             Fi_buf = Fi_hat.input_symbolic_array('{}_hat'.format(Fi.name))
             Wn     = self.tg.push_expressions(Assignment(Fi_hat, e))
@@ -165,7 +165,7 @@ class SymbolicExternalForce(ExternalForce):
                 raise RuntimeError(msg)
             assert len(transforms)>=1, msg
 
-            fft_buffers = { Ft: Ft.output_symbolic_array('{}_hat'.format(Ft.field.name)) 
+            fft_buffers = { Ft.s: Ft.output_symbolic_array('{}_hat'.format(Ft.field.name)) 
                                 for Ft in self.forward_transforms.values() }
             wavenumbers = { Wi: self.tg._indexed_wave_numbers[Wi] 
                                 for Wi in Wn }
@@ -179,14 +179,13 @@ class SymbolicExternalForce(ExternalForce):
             kname = 'compute_{}'.format(Fi.var_name)
             op.require_symbolic_kernel(kname, expr)
             force_kernels += (kname,)
-            print expr
 
+            Fis = Fi.s()
             Wis = Wi.s()
             expr = Assignment(Wis, Wis + dts*Fis)
             kname = 'update_{}'.format(Wi.var_name)
             op.require_symbolic_kernel(kname, expr)
             vorticity_kernels += (kname,)
-            print expr
 
         assert len(diffusion_kernels) == len(self.diffusion)
         assert len(force_kernels) == op.vorticity.nb_components == len(vorticity_kernels)
@@ -199,8 +198,9 @@ class SymbolicExternalForce(ExternalForce):
     
     def get_mem_requests(self, op):
         requests = {}
-        for (Ft, Bt) in zip(self.forward_transforms.values(), 
-                            self.backward_transforms.values()):
+        for Fi in self.forward_transforms.keys(): 
+            Ft = self.forward_transforms[Fi]
+            Bt = self.backward_transforms.get(Fi, None)
             if (Bt is not None):
                 assert (Ft.backend is Bt.backend)
                 assert (Ft.output_dtype == Bt.input_dtype), (Ft.output_dtype, Bt.input_dtype)
@@ -215,18 +215,20 @@ class SymbolicExternalForce(ExternalForce):
         return requests
     
     def pre_setup(self, op, work):
-        for (Ft, Bt) in zip(self.forward_transforms.values(), 
-                            self.backward_transforms.values()):
+        for Fi in self.forward_transforms.keys(): 
+            Ft = self.forward_transforms[Fi]
+            Bt = self.backward_transforms.get(Fi, None)
             dtmp, = work.get_buffer(op, '{}_hat'.format(Ft.field.name))
             Ft.configure_output_buffer(dtmp)
             if (Bt is not None):
                 Bt.configure_input_buffer(dtmp)
 
     def post_setup(self, op, work):
-        diffusion_kernels = {}
-        force_kernels     = {}
-        vorticity_kernels = {}
-        ghost_exchangers  = {}
+        diffusion_kernels  = {}
+        force_kernels      = {}
+        compute_statistics = {}
+        vorticity_kernels  = {}
+        ghost_exchangers   = {}
 
         queue = self.tg.backend.cl_env.default_queue
         def build_launcher(knl, update_params):
@@ -236,31 +238,61 @@ class SymbolicExternalForce(ExternalForce):
             return kernel_launcher
         
         for (field, kname) in self.diffusion_kernel_names.iteritems():
-            dfield = op.get_discrete_field(field)
+            dfield = op.get_input_discrete_field(field)
             knl, update_params = op.symbolic_kernels[kname]
             diffusion_kernels[field] = build_launcher(knl, update_params)
-            ghost_exchangers[field] = dfield.build_ghost_exchanger(queue=queue)
+            ghost_exchangers[field] = functools.partial(dfield.build_ghost_exchanger(),
+                                                                            queue=queue)
         
+        if (op.Fmin is not None):
+            min_values = npw.asarray(op.Fmin()).copy()
+        if (op.Fmax is not None):
+            max_values = npw.asarray(op.Fmax()).copy()
+
         for i, (kname0, kname1) in enumerate(zip(
             self.force_kernel_names, self.vorticity_kernel_names)):
             if (kname0 is None):
                 assert (kname1 is None)
                 continue
-            Wi  = op.vorticity[i]
-            Fi  = op.force[i]
-            dWi = op.dW[i]
+            Wi  = op.vorticity.fields[i]
+            Fi  = op.force.fields[i]
+            dWi = op.dW.dfields[i]
+            dFi = op.dF.dfields[i]
             
             knl, update_params = op.symbolic_kernels[kname0]
             force_kernels[(Fi,Wi)]  = build_launcher(knl, update_params)
             
-            knl, update_params    = op.symbolic_kernels[kname1]
+            knl, update_params = op.symbolic_kernels[kname1]
             vorticity_kernels[(Fi,Wi)] = build_launcher(knl, update_params)
 
-            ghost_exchangers[Wi] = dWi.build_ghost_exchanger(queue=queue)
+            ghost_exchangers[Wi] = functools.partial(dWi.build_ghost_exchanger(), queue=queue)
+            
+            def compute_statistic(op=op, queue=queue, dFi=dFi, 
+                                min_values=min_values, max_values=max_values):
+                if (op.Fmin is not None):
+                    min_values[i] = dFi.sdata.min(queue=queue).get()
+                if (op.Fmax is not None):
+                    max_values[i] = dFi.sdata.max(queue=queue).get()
+            compute_statistics[Fi] = compute_statistic
+        
+        def update_statistics(op=op, min_values=min_values, max_values=max_values):
+            if (op.Fmin is not None):
+                op.Fmin.value = min_values
+            if (op.Fmax is not None):
+                op.Fmax.value = max_values
+            if (op.Finf is not None):
+                op.Finf.value = npw.maximum(npw.abs(min_values), npw.abs(max_values))
+                
+        assert len(diffusion_kernels) == len(self.diffusion) == len(self.backward_transforms)
+        assert len(vorticity_kernels) == len(force_kernels) == len(self.force_backward_transforms)
+        assert len(ghost_exchangers) == len(diffusion_kernels) + len(vorticity_kernels)
 
-        def compute_statistics():
-            pass
+        self.diffusion_kernels  = diffusion_kernels
+        self.force_kernels      = force_kernels
+        self.vorticity_kernels  = vorticity_kernels
+        self.ghost_exchangers   = ghost_exchangers
         self.compute_statistics = compute_statistics
+        self.update_statistics  = update_statistics
 
     def apply(self, op, **kwds):
         for (field, Ft) in self.forward_transforms.iteritems():
@@ -271,13 +303,14 @@ class SymbolicExternalForce(ExternalForce):
                 evt = self.ghost_exchangers[field]()
         
         for (Fi,Wi) in self.force_kernels.keys():
-            evt = self.force_kernel[Wi]()
+            evt = self.force_kernels[(Fi,Wi)]()
             evt = self.force_backward_transforms[Fi]()
-            if op.compute_statistics:
-                evt = self.compute_statistics()
-            evt = self.vorticity_kernels[Wi]()
+            evt = self.compute_statistics[Fi]()
+            evt = self.vorticity_kernels[(Fi,Wi)]()
             evt = self.ghost_exchangers[Wi]()
 
+        self.update_statistics()
+
     def _extract_objects(self, obj_type):
         objs = set()
         for e in self.Fext:
diff --git a/hysop/core/graph/computational_operator.py b/hysop/core/graph/computational_operator.py
index 3adf180d7..509ddca9e 100644
--- a/hysop/core/graph/computational_operator.py
+++ b/hysop/core/graph/computational_operator.py
@@ -497,10 +497,18 @@ class ComputationalGraphOperator(ComputationalGraphNode):
         computed and returned.
         """
         requests = OperatorMemoryRequests(self)
+        delayed_requests = {}
         for dfield in self.discrete_fields:
             if dfield.is_tmp:
-                req_id = 'tmp_{}_{}'.format(dfield.name, dfield.tag)
-                requests.push_mem_request(req_id, dfield.dfield.memory_request)
+                if (dfield.mem_tag is not None):
+                    req_id = dfield.mem_tag
+                    try:
+                        requests.push_mem_request(req_id, dfield.dfield.memory_request)
+                    except ValueError:
+                        pass
+                else:
+                    req_id = 'tmp_{}_{}'.format(dfield.name, dfield.tag)
+                    requests.push_mem_request(req_id, dfield.dfield.memory_request)
         return requests
 
     @debug
@@ -520,8 +528,11 @@ class ComputationalGraphOperator(ComputationalGraphNode):
 
     def allocate_tmp_fields(self, work):
         for dfield in self.discrete_fields:
-            if dfield.is_tmp:
-                req_id = 'tmp_{}_{}'.format(dfield.name, dfield.tag)
+            if dfield.is_tmp and (dfield._dfield._data is None):
+                if (dfield.mem_tag is not None):
+                    req_id = dfield.mem_tag
+                else:
+                    req_id = 'tmp_{}_{}'.format(dfield.name, dfield.tag)
                 data = work.get_buffer(self, req_id)
                 dfield.dfield.honor_memory_request(data)
     
diff --git a/hysop/fields/cartesian_discrete_field.py b/hysop/fields/cartesian_discrete_field.py
index d0570a694..77c616e81 100644
--- a/hysop/fields/cartesian_discrete_field.py
+++ b/hysop/fields/cartesian_discrete_field.py
@@ -722,6 +722,8 @@ class CartesianDiscreteScalarFieldView(CartesianDiscreteScalarFieldViewContainer
     def _get_is_tmp(self):
         """Is this DiscreteScalarField temporary ?"""
         return self._dfield.is_tmp
+    def _get_mem_tag(self):
+        return self._dfield.mem_tag
     
     def _get_global_lboundaries(self):
         """Return global left boundaries."""
@@ -1209,6 +1211,7 @@ class CartesianDiscreteScalarFieldView(CartesianDiscreteScalarFieldViewContainer
     ghosts             = property(_get_ghosts)
     space_step         = property(_get_space_step)
     is_tmp             = property(_get_is_tmp)
+    mem_tag            = property(_get_mem_tag)
     coords             = property(_get_coords)
     mesh_coords        = property(_get_mesh_coords)
     
@@ -1334,8 +1337,9 @@ class CartesianDiscreteScalarField(CartesianDiscreteScalarFieldView, DiscreteSca
             from hysop.core.memory.memory_request import MemoryRequest
             memory_request = MemoryRequest(backend=obj.backend, 
                     dtype=obj.dtype, shape=obj.resolution)
-            obj._memory_request = memory_request
+            obj._memory_request    = memory_request
             obj._memory_request_id = obj.name
+            obj._mem_tag = field.mem_tag
         return obj
 
     def _handle_data(self, data):
@@ -1368,6 +1372,10 @@ class CartesianDiscreteScalarField(CartesianDiscreteScalarFieldView, DiscreteSca
     @property
     def is_tmp(self):
         return False
+    
+    @property
+    def mem_tag(self):
+        return self._field.mem_tag
 
     def __eq__(self, other):
         return id(self) == id(other)
@@ -1378,8 +1386,9 @@ class CartesianDiscreteScalarField(CartesianDiscreteScalarFieldView, DiscreteSca
 class TmpCartesianDiscreteScalarField(CartesianDiscreteScalarField):
     @debug
     def __new__(cls, **kwds):
-        return super(TmpCartesianDiscreteScalarField, cls).__new__(cls, allocate_data=False,
+        obj = super(TmpCartesianDiscreteScalarField, cls).__new__(cls, allocate_data=False,
                 register_discrete_field=False, **kwds)
+        return obj
 
     @debug
     def __init__(self, **kwds):
diff --git a/hysop/fields/continuous_field.py b/hysop/fields/continuous_field.py
index d60f6776f..8b5eeca9d 100644
--- a/hysop/fields/continuous_field.py
+++ b/hysop/fields/continuous_field.py
@@ -503,7 +503,7 @@ class ScalarField(NamedScalarContainerI, FieldContainerI):
                 var_name=None, latex_name=None,
                 initial_values=None, dtype=HYSOP_REAL,
                 lboundaries=None, rboundaries=None,
-                is_tmp=False, **kwds):
+                is_tmp=False, mem_tag=None, **kwds):
         """
         Create or get an existing continuous ScalarField (scalar or vector) on a specific domain.
 
@@ -564,6 +564,9 @@ class ScalarField(NamedScalarContainerI, FieldContainerI):
         check_instance(var_name, str, allow_none=True)
         check_instance(is_tmp, bool)
 
+        if (mem_tag is not None):
+            assert is_tmp, 'Can only specify mem_tag for temporary fields.'
+
         # Data type of the field
         if (dtype==npw.bool) or (dtype==bool):
             import warnings
@@ -615,6 +618,7 @@ class ScalarField(NamedScalarContainerI, FieldContainerI):
         obj._dtype  = dtype
         obj._initial_values = initial_values
         obj._is_tmp = is_tmp
+        obj._mem_tag = mem_tag
         obj._lboundaries = lboundaries
         obj._rboundaries = rboundaries
         obj._periodicity = periodicity
@@ -794,6 +798,8 @@ class ScalarField(NamedScalarContainerI, FieldContainerI):
     def _get_is_tmp(self):
         """Is this ScalarField a temporary field ?"""
         return self._is_tmp
+    def _get_mem_tag(self):
+        return self._mem_tag
     
     dtype = property(_get_dtype)
     initial_values = property(_get_initial_values)
@@ -803,6 +809,7 @@ class ScalarField(NamedScalarContainerI, FieldContainerI):
     boundaries = property(_get_boundaries)
     periodicity = property(_get_periodicity)
     is_tmp = property(_get_is_tmp)
+    mem_tag = property(_get_mem_tag)
     
     @property
     def is_tensor(self):
diff --git a/hysop/operator/adapt_timestep.py b/hysop/operator/adapt_timestep.py
index 72265f377..fe66e08ec 100755
--- a/hysop/operator/adapt_timestep.py
+++ b/hysop/operator/adapt_timestep.py
@@ -509,7 +509,7 @@ class AdaptiveTimeStep(ComputationalGraphNodeGenerator):
         
         parameter = self._build_parameter(parameter=parameter, quiet=quiet,
                 name=param_name, pretty_name=param_pretty_name, 
-                basename=name)
+                basename=name.replace('dt_', ''))
         criteria = ConstantTimestepCriteria(cst=cst, Finf=Finf,
             parameter=parameter, name=name, pretty_name=pretty_name, **kwds)
         self._push_criteria(parameter.name, criteria)
diff --git a/hysop/operator/base/external_force.py b/hysop/operator/base/external_force.py
index f33c0d4b2..ee94138d9 100644
--- a/hysop/operator/base/external_force.py
+++ b/hysop/operator/base/external_force.py
@@ -131,6 +131,13 @@ class SpectralExternalForceOperatorBase(SpectralOperatorBase):
         check_instance(Fmax, (ScalarParameter,TensorParameter), allow_none=True)
         check_instance(Finf, (ScalarParameter,TensorParameter), allow_none=True)
         check_instance(variables, dict, keys=Field, values=CartesianTopologyDescriptors)
+        
+        if (Fmin is not None):
+            Fmin.value = npw.asarray((1e8,)*vorticity.nb_components, dtype=Fmin.dtype)
+        if (Fmax is not None):
+            Fmax.value = npw.asarray((1e8,)*vorticity.nb_components, dtype=Fmax.dtype)
+        if (Finf is not None):
+            Finf.value = npw.asarray((1e8,)*vorticity.nb_components, dtype=Finf.dtype)
 
         # check fields
         dim           = vorticity.dim
@@ -165,13 +172,10 @@ class SpectralExternalForceOperatorBase(SpectralOperatorBase):
         msg=msg.format(pshape)
         if isinstance(Fmin, TensorParameter):
             assert Fmin.shape==pshape, msg.format(Fmin.shape, 'Fmin')
-            Fmin = Fmin.view((0,)) if (dim==2) else Fmin
         if isinstance(Fmax, TensorParameter):
             assert Fmin.shape==pshape, msg.format(Fmax.shape, 'Fmax')
-            Fmax = Fmax.view((0,)) if (dim==2) else Fmax
         if isinstance(Finf, TensorParameter):
             assert Fmin.shape==pshape, msg.format(Finf.shape, 'Finf')
-            Finf = Finf.view((0,)) if (dim==2) else Finf
 
         compute_statistics  = (Fmin is not None)
         compute_statistics |= (Fmax is not None)
@@ -199,7 +203,7 @@ class SpectralExternalForceOperatorBase(SpectralOperatorBase):
         output_params = {p.name: p for p in output_params}
         
         # TODO share tmp buffers for the whole tensor
-        force = vorticity.tmp_like(name='Fext', ghosts=0)
+        force = vorticity.tmp_like(name='Fext', ghosts=0, mem_tag='tmp_fext')
         for (Fi, Wi) in zip(force.fields, vorticity.fields):
             input_fields[Fi]  = self.get_topo_descriptor(variables, Wi)
             output_fields[Fi] = self.get_topo_descriptor(variables, Wi)
@@ -237,7 +241,8 @@ class SpectralExternalForceOperatorBase(SpectralOperatorBase):
         if self.discretized:
             return
         super(SpectralExternalForceOperatorBase, self).discretize()
-        self.dW     = self.get_input_discrete_field(self.vorticity)
+        self.dW = self.get_input_discrete_field(self.vorticity)
+        self.dF = self.get_input_discrete_field(self.force)
         self.Fext.discretize(self)
 
     def get_work_properties(self):
-- 
GitLab