diff --git a/hysop/backend/host/python/operator/custom.py b/hysop/backend/host/python/operator/custom.py index 74641b19f396cf1d76365f0cf8ed4a149f46aebd..25f8f48fbdeaccf6e8bc276d714aa7bc93076cf8 100644 --- a/hysop/backend/host/python/operator/custom.py +++ b/hysop/backend/host/python/operator/custom.py @@ -9,11 +9,13 @@ from hysop.core.graph.graph import op_apply class PythonCustomOperator(HostOperator): @debug - def __init__(self, func, invars=None, outvars=None, variables=None, ghosts=None, **kwds): + def __init__(self, func, invars=None, outvars=None, + extra_args=None, variables=None, ghosts=None, **kwds): check_instance(invars, (tuple, list), values=(Field, Parameter), allow_none=True) check_instance(outvars, (tuple, list), values=(Field, Parameter), allow_none=True) + check_instance(extra_args, tuple, allow_none=True) check_instance(variables, dict, keys=Field, values=CartesianTopologyDescriptors, allow_none=True) @@ -34,8 +36,10 @@ class PythonCustomOperator(HostOperator): output_params[v.name] = v self.invars, self.outvars = invars, outvars self.func = func + self.extra_args = tuple() + if not extra_args is None: + self.extra_args = extra_args self._ghosts = ghosts - super(PythonCustomOperator, self).__init__( input_fields=input_fields, output_fields=output_fields, input_params=input_params, output_params=output_params, @@ -90,6 +94,6 @@ class PythonCustomOperator(HostOperator): @op_apply def apply(self, **kwds): super(PythonCustomOperator, self).apply(**kwds) - self.func(*(self.dinvar + self.dinparam + self.doutvar + self.doutparam)) + self.func(*(self.dinvar + self.dinparam + self.doutvar + self.doutparam + self.extra_args)) for gh_exch in self.ghost_exchanger: gh_exch.exchange_ghosts() diff --git a/hysop/operator/custom.py b/hysop/operator/custom.py index 3a00efaa01f2ee0bdb52d8481e4b2ebd17ca1fd8..e4aa9bf824fb76ff37d3da544b1c5c6d1fffb190 100755 --- a/hysop/operator/custom.py +++ b/hysop/operator/custom.py @@ -31,15 +31,16 @@ class CustomOperator(ComputationalGraphNodeFrontend): return Implementation.PYTHON @debug - def __init__(self, func, invars=None, outvars=None, ghosts=None, **kwds): + def __init__(self, func, invars=None, outvars=None, extra_args=None, ghosts=None, **kwds): check_instance(invars, (tuple, list), values=(Field, Parameter), allow_none=True) check_instance(outvars, (tuple, list), values=(Field, Parameter), allow_none=True) + check_instance(extra_args, tuple, allow_none=True) check_instance(ghosts, int, allow_none=True) from inspect import getargspec as signature # should be inspect.signature in python 3 nb_args = len(signature(func).args) - nb_in_f, nb_in_p, nb_out_f, nb_out_p = 0, 0, 0, 0 + nb_in_f, nb_in_p, nb_out_f, nb_out_p, nb_extra = 0, 0, 0, 0, 0 if invars is not None: for v in invars: if isinstance(v, Field): @@ -52,10 +53,12 @@ class CustomOperator(ComputationalGraphNodeFrontend): nb_out_f += v.nb_components elif isinstance(v, Parameter): nb_out_p += 1 + if not extra_args is None: + nb_extra = len(extra_args) msg = "function arguments ({}) did not match given in/out ".format(signature(func)) msg += "fields and parameters ({} input fields, {} input params,".format(nb_in_f, nb_in_p) msg += " {} output fields, {} output params).".format(nb_out_f, nb_out_p) - assert nb_args == nb_in_f + nb_in_p + nb_out_f + nb_out_p, msg + assert nb_args == nb_in_f + nb_in_p + nb_out_f + nb_out_p + nb_extra, msg super(CustomOperator, self).__init__( - func=func, invars=invars, outvars=outvars, ghosts=ghosts, **kwds) + func=func, invars=invars, outvars=outvars, extra_args=extra_args, ghosts=ghosts, **kwds)