Skip to content
Snippets Groups Projects
Commit d38c694b authored by EXT Jean-Matthieu Etancelin's avatar EXT Jean-Matthieu Etancelin
Browse files

add to custom operator a ghosts parameter if needed

parent 5f6a20a9
No related branches found
No related tags found
1 merge request!16MPI operators
Pipeline #41077 failed
...@@ -9,7 +9,7 @@ from hysop.core.graph.graph import op_apply ...@@ -9,7 +9,7 @@ from hysop.core.graph.graph import op_apply
class PythonCustomOperator(HostOperator): class PythonCustomOperator(HostOperator):
@debug @debug
def __init__(self, func, invars=None, outvars=None, variables=None, **kwds): def __init__(self, func, invars=None, outvars=None, variables=None, ghosts=None, **kwds):
check_instance(invars, (tuple, list), values=(Field, Parameter), check_instance(invars, (tuple, list), values=(Field, Parameter),
allow_none=True) allow_none=True)
check_instance(outvars, (tuple, list), values=(Field, Parameter), check_instance(outvars, (tuple, list), values=(Field, Parameter),
...@@ -17,6 +17,7 @@ class PythonCustomOperator(HostOperator): ...@@ -17,6 +17,7 @@ class PythonCustomOperator(HostOperator):
check_instance(variables, dict, keys=Field, check_instance(variables, dict, keys=Field,
values=CartesianTopologyDescriptors, values=CartesianTopologyDescriptors,
allow_none=True) allow_none=True)
check_instance(ghosts, int, allow_none=True)
input_fields, output_fields = {}, {} input_fields, output_fields = {}, {}
input_params, output_params = {}, {} input_params, output_params = {}, {}
if invars is not None: if invars is not None:
...@@ -33,6 +34,7 @@ class PythonCustomOperator(HostOperator): ...@@ -33,6 +34,7 @@ class PythonCustomOperator(HostOperator):
output_params[v.name] = v output_params[v.name] = v
self.invars, self.outvars = invars, outvars self.invars, self.outvars = invars, outvars
self.func = func self.func = func
self._ghosts = ghosts
super(PythonCustomOperator, self).__init__( super(PythonCustomOperator, self).__init__(
input_fields=input_fields, output_fields=output_fields, input_fields=input_fields, output_fields=output_fields,
...@@ -43,6 +45,17 @@ class PythonCustomOperator(HostOperator): ...@@ -43,6 +45,17 @@ class PythonCustomOperator(HostOperator):
def supports_mpi(cls): def supports_mpi(cls):
return True return True
@debug
def get_field_requirements(self):
requirements = super(PythonCustomOperator, self).get_field_requirements()
if not self._ghosts is None:
for it in requirements.iter_requirements():
if not it[1] is None:
is_input, (field, td, req) = it
min_ghosts = (max(g, self._ghosts) for g in req.min_ghosts.copy())
req.min_ghosts = min_ghosts
return requirements
@debug @debug
def discretize(self): def discretize(self):
if self.discretized: if self.discretized:
......
...@@ -3,7 +3,7 @@ import graph_tool as gt ...@@ -3,7 +3,7 @@ import graph_tool as gt
from graph_tool import Graph, GraphView from graph_tool import Graph, GraphView
from graph_tool import topology, stats, search from graph_tool import topology, stats, search
from hysop.tools.decorators import not_implemented, debug, wraps, profile from hysop.tools.decorators import not_implemented, debug, wraps, profile
from hysop import vprint from hysop import dprint
class ComputationalGraphNodeData(object): class ComputationalGraphNodeData(object):
...@@ -171,7 +171,7 @@ def op_apply(f): ...@@ -171,7 +171,7 @@ def op_apply(f):
if not op.to_be_skipped(): if not op.to_be_skipped():
return f(*args, **kwds) return f(*args, **kwds)
else: else:
vprint("Skip {}".format(op.name)) dprint("Skip {}".format(op.name))
return return
return ret return ret
return apply return apply
...@@ -31,11 +31,12 @@ class CustomOperator(ComputationalGraphNodeFrontend): ...@@ -31,11 +31,12 @@ class CustomOperator(ComputationalGraphNodeFrontend):
return Implementation.PYTHON return Implementation.PYTHON
@debug @debug
def __init__(self, func, invars=None, outvars=None, **kwds): def __init__(self, func, invars=None, outvars=None, ghosts=None, **kwds):
check_instance(invars, (tuple, list), values=(Field, Parameter), check_instance(invars, (tuple, list), values=(Field, Parameter),
allow_none=True) allow_none=True)
check_instance(outvars, (tuple, list), values=(Field, Parameter), check_instance(outvars, (tuple, list), values=(Field, Parameter),
allow_none=True) allow_none=True)
check_instance(ghosts, int, allow_none=True)
from inspect import getargspec as signature # should be inspect.signature in python 3 from inspect import getargspec as signature # should be inspect.signature in python 3
nb_args = len(signature(func).args) nb_args = len(signature(func).args)
nb_in_f, nb_in_p, nb_out_f, nb_out_p = 0, 0, 0, 0 nb_in_f, nb_in_p, nb_out_f, nb_out_p = 0, 0, 0, 0
...@@ -57,4 +58,4 @@ class CustomOperator(ComputationalGraphNodeFrontend): ...@@ -57,4 +58,4 @@ class CustomOperator(ComputationalGraphNodeFrontend):
assert nb_args == nb_in_f + nb_in_p + nb_out_f + nb_out_p, msg assert nb_args == nb_in_f + nb_in_p + nb_out_f + nb_out_p, msg
super(CustomOperator, self).__init__( super(CustomOperator, self).__init__(
func=func, invars=invars, outvars=outvars, **kwds) func=func, invars=invars, outvars=outvars, ghosts=ghosts, **kwds)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment