Skip to content

Commit a3d7b28

Browse files
committed
Keep the trailing null bytes in ShareableList
1 parent b625601 commit a3d7b28

File tree

4 files changed

+54
-40
lines changed

4 files changed

+54
-40
lines changed

Doc/library/multiprocessing.shared_memory.rst

Lines changed: 3 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -311,34 +311,9 @@ finishes execution.
311311
existing :class:`!ShareableList`, specify its shared memory block's unique
312312
name while leaving *sequence* set to ``None``.
313313

314-
.. note::
315-
316-
A known issue exists for :class:`bytes` and :class:`str` values.
317-
If they end with ``\x00`` nul bytes or characters, those may be
318-
*silently stripped* when fetching them by index from the
319-
:class:`!ShareableList`. This ``.rstrip(b'\x00')`` behavior is
320-
considered a bug and may go away in the future. See :gh:`106939`.
321-
322-
For applications where rstripping of trailing nulls is a problem,
323-
work around it by always unconditionally appending an extra non-0
324-
byte to the end of such values when storing and unconditionally
325-
removing it when fetching:
326-
327-
.. doctest::
328-
329-
>>> from multiprocessing import shared_memory
330-
>>> nul_bug_demo = shared_memory.ShareableList(['?\x00', b'\x03\x02\x01\x00\x00\x00'])
331-
>>> nul_bug_demo[0]
332-
'?'
333-
>>> nul_bug_demo[1]
334-
b'\x03\x02\x01'
335-
>>> nul_bug_demo.shm.unlink()
336-
>>> padded = shared_memory.ShareableList(['?\x00\x07', b'\x03\x02\x01\x00\x00\x00\x07'])
337-
>>> padded[0][:-1]
338-
'?\x00'
339-
>>> padded[1][:-1]
340-
b'\x03\x02\x01\x00\x00\x00'
341-
>>> padded.shm.unlink()
314+
.. versionchanged:: next
315+
Trailing null bytes are preserved for :class:`bytes` and :class:`str`
316+
values now. Previously they were stripped silently. See :gh:`106939`.
342317

343318
.. method:: count(value)
344319

Lib/multiprocessing/shared_memory.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -286,8 +286,8 @@ class ShareableList:
286286
_alignment = 8
287287
_back_transforms_mapping = {
288288
0: lambda value: value, # int, float, bool
289-
1: lambda value: value.rstrip(b'\x00').decode(_encoding), # str
290-
2: lambda value: value.rstrip(b'\x00'), # bytes
289+
1: lambda value: value.decode(_encoding), # str
290+
2: lambda value: value, # bytes
291291
3: lambda _value: None, # None
292292
}
293293

@@ -326,6 +326,15 @@ def __init__(self, sequence=None, *, name=None):
326326
for fmt in _formats:
327327
offset += self._alignment if fmt[-1] != "s" else int(fmt[:-1])
328328
self._allocated_offsets.append(offset)
329+
_stored_formats = []
330+
for item, fmt in zip(sequence, _formats):
331+
if isinstance(item, (str, bytes)):
332+
encoded = (item.encode(_encoding)
333+
if isinstance(item, str) else item)
334+
_stored_formats.append("%ds" % len(encoded))
335+
else:
336+
_stored_formats.append(fmt)
337+
329338
_recreation_codes = [
330339
self._extract_recreation_code(item) for item in sequence
331340
]
@@ -359,7 +368,7 @@ def __init__(self, sequence=None, *, name=None):
359368
self._format_packing_metainfo,
360369
self.shm.buf,
361370
self._offset_packing_formats,
362-
*(v.encode(_enc) for v in _formats)
371+
*(v.encode(_enc) for v in _stored_formats)
363372
)
364373
struct.pack_into(
365374
self._format_back_transform_codes,
@@ -459,6 +468,7 @@ def __setitem__(self, position, value):
459468

460469
if not isinstance(value, (str, bytes)):
461470
new_format = self._types_mapping[type(value)]
471+
pack_format = new_format
462472
encoded_value = value
463473
else:
464474
allocated_length = self._allocated_offsets[position + 1] - item_offset
@@ -467,19 +477,17 @@ def __setitem__(self, position, value):
467477
if isinstance(value, str) else value)
468478
if len(encoded_value) > allocated_length:
469479
raise ValueError("bytes/str item exceeds available storage")
470-
if current_format[-1] == "s":
471-
new_format = current_format
472-
else:
473-
new_format = self._types_mapping[str] % (
474-
allocated_length,
475-
)
480+
# Allocated-length format.
481+
pack_format = "%ds" % allocated_length
482+
# Actual-length format.
483+
new_format = "%ds" % len(encoded_value)
476484

