Skip to content

Commit 6c75041

Browse files
committed
BUG: avoid linked array values in distinct Atoms
Change array attributes of `Atom` instances inplace. Use numpy broadcast only when assigning array values.
1 parent d9cab02 commit 6c75041

File tree

2 files changed

+20
-8
lines changed

2 files changed

+20
-8
lines changed

src/diffpy/structure/tests/teststructure.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,6 @@ def test_element(self):
464464
return
465465

466466

467-
@unittest.expectedFailure
468467
def test_xyz(self):
469468
"""check Structure.xyz
470469
"""
@@ -477,7 +476,7 @@ def test_xyz(self):
477476
stru.xyz = 0
478477
stru[1].xyz[:] = 1
479478
self.assertTrue(numpy.array_equal([0, 0, 0], stru[0].xyz))
480-
self.assertTrue(numpy.array_equal([1, 1, 1], stru[0].xyz))
479+
self.assertTrue(numpy.array_equal([1, 1, 1], stru[1].xyz))
481480
return
482481

483482

@@ -532,7 +531,6 @@ def test_label(self):
532531
return
533532

534533

535-
@unittest.expectedFailure
536534
def test_occupancy(self):
537535
"""check Structure.occupancy
538536
"""

src/diffpy/structure/utils.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,15 +54,29 @@ def _linkAtomAttribute(attrname, doc, toarray=numpy.array):
5454
5555
Return a property object.
5656
'''
57+
from itertools import repeat
58+
from operator import setitem
59+
_all = slice(None)
5760
def fget(self):
5861
va = toarray([getattr(a, attrname) for a in self])
5962
return va
6063
def fset(self, value):
61-
if len(self) == 0: return
62-
# dummy array va helps to broadcast the value to proper iterable
63-
va = numpy.asarray(len(self) * [getattr(self[0], attrname)])
64-
for a, v in zip(self, numpy.broadcast_arrays(va, value)[1]):
65-
setattr(a, attrname, v)
64+
n = len(self)
65+
if n == 0:
66+
return
67+
v0 = getattr(self[0], attrname)
68+
# replace scalar values, but change array attributes in place
69+
if numpy.isscalar(v0):
70+
setval = lambda a, v: setattr(a, attrname, v)
71+
else:
72+
setval = lambda a, v: setitem(getattr(a, attrname), _all, v)
73+
# avoid broadcasting if the new value is a scalar
74+
if numpy.isscalar(value):
75+
gvalues = repeat(value)
76+
else:
77+
gvalues = numpy.broadcast_to(value, (n,) + numpy.shape(v0))
78+
for a, v in zip(self, gvalues):
79+
setval(a, v)
6680
return
6781
rv = property(fget, fset, doc=doc)
6882
return rv

0 commit comments

Comments
 (0)