Skip to content
Snippets Groups Projects
Commit 12560753 authored by EXT Jean-Matthieu Etancelin's avatar EXT Jean-Matthieu Etancelin
Browse files

Add an extra argument parameter for custom operator

parent 3a978c77
No related branches found
No related tags found
1 merge request!16MPI operators
Pipeline #44971 passed
......@@ -9,11 +9,13 @@ from hysop.core.graph.graph import op_apply
class PythonCustomOperator(HostOperator):
@debug
def __init__(self, func, invars=None, outvars=None, variables=None, ghosts=None, **kwds):
def __init__(self, func, invars=None, outvars=None,
extra_args=None, variables=None, ghosts=None, **kwds):
check_instance(invars, (tuple, list), values=(Field, Parameter),
allow_none=True)
check_instance(outvars, (tuple, list), values=(Field, Parameter),
allow_none=True)
check_instance(extra_args, tuple, allow_none=True)
check_instance(variables, dict, keys=Field,
values=CartesianTopologyDescriptors,
allow_none=True)
......@@ -34,8 +36,10 @@ class PythonCustomOperator(HostOperator):
output_params[v.name] = v
self.invars, self.outvars = invars, outvars
self.func = func
self.extra_args = tuple()
if not extra_args is None:
self.extra_args = extra_args
self._ghosts = ghosts
super(PythonCustomOperator, self).__init__(
input_fields=input_fields, output_fields=output_fields,
input_params=input_params, output_params=output_params,
......@@ -90,6 +94,6 @@ class PythonCustomOperator(HostOperator):
@op_apply
def apply(self, **kwds):
super(PythonCustomOperator, self).apply(**kwds)
self.func(*(self.dinvar + self.dinparam + self.doutvar + self.doutparam))
self.func(*(self.dinvar + self.dinparam + self.doutvar + self.doutparam + self.extra_args))
for gh_exch in self.ghost_exchanger:
gh_exch.exchange_ghosts()
......@@ -31,15 +31,16 @@ class CustomOperator(ComputationalGraphNodeFrontend):
return Implementation.PYTHON
@debug
def __init__(self, func, invars=None, outvars=None, ghosts=None, **kwds):
def __init__(self, func, invars=None, outvars=None, extra_args=None, ghosts=None, **kwds):
check_instance(invars, (tuple, list), values=(Field, Parameter),
allow_none=True)
check_instance(outvars, (tuple, list), values=(Field, Parameter),
allow_none=True)
check_instance(extra_args, tuple, allow_none=True)
check_instance(ghosts, int, allow_none=True)
from inspect import getargspec as signature # should be inspect.signature in python 3
nb_args = len(signature(func).args)
nb_in_f, nb_in_p, nb_out_f, nb_out_p = 0, 0, 0, 0
nb_in_f, nb_in_p, nb_out_f, nb_out_p, nb_extra = 0, 0, 0, 0, 0
if invars is not None:
for v in invars:
if isinstance(v, Field):
......@@ -52,10 +53,12 @@ class CustomOperator(ComputationalGraphNodeFrontend):
nb_out_f += v.nb_components
elif isinstance(v, Parameter):
nb_out_p += 1
if not extra_args is None:
nb_extra = len(extra_args)
msg = "function arguments ({}) did not match given in/out ".format(signature(func))
msg += "fields and parameters ({} input fields, {} input params,".format(nb_in_f, nb_in_p)
msg += " {} output fields, {} output params).".format(nb_out_f, nb_out_p)
assert nb_args == nb_in_f + nb_in_p + nb_out_f + nb_out_p, msg
assert nb_args == nb_in_f + nb_in_p + nb_out_f + nb_out_p + nb_extra, msg
super(CustomOperator, self).__init__(
func=func, invars=invars, outvars=outvars, ghosts=ghosts, **kwds)
func=func, invars=invars, outvars=outvars, extra_args=extra_args, ghosts=ghosts, **kwds)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment