diff --git a/hysop/parameters/tensor_parameter.py b/hysop/parameters/tensor_parameter.py index 7bc6edfe58c63e5e01b525a373e098e18cb4f60d..8ba0d474666b6fdd57f4ae6a6096289388c3b49f 100644 --- a/hysop/parameters/tensor_parameter.py +++ b/hysop/parameters/tensor_parameter.py @@ -160,12 +160,21 @@ class TensorParameter(Parameter): _pretty_name = self.pretty_name + subscripts(ids=idx, sep='').encode('utf-8') name = first_not_None(name, _name) pretty_name = first_not_None(pretty_name, _pretty_name) - return TensorParameter(name=name, pretty_name=pretty_name, - initial_value=initial_value, dtype=self.dtype, shape=initial_value.shape, - min_value=self.min_value, max_value=self.max_value, - ignore_nans=self.ignore_nans, - const=self.const, quiet=self.quiet, - is_view=True, **kwds) + if initial_value.size == 1: + from scalar_parameter import ScalarParameter + return ScalarParameter(name=name, pretty_name=pretty_name, + initial_value=initial_value, dtype=self.dtype, + min_value=self.min_value, max_value=self.max_value, + ignore_nans=self.ignore_nans, + const=self.const, quiet=self.quiet, + is_view=True, **kwds) + else: + return TensorParameter(name=name, pretty_name=pretty_name, + initial_value=initial_value, dtype=self.dtype, shape=initial_value.shape, + min_value=self.min_value, max_value=self.max_value, + ignore_nans=self.ignore_nans, + const=self.const, quiet=self.quiet, + is_view=True, **kwds) def iterviews(self): """Iterate over all parameters views to yield scalarparameters."""