Skip to content
Snippets Groups Projects
Commit c7cfceb8 authored by EXT Jean-Matthieu Etancelin's avatar EXT Jean-Matthieu Etancelin
Browse files

Tensor parameter view as (1,) shape is converted to ScalarParameter

parent 99cdd23b
No related branches found
No related tags found
1 merge request!16MPI operators
......@@ -160,12 +160,21 @@ class TensorParameter(Parameter):
_pretty_name = self.pretty_name + subscripts(ids=idx, sep='').encode('utf-8')
name = first_not_None(name, _name)
pretty_name = first_not_None(pretty_name, _pretty_name)
return TensorParameter(name=name, pretty_name=pretty_name,
initial_value=initial_value, dtype=self.dtype, shape=initial_value.shape,
min_value=self.min_value, max_value=self.max_value,
ignore_nans=self.ignore_nans,
const=self.const, quiet=self.quiet,
is_view=True, **kwds)
if initial_value.size == 1:
from scalar_parameter import ScalarParameter
return ScalarParameter(name=name, pretty_name=pretty_name,
initial_value=initial_value, dtype=self.dtype,
min_value=self.min_value, max_value=self.max_value,
ignore_nans=self.ignore_nans,
const=self.const, quiet=self.quiet,
is_view=True, **kwds)
else:
return TensorParameter(name=name, pretty_name=pretty_name,
initial_value=initial_value, dtype=self.dtype, shape=initial_value.shape,
min_value=self.min_value, max_value=self.max_value,
ignore_nans=self.ignore_nans,
const=self.const, quiet=self.quiet,
is_view=True, **kwds)
def iterviews(self):
"""Iterate over all parameters views to yield scalarparameters."""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment