diff --git a/src/openai/lib/streaming/_assistants.py b/src/openai/lib/streaming/_assistants.py index 6efb3ca3f1..0a9ae366ab 100644 --- a/src/openai/lib/streaming/_assistants.py +++ b/src/openai/lib/streaming/_assistants.py @@ -977,15 +977,70 @@ def accumulate_event( return current_message_snapshot, new_content +def _find_list_entry(acc_value: list[object], index: int) -> int | None: + for acc_index, acc_entry in enumerate(acc_value): + if is_dict(acc_entry) and acc_entry.get("index") == index: + return acc_index + + return None + + +def _has_indexed_entries(delta_value: list[object]) -> bool: + return any(is_dict(delta_entry) and "index" in delta_entry for delta_entry in delta_value) + + +def _accumulate_list(acc_value: list[object], delta_value: list[object]) -> list[object]: + # for lists of non-dictionary items we'll only ever get new entries + # in the array, existing entries will never be changed + if all(isinstance(x, (str, int, float)) for x in acc_value) and all( + isinstance(x, (str, int, float)) for x in delta_value + ): + acc_value.extend(delta_value) + return acc_value + + for delta_entry in delta_value: + if not is_dict(delta_entry): + raise TypeError(f"Unexpected list delta entry is not a dictionary: {delta_entry}") + + try: + index = delta_entry["index"] + except KeyError as exc: + raise RuntimeError(f"Expected list delta entry to have an `index` key; {delta_entry}") from exc + + if not isinstance(index, int): + raise TypeError(f"Unexpected, list delta entry `index` value is not an integer; {index}") + + acc_index = _find_list_entry(acc_value, index) + if acc_index is None: + acc_value.insert(min(index, len(acc_value)), delta_entry) + continue + + acc_entry = acc_value[acc_index] + if not is_dict(acc_entry): + raise TypeError("not handled yet") + + acc_value[acc_index] = accumulate_delta(acc_entry, delta_entry) + + return acc_value + + def accumulate_delta(acc: dict[object, object], delta: dict[object, object]) -> dict[object, object]: for key, delta_value in delta.items(): if key not in acc: - acc[key] = delta_value + acc[key] = ( + _accumulate_list([], delta_value) + if is_list(delta_value) and _has_indexed_entries(delta_value) + else delta_value + ) continue acc_value = acc[key] if acc_value is None: - acc[key] = delta_value + acc[key] = ( + _accumulate_list([], delta_value) + if is_list(delta_value) and _has_indexed_entries(delta_value) + else delta_value + ) continue # the `index` property is used in arrays of objects so it should @@ -1005,33 +1060,7 @@ def accumulate_delta(acc: dict[object, object], delta: dict[object, object]) -> elif is_dict(acc_value) and is_dict(delta_value): acc_value = accumulate_delta(acc_value, delta_value) elif is_list(acc_value) and is_list(delta_value): - # for lists of non-dictionary items we'll only ever get new entries - # in the array, existing entries will never be changed - if all(isinstance(x, (str, int, float)) for x in acc_value): - acc_value.extend(delta_value) - continue - - for delta_entry in delta_value: - if not is_dict(delta_entry): - raise TypeError(f"Unexpected list delta entry is not a dictionary: {delta_entry}") - - try: - index = delta_entry["index"] - except KeyError as exc: - raise RuntimeError(f"Expected list delta entry to have an `index` key; {delta_entry}") from exc - - if not isinstance(index, int): - raise TypeError(f"Unexpected, list delta entry `index` value is not an integer; {index}") - - try: - acc_entry = acc_value[index] - except IndexError: - acc_value.insert(index, delta_entry) - else: - if not is_dict(acc_entry): - raise TypeError("not handled yet") - - acc_value[index] = accumulate_delta(acc_entry, delta_entry) + acc_value = _accumulate_list(acc_value, delta_value) acc[key] = acc_value diff --git a/src/openai/lib/streaming/_deltas.py b/src/openai/lib/streaming/_deltas.py index a5e1317612..de9ccfe252 100644 --- a/src/openai/lib/streaming/_deltas.py +++ b/src/openai/lib/streaming/_deltas.py @@ -3,15 +3,70 @@ from ..._utils import is_dict, is_list +def _find_list_entry(acc_value: list[object], index: int) -> int | None: + for acc_index, acc_entry in enumerate(acc_value): + if is_dict(acc_entry) and acc_entry.get("index") == index: + return acc_index + + return None + + +def _has_indexed_entries(delta_value: list[object]) -> bool: + return any(is_dict(delta_entry) and "index" in delta_entry for delta_entry in delta_value) + + +def _accumulate_list(acc_value: list[object], delta_value: list[object]) -> list[object]: + # for lists of non-dictionary items we'll only ever get new entries + # in the array, existing entries will never be changed + if all(isinstance(x, (str, int, float)) for x in acc_value) and all( + isinstance(x, (str, int, float)) for x in delta_value + ): + acc_value.extend(delta_value) + return acc_value + + for delta_entry in delta_value: + if not is_dict(delta_entry): + raise TypeError(f"Unexpected list delta entry is not a dictionary: {delta_entry}") + + try: + index = delta_entry["index"] + except KeyError as exc: + raise RuntimeError(f"Expected list delta entry to have an `index` key; {delta_entry}") from exc + + if not isinstance(index, int): + raise TypeError(f"Unexpected, list delta entry `index` value is not an integer; {index}") + + acc_index = _find_list_entry(acc_value, index) + if acc_index is None: + acc_value.insert(min(index, len(acc_value)), delta_entry) + continue + + acc_entry = acc_value[acc_index] + if not is_dict(acc_entry): + raise TypeError("not handled yet") + + acc_value[acc_index] = accumulate_delta(acc_entry, delta_entry) + + return acc_value + + def accumulate_delta(acc: dict[object, object], delta: dict[object, object]) -> dict[object, object]: for key, delta_value in delta.items(): if key not in acc: - acc[key] = delta_value + acc[key] = ( + _accumulate_list([], delta_value) + if is_list(delta_value) and _has_indexed_entries(delta_value) + else delta_value + ) continue acc_value = acc[key] if acc_value is None: - acc[key] = delta_value + acc[key] = ( + _accumulate_list([], delta_value) + if is_list(delta_value) and _has_indexed_entries(delta_value) + else delta_value + ) continue # the `index` property is used in arrays of objects so it should @@ -31,33 +86,7 @@ def accumulate_delta(acc: dict[object, object], delta: dict[object, object]) -> elif is_dict(acc_value) and is_dict(delta_value): acc_value = accumulate_delta(acc_value, delta_value) elif is_list(acc_value) and is_list(delta_value): - # for lists of non-dictionary items we'll only ever get new entries - # in the array, existing entries will never be changed - if all(isinstance(x, (str, int, float)) for x in acc_value): - acc_value.extend(delta_value) - continue - - for delta_entry in delta_value: - if not is_dict(delta_entry): - raise TypeError(f"Unexpected list delta entry is not a dictionary: {delta_entry}") - - try: - index = delta_entry["index"] - except KeyError as exc: - raise RuntimeError(f"Expected list delta entry to have an `index` key; {delta_entry}") from exc - - if not isinstance(index, int): - raise TypeError(f"Unexpected, list delta entry `index` value is not an integer; {index}") - - try: - acc_entry = acc_value[index] - except IndexError: - acc_value.insert(index, delta_entry) - else: - if not is_dict(acc_entry): - raise TypeError("not handled yet") - - acc_value[index] = accumulate_delta(acc_entry, delta_entry) + acc_value = _accumulate_list(acc_value, delta_value) acc[key] = acc_value diff --git a/tests/lib/test_streaming_deltas.py b/tests/lib/test_streaming_deltas.py new file mode 100644 index 0000000000..543edcf4a0 --- /dev/null +++ b/tests/lib/test_streaming_deltas.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +import json +from typing import Any, cast +from collections.abc import Callable + +import pytest + +from openai.lib.streaming._deltas import accumulate_delta as accumulate_chat_delta +from openai.lib.streaming._assistants import accumulate_delta as accumulate_assistant_delta + + +@pytest.mark.parametrize("accumulate", [accumulate_chat_delta, accumulate_assistant_delta]) +def test_accumulate_delta_merges_duplicate_index_entries_in_initial_list( + accumulate: Callable[[dict[object, object], dict[object, object]], dict[object, object]], +) -> None: + acc: dict[object, object] = {} + + accumulate( + acc, + { + "tool_calls": [ + { + "index": 0, + "id": "functions.list_files:0", + "function": {"name": "list_files"}, + "type": "function", + }, + {"index": 0, "function": {"arguments": ' {"path"'}}, + ], + }, + ) + accumulate( + acc, + { + "tool_calls": [ + {"index": 0, "function": {"arguments": ': "."}'}}, + ], + }, + ) + + tool_calls = acc["tool_calls"] + assert isinstance(tool_calls, list) + assert tool_calls == [ + { + "index": 0, + "id": "functions.list_files:0", + "function": {"name": "list_files", "arguments": ' {"path": "."}'}, + "type": "function", + } + ] + tool_call = cast(dict[str, Any], tool_calls[0]) + function = cast(dict[str, str], tool_call["function"]) + assert json.loads(function["arguments"]) == {"path": "."}