diff --git a/libs/core/langchain_core/language_models/v1/chat_models.py b/libs/core/langchain_core/language_models/v1/chat_models.py index 775d58dd7a0..e34ab65bea4 100644 --- a/libs/core/langchain_core/language_models/v1/chat_models.py +++ b/libs/core/langchain_core/language_models/v1/chat_models.py @@ -25,7 +25,7 @@ from pydantic import ( Field, field_validator, ) -from typing_extensions import TypeAlias, override +from typing_extensions import override from langchain_core.caches import BaseCache from langchain_core.callbacks import ( @@ -79,8 +79,8 @@ if TYPE_CHECKING: def _generate_response_from_error(error: BaseException) -> list[AIMessageV1]: - if hasattr(error, "response"): - response = error.response + response = getattr(error, "response", None) + if response is not None: metadata: dict = {} if hasattr(response, "headers"): try: @@ -90,7 +90,7 @@ def _generate_response_from_error(error: BaseException) -> list[AIMessageV1]: if hasattr(response, "status_code"): metadata["status_code"] = response.status_code if hasattr(error, "request_id"): - metadata["request_id"] = error.request_id + metadata["request_id"] = error.request_id # type: ignore[arg-type] # Permit response_metadata without model_name, model_provider fields generations = [AIMessageV1(content=[], response_metadata=metadata)] # type: ignore[arg-type] else: @@ -118,7 +118,7 @@ def _format_for_tracing(messages: Sequence[MessageV1]) -> list[MessageV1]: for idx, block in enumerate(message.content): # Update image content blocks to OpenAI # Chat Completions format. if ( - block["type"] == "image" + block.get("type") == "image" and is_data_content_block(block) # type: ignore[arg-type] # permit unnecessary runtime check and block.get("source_type") != "id" ): @@ -338,7 +338,7 @@ class BaseChatModelV1(RunnableSerializable[LanguageModelInput, AIMessageV1], ABC @property @override - def InputType(self) -> TypeAlias: + def InputType(self) -> Any: """Get the input type for this runnable.""" from langchain_core.prompt_values import ( ChatPromptValueConcrete, @@ -458,7 +458,7 @@ class BaseChatModelV1(RunnableSerializable[LanguageModelInput, AIMessageV1], ABC chunks: list[AIMessageChunkV1] = [] try: for msg in self._stream(input_messages, **kwargs): - run_manager.on_llm_new_token(msg.text or "") + run_manager.on_llm_new_token(msg.text) chunks.append(msg) except BaseException as e: run_manager.on_llm_error(e, response=_generate_response_from_error(e)) @@ -525,7 +525,7 @@ class BaseChatModelV1(RunnableSerializable[LanguageModelInput, AIMessageV1], ABC chunks: list[AIMessageChunkV1] = [] try: async for msg in self._astream(input_messages, **kwargs): - await run_manager.on_llm_new_token(msg.text or "") + await run_manager.on_llm_new_token(msg.text) chunks.append(msg) except BaseException as e: await run_manager.on_llm_error( @@ -602,9 +602,12 @@ class BaseChatModelV1(RunnableSerializable[LanguageModelInput, AIMessageV1], ABC # TODO: replace this with something for new messages input_messages = _normalize_messages_v1(messages) for msg in self._stream(input_messages, **kwargs): - run_manager.on_llm_new_token(msg.text or "") + run_manager.on_llm_new_token(msg.text) chunks.append(msg) yield msg + + if msg.chunk_position != "last": + yield (AIMessageChunkV1([], chunk_position="last")) except BaseException as e: run_manager.on_llm_error(e, response=_generate_response_from_error(e)) raise @@ -673,9 +676,11 @@ class BaseChatModelV1(RunnableSerializable[LanguageModelInput, AIMessageV1], ABC input_messages, **kwargs, ): - await run_manager.on_llm_new_token(msg.text or "") + await run_manager.on_llm_new_token(msg.text) chunks.append(msg) yield msg + if msg.chunk_position != "last": + yield (AIMessageChunkV1([], chunk_position="last")) except BaseException as e: await run_manager.on_llm_error(e, response=_generate_response_from_error(e)) raise @@ -716,22 +721,23 @@ class BaseChatModelV1(RunnableSerializable[LanguageModelInput, AIMessageV1], ABC ls_params["ls_stop"] = stop # model - if hasattr(self, "model") and isinstance(self.model, str): - ls_params["ls_model_name"] = self.model - elif hasattr(self, "model_name") and isinstance(self.model_name, str): - ls_params["ls_model_name"] = self.model_name + model = ( + kwargs.get("model") + or getattr(self, "model", None) + or getattr(self, "model_name", None) + ) + if isinstance(model, str): + ls_params["ls_model_name"] = model # temperature - if "temperature" in kwargs and isinstance(kwargs["temperature"], float): - ls_params["ls_temperature"] = kwargs["temperature"] - elif hasattr(self, "temperature") and isinstance(self.temperature, float): - ls_params["ls_temperature"] = self.temperature + temperature = kwargs.get("temperature") or getattr(self, "temperature", None) + if isinstance(temperature, (int, float)): + ls_params["ls_temperature"] = temperature # max_tokens - if "max_tokens" in kwargs and isinstance(kwargs["max_tokens"], int): - ls_params["ls_max_tokens"] = kwargs["max_tokens"] - elif hasattr(self, "max_tokens") and isinstance(self.max_tokens, int): - ls_params["ls_max_tokens"] = self.max_tokens + max_tokens = kwargs.get("max_tokens") or getattr(self, "max_tokens", None) + if isinstance(max_tokens, int): + ls_params["ls_max_tokens"] = max_tokens return ls_params @@ -806,7 +812,7 @@ class BaseChatModelV1(RunnableSerializable[LanguageModelInput, AIMessageV1], ABC Union[typing.Dict[str, Any], type, Callable, BaseTool] # noqa: UP006 ], *, - tool_choice: Optional[Union[str]] = None, + tool_choice: Optional[str] = None, **kwargs: Any, ) -> Runnable[LanguageModelInput, AIMessageV1]: """Bind tools to the model. diff --git a/libs/core/langchain_core/messages/content_blocks.py b/libs/core/langchain_core/messages/content_blocks.py index d5775ad7433..16a42b36d1e 100644 --- a/libs/core/langchain_core/messages/content_blocks.py +++ b/libs/core/langchain_core/messages/content_blocks.py @@ -103,7 +103,7 @@ The module defines several types of content blocks, including: """ # noqa: E501 import warnings -from typing import Any, Literal, Optional, Union +from typing import Any, Literal, Optional, TypeGuard, Union from uuid import uuid4 from typing_extensions import NotRequired, TypedDict, get_args, get_origin @@ -844,8 +844,6 @@ ContentBlock = Union[ TextContentBlock, ToolCall, ToolCallChunk, - Citation, - NonStandardAnnotation, InvalidToolCall, ReasoningContentBlock, NonStandardContentBlock, @@ -884,7 +882,24 @@ def _extract_typedict_type_values(union_type: Any) -> set[str]: return result -KNOWN_BLOCK_TYPES = _extract_typedict_type_values(ContentBlock) +KNOWN_BLOCK_TYPES = { + "text", + "text-plain", + "tool_call", + "invalid_tool_call", + "tool_call_chunk", + "reasoning", + "non_standard", + "image", + "audio", + "file", + "video", + "code_interpreter_call", + "code_interpreter_output", + "code_interpreter_result", + "web_search_call", + "web_search_result", +} def is_data_content_block(block: dict) -> bool: @@ -914,6 +929,28 @@ def is_data_content_block(block: dict) -> bool: ) +def is_tool_call_block(block: ContentBlock) -> TypeGuard[ToolCall]: + """Type guard to check if a content block is a tool call.""" + return block.get("type") == "tool_call" + + +def is_tool_call_chunk(block: ContentBlock) -> TypeGuard[ToolCallChunk]: + """Type guard to check if a content block is a tool call chunk.""" + return block.get("type") == "tool_call_chunk" + + +def is_text_block(block: ContentBlock) -> TypeGuard[TextContentBlock]: + """Type guard to check if a content block is a text block.""" + return block.get("type") == "text" + + +def is_invalid_tool_call_block( + block: ContentBlock, +) -> TypeGuard[InvalidToolCall]: + """Type guard to check if a content block is an invalid tool call.""" + return block.get("type") == "invalid_tool_call" + + def convert_to_openai_image_block(block: dict[str, Any]) -> dict: """Convert image content block to format expected by OpenAI Chat Completions API.""" if "url" in block: diff --git a/libs/core/langchain_core/messages/v1.py b/libs/core/langchain_core/messages/v1.py index 66063813846..0827688b9a1 100644 --- a/libs/core/langchain_core/messages/v1.py +++ b/libs/core/langchain_core/messages/v1.py @@ -4,10 +4,9 @@ Each message has content that may be comprised of content blocks, defined under ``langchain_core.messages.content_blocks``. """ -import json import uuid from dataclasses import dataclass, field -from typing import Any, Literal, Optional, TypeGuard, Union, cast, get_args +from typing import Any, Literal, Optional, Union, cast, get_args from pydantic import BaseModel from typing_extensions import TypedDict @@ -20,32 +19,12 @@ from langchain_core.messages.ai import ( add_usage, ) from langchain_core.messages.base import merge_content -from langchain_core.messages.content_blocks import create_text_block -from langchain_core.messages.tool import ToolCallChunk from langchain_core.messages.tool import invalid_tool_call as create_invalid_tool_call from langchain_core.messages.tool import tool_call as create_tool_call -from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk -from langchain_core.utils._merge import merge_dicts, merge_lists +from langchain_core.utils._merge import merge_dicts from langchain_core.utils.json import parse_partial_json -def is_tool_call_block(block: types.ContentBlock) -> TypeGuard[types.ToolCall]: - """Type guard to check if a content block is a tool call.""" - return block.get("type") == "tool_call" - - -def is_text_block(block: types.ContentBlock) -> TypeGuard[types.TextContentBlock]: - """Type guard to check if a content block is a text block.""" - return block.get("type") == "text" - - -def is_invalid_tool_call_block( - block: types.ContentBlock, -) -> TypeGuard[types.InvalidToolCall]: - """Type guard to check if a content block is an invalid tool call.""" - return block.get("type") == "invalid_tool_call" - - def _ensure_id(id_val: Optional[str]) -> str: """Ensure the ID is a valid string, generating a new UUID if not provided. @@ -169,7 +148,7 @@ class AIMessage: parsed: Optional auto-parsed message contents, if applicable. """ if isinstance(content, str): - self.content = [create_text_block(content)] + self.content = [types.create_text_block(content)] else: self.content = content @@ -188,33 +167,46 @@ class AIMessage: content_tool_calls = { block["id"] for block in self.content - if block.get("type") == "tool_call" and "id" in block + if types.is_tool_call_block(block) and "id" in block } for tool_call in tool_calls: if "id" in tool_call and tool_call["id"] in content_tool_calls: continue self.content.append(tool_call) + if invalid_tool_calls: + content_tool_calls = { + block["id"] + for block in self.content + if types.is_invalid_tool_call_block(block) and "id" in block + } + for invalid_tool_call in invalid_tool_calls: + if ( + "id" in invalid_tool_call + and invalid_tool_call["id"] in content_tool_calls + ): + continue + self.content.append(invalid_tool_call) self._tool_calls: list[types.ToolCall] = [ - block for block in self.content if is_tool_call_block(block) + block for block in self.content if types.is_tool_call_block(block) + ] + self._invalid_tool_calls = [ + block for block in self.content if types.is_invalid_tool_call_block(block) ] - self.invalid_tool_calls = invalid_tool_calls or [] @property - def text(self) -> Optional[str]: + def text(self) -> str: """Extract all text content from the AI message as a string.""" - text_blocks = [block for block in self.content if is_text_block(block)] - if text_blocks: - return "".join(block["text"] for block in text_blocks) - return None + return "".join( + block["text"] for block in self.content if types.is_text_block(block) + ) @property - def tool_calls(self) -> list[types.ToolCall]: # update once we fix branch + def tool_calls(self) -> list[types.ToolCall]: """Get the tool calls made by the AI.""" - if self._tool_calls: - return self._tool_calls - tool_calls = [block for block in self.content if is_tool_call_block(block)] - if tool_calls: - self._tool_calls = tool_calls + if not self._tool_calls: + self._tool_calls = [ + block for block in self.content if types.is_tool_call_block(block) + ] return self._tool_calls @tool_calls.setter @@ -222,6 +214,17 @@ class AIMessage: """Set the tool calls for the AI message.""" self._tool_calls = value + @property + def invalid_tool_calls(self) -> list[types.InvalidToolCall]: + """Get the invalid tool calls made by the AI.""" + if not self._invalid_tool_calls: + self._invalid_tool_calls = [ + block + for block in self.content + if types.is_invalid_tool_call_block(block) + ] + return self._invalid_tool_calls + @dataclass class AIMessageChunk(AIMessage): @@ -246,17 +249,10 @@ class AIMessageChunk(AIMessage): when deserializing messages. """ - tool_call_chunks: list[types.ToolCallChunk] = field(init=False) - """List of partial tool call data. - - Emitted by the model during streaming, this field contains - tool call chunks that may not yet be complete. It is used to reconstruct - tool calls from the streamed content. - """ - def __init__( self, content: Union[str, list[types.ContentBlock]], + *, id: Optional[str] = None, name: Optional[str] = None, lc_version: str = "v1", @@ -264,6 +260,7 @@ class AIMessageChunk(AIMessage): usage_metadata: Optional[UsageMetadata] = None, tool_call_chunks: Optional[list[types.ToolCallChunk]] = None, parsed: Optional[Union[dict[str, Any], BaseModel]] = None, + chunk_position: Optional[Literal["last"]] = None, ): """Initialize an AI message. @@ -276,6 +273,8 @@ class AIMessageChunk(AIMessage): usage_metadata: Optional metadata about token usage. tool_call_chunks: Optional list of partial tool call data. parsed: Optional auto-parsed message contents, if applicable. + chunk_position: Optional position of the chunk in the stream. If "last", + tool calls will be parsed when aggregated into a stream. """ if isinstance(content, str): self.content = [{"type": "text", "text": content, "index": 0}] @@ -287,115 +286,51 @@ class AIMessageChunk(AIMessage): self.lc_version = lc_version self.usage_metadata = usage_metadata self.parsed = parsed + self.chunk_position = chunk_position if response_metadata is None: self.response_metadata = {} else: self.response_metadata = response_metadata - if tool_call_chunks is None: - self.tool_call_chunks: list[types.ToolCallChunk] = [] - else: - self.tool_call_chunks = tool_call_chunks + + if tool_call_chunks: + content_tool_call_chunks = { + block["id"] + for block in self.content + if types.is_tool_call_chunk(block) and "id" in block + } + for chunk in tool_call_chunks: + if "id" in chunk and chunk["id"] in content_tool_call_chunks: + continue + self.content.append(chunk) + self._tool_call_chunks = [ + block for block in self.content if types.is_tool_call_chunk(block) + ] self._tool_calls: list[types.ToolCall] = [] - self.invalid_tool_calls: list[types.InvalidToolCall] = [] - self._init_tool_calls() - - def _init_tool_calls(self) -> None: - """Initialize tool calls from tool call chunks. - - Args: - values: The values to validate. - - Raises: - ValueError: If the tool call chunks are malformed. - """ - self._tool_calls = [] - self.invalid_tool_calls = [] - if not self.tool_call_chunks: - if self._tool_calls: - self.tool_call_chunks = [ - create_tool_call_chunk( - name=tc["name"], - args=json.dumps(tc["args"]), - id=tc["id"], - index=None, - ) - for tc in self._tool_calls - ] - if self.invalid_tool_calls: - tool_call_chunks = self.tool_call_chunks - tool_call_chunks.extend( - [ - create_tool_call_chunk( - name=tc["name"], args=tc["args"], id=tc["id"], index=None - ) - for tc in self.invalid_tool_calls - ] - ) - self.tool_call_chunks = tool_call_chunks - - tool_calls = [] - invalid_tool_calls = [] - - def add_chunk_to_invalid_tool_calls(chunk: ToolCallChunk) -> None: - invalid_tool_calls.append( - create_invalid_tool_call( - name=chunk.get("name", ""), - args=chunk.get("args", ""), - id=chunk.get("id", ""), - error=None, - ) - ) - - for chunk in self.tool_call_chunks: - try: - args_ = parse_partial_json(chunk["args"]) if chunk["args"] != "" else {} # type: ignore[arg-type] - if isinstance(args_, dict): - tool_calls.append( - create_tool_call( - name=chunk.get("name") or "", - args=args_, - id=chunk.get("id", ""), - ) - ) - else: - add_chunk_to_invalid_tool_calls(chunk) - except Exception: - add_chunk_to_invalid_tool_calls(chunk) - self._tool_calls = tool_calls - self.invalid_tool_calls = invalid_tool_calls + self._invalid_tool_calls: list[types.InvalidToolCall] = [] @property - def text(self) -> Optional[str]: - """Extract all text content from the AI message as a string.""" - text_blocks = [block for block in self.content if is_text_block(block)] - if text_blocks: - return "".join(block["text"] for block in text_blocks) - return None - - @property - def reasoning(self) -> Optional[str]: - """Extract all reasoning text from the AI message as a string.""" - text_blocks = [ - block - for block in self.content - if block.get("type") == "reasoning" and "reasoning" in block - ] - if text_blocks: - return "".join( - cast("types.ReasoningContentBlock", block).get("reasoning", "") - for block in text_blocks - ) - return None + def tool_call_chunks(self) -> list[types.ToolCallChunk]: + """Get the tool calls made by the AI.""" + if not self._tool_call_chunks: + self._tool_call_chunks = [ + block for block in self.content if types.is_tool_call_chunk(block) + ] + return cast("list[types.ToolCallChunk]", self._tool_call_chunks) @property def tool_calls(self) -> list[types.ToolCall]: """Get the tool calls made by the AI.""" - if self._tool_calls: - return self._tool_calls - tool_calls = [block for block in self.content if is_tool_call_block(block)] - if tool_calls: - self._tool_calls = tool_calls + if not self._tool_calls: + parsed_content = _init_tool_calls(self.content) + self._tool_calls = [ + block for block in parsed_content if types.is_tool_call_block(block) + ] + self._invalid_tool_calls = [ + block + for block in parsed_content + if types.is_invalid_tool_call_block(block) + ] return self._tool_calls @tool_calls.setter @@ -403,6 +338,21 @@ class AIMessageChunk(AIMessage): """Set the tool calls for the AI message.""" self._tool_calls = value + @property + def invalid_tool_calls(self) -> list[types.InvalidToolCall]: + """Get the invalid tool calls made by the AI.""" + if not self._invalid_tool_calls: + parsed_content = _init_tool_calls(self.content) + self._tool_calls = [ + block for block in parsed_content if types.is_tool_call_block(block) + ] + self._invalid_tool_calls = [ + block + for block in parsed_content + if types.is_invalid_tool_call_block(block) + ] + return self._invalid_tool_calls + def __add__(self, other: Any) -> "AIMessageChunk": """Add ``AIMessageChunk`` to this one.""" if isinstance(other, AIMessageChunk): @@ -417,49 +367,76 @@ class AIMessageChunk(AIMessage): def to_message(self) -> "AIMessage": """Convert this ``AIMessageChunk`` to an AIMessage.""" return AIMessage( - content=self.content, + content=_init_tool_calls(self.content), id=self.id, name=self.name, lc_version=self.lc_version, response_metadata=self.response_metadata, usage_metadata=self.usage_metadata, - tool_calls=self.tool_calls, - invalid_tool_calls=self.invalid_tool_calls, parsed=self.parsed, ) +def _init_tool_calls(content: list[types.ContentBlock]) -> list[types.ContentBlock]: + """Parse tool call chunks in content into tool calls.""" + new_content = [] + for block in content: + if not types.is_tool_call_chunk(block): + new_content.append(block) + continue + try: + args_ = ( + parse_partial_json(cast("str", block.get("args") or "")) + if block.get("args") + else {} + ) + if isinstance(args_, dict): + new_content.append( + create_tool_call( + name=cast("str", block.get("name") or ""), + args=args_, + id=cast("str", block.get("id", "")), + ) + ) + else: + new_content.append( + create_invalid_tool_call( + name=cast("str", block.get("name", "")), + args=cast("str", block.get("args", "")), + id=cast("str", block.get("id", "")), + error=None, + ) + ) + except Exception: + new_content.append( + create_invalid_tool_call( + name=cast("str", block.get("name", "")), + args=cast("str", block.get("args", "")), + id=cast("str", block.get("id", "")), + error=None, + ) + ) + return new_content + + def add_ai_message_chunks( left: AIMessageChunk, *others: AIMessageChunk ) -> AIMessageChunk: """Add multiple ``AIMessageChunks`` together.""" if not others: return left - content = merge_content( - cast("list[str | dict[Any, Any]]", left.content), - *(cast("list[str | dict[Any, Any]]", o.content) for o in others), + content = cast( + "list[types.ContentBlock]", + merge_content( + cast("list[str | dict[Any, Any]]", left.content), + *(cast("list[str | dict[Any, Any]]", o.content) for o in others), + ), ) response_metadata = merge_dicts( cast("dict", left.response_metadata), *(cast("dict", o.response_metadata) for o in others), ) - # Merge tool call chunks - if raw_tool_calls := merge_lists( - left.tool_call_chunks, *(o.tool_call_chunks for o in others) - ): - tool_call_chunks = [ - create_tool_call_chunk( - name=rtc.get("name"), - args=rtc.get("args"), - index=rtc.get("index"), - id=rtc.get("id"), - ) - for rtc in raw_tool_calls - ] - else: - tool_call_chunks = [] - # Token usage if left.usage_metadata or any(o.usage_metadata is not None for o in others): usage_metadata: Optional[UsageMetadata] = left.usage_metadata @@ -501,13 +478,19 @@ def add_ai_message_chunks( chunk_id = id_ break + chunk_position: Optional[Literal["last"]] = ( + "last" if any(x.chunk_position == "last" for x in [left, *others]) else None + ) + if chunk_position == "last": + content = _init_tool_calls(content) + return left.__class__( - content=cast("list[types.ContentBlock]", content), - tool_call_chunks=tool_call_chunks, + content=content, response_metadata=cast("ResponseMetadata", response_metadata), usage_metadata=usage_metadata, parsed=parsed, id=chunk_id, + chunk_position=chunk_position, ) @@ -579,9 +562,7 @@ class HumanMessage: Concatenated string of all text blocks in the message. """ return "".join( - cast("types.TextContentBlock", block)["text"] - for block in self.content - if block.get("type") == "text" + block["text"] for block in self.content if types.is_text_block(block) ) @@ -660,9 +641,7 @@ class SystemMessage: def text(self) -> str: """Extract all text content from the system message.""" return "".join( - cast("types.TextContentBlock", block)["text"] - for block in self.content - if block.get("type") == "text" + block["text"] for block in self.content if types.is_text_block(block) ) @@ -754,9 +733,7 @@ class ToolMessage: def text(self) -> str: """Extract all text content from the tool message.""" return "".join( - cast("types.TextContentBlock", block)["text"] - for block in self.content - if block.get("type") == "text" + block["text"] for block in self.content if types.is_text_block(block) ) def __post_init__(self) -> None: diff --git a/libs/core/langchain_core/output_parsers/base.py b/libs/core/langchain_core/output_parsers/base.py index 6ce45f94590..b91ac475842 100644 --- a/libs/core/langchain_core/output_parsers/base.py +++ b/libs/core/langchain_core/output_parsers/base.py @@ -283,7 +283,7 @@ class BaseOutputParser( Structured output. """ if isinstance(result, AIMessage): - return self.parse(result.text or "") + return self.parse(result.text) return self.parse(result[0].text) @abstractmethod diff --git a/libs/core/langchain_core/output_parsers/json.py b/libs/core/langchain_core/output_parsers/json.py index c9465d690a1..1384774afff 100644 --- a/libs/core/langchain_core/output_parsers/json.py +++ b/libs/core/langchain_core/output_parsers/json.py @@ -73,7 +73,7 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]): Raises: OutputParserException: If the output is not valid JSON. """ - text = result.text or "" if isinstance(result, AIMessage) else result[0].text + text = result.text if isinstance(result, AIMessage) else result[0].text text = text.strip() if partial: try: diff --git a/libs/core/langchain_core/output_parsers/list.py b/libs/core/langchain_core/output_parsers/list.py index a6ca2511c85..ce825ee2b7d 100644 --- a/libs/core/langchain_core/output_parsers/list.py +++ b/libs/core/langchain_core/output_parsers/list.py @@ -83,7 +83,7 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]): continue buffer += chunk_content elif isinstance(chunk, AIMessage): - buffer += chunk.text or "" + buffer += chunk.text else: # add current chunk to buffer buffer += chunk @@ -119,7 +119,7 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]): continue buffer += chunk_content elif isinstance(chunk, AIMessage): - buffer += chunk.text or "" + buffer += chunk.text else: # add current chunk to buffer buffer += chunk diff --git a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py index 37b05ed8255..e653d35fda8 100644 --- a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py +++ b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py @@ -14,7 +14,10 @@ from langchain_core.language_models import ( ParrotFakeChatModel, ) from langchain_core.language_models._utils import _normalize_messages -from langchain_core.language_models.fake_chat_models import FakeListChatModelError +from langchain_core.language_models.fake_chat_models import ( + FakeListChatModelError, + GenericFakeChatModelV1, +) from langchain_core.messages import ( AIMessage, AIMessageChunk, @@ -22,6 +25,7 @@ from langchain_core.messages import ( HumanMessage, SystemMessage, ) +from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1 from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.outputs.llm_result import LLMResult from langchain_core.tracers import LogStreamCallbackHandler @@ -654,3 +658,93 @@ def test_normalize_messages_edge_cases() -> None: ) ] assert messages == _normalize_messages(messages) + + +def test_streaming_v1() -> None: + chunks = [ + AIMessageChunkV1( + [ + { + "type": "reasoning", + "reasoning": "Let's call a tool.", + "index": 0, + } + ] + ), + AIMessageChunkV1( + [], + tool_call_chunks=[ + { + "type": "tool_call_chunk", + "args": "", + "name": "tool_name", + "id": "call_123", + "index": 1, + }, + ], + ), + AIMessageChunkV1( + [], + tool_call_chunks=[ + { + "type": "tool_call_chunk", + "args": '{"a', + "name": "", + "id": "", + "index": 1, + }, + ], + ), + AIMessageChunkV1( + [], + tool_call_chunks=[ + { + "type": "tool_call_chunk", + "args": '": 1}', + "name": "", + "id": "", + "index": 1, + }, + ], + ), + ] + full: Optional[AIMessageChunkV1] = None + for chunk in chunks: + full = chunk if full is None else full + chunk + + assert isinstance(full, AIMessageChunkV1) + assert full.content == [ + { + "type": "reasoning", + "reasoning": "Let's call a tool.", + "index": 0, + }, + { + "type": "tool_call_chunk", + "args": '{"a": 1}', + "name": "tool_name", + "id": "call_123", + "index": 1, + }, + ] + + llm = GenericFakeChatModelV1(message_chunks=chunks) + + full = None + for chunk in llm.stream("anything"): + full = chunk if full is None else full + chunk + + assert isinstance(full, AIMessageChunkV1) + assert full.content == [ + { + "type": "reasoning", + "reasoning": "Let's call a tool.", + "index": 0, + }, + { + "type": "tool_call", + "args": {"a": 1}, + "name": "tool_name", + "id": "call_123", + }, + ] diff --git a/libs/core/tests/unit_tests/output_parsers/test_base_parsers.py b/libs/core/tests/unit_tests/output_parsers/test_base_parsers.py index fc687629f01..8794d3d2f65 100644 --- a/libs/core/tests/unit_tests/output_parsers/test_base_parsers.py +++ b/libs/core/tests/unit_tests/output_parsers/test_base_parsers.py @@ -37,7 +37,7 @@ def test_base_generation_parser() -> None: that support streaming """ if isinstance(result, AIMessageV1): - content = result.text or "" + content = result.text else: if len(result) != 1: msg = ( @@ -89,7 +89,7 @@ def test_base_transform_output_parser() -> None: that support streaming """ if isinstance(result, AIMessageV1): - content = result.text or "" + content = result.text else: if len(result) != 1: msg = ( @@ -116,4 +116,4 @@ def test_base_transform_output_parser() -> None: model_v1 = GenericFakeChatModelV1(message_chunks=["hello", " ", "world"]) chain_v1 = model_v1 | StrInvertCase() chunks = list(chain_v1.stream("")) - assert chunks == ["HELLO", " ", "WORLD"] + assert chunks == ["HELLO", " ", "WORLD", ""] diff --git a/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr b/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr index ee37908b390..c71663e3223 100644 --- a/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr +++ b/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr @@ -2543,12 +2543,6 @@ dict({ '$ref': '#/$defs/ToolCallChunk', }), - dict({ - '$ref': '#/$defs/Citation', - }), - dict({ - '$ref': '#/$defs/NonStandardAnnotation', - }), dict({ '$ref': '#/$defs/InvalidToolCall', }), @@ -2675,12 +2669,6 @@ dict({ '$ref': '#/$defs/ToolCallChunk', }), - dict({ - '$ref': '#/$defs/Citation', - }), - dict({ - '$ref': '#/$defs/NonStandardAnnotation', - }), dict({ '$ref': '#/$defs/InvalidToolCall', }), @@ -2807,12 +2795,6 @@ dict({ '$ref': '#/$defs/ToolCallChunk', }), - dict({ - '$ref': '#/$defs/Citation', - }), - dict({ - '$ref': '#/$defs/NonStandardAnnotation', - }), dict({ '$ref': '#/$defs/InvalidToolCall', }), @@ -2901,12 +2883,6 @@ dict({ '$ref': '#/$defs/ToolCallChunk', }), - dict({ - '$ref': '#/$defs/Citation', - }), - dict({ - '$ref': '#/$defs/NonStandardAnnotation', - }), dict({ '$ref': '#/$defs/InvalidToolCall', }), @@ -3018,12 +2994,6 @@ dict({ '$ref': '#/$defs/ToolCallChunk', }), - dict({ - '$ref': '#/$defs/Citation', - }), - dict({ - '$ref': '#/$defs/NonStandardAnnotation', - }), dict({ '$ref': '#/$defs/InvalidToolCall', }), diff --git a/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr b/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr index ea62b19f550..fdc3ca0c813 100644 --- a/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr +++ b/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr @@ -11535,6 +11535,9 @@ dict({ '$ref': '#/definitions/InvalidToolCall', }), + dict({ + '$ref': '#/definitions/ToolCallChunk', + }), dict({ '$ref': '#/definitions/ReasoningContentBlock', }), @@ -11657,6 +11660,9 @@ dict({ '$ref': '#/definitions/InvalidToolCall', }), + dict({ + '$ref': '#/definitions/ToolCallChunk', + }), dict({ '$ref': '#/definitions/ReasoningContentBlock', }), @@ -11779,6 +11785,9 @@ dict({ '$ref': '#/definitions/InvalidToolCall', }), + dict({ + '$ref': '#/definitions/ToolCallChunk', + }), dict({ '$ref': '#/definitions/ReasoningContentBlock', }), @@ -11863,6 +11872,9 @@ dict({ '$ref': '#/definitions/InvalidToolCall', }), + dict({ + '$ref': '#/definitions/ToolCallChunk', + }), dict({ '$ref': '#/definitions/ReasoningContentBlock', }), @@ -11970,6 +11982,9 @@ dict({ '$ref': '#/definitions/InvalidToolCall', }), + dict({ + '$ref': '#/definitions/ToolCallChunk', + }), dict({ '$ref': '#/definitions/ReasoningContentBlock', }), diff --git a/libs/core/tests/unit_tests/test_messages.py b/libs/core/tests/unit_tests/test_messages.py index 13d86eefa25..5da1ad20f5d 100644 --- a/libs/core/tests/unit_tests/test_messages.py +++ b/libs/core/tests/unit_tests/test_messages.py @@ -3,6 +3,7 @@ import uuid from typing import Optional, Union import pytest +from typing_extensions import get_args from langchain_core.documents import Document from langchain_core.load import dumpd, load @@ -30,7 +31,7 @@ from langchain_core.messages import ( messages_from_dict, messages_to_dict, ) -from langchain_core.messages.content_blocks import KNOWN_BLOCK_TYPES +from langchain_core.messages.content_blocks import KNOWN_BLOCK_TYPES, ContentBlock from langchain_core.messages.tool import invalid_tool_call as create_invalid_tool_call from langchain_core.messages.tool import tool_call as create_tool_call from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk @@ -1363,23 +1364,19 @@ def test_convert_to_openai_image_block() -> None: def test_known_block_types() -> None: - assert { - "audio", - "citation", - "code_interpreter_call", - "code_interpreter_output", - "code_interpreter_result", - "file", - "image", - "invalid_tool_call", - "non_standard", - "non_standard_annotation", - "reasoning", - "text", - "text-plain", - "tool_call", - "tool_call_chunk", - "video", - "web_search_call", - "web_search_result", - } == KNOWN_BLOCK_TYPES + expected = { + bt + for bt in get_args(ContentBlock) + for bt in get_args(bt.__annotations__["type"]) + } + # Normalize any Literal[...] types in block types to their string values. + # This ensures all entries are plain strings, not Literal objects. + expected = { + t + if isinstance(t, str) + else t.__args__[0] + if hasattr(t, "__args__") and len(t.__args__) == 1 + else t + for t in expected + } + assert expected == KNOWN_BLOCK_TYPES