diff --git a/hysop/backend/host/python/operator/custom.py b/hysop/backend/host/python/operator/custom.py index 28ddf36dc2863d30c046c8fe6ff19a8e0dc6797e..f63d0647bd275f5607ccb83a6a0955851f1edfc8 100644 --- a/hysop/backend/host/python/operator/custom.py +++ b/hysop/backend/host/python/operator/custom.py @@ -26,12 +26,13 @@ from hysop.operator.base.custom import CustomOperatorBase class PythonCustomOperator(CustomOperatorBase, HostOperator): @debug - def __new__(cls, **kwds): + def __new__(cls, forward_apply_kwds=None, **kwds): return super().__new__(cls, **kwds) @debug - def __init__(self, func, invars=None, outvars=None, extra_args=None, **kwds): - super().__init__(func, invars=invars, outvars=outvars, extra_args=extra_args, **kwds) + def __init__(self, func, invars=None, outvars=None, extra_args=None, forward_apply_kwds=False, **kwds): + super().__init__(func, invars=invars, outvars=outvars, extra_args=extra_args, + **kwds) from inspect import signature nb_args = len(signature(func).parameters) @@ -53,11 +54,15 @@ class PythonCustomOperator(CustomOperatorBase, HostOperator): msg = f"function arguments ({signature(func)}) did not match given in/out " msg += f"fields and parameters ({nb_in_f} input fields, {nb_in_p} input params," msg += f" {nb_out_f} output fields, {nb_out_p} output params)." - assert nb_args == nb_in_f + nb_in_p + nb_out_f + nb_out_p + nb_extra, msg + assert nb_args == nb_in_f + nb_in_p + nb_out_f + nb_out_p + nb_extra + forward_apply_kwds, msg + self._forward_apply_kwds = forward_apply_kwds @op_apply def apply(self, **kwds): super().apply(**kwds) - self.func(*(self.dinvar + self.dinparam + self.doutvar + self.doutparam + self.extra_args)) + if self._forward_apply_kwds: + self.func(*(self.dinvar + self.dinparam + self.doutvar + self.doutparam + self.extra_args), **kwds) + else: + self.func(*(self.dinvar + self.dinparam + self.doutvar + self.doutparam + self.extra_args)) for gh_exch in self.ghost_exchanger: gh_exch.exchange_ghosts()