-
EXT Jean-Matthieu Etancelin authoredEXT Jean-Matthieu Etancelin authored
custom.py 4.31 KiB
from hysop.tools.decorators import debug
from hysop.tools.types import check_instance
from hysop.fields.continuous_field import Field, VectorField
from hysop.parameters.parameter import Parameter
from hysop.topology.cartesian_descriptor import CartesianTopologyDescriptors
from hysop.backend.host.host_operator import HostOperator
from hysop.core.graph.graph import op_apply
class PythonCustomOperator(HostOperator):
@debug
def __init__(self, func, invars=None, outvars=None,
extra_args=None, variables=None, ghosts=None, **kwds):
check_instance(invars, (tuple, list), values=(Field, Parameter),
allow_none=True)
check_instance(outvars, (tuple, list), values=(Field, Parameter),
allow_none=True)
check_instance(extra_args, tuple, allow_none=True)
check_instance(variables, dict, keys=Field,
values=CartesianTopologyDescriptors,
allow_none=True)
check_instance(ghosts, int, allow_none=True)
input_fields, output_fields = {}, {}
input_params, output_params = {}, {}
if invars is not None:
for v in invars:
if isinstance(v, Field):
input_fields[v] = variables[v]
elif isinstance(v, Parameter):
input_params[v.name] = v
if outvars is not None:
for v in outvars:
if isinstance(v, Field):
output_fields[v] = variables[v]
elif isinstance(v, Parameter):
output_params[v.name] = v
self.invars, self.outvars = invars, outvars
self.func = func
self.extra_args = tuple()
if not extra_args is None:
self.extra_args = extra_args
self._ghosts = ghosts
super(PythonCustomOperator, self).__init__(
input_fields=input_fields, output_fields=output_fields,
input_params=input_params, output_params=output_params,
**kwds)
@classmethod
def supports_mpi(cls):
return True
@debug
def get_field_requirements(self):
requirements = super(PythonCustomOperator, self).get_field_requirements()
if not self._ghosts is None:
for it in requirements.iter_requirements():
if not it[1] is None:
is_input, (field, td, req) = it
min_ghosts = (max(g, self._ghosts) for g in req.min_ghosts.copy())
max_ghosts = (min(g, self._ghosts) for g in req.max_ghosts.copy())
req.min_ghosts = min_ghosts
req.max_ghosts = max_ghosts
return requirements
@debug
def discretize(self):
if self.discretized:
return
super(PythonCustomOperator, self).discretize()
dinvar, dinparam = [], []
doutvar, doutparam = [], []
idf, odf = self.input_discrete_fields, self.output_discrete_fields
self.ghost_exchanger = []
if self.invars is not None:
for v in self.invars:
if isinstance(v, Field):
for _v in v if isinstance(v, VectorField) else (v, ):
for vd in idf[_v]:
dinvar.append(vd)
elif isinstance(v, Parameter):
dinparam.append(v)
if self.outvars is not None:
for v in self.outvars:
if isinstance(v, Field):
for _v in v if isinstance(v, VectorField) else (v, ):
for vd in self.output_discrete_fields[_v]:
doutvar.append(vd)
gh = self.output_discrete_fields[_v].build_ghost_exchanger()
if gh is not None:
self.ghost_exchanger.append(gh)
elif isinstance(v, Parameter):
doutparam.append(v)
self.dinvar, self.doutvar = tuple(dinvar), tuple(doutvar)
self.dinparam, self.doutparam = tuple(dinparam), tuple(doutparam)
@op_apply
def apply(self, **kwds):
super(PythonCustomOperator, self).apply(**kwds)
self.func(*(self.dinvar + self.dinparam + self.doutvar + self.doutparam + self.extra_args))
for gh_exch in self.ghost_exchanger:
gh_exch.exchange_ghosts()