From 992884b410ad63d93cc3919d53cc211698980db5 Mon Sep 17 00:00:00 2001 From: Bagatur Date: Thu, 29 Aug 2024 17:18:41 -0700 Subject: [PATCH] fmt --- libs/core/langchain_core/messages/utils.py | 367 +++++++---- .../tests/unit_tests/messages/test_utils.py | 581 +++++++++++++++++- libs/langchain/langchain/chat_models/base.py | 4 +- 3 files changed, 825 insertions(+), 127 deletions(-) diff --git a/libs/core/langchain_core/messages/utils.py b/libs/core/langchain_core/messages/utils.py index c7fb622bd25..e39db43286b 100644 --- a/libs/core/langchain_core/messages/utils.py +++ b/libs/core/langchain_core/messages/utils.py @@ -20,6 +20,7 @@ from typing import ( Callable, Dict, Iterable, + Iterator, List, Literal, Optional, @@ -28,7 +29,7 @@ from typing import ( Type, Union, cast, - overload, Iterator, + overload, ) from langchain_core.messages.ai import AIMessage, AIMessageChunk @@ -41,7 +42,9 @@ from langchain_core.messages.system import SystemMessage, SystemMessageChunk from langchain_core.messages.tool import ToolMessage, ToolMessageChunk from langchain_core.messages.tool import ( tool_call as create_tool_call, - tool_call_chunk as create_tool_call_chunk +) +from langchain_core.messages.tool import ( + tool_call_chunk as create_tool_call_chunk, ) if TYPE_CHECKING: @@ -335,50 +338,6 @@ def convert_to_messages( return [_convert_to_message(m, copy=copy) for m in messages] -def _runnable_generator(func: Callable) -> Callable: - @overload - def wrapped( - messages: Literal[None] = None, **kwargs: Any - ) -> Runnable[ - Union[MessageLikeRepresentation, Sequence[MessageLikeRepresentation]], - Union[BaseMessage, List[BaseMessage]], - ]: ... - - @overload - def wrapped( - messages: Sequence[Union[BaseMessage, Dict, Tuple]], **kwargs: Any - ) -> List[BaseMessage]: ... - - @overload - def wrapped(messages: MessageLikeRepresentation, **kwargs: Any) -> BaseMessage: ... - - def wrapped( - messages: Union[ - MessageLikeRepresentation, Sequence[MessageLikeRepresentation], None - ] = None, - **kwargs: Any, - ) -> Union[ - BaseMessage, - List[BaseMessage], - Runnable[ - Union[MessageLikeRepresentation, Sequence[MessageLikeRepresentation]], - Union[BaseMessage, List[BaseMessage]], - ], - ]: - from langchain_core.runnables.base import RunnableGenerator - - if messages is not None: - return func(messages, **kwargs) - else: - def transform(input_: Iterator, **kwargs: Any) -> Iterator: - for x in input_: - yield func(x, **kwargs) - return RunnableGenerator(partial(transform, **kwargs), name=func.__name__) - - wrapped.__doc__ = func.__doc__ - return wrapped - - def _runnable_support(func: Callable) -> Callable: @overload def wrapped( @@ -903,13 +862,84 @@ def trim_messages( ) +def _runnable_generator(func: Callable) -> Callable: + @overload + def wrapped( + messages: Literal[None] = None, **kwargs: Any + ) -> Runnable[ + Union[MessageLikeRepresentation, Sequence[MessageLikeRepresentation]], + Union[BaseMessage, List[BaseMessage]], + ]: ... + + @overload + def wrapped( + messages: Sequence[Union[BaseMessage, Dict, Tuple]], **kwargs: Any + ) -> List[BaseMessage]: ... + + @overload + def wrapped(messages: MessageLikeRepresentation, **kwargs: Any) -> BaseMessage: ... + + def wrapped( + messages: Union[ + MessageLikeRepresentation, Sequence[MessageLikeRepresentation], None + ] = None, + **kwargs: Any, + ) -> Union[ + BaseMessage, + List[BaseMessage], + Runnable[ + Union[MessageLikeRepresentation, Sequence[MessageLikeRepresentation]], + Union[BaseMessage, List[BaseMessage]], + ], + ]: + from langchain_core.runnables.base import RunnableGenerator + + if messages is not None: + return func(messages, **kwargs) + else: + + def transform(input_: Iterator, **kwargs: Any) -> Iterator: + block_indexes = set() + for x in input_: + msg = func(x, **kwargs) + # Special handling for transforming an OpenAI stream to an + # Anthropic stream. + if isinstance(msg, AIMessageChunk) and isinstance( + msg.content, list + ): + tool_use_ct = 0 + for block in msg.content: + if not isinstance(block, dict): + continue + if "index" in block: + block_indexes.add(block["index"]) + elif block.get("type") == "tool_use": + block["index"] = max(len(block_indexes), *block_indexes) + msg.tool_call_chunks[tool_use_ct]["index"] = block[ + "index" + ] + else: + pass + + if block.get("type") == "tool_use": + tool_use_ct += 1 + else: + block_indexes.add(0) + yield msg + + return RunnableGenerator(partial(transform, **kwargs), name=func.__name__) + + wrapped.__doc__ = func.__doc__ + return wrapped + + @_runnable_generator def format_messages_as( - messages: Union[MessageLikeRepresentation, Iterator[MessageLikeRepresentation]], + messages: Union[MessageLikeRepresentation, Sequence[MessageLikeRepresentation]], *, format: Literal["openai", "anthropic"], text: Literal["string", "block"], -) -> Union[BaseMessage, Iterator[BaseMessage]]: +) -> Union[BaseMessage, List[BaseMessage]]: """Convert message contents into a standard format. .. versionadded:: 0.2.36 @@ -975,61 +1005,63 @@ def format_messages_as( from langchain_core.messages import format_messages_as from langchain.chat_models import init_chat_model - formatter = format_messages_as(format="openai", text="block") + formatter = format_messages_as(format="openai", text="string") llm = init_chat_model() | formatter llm.invoke( [{"role": "user", "content": "how are you"}], config={"model": "gpt-4o"}, ) - # -> AIMessage([{"type": "text", "text": ""}], ...) + # -> AIMessage(["I am good..."], ...) llm.invoke( [{"role": "user", "content": "whats your name"}], - config={"model": "claude-3-5-sonnet-20240620"}) - # -> AIMessage([{"type": "text", "text": ""}], ...) + config={"model": "claude-3-5-sonnet-20240620"} + ) + # -> AIMessage(["My name is...], ...) - .. note:: Doesn't support streaming - This util does not support formatting streamed chunks on the fly (i.e. - "transforming" chunks). This means if you pipe the outputs of a model to this - formatter in a chain, the chain will not have token-level streaming when - using ``chain.stream()/.astream()``. You'll still see the - token stream when using ``chat.astream_events()`` but the message chunks will - not yet be formatted. + def multiply(a: int, b: int) -> int: + '''Return product of a and b.''' + return a * b - .. code-block:: python + llm_with_tools = init_chat_model().bind_tools([multiply]) | formatter - from langchain_core.messages import format_messages_as - from langchain.chat_models import init_chat_model - - formatter = format_messages_as(format="openai", text="block") - llm = init_chat_model() | formatter - - # Will contain a single, completed chunk. - list(llm.stream( - [{"role": "user", "content": "how are you"}], - config={"model": "gpt-4o"}, - )) - - # Will include token-level events, but the streamed chunks will not yet be - # formatted. - async for chunk in llm.astream_events( - [{"role": "user", "content": "how are you"}], - config={"model": "gpt-4o"}, - version="v2", + for chunk in llm_with_tools.stream( + "what's 5 times 2", config={"model": "claude-3-5-sonnet-20240620"} ): - ... - + print(chunk) + # -> AIMessageChunk(content='', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75', usage_metadata={'input_tokens': 370, 'output_tokens': 0, 'total_tokens': 370}), + # AIMessageChunk(content='Certainly', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'), + # AIMessageChunk(content='! To', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'), + # AIMessageChunk(content=' calculate', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'), + # AIMessageChunk(content=' 5 times ', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'), + # AIMessageChunk(content='2, we can use', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'), + # AIMessageChunk(content=' the "', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'), + # AIMessageChunk(content='multiply" function that', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'), + # AIMessageChunk(content="'s", id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'), + # AIMessageChunk(content=' available to', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'), + # AIMessageChunk(content=' us.', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'), + # AIMessageChunk(content=' Let', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'), + # AIMessageChunk(content="'s use", id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'), + # AIMessageChunk(content=' this tool', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'), + # AIMessageChunk(content=' to', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'), + # AIMessageChunk(content=' get', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'), + # AIMessageChunk(content=' the result.', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'), + # AIMessageChunk(content='', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75', tool_calls=[{'name': 'multiply', 'args': {}, 'id': 'toolu_01PW8o6BkATCecjsJX8QgG6z', 'type': 'tool_call'}], tool_call_chunks=[{'name': 'multiply', 'args': '', 'id': 'toolu_01PW8o6BkATCecjsJX8QgG6z', 'index': 1, 'type': 'tool_call_chunk'}]), + # AIMessageChunk(content='', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75', tool_calls=[{'name': '', 'args': {}, 'id': None, 'type': 'tool_call'}], tool_call_chunks=[{'name': None, 'args': '', 'id': None, 'index': 1, 'type': 'tool_call_chunk'}]), + # AIMessageChunk(content='', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75', tool_calls=[{'name': '', 'args': {'a': 5}, 'id': None, 'type': 'tool_call'}], tool_call_chunks=[{'name': None, 'args': '{"a": 5', 'id': None, 'index': 1, 'type': 'tool_call_chunk'}]), + # AIMessageChunk(content='', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75', invalid_tool_calls=[{'name': None, 'args': ', "b": 2}', 'id': None, 'error': None, 'type': 'invalid_tool_call'}], tool_call_chunks=[{'name': None, 'args': ', "b": 2}', 'id': None, 'index': 1, 'type': 'tool_call_chunk'}]), + # AIMessageChunk(content='', response_metadata={'stop_reason': 'tool_use', 'stop_sequence': None}, id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75', usage_metadata={'input_tokens': 0, 'output_tokens': 104, 'total_tokens': 104}) """ # noqa: E501 if is_single := isinstance(messages, (BaseMessage, dict)): messages = [messages] messages = convert_to_messages(messages, copy=True) if format.lower() == "openai": - formatted = _format_contents_as_openai(messages, text=text) + formatted = _format_messages_as_openai(messages, text=text) elif format.lower() == "anthropic": - formatted = _format_contents_as_anthropic(messages, text=text) + formatted = _format_messages_as_anthropic(messages, text=text) else: raise ValueError( f"Unrecognized {format=}. Expected one of ('openai', 'anthropic')." @@ -1040,7 +1072,7 @@ def format_messages_as( return formatted -def _format_contents_as_openai( +def _format_messages_as_openai( messages: Sequence[BaseMessage], *, text: Literal["string", "block"] ) -> List[BaseMessage]: """Mutates messages so their contents match OpenAI messages API.""" @@ -1146,14 +1178,14 @@ def _format_contents_as_openai( f"content block:\n\n{block}" ) elif block.get("type") == "tool_use": - if not isinstance(message, BaseMessageChunk): + if not isinstance(message, AIMessageChunk): if missing := [ k for k in ("id", "name", "input") if k not in block ]: raise ValueError( f"Unrecognized content block at " - f"messages[{i}].content[{j}] has 'type': 'tool_use', " - f"but is missing expected key(s) " + f"messages[{i}].content[{j}] has 'type': " + f"'tool_use', but is missing expected key(s) " f"{missing}. Full content block:\n\n{block}" ) if not any( @@ -1169,7 +1201,14 @@ def _format_contents_as_openai( ) else: if not message.tool_call_chunks: - message.tool_call_chunks = [create_tool_call_chunk(id=block.get("id"), index=block.get("index"), args=block.get("partial_json"), name=block.get("name"))] + message.tool_call_chunks = [ + create_tool_call_chunk( + id=block.get("id"), + index=block.get("index"), + args=block.get("partial_json"), + name=block.get("name"), + ) + ] elif block.get("type") == "tool_result": if missing := [ k for k in ("content", "tool_use_id") if k not in block @@ -1187,7 +1226,7 @@ def _format_contents_as_openai( ) # Recurse to make sure tool message contents are OpenAI format. tool_messages.extend( - _format_contents_as_openai([tool_message], text=text) + _format_messages_as_openai([tool_message], text=text) ) elif (block.get("type") == "json") or "json" in block: if "json" not in block: @@ -1253,15 +1292,20 @@ def _format_contents_as_openai( f"Anthropic, Bedrock Converse, or VertexAI format. Full " f"content block:\n\n{block}" ) - message.content = content # type: ignore[assignment] + if text == "string" and not any( + block["type"] != "text" for block in content + ): + message.content = "\n".join(block["text"] for block in content) + else: + message.content = content # type: ignore[assignment] updated_messages.extend([message, *tool_messages]) return updated_messages -_OPTIONAL_ANTHROPIC_KEYS = ("cache_control", "is_error") +_OPTIONAL_ANTHROPIC_KEYS = ("cache_control", "is_error", "index") -def _format_contents_as_anthropic( +def _format_messages_as_anthropic( messages: Sequence[BaseMessage], *, text: Literal["string", "block"] ) -> List[BaseMessage]: """Mutates messages so their contents match Anthropic messages API.""" @@ -1289,7 +1333,10 @@ def _format_contents_as_anthropic( if text == "string": pass else: - message.content = [{"type": "text", "text": message.content}] + text_block: dict = {"type": "text", "text": message.content} + if isinstance(message, AIMessageChunk): + text_block["index"] = 0 + message.content = [text_block] else: if text == "string" and all( isinstance(block, str) @@ -1303,13 +1350,20 @@ def _format_contents_as_anthropic( else: content = [] for j, block in enumerate(message.content): - # OpenAI format - if isinstance(block, str): - content.append({"type": "text", "text": block}) - elif block.get("type") == "text": + if isinstance(block, dict): block_extra = { k: block[k] for k in _OPTIONAL_ANTHROPIC_KEYS if k in block } + else: + block_extra = {} + + # OpenAI format + if isinstance(block, str): + text_block = {"type": "text", "text": block} + if isinstance(message, AIMessageChunk): + text_block["index"] = 0 + content.append(text_block) + elif block.get("type") == "text": if missing := [k for k in ("text",) if k not in block]: raise ValueError( f"Unrecognized content block at " @@ -1380,35 +1434,72 @@ def _format_contents_as_anthropic( f"content block:\n\n{block}" ) elif block.get("type") == "tool_use": - if missing := [ - k for k in ("id", "name", "input") if k not in block - ]: - raise ValueError( - f"Unrecognized content block at " - f"messages[{i}].content[{j}] has 'type': 'tool_use', " - f"but is missing expected key(s) " - f"{missing}. Full content block:\n\n{block}" - ) - content.append( - { - "type": "tool_use", - "name": block["name"], - "id": block["id"], - "input": block["input"], - **block_extra, - } - ) - if not any( - tool_call["id"] == block["id"] - for tool_call in cast(AIMessage, message).tool_calls - ): - cast(AIMessage, message).tool_calls.append( - create_tool_call( - name=block["name"], - id=block["id"], - args=block["input"], + if not isinstance(message, AIMessageChunk): + if missing := [ + k for k in ("id", "name", "input") if k not in block + ]: + raise ValueError( + f"Unrecognized content block at " + f"messages[{i}].content[{j}] has 'type': " + f"'tool_use', " + f"but is missing expected key(s) " + f"{missing}. Full content block:\n\n{block}" ) + content.append( + { + "type": "tool_use", + "name": block["name"], + "id": block["id"], + "input": block["input"], + **block_extra, + } ) + if not any( + tool_call["id"] == block["id"] + for tool_call in cast(AIMessage, message).tool_calls + ): + cast(AIMessage, message).tool_calls.append( + create_tool_call( + name=block["name"], + id=block["id"], + args=block["input"], + ) + ) + else: + if ( + not any(k in block for k in ("input", "partial_json")) + or "index" not in block + ): + raise ValueError( + f"Unrecognized content block at " + f"message_chunks[{i}].content[{j}] has " + f"'type': 'tool_use', " + f"but is does not have either an 'input' or " + f"'partial_json' and an 'index' key. Full content " + f"block:\n\n{block}" + ) + content.append( + { + "type": "tool_use", + "index": block["index"], + **{ + k: block[k] + for k in ("name", "input", "partial_json", "id") + if k in block + }, + } + ) + if not message.tool_call_chunks: + message.tool_call_chunks = [ + create_tool_call_chunk( + name=block.get("name"), + id=block.get("id"), + index=block["index"], + args=block["partial_json"] + if "partial_json" in block + else block["input"], + ) + ] elif block.get("type") == "tool_result": if missing := [ k for k in ("content", "tool_use_id") if k not in block @@ -1489,7 +1580,35 @@ def _format_contents_as_anthropic( ) message.content = content # type: ignore[assignment] - if isinstance(message, AIMessage) and message.tool_calls: + if isinstance(message, AIMessageChunk) and message.tool_call_chunks: + if isinstance(message.content, str): + if message.content: + message.content = [ + {"type": "text", "text": message.content, "index": 0} + ] + else: + message.content = [] + if not any( + cast(dict, block).get("type") == "tool_use" for block in message.content + ): + tool_use_blocks = [ + # Note: we intentionally omit index so that it can be set by the + # stream handler, which can count how many blocks + # have been seen in preceding chunks + { + "type": "tool_use", + "partial_json": tc_chunk["args"], + "id": tc_chunk["id"], + "name": tc_chunk["name"], + } + for i, tc_chunk in enumerate(message.tool_call_chunks) + ] + tool_use_blocks = [ + {k: v for k, v in tu_block.items() if v is not None} + for tu_block in tool_use_blocks + ] + message.content.extend(tool_use_blocks) + elif isinstance(message, AIMessage) and message.tool_calls: if isinstance(message.content, str): message.content = [{"type": "text", "text": message.content}] for tool_call in message.tool_calls: diff --git a/libs/core/tests/unit_tests/messages/test_utils.py b/libs/core/tests/unit_tests/messages/test_utils.py index 9a1d34bc3ed..b4047481b58 100644 --- a/libs/core/tests/unit_tests/messages/test_utils.py +++ b/libs/core/tests/unit_tests/messages/test_utils.py @@ -1,11 +1,12 @@ import json -from typing import Dict, List, Type +from typing import Any, Callable, Dict, Iterator, List, Type import pytest from langchain_core.language_models.fake_chat_models import FakeChatModel from langchain_core.messages import ( AIMessage, + AIMessageChunk, BaseMessage, HumanMessage, SystemMessage, @@ -20,6 +21,7 @@ from langchain_core.messages.utils import ( merge_message_runs, trim_messages, ) +from langchain_core.runnables import RunnableLambda @pytest.mark.parametrize("msg_cls", [HumanMessage, AIMessage, SystemMessage]) @@ -810,3 +812,580 @@ def test_format_messages_as_declarative() -> None: result = formatter.invoke(messages) assert result[0].content[1]["type"] == "image_url" assert result[0].content[1]["image_url"]["url"] == base64_image + + +def _stream_oai(input_: Any) -> Iterator: + chunks = [ + AIMessageChunk(content=""), + AIMessageChunk(content="Certainly"), + AIMessageChunk(content="!"), + AIMessageChunk( + content="", + tool_calls=[ + { + "name": "multiply", + "args": {}, + "id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8", + "type": "tool_call", + } + ], + tool_call_chunks=[ + { + "name": "multiply", + "args": "", + "id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8", + "index": 0, + "type": "tool_call_chunk", + } + ], + ), + AIMessageChunk( + content="", + tool_calls=[{"name": "", "args": {}, "id": None, "type": "tool_call"}], + tool_call_chunks=[ + { + "name": None, + "args": '{"', + "id": None, + "index": 0, + "type": "tool_call_chunk", + } + ], + ), + AIMessageChunk( + content="", + invalid_tool_calls=[ + { + "name": None, + "args": 'a": 5, "b": 2', + "id": None, + "error": None, + "type": "invalid_tool_call", + } + ], + tool_call_chunks=[ + { + "name": None, + "args": 'a": 5, "b": 2', + "id": None, + "index": 0, + "type": "tool_call_chunk", + } + ], + ), + AIMessageChunk( + content="", + invalid_tool_calls=[ + { + "name": None, + "args": "}", + "id": None, + "error": None, + "type": "invalid_tool_call", + } + ], + tool_call_chunks=[ + { + "name": None, + "args": "}", + "id": None, + "index": 0, + "type": "tool_call_chunk", + } + ], + ), + AIMessageChunk( + content="", + ), + ] + yield from chunks + + +def _stream_anthropic(input_: Any) -> Iterator: + chunks = [ + AIMessageChunk(content=[]), + AIMessageChunk(content=[{"text": "Certainly", "type": "text", "index": 0}]), + AIMessageChunk(content=[{"text": "!", "type": "text", "index": 0}]), + AIMessageChunk( + content=[ + { + "id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8", + "input": {}, + "name": "multiply", + "type": "tool_use", + "index": 1, + } + ], + tool_calls=[ + { + "name": "multiply", + "args": {}, + "id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8", + "type": "tool_call", + } + ], + tool_call_chunks=[ + { + "name": "multiply", + "args": "", + "id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8", + "index": 1, + "type": "tool_call_chunk", + } + ], + ), + AIMessageChunk( + content=[{"partial_json": '{"', "type": "tool_use", "index": 1}], + tool_calls=[{"name": "", "args": {}, "id": None, "type": "tool_call"}], + tool_call_chunks=[ + { + "name": None, + "args": '{"', + "id": None, + "index": 1, + "type": "tool_call_chunk", + } + ], + ), + AIMessageChunk( + content=[{"partial_json": 'a": 5, "b": 2', "type": "tool_use", "index": 1}], + invalid_tool_calls=[ + { + "name": None, + "args": 'a": 5, "b": 2', + "id": None, + "error": None, + "type": "invalid_tool_call", + } + ], + tool_call_chunks=[ + { + "name": None, + "args": 'a": 5, "b": 2', + "id": None, + "index": 1, + "type": "tool_call_chunk", + } + ], + ), + AIMessageChunk( + content=[{"partial_json": "}", "type": "tool_use", "index": 1}], + invalid_tool_calls=[ + { + "name": None, + "args": "}", + "id": None, + "error": None, + "type": "invalid_tool_call", + } + ], + tool_call_chunks=[ + { + "name": None, + "args": "}", + "id": None, + "index": 1, + "type": "tool_call_chunk", + } + ], + ), + AIMessageChunk(content=""), + ] + yield from chunks + + +@pytest.mark.parametrize("stream", [_stream_oai, _stream_anthropic]) +def test_format_messages_openai_string_stream(stream: Callable) -> None: + formatter = format_messages_as(format="openai", text="string") + + chain = RunnableLambda(stream) | formatter + tool_call_idx = 1 if stream == _stream_anthropic else 0 + expected = [ + AIMessageChunk(content=""), + AIMessageChunk(content="Certainly"), + AIMessageChunk(content="!"), + AIMessageChunk( + content="", + tool_calls=[ + { + "name": "multiply", + "args": {}, + "id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8", + "type": "tool_call", + } + ], + tool_call_chunks=[ + { + "name": "multiply", + "args": "", + "id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8", + "index": tool_call_idx, + "type": "tool_call_chunk", + } + ], + ), + AIMessageChunk( + content="", + tool_calls=[{"name": "", "args": {}, "id": None, "type": "tool_call"}], + tool_call_chunks=[ + { + "name": None, + "args": '{"', + "id": None, + "index": tool_call_idx, + "type": "tool_call_chunk", + } + ], + ), + AIMessageChunk( + content="", + invalid_tool_calls=[ + { + "name": None, + "args": 'a": 5, "b": 2', + "id": None, + "error": None, + "type": "invalid_tool_call", + } + ], + tool_call_chunks=[ + { + "name": None, + "args": 'a": 5, "b": 2', + "id": None, + "index": tool_call_idx, + "type": "tool_call_chunk", + } + ], + ), + AIMessageChunk( + content="", + invalid_tool_calls=[ + { + "name": None, + "args": "}", + "id": None, + "error": None, + "type": "invalid_tool_call", + } + ], + tool_call_chunks=[ + { + "name": None, + "args": "}", + "id": None, + "index": tool_call_idx, + "type": "tool_call_chunk", + } + ], + ), + AIMessageChunk( + content="", + ), + ] + + actual = list(chain.stream({})) + assert expected == actual + + +@pytest.mark.parametrize("stream", [_stream_oai, _stream_anthropic]) +def test_format_messages_openai_block_stream(stream: Callable) -> None: + formatter = format_messages_as(format="openai", text="block") + + chain = RunnableLambda(stream) | formatter + tool_call_idx = 1 if stream == _stream_anthropic else 0 + expected = [ + AIMessageChunk(content=[]), + AIMessageChunk(content=[{"type": "text", "text": "Certainly"}]), + AIMessageChunk(content=[{"type": "text", "text": "!"}]), + AIMessageChunk( + content=[], + tool_calls=[ + { + "name": "multiply", + "args": {}, + "id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8", + "type": "tool_call", + } + ], + tool_call_chunks=[ + { + "name": "multiply", + "args": "", + "id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8", + "index": tool_call_idx, + "type": "tool_call_chunk", + } + ], + ), + AIMessageChunk( + content=[], + tool_calls=[{"name": "", "args": {}, "id": None, "type": "tool_call"}], + tool_call_chunks=[ + { + "name": None, + "args": '{"', + "id": None, + "index": tool_call_idx, + "type": "tool_call_chunk", + } + ], + ), + AIMessageChunk( + content=[], + invalid_tool_calls=[ + { + "name": None, + "args": 'a": 5, "b": 2', + "id": None, + "error": None, + "type": "invalid_tool_call", + } + ], + tool_call_chunks=[ + { + "name": None, + "args": 'a": 5, "b": 2', + "id": None, + "index": tool_call_idx, + "type": "tool_call_chunk", + } + ], + ), + AIMessageChunk( + content=[], + invalid_tool_calls=[ + { + "name": None, + "args": "}", + "id": None, + "error": None, + "type": "invalid_tool_call", + } + ], + tool_call_chunks=[ + { + "name": None, + "args": "}", + "id": None, + "index": tool_call_idx, + "type": "tool_call_chunk", + } + ], + ), + AIMessageChunk( + content=[], + ), + ] + actual = list(chain.stream({})) + assert expected == actual + + +@pytest.mark.parametrize("stream", [_stream_oai, _stream_anthropic]) +def test_format_messages_anthropic_block_stream(stream: Callable) -> None: + formatter = format_messages_as(format="anthropic", text="block") + + chain = RunnableLambda(stream) | formatter + expected = [ + AIMessageChunk(content=[]), + AIMessageChunk(content=[{"text": "Certainly", "type": "text", "index": 0}]), + AIMessageChunk(content=[{"text": "!", "type": "text", "index": 0}]), + AIMessageChunk( + content=[ + { + "id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8", + "name": "multiply", + "type": "tool_use", + "index": 1, + **( + {"input": {}} + if stream == _stream_anthropic + else {"partial_json": ""} + ), + } + ], + tool_calls=[ + { + "name": "multiply", + "args": {}, + "id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8", + "type": "tool_call", + } + ], + tool_call_chunks=[ + { + "name": "multiply", + "args": "", + "id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8", + "index": 1, + "type": "tool_call_chunk", + } + ], + ), + AIMessageChunk( + content=[{"partial_json": '{"', "type": "tool_use", "index": 1}], + tool_calls=[{"name": "", "args": {}, "id": None, "type": "tool_call"}], + tool_call_chunks=[ + { + "name": None, + "args": '{"', + "id": None, + "index": 1, + "type": "tool_call_chunk", + } + ], + ), + AIMessageChunk( + content=[{"partial_json": 'a": 5, "b": 2', "type": "tool_use", "index": 1}], + invalid_tool_calls=[ + { + "name": None, + "args": 'a": 5, "b": 2', + "id": None, + "error": None, + "type": "invalid_tool_call", + } + ], + tool_call_chunks=[ + { + "name": None, + "args": 'a": 5, "b": 2', + "id": None, + "index": 1, + "type": "tool_call_chunk", + } + ], + ), + AIMessageChunk( + content=[{"partial_json": "}", "type": "tool_use", "index": 1}], + invalid_tool_calls=[ + { + "name": None, + "args": "}", + "id": None, + "error": None, + "type": "invalid_tool_call", + } + ], + tool_call_chunks=[ + { + "name": None, + "args": "}", + "id": None, + "index": 1, + "type": "tool_call_chunk", + } + ], + ), + AIMessageChunk(content=[]), + ] + actual = list(chain.stream({})) + assert expected == actual + + +@pytest.mark.parametrize("stream", [_stream_oai, _stream_anthropic]) +def test_format_messages_anthropic_string_stream(stream: Callable) -> None: + formatter = format_messages_as(format="anthropic", text="string") + + chain = RunnableLambda(stream) | formatter + expected = [ + AIMessageChunk(content=""), + AIMessageChunk(content="Certainly"), + AIMessageChunk(content="!"), + AIMessageChunk( + content=[ + { + "type": "tool_use", + "id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8", + "name": "multiply", + "index": 1, + **( + {"input": {}} + if stream == _stream_anthropic + else {"partial_json": ""} + ), + }, + ], + tool_calls=[ + { + "name": "multiply", + "args": {}, + "id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8", + "type": "tool_call", + } + ], + tool_call_chunks=[ + { + "name": "multiply", + "args": "", + "id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8", + "index": 1, + "type": "tool_call_chunk", + } + ], + ), + AIMessageChunk( + content=[ + {"type": "tool_use", "partial_json": '{"', "index": 1}, + ], + tool_calls=[{"name": "", "args": {}, "id": None, "type": "tool_call"}], + tool_call_chunks=[ + { + "name": None, + "args": '{"', + "id": None, + "index": 1, + "type": "tool_call_chunk", + } + ], + ), + AIMessageChunk( + content=[ + {"type": "tool_use", "partial_json": 'a": 5, "b": 2', "index": 1}, + ], + invalid_tool_calls=[ + { + "name": None, + "args": 'a": 5, "b": 2', + "id": None, + "error": None, + "type": "invalid_tool_call", + } + ], + tool_call_chunks=[ + { + "name": None, + "args": 'a": 5, "b": 2', + "id": None, + "index": 1, + "type": "tool_call_chunk", + } + ], + ), + AIMessageChunk( + content=[ + {"type": "tool_use", "partial_json": "}", "index": 1}, + ], + invalid_tool_calls=[ + { + "name": None, + "args": "}", + "id": None, + "error": None, + "type": "invalid_tool_call", + } + ], + tool_call_chunks=[ + { + "name": None, + "args": "}", + "id": None, + "index": 1, + "type": "tool_call_chunk", + } + ], + ), + AIMessageChunk(content=""), + ] + actual = list(chain.stream({})) + assert expected == actual diff --git a/libs/langchain/langchain/chat_models/base.py b/libs/langchain/langchain/chat_models/base.py index ce50f7b7b09..99ca42b8d6d 100644 --- a/libs/langchain/langchain/chat_models/base.py +++ b/libs/langchain/langchain/chat_models/base.py @@ -31,7 +31,7 @@ from langchain_core.language_models.chat_models import ( ) from langchain_core.messages import AnyMessage, BaseMessage from langchain_core.pydantic_v1 import BaseModel -from langchain_core.runnables import Runnable, RunnableConfig +from langchain_core.runnables import Runnable, RunnableConfig, ensure_config from langchain_core.runnables.schema import StreamEvent from langchain_core.tools import BaseTool from langchain_core.tracers import RunLog, RunLogPatch @@ -512,7 +512,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]): return model def _model_params(self, config: Optional[RunnableConfig]) -> dict: - config = config or {} + config = ensure_config(config) model_params = { _remove_prefix(k, self._config_prefix): v for k, v in config.get("configurable", {}).items()