477485
self._set_packing_format_and_transform(
478486
position,
479487
new_format,
480488
value
481489
)
482-
struct.pack_into(new_format, self.shm.buf, offset, encoded_value)
490+
struct.pack_into(pack_format, self.shm.buf, offset, encoded_value)
483491

484492
def __reduce__(self):
485493
return partial(self.__class__, name=self.shm.name), ()

Lib/test/_test_multiprocessing.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4757,7 +4757,7 @@ def test_shared_memory_ShareableList_basics(self):
47574757
self.assertEqual(current_format, sl._get_packing_format(0))
47584758

47594759
# Verify attributes are readable.
4760-
self.assertEqual(sl.format, '8s8sdqxxxxxx?xxxxxxxx?q')
4760+
self.assertEqual(sl.format, '5s5sdqxxxxxx?xxxxxxxx?q')
47614761

47624762
# Exercise len().
47634763
self.assertEqual(len(sl), 7)
@@ -4785,7 +4785,7 @@ def test_shared_memory_ShareableList_basics(self):
47854785
self.assertEqual(sl[3], 42)
47864786
sl[4] = 'some' # Change type at a given position.
47874787
self.assertEqual(sl[4], 'some')
4788-
self.assertEqual(sl.format, '8s8sdq8sxxxxxxx?q')
4788+
self.assertEqual(sl.format, '5s5sdq4sxxxxxxx?q')
47894789
with self.assertRaisesRegex(ValueError,
47904790
"exceeds available storage"):
47914791
sl[4] = 'far too many'
@@ -4887,6 +4887,34 @@ def test_shared_memory_ShareableList_pickling_dead_object(self):
48874887
with self.assertRaises(FileNotFoundError):
48884888
pickle.loads(serialized_sl)
48894889

4890+
def test_shared_memory_ShareableList_trailing_nulls(self):
4891+
# gh-106939: ShareableList should preserve trailing null bytes
4892+
# in bytes and str values.
4893+
sl = shared_memory.ShareableList([
4894+
b'\x03\x02\x01\x00\x00\x00',
4895+
'?\x00',
4896+
b'\x00\x00\x00',
4897+
b'',
4898+
b'no nulls',
4899+
])
4900+
self.addCleanup(sl.shm.unlink)
4901+
4902+
self.assertEqual(sl[0], b'\x03\x02\x01\x00\x00\x00')
4903+
self.assertEqual(sl[1], '?\x00')
4904+
self.assertEqual(sl[2], b'\x00\x00\x00')
4905+
self.assertEqual(sl[3], b'')
4906+
self.assertEqual(sl[4], b'no nulls')
4907+
4908+
sl2 = shared_memory.ShareableList(name=sl.shm.name)
4909+
self.assertEqual(sl2[0], b'\x03\x02\x01\x00\x00\x00')
4910+
self.assertEqual(sl2[1], '?\x00')
4911+
self.assertEqual(sl2[2], b'\x00\x00\x00')
4912+
self.assertEqual(sl2[3], b'')
4913+
self.assertEqual(sl2[4], b'no nulls')
4914+
sl2.shm.close()
4915+
4916+
sl.shm.close()
4917+
48904918
def test_shared_memory_cleaned_after_process_termination(self):
48914919
cmd = '''if 1:
48924920
import os, time, sys
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
:class:`~multiprocessing.shared_memory.ShareableList` keeps the trailing
2+
null bytes for :class:`bytes` and :class:`str` values. Previously they were
3+
stripped silently.

0 commit comments

Comments
 (0)