Skip to content

Commit 2a11d4b

Browse files
gpsheadclaude
andcommitted
Refactor run_pipeline() to use multiplexed I/O
Add _communicate_streams() helper function that properly multiplexes read/write operations to prevent pipe buffer deadlocks. The helper uses selectors on POSIX and threads on Windows, similar to Popen.communicate(). This fixes potential deadlocks when large amounts of data flow through the pipeline and significantly improves performance. Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 4feb2a8 commit 2a11d4b

File tree

1 file changed

+269
-49
lines changed

1 file changed

+269
-49
lines changed

Lib/subprocess.py

Lines changed: 269 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,220 @@ def _cleanup():
320320
DEVNULL = -3
321321

322322

323+
# Helper function for multiplexed I/O, used by run_pipeline()
324+
def _remaining_time_helper(endtime):
325+
"""Calculate remaining time until deadline."""
326+
if endtime is None:
327+
return None
328+
return endtime - _time()
329+
330+
331+
def _communicate_streams(stdin=None, input_data=None, read_streams=None,
332+
timeout=None, cmd_for_timeout=None):
333+
"""
334+
Multiplex I/O: write input_data to stdin, read from read_streams.
335+
336+
Works with both file objects and raw file descriptors.
337+
All I/O is done in binary mode; caller handles text encoding.
338+
339+
Args:
340+
stdin: Writable file object for input, or None
341+
input_data: Bytes to write to stdin, or None
342+
read_streams: List of readable file objects or raw fds to read from
343+
timeout: Timeout in seconds, or None for no timeout
344+
cmd_for_timeout: Value to use for TimeoutExpired.cmd
345+
346+
Returns:
347+
Dict mapping each item in read_streams to its bytes data
348+
349+
Raises:
350+
TimeoutExpired: If timeout expires (with partial data)
351+
"""
352+
if timeout is not None:
353+
endtime = _time() + timeout
354+
else:
355+
endtime = None
356+
357+
read_streams = read_streams or []
358+
359+
if _mswindows:
360+
return _communicate_streams_windows(
361+
stdin, input_data, read_streams, endtime, timeout, cmd_for_timeout)
362+
else:
363+
return _communicate_streams_posix(
364+
stdin, input_data, read_streams, endtime, timeout, cmd_for_timeout)
365+
366+
367+
if _mswindows:
368+
def _reader_thread_func(fh, buffer):
369+
"""Thread function to read from a file handle into a buffer list."""
370+
try:
371+
buffer.append(fh.read())
372+
except OSError:
373+
buffer.append(b'')
374+
375+
def _communicate_streams_windows(stdin, input_data, read_streams,
376+
endtime, orig_timeout, cmd_for_timeout):
377+
"""Windows implementation using threads."""
378+
threads = []
379+
buffers = {}
380+
fds_to_close = []
381+
382+
# Start reader threads
383+
for stream in read_streams:
384+
buf = []
385+
buffers[stream] = buf
386+
# Wrap raw fds in file objects
387+
if isinstance(stream, int):
388+
fobj = os.fdopen(os.dup(stream), 'rb')
389+
fds_to_close.append(stream)
390+
else:
391+
fobj = stream
392+
t = threading.Thread(target=_reader_thread_func, args=(fobj, buf))
393+
t.daemon = True
394+
t.start()
395+
threads.append((stream, t, fobj))
396+
397+
# Write stdin
398+
if stdin and input_data:
399+
try:
400+
stdin.write(input_data)
401+
except BrokenPipeError:
402+
pass
403+
except OSError as exc:
404+
if exc.errno != errno.EINVAL:
405+
raise
406+
if stdin:
407+
try:
408+
stdin.close()
409+
except BrokenPipeError:
410+
pass
411+
except OSError as exc:
412+
if exc.errno != errno.EINVAL:
413+
raise
414+
415+
# Join threads with timeout
416+
for stream, t, fobj in threads:
417+
remaining = _remaining_time_helper(endtime)
418+
if remaining is not None and remaining < 0:
419+
remaining = 0
420+
t.join(remaining)
421+
if t.is_alive():
422+
# Collect partial results
423+
results = {s: (b[0] if b else b'') for s, b in buffers.items()}
424+
raise TimeoutExpired(
425+
cmd_for_timeout, orig_timeout,
426+
output=results.get(read_streams[0]) if read_streams else None)
427+
428+
# Close any raw fds we duped
429+
for fd in fds_to_close:
430+
try:
431+
os.close(fd)
432+
except OSError:
433+
pass
434+
435+
# Collect results
436+
return {stream: (buf[0] if buf else b'') for stream, buf in buffers.items()}
437+
438+
else:
439+
def _communicate_streams_posix(stdin, input_data, read_streams,
440+
endtime, orig_timeout, cmd_for_timeout):
441+
"""POSIX implementation using selectors."""
442+
# Normalize read_streams: build mapping of fd -> (original_key, chunks)
443+
fd_info = {} # fd -> (original_stream, chunks_list)
444+
for stream in read_streams:
445+
if isinstance(stream, int):
446+
fd = stream
447+
else:
448+
fd = stream.fileno()
449+
fd_info[fd] = (stream, [])
450+
451+
# Prepare stdin
452+
stdin_fd = None
453+
if stdin:
454+
try:
455+
stdin.flush()
456+
except BrokenPipeError:
457+
pass
458+
if input_data:
459+
stdin_fd = stdin.fileno()
460+
else:
461+
try:
462+
stdin.close()
463+
except BrokenPipeError:
464+
pass
465+
466+
# Prepare input data
467+
input_offset = 0
468+
input_view = memoryview(input_data) if input_data else None
469+
470+
with _PopenSelector() as selector:
471+
if stdin_fd is not None and input_data:
472+
selector.register(stdin_fd, selectors.EVENT_WRITE)
473+
for fd in fd_info:
474+
selector.register(fd, selectors.EVENT_READ)
475+
476+
while selector.get_map():
477+
remaining = _remaining_time_helper(endtime)
478+
if remaining is not None and remaining < 0:
479+
# Timed out - collect partial results
480+
results = {orig: b''.join(chunks)
481+
for fd, (orig, chunks) in fd_info.items()}
482+
raise TimeoutExpired(
483+
cmd_for_timeout, orig_timeout,
484+
output=results.get(read_streams[0]) if read_streams else None)
485+
486+
ready = selector.select(remaining)
487+
488+
# Check timeout after select
489+
if endtime is not None and _time() > endtime:
490+
results = {orig: b''.join(chunks)
491+
for fd, (orig, chunks) in fd_info.items()}
492+
raise TimeoutExpired(
493+
cmd_for_timeout, orig_timeout,
494+
output=results.get(read_streams[0]) if read_streams else None)
495+
496+
for key, events in ready:
497+
if key.fd == stdin_fd:
498+
# Write chunk to stdin
499+
chunk = input_view[input_offset:input_offset + _PIPE_BUF]
500+
try:
501+
input_offset += os.write(key.fd, chunk)
502+
except BrokenPipeError:
503+
selector.unregister(key.fd)
504+
try:
505+
stdin.close()
506+
except BrokenPipeError:
507+
pass
508+
else:
509+
if input_offset >= len(input_data):
510+
selector.unregister(key.fd)
511+
try:
512+
stdin.close()
513+
except BrokenPipeError:
514+
pass
515+
elif key.fd in fd_info:
516+
# Read chunk from output stream
517+
data = os.read(key.fd, 32768)
518+
if not data:
519+
selector.unregister(key.fd)
520+
else:
521+
fd_info[key.fd][1].append(data)
522+
523+
# Build results: map original stream keys to joined data
524+
results = {}
525+
for fd, (orig_stream, chunks) in fd_info.items():
526+
results[orig_stream] = b''.join(chunks)
527+
# Close file objects (but not raw fds - caller manages those)
528+
if not isinstance(orig_stream, int):
529+
try:
530+
orig_stream.close()
531+
except OSError:
532+
pass
533+
534+
return results
535+
536+
323537
# XXX This function is only used by multiprocessing and the test suite,
324538
# but it's here so that it can be imported when Python is compiled without
325539
# threads.
@@ -781,54 +995,70 @@ def run_pipeline(*commands, input=None, capture_output=False, timeout=None,
781995
first_proc = processes[0]
782996
last_proc = processes[-1]
783997

784-
# Handle communication with timeout
785-
start_time = _time() if timeout is not None else None
786-
787-
# Write input to first process if provided
788-
if input is not None and first_proc.stdin is not None:
789-
try:
790-
first_proc.stdin.write(input)
791-
except BrokenPipeError:
792-
pass # First process may have exited early
793-
finally:
794-
first_proc.stdin.close()
998+
# Calculate deadline for timeout (used throughout)
999+
if timeout is not None:
1000+
endtime = _time() + timeout
1001+
else:
1002+
endtime = None
7951003

7961004
# Determine if we're in text mode
7971005
text_mode = kwargs.get('text') or kwargs.get('encoding') or kwargs.get('errors')
1006+
encoding = kwargs.get('encoding')
1007+
errors_param = kwargs.get('errors', 'strict')
1008+
if text_mode and encoding is None:
1009+
encoding = locale.getencoding()
1010+
1011+
# Encode input if in text mode
1012+
input_data = input
1013+
if input_data is not None and text_mode:
1014+
input_data = input_data.encode(encoding, errors_param)
1015+
1016+
# Build list of streams to read from
1017+
read_streams = []
1018+
if last_proc.stdout is not None:
1019+
read_streams.append(last_proc.stdout)
1020+
if stderr_read_fd is not None:
1021+
read_streams.append(stderr_read_fd)
7981022

799-
# Read output from the last process
800-
stdout = None
801-
stderr = None
1023+
# Use multiplexed I/O to handle stdin/stdout/stderr concurrently
1024+
# This avoids deadlocks from pipe buffer limits
1025+
stdin_stream = first_proc.stdin if input is not None else None
8021026

803-
# Read stdout if we created a pipe for it (capture_output or stdout=PIPE)
804-
if last_proc.stdout is not None:
805-
stdout = last_proc.stdout.read()
1027+
try:
1028+
results = _communicate_streams(
1029+
stdin=stdin_stream,
1030+
input_data=input_data,
1031+
read_streams=read_streams,
1032+
timeout=_remaining_time_helper(endtime),
1033+
cmd_for_timeout=commands,
1034+
)
1035+
except TimeoutExpired:
1036+
# Kill all processes on timeout
1037+
for p in processes:
1038+
if p.poll() is None:
1039+
p.kill()
1040+
for p in processes:
1041+
p.wait()
1042+
raise
8061043

807-
# Read stderr from the shared pipe
808-
if stderr_read_fd is not None:
809-
stderr = os.read(stderr_read_fd, 1024 * 1024 * 10) # Up to 10MB
810-
# Keep reading until EOF
811-
while True:
812-
chunk = os.read(stderr_read_fd, 65536)
813-
if not chunk:
814-
break
815-
stderr += chunk
816-
817-
# Calculate remaining timeout
818-
def remaining_timeout():
819-
if timeout is None:
820-
return None
821-
elapsed = _time() - start_time
822-
remaining = timeout - elapsed
823-
if remaining <= 0:
824-
raise TimeoutExpired(commands, timeout, stdout, stderr)
825-
return remaining
1044+
# Extract results
1045+
stdout = results.get(last_proc.stdout)
1046+
stderr = results.get(stderr_read_fd)
8261047

827-
# Wait for all processes to complete
1048+
# Decode stdout if in text mode (Popen text mode only applies to
1049+
# streams it creates, but we read via _communicate_streams which
1050+
# always returns bytes)
1051+
if text_mode and stdout is not None:
1052+
stdout = stdout.decode(encoding, errors_param)
1053+
if text_mode and stderr is not None:
1054+
stderr = stderr.decode(encoding, errors_param)
1055+
1056+
# Wait for all processes to complete (use remaining time from deadline)
8281057
returncodes = []
8291058
for proc in processes:
8301059
try:
831-
proc.wait(timeout=remaining_timeout())
1060+
remaining = _remaining_time_helper(endtime)
1061+
proc.wait(timeout=remaining)
8321062
except TimeoutExpired:
8331063
# Kill all processes on timeout
8341064
for p in processes:
@@ -839,16 +1069,6 @@ def remaining_timeout():
8391069
raise TimeoutExpired(commands, timeout, stdout, stderr)
8401070
returncodes.append(proc.returncode)
8411071

842-
# Handle text mode conversion for stderr (stdout is already handled
843-
# by Popen when text=True). stderr is always read as bytes since
844-
# we use os.pipe() directly.
845-
if text_mode and stderr is not None:
846-
encoding = kwargs.get('encoding')
847-
errors = kwargs.get('errors', 'strict')
848-
if encoding is None:
849-
encoding = locale.getencoding()
850-
stderr = stderr.decode(encoding, errors)
851-
8521072
result = PipelineResult(commands, returncodes, stdout, stderr)
8531073

8541074
if check and any(rc != 0 for rc in returncodes):
@@ -867,7 +1087,7 @@ def remaining_timeout():
8671087
proc.stdin.close()
8681088
if proc.stdout and not proc.stdout.closed:
8691089
proc.stdout.close()
870-
# Close stderr pipe file descriptors
1090+
# Close stderr pipe file descriptor
8711091
if stderr_read_fd is not None:
8721092
try:
8731093
os.close(stderr_read_fd)

0 commit comments

Comments
 (0)