Skip to content
Snippets Groups Projects
custom.py 4.31 KiB
from hysop.tools.decorators import debug
from hysop.tools.types import check_instance
from hysop.fields.continuous_field import Field, VectorField
from hysop.parameters.parameter import Parameter
from hysop.topology.cartesian_descriptor import CartesianTopologyDescriptors
from hysop.backend.host.host_operator import HostOperator
from hysop.core.graph.graph import op_apply


class PythonCustomOperator(HostOperator):
    @debug
    def __init__(self, func, invars=None, outvars=None,
                 extra_args=None, variables=None, ghosts=None, **kwds):
        check_instance(invars, (tuple, list), values=(Field, Parameter),
                       allow_none=True)
        check_instance(outvars, (tuple, list), values=(Field, Parameter),
                       allow_none=True)
        check_instance(extra_args, tuple, allow_none=True)
        check_instance(variables, dict, keys=Field,
                       values=CartesianTopologyDescriptors,
                       allow_none=True)
        check_instance(ghosts, int, allow_none=True)
        input_fields, output_fields = {}, {}
        input_params, output_params = {}, {}
        if invars is not None:
            for v in invars:
                if isinstance(v, Field):
                    input_fields[v] = variables[v]
                elif isinstance(v, Parameter):
                    input_params[v.name] = v
        if outvars is not None:
            for v in outvars:
                if isinstance(v, Field):
                    output_fields[v] = variables[v]
                elif isinstance(v, Parameter):
                    output_params[v.name] = v
        self.invars, self.outvars = invars, outvars
        self.func = func
        self.extra_args = tuple()
        if not extra_args is None:
            self.extra_args = extra_args
        self._ghosts = ghosts
        super(PythonCustomOperator, self).__init__(
            input_fields=input_fields, output_fields=output_fields,
            input_params=input_params, output_params=output_params,
            **kwds)

    @classmethod
    def supports_mpi(cls):
        return True

    @debug
    def get_field_requirements(self):
        requirements = super(PythonCustomOperator, self).get_field_requirements()
        if not self._ghosts is None:
            for it in requirements.iter_requirements():
                if not it[1] is None:
                    is_input, (field, td, req) = it
                    min_ghosts = (max(g, self._ghosts) for g in req.min_ghosts.copy())
                    max_ghosts = (min(g, self._ghosts) for g in req.max_ghosts.copy())
                    req.min_ghosts = min_ghosts
                    req.max_ghosts = max_ghosts
        return requirements

    @debug
    def discretize(self):
        if self.discretized:
            return
        super(PythonCustomOperator, self).discretize()
        dinvar, dinparam = [], []
        doutvar, doutparam = [], []
        idf, odf = self.input_discrete_fields, self.output_discrete_fields
        self.ghost_exchanger = []
        if self.invars is not None:
            for v in self.invars:
                if isinstance(v, Field):
                    for _v in v if isinstance(v, VectorField) else (v, ):
                        for vd in idf[_v]:
                            dinvar.append(vd)
                elif isinstance(v, Parameter):
                    dinparam.append(v)
        if self.outvars is not None:
            for v in self.outvars:
                if isinstance(v, Field):
                    for _v in v if isinstance(v, VectorField) else (v, ):
                        for vd in self.output_discrete_fields[_v]:
                            doutvar.append(vd)
                        gh = self.output_discrete_fields[_v].build_ghost_exchanger()
                        if gh is not None:
                            self.ghost_exchanger.append(gh)
                elif isinstance(v, Parameter):
                    doutparam.append(v)
        self.dinvar, self.doutvar = tuple(dinvar), tuple(doutvar)
        self.dinparam, self.doutparam = tuple(dinparam), tuple(doutparam)

    @op_apply
    def apply(self, **kwds):
        super(PythonCustomOperator, self).apply(**kwds)
        self.func(*(self.dinvar + self.dinparam + self.doutvar + self.doutparam + self.extra_args))
        for gh_exch in self.ghost_exchanger:
            gh_exch.exchange_ghosts()