Skip to content
Snippets Groups Projects
Commit 129e08e5 authored by EXT Jean-Matthieu Etancelin's avatar EXT Jean-Matthieu Etancelin
Browse files

improve custom python operator that handle fields as function arguments...

improve custom python operator that handle fields as function arguments instead of data buffers (get access to topology informations)
parent 495baa39
No related branches found
No related tags found
1 merge request!16MPI operators
...@@ -31,7 +31,6 @@ class PythonCustomOperator(HostOperator): ...@@ -31,7 +31,6 @@ class PythonCustomOperator(HostOperator):
output_fields[v] = variables[v] output_fields[v] = variables[v]
elif isinstance(v, Parameter): elif isinstance(v, Parameter):
output_params[v.name] = v output_params[v.name] = v
self.invars, self.outvars = invars, outvars self.invars, self.outvars = invars, outvars
self.func = func self.func = func
...@@ -55,25 +54,26 @@ class PythonCustomOperator(HostOperator): ...@@ -55,25 +54,26 @@ class PythonCustomOperator(HostOperator):
if self.invars is not None: if self.invars is not None:
for v in self.invars: for v in self.invars:
if isinstance(v, Field): if isinstance(v, Field):
for vd in self.input_discrete_fields[v].buffers: for vd in self.input_discrete_fields[v]:
dinvar.append(vd) dinvar.append(vd)
elif isinstance(v, Parameter): elif isinstance(v, Parameter):
dinparam.append(v) dinparam.append(v)
if self.outvars is not None: if self.outvars is not None:
for v in self.outvars: for v in self.outvars:
if isinstance(v, Field): if isinstance(v, Field):
for vd in self.output_discrete_fields[v].buffers: for vd in self.output_discrete_fields[v]:
doutvar.append(vd) doutvar.append(vd)
self.ghost_exchanger.append( gh = self.output_discrete_fields[v].build_ghost_exchanger()
self.output_discrete_fields[v].build_ghost_exchanger()) if gh is not None:
self.ghost_exchanger.append(gh)
elif isinstance(v, Parameter): elif isinstance(v, Parameter):
doutparam.append(v) doutparam.append(v)
self.dinvar, self.doutvar = dinvar, doutvar self.dinvar, self.doutvar = tuple(dinvar), tuple(doutvar)
self.dinparam, self.doutparam = dinparam, doutparam self.dinparam, self.doutparam = tuple(dinparam), tuple(doutparam)
@op_apply @op_apply
def apply(self, **kwds): def apply(self, **kwds):
super(PythonCustomOperator, self).apply(**kwds) super(PythonCustomOperator, self).apply(**kwds)
self.doutvar = self.func(*(self.dinvar + self.dinparam + self.doutparam)) self.func(*(self.dinvar + self.dinparam + self.doutvar + self.doutparam))
for gh_exch in self.ghost_exchanger: for gh_exch in self.ghost_exchanger:
gh_exch.exchange_ghosts() gh_exch.exchange_ghosts()
...@@ -8,6 +8,14 @@ from hysop.topology.cartesian_descriptor import CartesianTopologyDescriptors ...@@ -8,6 +8,14 @@ from hysop.topology.cartesian_descriptor import CartesianTopologyDescriptors
class CustomOperator(ComputationalGraphNodeFrontend): class CustomOperator(ComputationalGraphNodeFrontend):
"""
Function should take parameters in the following order:
1. all input fields
2. all input parameters
3. all output fields
4. all output parameters
Note that discrete fields are passed as arguments to the custom function.
"""
@classmethod @classmethod
def implementations(cls): def implementations(cls):
...@@ -28,6 +36,25 @@ class CustomOperator(ComputationalGraphNodeFrontend): ...@@ -28,6 +36,25 @@ class CustomOperator(ComputationalGraphNodeFrontend):
allow_none=True) allow_none=True)
check_instance(outvars, (tuple, list), values=(Field, Parameter), check_instance(outvars, (tuple, list), values=(Field, Parameter),
allow_none=True) allow_none=True)
from inspect import getargspec as signature # should be inspect.signature in python 3
nb_args = len(signature(func).args)
nb_in_f, nb_in_p, nb_out_f, nb_out_p = 0, 0, 0, 0
if invars is not None:
for v in invars:
if isinstance(v, Field):
nb_in_f += 1
elif isinstance(v, Parameter):
nb_in_p += 1
if outvars is not None:
for v in outvars:
if isinstance(v, Field):
nb_out_f += 1
elif isinstance(v, Parameter):
nb_out_p += 1
msg = "function arguments ({}) did not match given in/out ".format(signature(func))
msg += "fields and parameters ({} input fields, {} input params,".format(nb_in_f, nb_in_p)
msg += " {} output fields, {} output params).".format(nb_out_f, nb_out_p)
assert nb_args == nb_in_f + nb_in_p + nb_out_f + nb_out_p, msg
super(CustomOperator, self).__init__( super(CustomOperator, self).__init__(
func=func, invars=invars, outvars=outvars, **kwds) func=func, invars=invars, outvars=outvars, **kwds)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment