Skip to content

Commit b260d60

Browse files
committed
Merge branch 'fix-atom-structure-pickle'
* fix inconsistent (Atom, Structure) pickle * add `isiterable` utility function Resolve diffpy/diffpy.srfit#56.
2 parents 708a754 + eab12a2 commit b260d60

File tree

3 files changed

+64
-22
lines changed

3 files changed

+64
-22
lines changed

src/diffpy/structure/structure.py

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,15 @@
1616
"""This module defines class Structure.
1717
"""
1818

19-
import collections
20-
import copy
19+
import copy as copymod
2120
import numpy
2221
import codecs
2322
import six
2423

2524
from diffpy.structure.lattice import Lattice
2625
from diffpy.structure.atom import Atom
2726
from diffpy.structure.utils import _linkAtomAttribute, atomBareSymbol
27+
from diffpy.structure.utils import isiterable
2828

2929
# ----------------------------------------------------------------------------
3030

@@ -95,9 +95,9 @@ def __init__(self, atoms=None, lattice=None, title=None,
9595

9696

9797
def copy(self):
98-
'''Return a deep copy of this Structure object.
98+
'''Return a copy of this Structure object.
9999
'''
100-
return copy.copy(self)
100+
return copymod.copy(self)
101101

102102

103103
def __copy__(self, target=None):
@@ -116,7 +116,7 @@ def __copy__(self, target=None):
116116
# copy attributes as appropriate:
117117
target.title = self.title
118118
target.lattice = Lattice(self.lattice)
119-
target.pdffit = copy.deepcopy(self.pdffit)
119+
target.pdffit = copymod.deepcopy(self.pdffit)
120120
# copy all atoms to the target
121121
target[:] = self
122122
return target
@@ -332,25 +332,43 @@ def insert(self, idx, a, copy=True):
332332
333333
No return value.
334334
"""
335-
adup = copy and Atom(a) or a
335+
adup = copy and copymod.copy(a) or a
336336
adup.lattice = self.lattice
337337
super(Structure, self).insert(idx, adup)
338338
return
339339

340340

341-
def extend(self, atoms, copy=True):
342-
"""Extend Structure by appending copies from a list of atoms.
341+
def extend(self, atoms, copy=None):
342+
"""Extend Structure with an iterable of atoms.
343343
344-
atoms -- list of Atom instances
345-
copy -- flag for extending with copies of Atom instances.
346-
When False extend with atoms and update their lattice
347-
attributes.
344+
Update the `lattice` attribute of all added atoms.
348345
349-
No return value.
346+
Parameters
347+
----------
348+
atoms : iterable
349+
The `Atom` objects to be appended to this Structure.
350+
copy : bool, optional
351+
Flag for adding copies of Atom objects.
352+
Make copies when `True`, append `atoms` unchanged when ``False``.
353+
The default behavior is to make copies when `atoms` are of
354+
`Structure` type or if new atoms introduce repeated objects.
350355
"""
351-
adups = map(Atom, atoms) if copy else atoms
356+
adups = (copymod.copy(a) for a in atoms)
357+
if copy is None:
358+
if isinstance(atoms, Structure):
359+
newatoms = adups
360+
else:
361+
memo = set(id(a) for a in self)
362+
nextatom = lambda a: (a if id(a) not in memo
363+
else copymod.copy(a))
364+
mark = lambda a: (memo.add(id(a)), a)[-1]
365+
newatoms = (mark(nextatom(a)) for a in atoms)
366+
elif copy:
367+
newatoms = adups
368+
else:
369+
newatoms = atoms
352370
setlat = lambda a: (setattr(a, 'lattice', self.lattice), a)[-1]
353-
super(Structure, self).extend(setlat(a) for a in adups)
371+
super(Structure, self).extend(setlat(a) for a in newatoms)
354372
return
355373

356374

@@ -388,7 +406,7 @@ def __getitem__(self, idx):
388406
# check if there is any string label that should be resolved
389407
scalarstringlabel = isinstance(idx, six.string_types)
390408
hasstringlabel = scalarstringlabel or (
391-
isinstance(idx, collections.Iterable) and
409+
isiterable(idx) and
392410
any(isinstance(ii, six.string_types) for ii in idx))
393411
# if not, use numpy indexing to resolve idx
394412
if not hasstringlabel:
@@ -464,7 +482,7 @@ def __add__(self, other):
464482
465483
Return new Structure with a copy of Atom instances.
466484
'''
467-
rv = copy.copy(self)
485+
rv = copymod.copy(self)
468486
rv += other
469487
return rv
470488

@@ -476,7 +494,7 @@ def __iadd__(self, other):
476494
477495
Return self.
478496
'''
479-
self.extend(other)
497+
self.extend(other, copy=True)
480498
return self
481499

482500

@@ -489,7 +507,7 @@ def __sub__(self, other):
489507
'''
490508
otherset = set(other)
491509
keepindices = [i for i, a in enumerate(self) if not a in otherset]
492-
rv = copy.copy(self[keepindices])
510+
rv = copymod.copy(self[keepindices])
493511
return rv
494512

495513

@@ -513,7 +531,7 @@ def __mul__(self, n):
513531
514532
Return new Structure.
515533
'''
516-
rv = copy.copy(self[:0])
534+
rv = copymod.copy(self[:0])
517535
rv += n * self.tolist()
518536
return rv
519537

@@ -533,7 +551,7 @@ def __imul__(self, n):
533551
if n <= 0:
534552
self[:] = []
535553
else:
536-
self.extend((n - 1) * self.tolist())
554+
self.extend((n - 1) * self.tolist(), copy=True)
537555
return self
538556

539557
# Properties -------------------------------------------------------------

src/diffpy/structure/tests/teststructure.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919

2020
import copy
21+
import pickle
2122
import unittest
2223
import numpy
2324

@@ -240,7 +241,7 @@ def test_extend(self):
240241
self.assertEqual(6, len(stru))
241242
self.assertTrue(all(a.lattice is stru.lattice for a in stru))
242243
self.assertEqual(lst, stru.tolist()[:2])
243-
self.assertNotEqual(stru[-1], cdse[-1])
244+
self.assertFalse(stru[-1] is cdse[-1])
244245
return
245246

246247

@@ -656,6 +657,17 @@ def test_Bij(self):
656657
self.assertFalse(numpy.any(stru.U != 0.0))
657658
return
658659

660+
661+
def test_pickling(self):
662+
"""Make sure Atom in Structure can be consistently pickled.
663+
"""
664+
stru = self.stru
665+
a = stru[0]
666+
self.assertTrue(a is stru[0])
667+
a1, stru1 = pickle.loads(pickle.dumps((a, stru)))
668+
self.assertTrue(a1 is stru1[0])
669+
return
670+
659671
# End of class TestStructure
660672

661673
# ----------------------------------------------------------------------------

src/diffpy/structure/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,20 @@
1616
"""Small shared functions.
1717
"""
1818

19+
import six
1920
import numpy
2021

22+
if six.PY2:
23+
from collections import Iterable as _Iterable
24+
else:
25+
from collections.abc import Iterable as _Iterable
26+
27+
28+
def isiterable(obj):
29+
"""True if argument is iterable."""
30+
rv = isinstance(obj, _Iterable)
31+
return rv
32+
2133

2234
def isfloat(s):
2335
"""True if argument can be converted to float"""

0 commit comments

Comments
 (0)