From e69bf216fea46b4e2aadf7b09e3db775690cb6c2 Mon Sep 17 00:00:00 2001 From: Yufeng He <40085740+he-yufeng@users.noreply.github.com> Date: Mon, 8 Jun 2026 00:24:08 +0800 Subject: [PATCH] fix: merge streaming tool calls by logical index --- src/openai/lib/streaming/_assistants.py | 74 +++++++++++------ src/openai/lib/streaming/_deltas.py | 80 ++++++++++++------- src/openai/lib/streaming/chat/_completions.py | 2 +- tests/lib/chat/test_completions_streaming.py | 75 +++++++++++++++++ tests/lib/test_streaming_deltas.py | 51 ++++++++++++ 5 files changed, 226 insertions(+), 56 deletions(-) create mode 100644 tests/lib/test_streaming_deltas.py diff --git a/src/openai/lib/streaming/_assistants.py b/src/openai/lib/streaming/_assistants.py index 6efb3ca3f1..ed15fec8ab 100644 --- a/src/openai/lib/streaming/_assistants.py +++ b/src/openai/lib/streaming/_assistants.py @@ -980,12 +980,18 @@ def accumulate_event( def accumulate_delta(acc: dict[object, object], delta: dict[object, object]) -> dict[object, object]: for key, delta_value in delta.items(): if key not in acc: - acc[key] = delta_value + if is_list(delta_value): + acc[key] = _accumulate_list_delta([], delta_value) + else: + acc[key] = delta_value continue acc_value = acc[key] if acc_value is None: - acc[key] = delta_value + if is_list(delta_value): + acc[key] = _accumulate_list_delta([], delta_value) + else: + acc[key] = delta_value continue # the `index` property is used in arrays of objects so it should @@ -1005,34 +1011,50 @@ def accumulate_delta(acc: dict[object, object], delta: dict[object, object]) -> elif is_dict(acc_value) and is_dict(delta_value): acc_value = accumulate_delta(acc_value, delta_value) elif is_list(acc_value) and is_list(delta_value): - # for lists of non-dictionary items we'll only ever get new entries - # in the array, existing entries will never be changed - if all(isinstance(x, (str, int, float)) for x in acc_value): - acc_value.extend(delta_value) - continue + acc_value = _accumulate_list_delta(acc_value, delta_value) - for delta_entry in delta_value: - if not is_dict(delta_entry): - raise TypeError(f"Unexpected list delta entry is not a dictionary: {delta_entry}") + acc[key] = acc_value - try: - index = delta_entry["index"] - except KeyError as exc: - raise RuntimeError(f"Expected list delta entry to have an `index` key; {delta_entry}") from exc + return acc - if not isinstance(index, int): - raise TypeError(f"Unexpected, list delta entry `index` value is not an integer; {index}") - try: - acc_entry = acc_value[index] - except IndexError: - acc_value.insert(index, delta_entry) - else: - if not is_dict(acc_entry): - raise TypeError("not handled yet") +def _accumulate_list_delta(acc_value: list[object], delta_value: list[object]) -> list[object]: + if all(isinstance(x, (str, int, float)) for x in [*acc_value, *delta_value]): + acc_value.extend(delta_value) + return acc_value - acc_value[index] = accumulate_delta(acc_entry, delta_entry) + for delta_entry in delta_value: + if not is_dict(delta_entry): + raise TypeError(f"Unexpected list delta entry is not a dictionary: {delta_entry}") - acc[key] = acc_value + try: + index = delta_entry["index"] + except KeyError as exc: + raise RuntimeError(f"Expected list delta entry to have an `index` key; {delta_entry}") from exc - return acc + if not isinstance(index, int): + raise TypeError(f"Unexpected, list delta entry `index` value is not an integer; {index}") + + for acc_index, acc_entry in enumerate(acc_value): + if not is_dict(acc_entry): + raise TypeError("not handled yet") + + if acc_entry.get("index") == index: + acc_value[acc_index] = accumulate_delta(acc_entry, delta_entry) + break + else: + acc_value.append(delta_entry) + + acc_value.sort(key=_list_delta_sort_key) + return acc_value + + +def _list_delta_sort_key(entry: object) -> int: + if not is_dict(entry): + return 0 + + index = entry.get("index") + if not isinstance(index, int): + return 0 + + return index diff --git a/src/openai/lib/streaming/_deltas.py b/src/openai/lib/streaming/_deltas.py index a5e1317612..ccd0c72963 100644 --- a/src/openai/lib/streaming/_deltas.py +++ b/src/openai/lib/streaming/_deltas.py @@ -3,15 +3,63 @@ from ..._utils import is_dict, is_list +def _accumulate_list_delta(acc_value: list[object], delta_value: list[object]) -> list[object]: + if all(isinstance(x, (str, int, float)) for x in [*acc_value, *delta_value]): + acc_value.extend(delta_value) + return acc_value + + for delta_entry in delta_value: + if not is_dict(delta_entry): + raise TypeError(f"Unexpected list delta entry is not a dictionary: {delta_entry}") + + try: + index = delta_entry["index"] + except KeyError as exc: + raise RuntimeError(f"Expected list delta entry to have an `index` key; {delta_entry}") from exc + + if not isinstance(index, int): + raise TypeError(f"Unexpected, list delta entry `index` value is not an integer; {index}") + + for acc_index, acc_entry in enumerate(acc_value): + if not is_dict(acc_entry): + raise TypeError("not handled yet") + + if acc_entry.get("index") == index: + acc_value[acc_index] = accumulate_delta(acc_entry, delta_entry) + break + else: + acc_value.append(delta_entry) + + acc_value.sort(key=_list_delta_sort_key) + return acc_value + + +def _list_delta_sort_key(entry: object) -> int: + if not is_dict(entry): + return 0 + + index = entry.get("index") + if not isinstance(index, int): + return 0 + + return index + + def accumulate_delta(acc: dict[object, object], delta: dict[object, object]) -> dict[object, object]: for key, delta_value in delta.items(): if key not in acc: - acc[key] = delta_value + if is_list(delta_value): + acc[key] = _accumulate_list_delta([], delta_value) + else: + acc[key] = delta_value continue acc_value = acc[key] if acc_value is None: - acc[key] = delta_value + if is_list(delta_value): + acc[key] = _accumulate_list_delta([], delta_value) + else: + acc[key] = delta_value continue # the `index` property is used in arrays of objects so it should @@ -31,33 +79,7 @@ def accumulate_delta(acc: dict[object, object], delta: dict[object, object]) -> elif is_dict(acc_value) and is_dict(delta_value): acc_value = accumulate_delta(acc_value, delta_value) elif is_list(acc_value) and is_list(delta_value): - # for lists of non-dictionary items we'll only ever get new entries - # in the array, existing entries will never be changed - if all(isinstance(x, (str, int, float)) for x in acc_value): - acc_value.extend(delta_value) - continue - - for delta_entry in delta_value: - if not is_dict(delta_entry): - raise TypeError(f"Unexpected list delta entry is not a dictionary: {delta_entry}") - - try: - index = delta_entry["index"] - except KeyError as exc: - raise RuntimeError(f"Expected list delta entry to have an `index` key; {delta_entry}") from exc - - if not isinstance(index, int): - raise TypeError(f"Unexpected, list delta entry `index` value is not an integer; {index}") - - try: - acc_entry = acc_value[index] - except IndexError: - acc_value.insert(index, delta_entry) - else: - if not is_dict(acc_entry): - raise TypeError("not handled yet") - - acc_value[index] = accumulate_delta(acc_entry, delta_entry) + acc_value = _accumulate_list_delta(acc_value, delta_value) acc[key] = acc_value diff --git a/src/openai/lib/streaming/chat/_completions.py b/src/openai/lib/streaming/chat/_completions.py index 5f072cafbd..27e02e7753 100644 --- a/src/openai/lib/streaming/chat/_completions.py +++ b/src/openai/lib/streaming/chat/_completions.py @@ -744,7 +744,7 @@ def _convert_initial_chunk_into_snapshot(chunk: ChatCompletionChunk) -> ParsedCh for choice in chunk.choices: choices[choice.index] = { **choice.model_dump(exclude_unset=True, exclude={"delta"}), - "message": choice.delta.to_dict(), + "message": accumulate_delta({}, cast("dict[object, object]", choice.delta.to_dict())), } return cast( diff --git a/tests/lib/chat/test_completions_streaming.py b/tests/lib/chat/test_completions_streaming.py index 598a41ee2b..5c3ebacbef 100644 --- a/tests/lib/chat/test_completions_streaming.py +++ b/tests/lib/chat/test_completions_streaming.py @@ -20,6 +20,7 @@ from openai import OpenAI, AsyncOpenAI from openai._utils import consume_sync_iterator, assert_signatures_in_sync from openai._compat import model_copy +from openai._models import construct_type from openai.types.chat import ChatCompletionChunk from openai.lib.streaming.chat import ( ContentDoneEvent, @@ -834,6 +835,80 @@ class GetStockPrice(BaseModel): ) +def test_streaming_tool_calls_merge_duplicate_indexes_in_initial_chunk() -> None: + state = ChatCompletionStreamState() + + first_chunk = cast( + ChatCompletionChunk, + construct_type( + type_=ChatCompletionChunk, + value={ + "id": "chatcmpl-duplicate-tool-index", + "object": "chat.completion.chunk", + "created": 0, + "model": "gpt-4o-mini", + "choices": [ + { + "index": 0, + "delta": { + "role": "assistant", + "tool_calls": [ + { + "index": 0, + "id": "call_abc", + "type": "function", + "function": {"name": "list_files"}, + }, + { + "index": 0, + "function": {"arguments": ' {"'}, + }, + ], + }, + "finish_reason": None, + } + ], + }, + ), + ) + second_chunk = cast( + ChatCompletionChunk, + construct_type( + type_=ChatCompletionChunk, + value={ + "id": "chatcmpl-duplicate-tool-index", + "object": "chat.completion.chunk", + "created": 0, + "model": "gpt-4o-mini", + "choices": [ + { + "index": 0, + "delta": { + "tool_calls": [ + { + "index": 0, + "function": {"arguments": 'path": "."}'}, + }, + ], + }, + "finish_reason": "tool_calls", + } + ], + }, + ), + ) + + state.handle_chunk(first_chunk) + state.handle_chunk(second_chunk) + + tool_calls = state.get_final_completion().choices[0].message.tool_calls + + assert tool_calls is not None + assert len(tool_calls) == 1 + assert tool_calls[0].function.name == "list_files" + assert tool_calls[0].function.arguments == ' {"path": "."}' + + @pytest.mark.respx(base_url=base_url) def test_parse_strict_tools(client: OpenAI, respx_mock: MockRouter, monkeypatch: pytest.MonkeyPatch) -> None: listener = _make_stream_snapshot_request( diff --git a/tests/lib/test_streaming_deltas.py b/tests/lib/test_streaming_deltas.py new file mode 100644 index 0000000000..96caef5fcb --- /dev/null +++ b/tests/lib/test_streaming_deltas.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +from openai.lib.streaming import _deltas, _assistants + + +def test_accumulate_delta_merges_duplicate_indexes_on_first_chunk() -> None: + acc: dict[object, object] = {} + + _deltas.accumulate_delta( + acc, + { + "tool_calls": [ + {"index": 0, "id": "call_abc", "type": "function", "function": {"name": "list_files"}}, + {"index": 0, "function": {"arguments": ' {"'}}, + ] + }, + ) + _deltas.accumulate_delta(acc, {"tool_calls": [{"index": 0, "function": {"arguments": 'path": "."}'}}]}) + + assert acc["tool_calls"] == [ + { + "index": 0, + "id": "call_abc", + "type": "function", + "function": {"name": "list_files", "arguments": ' {"path": "."}'}, + } + ] + + +def test_assistants_accumulate_delta_merges_duplicate_indexes_on_first_chunk() -> None: + acc: dict[object, object] = {} + + _assistants.accumulate_delta( + acc, + { + "tool_calls": [ + {"index": 0, "id": "call_abc", "type": "function", "function": {"name": "list_files"}}, + {"index": 0, "function": {"arguments": ' {"'}}, + ] + }, + ) + _assistants.accumulate_delta(acc, {"tool_calls": [{"index": 0, "function": {"arguments": 'path": "."}'}}]}) + + assert acc["tool_calls"] == [ + { + "index": 0, + "id": "call_abc", + "type": "function", + "function": {"name": "list_files", "arguments": ' {"path": "."}'}, + } + ]