diff --git a/hysop/backend/host/python/operator/analytic.py b/hysop/backend/host/python/operator/analytic.py index 10eb14efd0aa4561702d698a981676eb33df962e..fe7df078130187c3c3facf68f8dbee608696a50f 100644 --- a/hysop/backend/host/python/operator/analytic.py +++ b/hysop/backend/host/python/operator/analytic.py @@ -1,12 +1,12 @@ - from hysop.tools.types import check_instance, first_not_None from hysop.tools.decorators import debug from hysop.backend.host.host_operator import HostOperator from hysop.core.graph.graph import op_apply -from hysop.fields.continuous_field import Field, ScalarField +from hysop.fields.continuous_field import Field, ScalarField, VectorField from hysop.parameters.parameter import Parameter from hysop.topology.cartesian_descriptor import CartesianTopologyDescriptors + class PythonAnalyticField(HostOperator): """ Applies an analytic formula, given by user, on its field. @@ -15,15 +15,15 @@ class PythonAnalyticField(HostOperator): @debug def __new__(cls, field, formula, variables, - extra_input_kwds=None, **kwds): + extra_input_kwds=None, **kwds): return super(PythonAnalyticField, cls).__new__(cls, - input_fields=None, - output_fields=None, - input_params=None, **kwds) + input_fields=None, + output_fields=None, + input_params=None, **kwds) @debug def __init__(self, field, formula, variables, - extra_input_kwds=None, **kwds): + extra_input_kwds=None, **kwds): """ Initialize a Analytic operator on the python backend. @@ -52,18 +52,18 @@ class PythonAnalyticField(HostOperator): """ extra_input_kwds = first_not_None(extra_input_kwds, {}) - check_instance(field, ScalarField) + check_instance(field, (ScalarField, VectorField)) assert callable(formula), type(formula) check_instance(variables, dict, keys=Field, values=CartesianTopologyDescriptors) check_instance(extra_input_kwds, dict, keys=str) - input_fields = {} - output_fields = { field: self.get_topo_descriptor(variables, field) } - input_params = {} + input_fields = {} + output_fields = {field: self.get_topo_descriptor(variables, field)} + input_params = {} extra_kwds = {} map_fields = {} - for (k,v) in extra_input_kwds.items(): + for (k, v) in extra_input_kwds.items(): if isinstance(v, Field): input_fields[v] = self.get_topo_descriptor(variables, v) map_fields[v] = k @@ -74,8 +74,8 @@ class PythonAnalyticField(HostOperator): extra_kwds[k] = v super(PythonAnalyticField, self).__init__(input_fields=input_fields, - output_fields=output_fields, - input_params=input_params, **kwds) + output_fields=output_fields, + input_params=input_params, **kwds) self.field = field self.formula = formula @@ -90,9 +90,11 @@ class PythonAnalyticField(HostOperator): dfield = self.get_output_discrete_field(self.field) extra_kwds = self.extra_kwds map_fields = self.map_fields - assert 'data' not in extra_kwds + assert 'data' not in extra_kwds assert 'coords' not in extra_kwds - extra_kwds['data'] = dfield.compute_data[0] + extra_kwds['data'] = dfield.compute_data[0] + if len(dfield.compute_data) > 1: + extra_kwds['data'] = dfield.compute_data extra_kwds['coords'] = dfield.compute_mesh_coords for (field, dfield) in self.input_discrete_fields.items(): assert field.name not in extra_kwds, field.name @@ -109,4 +111,3 @@ class PythonAnalyticField(HostOperator): @classmethod def supports_mpi(cls): return True -