diff --git a/libs/core/langchain_core/messages/base.py b/libs/core/langchain_core/messages/base.py index daed44f5014..96c7665ec7b 100644 --- a/libs/core/langchain_core/messages/base.py +++ b/libs/core/langchain_core/messages/base.py @@ -98,8 +98,12 @@ class BaseMessageChunk(BaseMessage): merged[k] = v elif merged[k] is None and v: merged[k] = v + elif v is None: + continue + elif merged[k] == v: + continue elif type(merged[k]) != type(v): - raise ValueError( + raise TypeError( f'additional_kwargs["{k}"] already exists in this message,' " but with a different type." ) @@ -107,8 +111,17 @@ class BaseMessageChunk(BaseMessage): merged[k] += v elif isinstance(merged[k], dict): merged[k] = self._merge_kwargs_dict(merged[k], v) + elif isinstance(merged[k], list): + merged[k] = merged[k].copy() + for i, e in enumerate(v): + if isinstance(e, dict) and isinstance(e.get("index"), int): + i = e["index"] + if i < len(merged[k]): + merged[k][i] = self._merge_kwargs_dict(merged[k][i], e) + else: + merged[k] = merged[k] + [e] else: - raise ValueError( + raise TypeError( f"Additional kwargs key {k} already exists in this message." ) return merged diff --git a/libs/core/tests/unit_tests/test_messages.py b/libs/core/tests/unit_tests/test_messages.py index bcb8bc88ce4..95d60a52f2b 100644 --- a/libs/core/tests/unit_tests/test_messages.py +++ b/libs/core/tests/unit_tests/test_messages.py @@ -1,4 +1,5 @@ import unittest +from typing import List import pytest @@ -203,3 +204,227 @@ def test_message_chunk_to_message() -> None: assert message_chunk_to_message( FunctionMessageChunk(name="hello", content="I am") ) == FunctionMessage(name="hello", content="I am") + + +def test_tool_calls_merge() -> None: + chunks: List[dict] = [ + dict(content=""), + dict( + content="", + additional_kwargs={ + "tool_calls": [ + { + "index": 0, + "id": "call_CwGAsESnXehQEjiAIWzinlva", + "function": {"arguments": "", "name": "person"}, + "type": "function", + } + ] + }, + ), + dict( + content="", + additional_kwargs={ + "tool_calls": [ + { + "index": 0, + "id": None, + "function": {"arguments": '{"na', "name": None}, + "type": None, + } + ] + }, + ), + dict( + content="", + additional_kwargs={ + "tool_calls": [ + { + "index": 0, + "id": None, + "function": {"arguments": 'me": ', "name": None}, + "type": None, + } + ] + }, + ), + dict( + content="", + additional_kwargs={ + "tool_calls": [ + { + "index": 0, + "id": None, + "function": {"arguments": '"jane"', "name": None}, + "type": None, + } + ] + }, + ), + dict( + content="", + additional_kwargs={ + "tool_calls": [ + { + "index": 0, + "id": None, + "function": {"arguments": ', "a', "name": None}, + "type": None, + } + ] + }, + ), + dict( + content="", + additional_kwargs={ + "tool_calls": [ + { + "index": 0, + "id": None, + "function": {"arguments": 'ge": ', "name": None}, + "type": None, + } + ] + }, + ), + dict( + content="", + additional_kwargs={ + "tool_calls": [ + { + "index": 0, + "id": None, + "function": {"arguments": "2}", "name": None}, + "type": None, + } + ] + }, + ), + dict( + content="", + additional_kwargs={ + "tool_calls": [ + { + "index": 1, + "id": "call_zXSIylHvc5x3JUAPcHZR5GZI", + "function": {"arguments": "", "name": "person"}, + "type": "function", + } + ] + }, + ), + dict( + content="", + additional_kwargs={ + "tool_calls": [ + { + "index": 1, + "id": None, + "function": {"arguments": '{"na', "name": None}, + "type": None, + } + ] + }, + ), + dict( + content="", + additional_kwargs={ + "tool_calls": [ + { + "index": 1, + "id": None, + "function": {"arguments": 'me": ', "name": None}, + "type": None, + } + ] + }, + ), + dict( + content="", + additional_kwargs={ + "tool_calls": [ + { + "index": 1, + "id": None, + "function": {"arguments": '"bob",', "name": None}, + "type": None, + } + ] + }, + ), + dict( + content="", + additional_kwargs={ + "tool_calls": [ + { + "index": 1, + "id": None, + "function": {"arguments": ' "ag', "name": None}, + "type": None, + } + ] + }, + ), + dict( + content="", + additional_kwargs={ + "tool_calls": [ + { + "index": 1, + "id": None, + "function": {"arguments": 'e": 3', "name": None}, + "type": None, + } + ] + }, + ), + dict( + content="", + additional_kwargs={ + "tool_calls": [ + { + "index": 1, + "id": None, + "function": {"arguments": "}", "name": None}, + "type": None, + } + ] + }, + ), + dict(content=""), + ] + + final = None + + for chunk in chunks: + msg = AIMessageChunk(**chunk) + if final is None: + final = msg + else: + final = final + msg + + assert final == AIMessageChunk( + content="", + additional_kwargs={ + "tool_calls": [ + { + "index": 0, + "id": "call_CwGAsESnXehQEjiAIWzinlva", + "function": { + "arguments": '{"name": "jane", "age": 2}', + "name": "person", + }, + "type": "function", + }, + { + "index": 1, + "id": "call_zXSIylHvc5x3JUAPcHZR5GZI", + "function": { + "arguments": '{"name": "bob", "age": 3}', + "name": "person", + }, + "type": "function", + }, + ] + }, + )