Skip to content
Snippets Groups Projects
Commit 8c5aad5b authored by Jean-Matthieu Etancelin's avatar Jean-Matthieu Etancelin
Browse files

Fix bug in accessing VectorField elements.

parent f8a2969a
No related branches found
No related tags found
No related merge requests found
......@@ -164,15 +164,17 @@ class VectorField(object):
Usage (3D): \n
A = VectorField(...), We have 3 components len(a.data) == 3. \n
Following instructions access to index 2,1,1 of y component:
@li A[2,1,1,1]
@li A[1,2,1,1]
@li A[1][2,1,1]
@li Access to whole vector of index 2,1,1: A[2,1,1]
@note Access to all datas as an array : A[:,:,:] (resulting shape is (dim, Nx, Ny, Nz))
@note Access to all datas in y component as an array : A[1][:,:,:] (resulting shape is (Nx, Ny, Nz))
"""
try:
if len(i) == len(self.data):
return [data.__getitem__(i) for data in self.data]
return np.asarray([data.__getitem__(i) for data in self.data])
else:
return self.data[i[-1]].__getitem__(tuple(i[0:-1]))
return self.data[i[0]].__getitem__(tuple(i[1:]))
except (TypeError):
return self.data[i]
......@@ -188,14 +190,15 @@ class VectorField(object):
A[2,1,1] = 12.0 # Calls A.data[d][2,1,1] = 12.0 for all components d.\n
A[2,1,1] = [12.0, 13.0, 14.0] # Calls A.data[0][2,1,1] = 12.0, A.data[1][2,1,1] = 13.0 and A.data[2][2,1,1] = 14.0\n
A[1][2,1,1] = 13.0 # Calls A.data[1][2,1,1] = 12.0
A[1,2,1,1] = 13.0 # Calls A.data[1][2,1,1] = 12.0
"""
if len(i) == len(self.data):
try:
[data.__setitem__(i, v) for data, v in zip(self.data, value)]
except (TypeError):
[data.__setitem__(i, value) for data in self.data]
else:
self.data[i[-1]].__getitem__(tuple(i[0:-1]))
elif len(i) > len(self.data):
self.data[i[0]].__setitem__(tuple(i[1:]), value)
def initialize(self, formula=None):
"""
......
......@@ -85,18 +85,21 @@ class VectorFieldTestCase(unittest.TestCase):
true_res_x = np.zeros_like(self.dv.data[0])
true_res_y = np.zeros_like(self.dv.data[1])
true_res_z = np.zeros_like(self.dv.data[2])
self.assertEqual(self.dv[:, :, :].shape, (3, 10, 10, 10))
self.assertEqual(self.dv[0][:, :, :].shape, (10, 10, 10))
self.assertEqual(self.dv[0,0,0].shape, (3,))
for i in np.arange(10):
for j in np.arange(10):
for k in np.arange(10):
true_res_x[i, j, k], true_res_y[i, j, k], true_res_z[i, j, k] = self.formula(i / 10., j / 10., k / 10.)
self.assertAlmostEqual(self.dv[0][i, j, k], true_res_x[i, j, k])
self.assertAlmostEqual(self.dv[i, j, k, 1], true_res_y[i, j, k])
self.assertAlmostEqual(self.dv[1, i, j, k], true_res_y[i, j, k])
np.testing.assert_array_almost_equal(self.dv[i, j, k], [true_res_x[i, j, k], true_res_y[i, j, k], true_res_z[i, j, k]])
np.testing.assert_array_almost_equal(self.dv[0], true_res_x)
np.testing.assert_array_almost_equal(self.dv[1], true_res_y)
np.testing.assert_array_almost_equal(self.dv[2], true_res_z)
np.testing.assert_array_almost_equal(self.dv[0][:,1,2], true_res_x[:,1,2])
np.testing.assert_array_almost_equal(self.dv[:,1,2,0], true_res_x[:,1,2])
np.testing.assert_array_almost_equal(self.dv[0,:,1,2], true_res_x[:,1,2])
self.dv[1,2,3] = [1., 2., 3.]
np.testing.assert_array_almost_equal( self.dv[1,2,3], [1., 2., 3.])
self.dv[0,1,:] = 5.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment