mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-14 07:07:34 +00:00
Merge branch 'standard_outputs_copy' into mdrxy/ollama_v1 + updates
This commit is contained in:
commit
cc56b8dbd3
@ -25,7 +25,7 @@ from pydantic import (
|
||||
Field,
|
||||
field_validator,
|
||||
)
|
||||
from typing_extensions import TypeAlias, override
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core.caches import BaseCache
|
||||
from langchain_core.callbacks import (
|
||||
@ -79,8 +79,8 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
def _generate_response_from_error(error: BaseException) -> list[AIMessageV1]:
|
||||
if hasattr(error, "response"):
|
||||
response = error.response
|
||||
response = getattr(error, "response", None)
|
||||
if response is not None:
|
||||
metadata: dict = {}
|
||||
if hasattr(response, "headers"):
|
||||
try:
|
||||
@ -90,7 +90,7 @@ def _generate_response_from_error(error: BaseException) -> list[AIMessageV1]:
|
||||
if hasattr(response, "status_code"):
|
||||
metadata["status_code"] = response.status_code
|
||||
if hasattr(error, "request_id"):
|
||||
metadata["request_id"] = error.request_id
|
||||
metadata["request_id"] = error.request_id # type: ignore[arg-type]
|
||||
# Permit response_metadata without model_name, model_provider fields
|
||||
generations = [AIMessageV1(content=[], response_metadata=metadata)] # type: ignore[arg-type]
|
||||
else:
|
||||
@ -118,7 +118,7 @@ def _format_for_tracing(messages: Sequence[MessageV1]) -> list[MessageV1]:
|
||||
for idx, block in enumerate(message.content):
|
||||
# Update image content blocks to OpenAI # Chat Completions format.
|
||||
if (
|
||||
block["type"] == "image"
|
||||
block.get("type") == "image"
|
||||
and is_data_content_block(block) # type: ignore[arg-type] # permit unnecessary runtime check
|
||||
and block.get("source_type") != "id"
|
||||
):
|
||||
@ -338,7 +338,7 @@ class BaseChatModelV1(RunnableSerializable[LanguageModelInput, AIMessageV1], ABC
|
||||
|
||||
@property
|
||||
@override
|
||||
def InputType(self) -> TypeAlias:
|
||||
def InputType(self) -> Any:
|
||||
"""Get the input type for this runnable."""
|
||||
from langchain_core.prompt_values import (
|
||||
ChatPromptValueConcrete,
|
||||
@ -458,7 +458,7 @@ class BaseChatModelV1(RunnableSerializable[LanguageModelInput, AIMessageV1], ABC
|
||||
chunks: list[AIMessageChunkV1] = []
|
||||
try:
|
||||
for msg in self._stream(input_messages, **kwargs):
|
||||
run_manager.on_llm_new_token(msg.text or "")
|
||||
run_manager.on_llm_new_token(msg.text)
|
||||
chunks.append(msg)
|
||||
except BaseException as e:
|
||||
run_manager.on_llm_error(e, response=_generate_response_from_error(e))
|
||||
@ -525,7 +525,7 @@ class BaseChatModelV1(RunnableSerializable[LanguageModelInput, AIMessageV1], ABC
|
||||
chunks: list[AIMessageChunkV1] = []
|
||||
try:
|
||||
async for msg in self._astream(input_messages, **kwargs):
|
||||
await run_manager.on_llm_new_token(msg.text or "")
|
||||
await run_manager.on_llm_new_token(msg.text)
|
||||
chunks.append(msg)
|
||||
except BaseException as e:
|
||||
await run_manager.on_llm_error(
|
||||
@ -602,9 +602,12 @@ class BaseChatModelV1(RunnableSerializable[LanguageModelInput, AIMessageV1], ABC
|
||||
# TODO: replace this with something for new messages
|
||||
input_messages = _normalize_messages_v1(messages)
|
||||
for msg in self._stream(input_messages, **kwargs):
|
||||
run_manager.on_llm_new_token(msg.text or "")
|
||||
run_manager.on_llm_new_token(msg.text)
|
||||
chunks.append(msg)
|
||||
yield msg
|
||||
|
||||
if msg.chunk_position != "last":
|
||||
yield (AIMessageChunkV1([], chunk_position="last"))
|
||||
except BaseException as e:
|
||||
run_manager.on_llm_error(e, response=_generate_response_from_error(e))
|
||||
raise
|
||||
@ -673,9 +676,11 @@ class BaseChatModelV1(RunnableSerializable[LanguageModelInput, AIMessageV1], ABC
|
||||
input_messages,
|
||||
**kwargs,
|
||||
):
|
||||
await run_manager.on_llm_new_token(msg.text or "")
|
||||
await run_manager.on_llm_new_token(msg.text)
|
||||
chunks.append(msg)
|
||||
yield msg
|
||||
if msg.chunk_position != "last":
|
||||
yield (AIMessageChunkV1([], chunk_position="last"))
|
||||
except BaseException as e:
|
||||
await run_manager.on_llm_error(e, response=_generate_response_from_error(e))
|
||||
raise
|
||||
@ -716,22 +721,23 @@ class BaseChatModelV1(RunnableSerializable[LanguageModelInput, AIMessageV1], ABC
|
||||
ls_params["ls_stop"] = stop
|
||||
|
||||
# model
|
||||
if hasattr(self, "model") and isinstance(self.model, str):
|
||||
ls_params["ls_model_name"] = self.model
|
||||
elif hasattr(self, "model_name") and isinstance(self.model_name, str):
|
||||
ls_params["ls_model_name"] = self.model_name
|
||||
model = (
|
||||
kwargs.get("model")
|
||||
or getattr(self, "model", None)
|
||||
or getattr(self, "model_name", None)
|
||||
)
|
||||
if isinstance(model, str):
|
||||
ls_params["ls_model_name"] = model
|
||||
|
||||
# temperature
|
||||
if "temperature" in kwargs and isinstance(kwargs["temperature"], float):
|
||||
ls_params["ls_temperature"] = kwargs["temperature"]
|
||||
elif hasattr(self, "temperature") and isinstance(self.temperature, float):
|
||||
ls_params["ls_temperature"] = self.temperature
|
||||
temperature = kwargs.get("temperature") or getattr(self, "temperature", None)
|
||||
if isinstance(temperature, (int, float)):
|
||||
ls_params["ls_temperature"] = temperature
|
||||
|
||||
# max_tokens
|
||||
if "max_tokens" in kwargs and isinstance(kwargs["max_tokens"], int):
|
||||
ls_params["ls_max_tokens"] = kwargs["max_tokens"]
|
||||
elif hasattr(self, "max_tokens") and isinstance(self.max_tokens, int):
|
||||
ls_params["ls_max_tokens"] = self.max_tokens
|
||||
max_tokens = kwargs.get("max_tokens") or getattr(self, "max_tokens", None)
|
||||
if isinstance(max_tokens, int):
|
||||
ls_params["ls_max_tokens"] = max_tokens
|
||||
|
||||
return ls_params
|
||||
|
||||
@ -806,7 +812,7 @@ class BaseChatModelV1(RunnableSerializable[LanguageModelInput, AIMessageV1], ABC
|
||||
Union[typing.Dict[str, Any], type, Callable, BaseTool] # noqa: UP006
|
||||
],
|
||||
*,
|
||||
tool_choice: Optional[Union[str]] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, AIMessageV1]:
|
||||
"""Bind tools to the model.
|
||||
|
@ -103,7 +103,7 @@ The module defines several types of content blocks, including:
|
||||
""" # noqa: E501
|
||||
|
||||
import warnings
|
||||
from typing import Any, Literal, Optional, Union
|
||||
from typing import Any, Literal, Optional, TypeGuard, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from typing_extensions import NotRequired, TypedDict, get_args, get_origin
|
||||
@ -844,8 +844,6 @@ ContentBlock = Union[
|
||||
TextContentBlock,
|
||||
ToolCall,
|
||||
ToolCallChunk,
|
||||
Citation,
|
||||
NonStandardAnnotation,
|
||||
InvalidToolCall,
|
||||
ReasoningContentBlock,
|
||||
NonStandardContentBlock,
|
||||
@ -884,7 +882,24 @@ def _extract_typedict_type_values(union_type: Any) -> set[str]:
|
||||
return result
|
||||
|
||||
|
||||
KNOWN_BLOCK_TYPES = _extract_typedict_type_values(ContentBlock)
|
||||
KNOWN_BLOCK_TYPES = {
|
||||
"text",
|
||||
"text-plain",
|
||||
"tool_call",
|
||||
"invalid_tool_call",
|
||||
"tool_call_chunk",
|
||||
"reasoning",
|
||||
"non_standard",
|
||||
"image",
|
||||
"audio",
|
||||
"file",
|
||||
"video",
|
||||
"code_interpreter_call",
|
||||
"code_interpreter_output",
|
||||
"code_interpreter_result",
|
||||
"web_search_call",
|
||||
"web_search_result",
|
||||
}
|
||||
|
||||
|
||||
def is_data_content_block(block: dict) -> bool:
|
||||
@ -914,6 +929,28 @@ def is_data_content_block(block: dict) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def is_tool_call_block(block: ContentBlock) -> TypeGuard[ToolCall]:
|
||||
"""Type guard to check if a content block is a tool call."""
|
||||
return block.get("type") == "tool_call"
|
||||
|
||||
|
||||
def is_tool_call_chunk(block: ContentBlock) -> TypeGuard[ToolCallChunk]:
|
||||
"""Type guard to check if a content block is a tool call chunk."""
|
||||
return block.get("type") == "tool_call_chunk"
|
||||
|
||||
|
||||
def is_text_block(block: ContentBlock) -> TypeGuard[TextContentBlock]:
|
||||
"""Type guard to check if a content block is a text block."""
|
||||
return block.get("type") == "text"
|
||||
|
||||
|
||||
def is_invalid_tool_call_block(
|
||||
block: ContentBlock,
|
||||
) -> TypeGuard[InvalidToolCall]:
|
||||
"""Type guard to check if a content block is an invalid tool call."""
|
||||
return block.get("type") == "invalid_tool_call"
|
||||
|
||||
|
||||
def convert_to_openai_image_block(block: dict[str, Any]) -> dict:
|
||||
"""Convert image content block to format expected by OpenAI Chat Completions API."""
|
||||
if "url" in block:
|
||||
|
@ -4,10 +4,9 @@ Each message has content that may be comprised of content blocks, defined under
|
||||
``langchain_core.messages.content_blocks``.
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Literal, Optional, TypeGuard, Union, cast, get_args
|
||||
from typing import Any, Literal, Optional, Union, cast, get_args
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import TypedDict
|
||||
@ -20,32 +19,12 @@ from langchain_core.messages.ai import (
|
||||
add_usage,
|
||||
)
|
||||
from langchain_core.messages.base import merge_content
|
||||
from langchain_core.messages.content_blocks import create_text_block
|
||||
from langchain_core.messages.tool import ToolCallChunk
|
||||
from langchain_core.messages.tool import invalid_tool_call as create_invalid_tool_call
|
||||
from langchain_core.messages.tool import tool_call as create_tool_call
|
||||
from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
|
||||
from langchain_core.utils._merge import merge_dicts, merge_lists
|
||||
from langchain_core.utils._merge import merge_dicts
|
||||
from langchain_core.utils.json import parse_partial_json
|
||||
|
||||
|
||||
def is_tool_call_block(block: types.ContentBlock) -> TypeGuard[types.ToolCall]:
|
||||
"""Type guard to check if a content block is a tool call."""
|
||||
return block.get("type") == "tool_call"
|
||||
|
||||
|
||||
def is_text_block(block: types.ContentBlock) -> TypeGuard[types.TextContentBlock]:
|
||||
"""Type guard to check if a content block is a text block."""
|
||||
return block.get("type") == "text"
|
||||
|
||||
|
||||
def is_invalid_tool_call_block(
|
||||
block: types.ContentBlock,
|
||||
) -> TypeGuard[types.InvalidToolCall]:
|
||||
"""Type guard to check if a content block is an invalid tool call."""
|
||||
return block.get("type") == "invalid_tool_call"
|
||||
|
||||
|
||||
def _ensure_id(id_val: Optional[str]) -> str:
|
||||
"""Ensure the ID is a valid string, generating a new UUID if not provided.
|
||||
|
||||
@ -169,7 +148,7 @@ class AIMessage:
|
||||
parsed: Optional auto-parsed message contents, if applicable.
|
||||
"""
|
||||
if isinstance(content, str):
|
||||
self.content = [create_text_block(content)]
|
||||
self.content = [types.create_text_block(content)]
|
||||
else:
|
||||
self.content = content
|
||||
|
||||
@ -188,33 +167,46 @@ class AIMessage:
|
||||
content_tool_calls = {
|
||||
block["id"]
|
||||
for block in self.content
|
||||
if block.get("type") == "tool_call" and "id" in block
|
||||
if types.is_tool_call_block(block) and "id" in block
|
||||
}
|
||||
for tool_call in tool_calls:
|
||||
if "id" in tool_call and tool_call["id"] in content_tool_calls:
|
||||
continue
|
||||
self.content.append(tool_call)
|
||||
if invalid_tool_calls:
|
||||
content_tool_calls = {
|
||||
block["id"]
|
||||
for block in self.content
|
||||
if types.is_invalid_tool_call_block(block) and "id" in block
|
||||
}
|
||||
for invalid_tool_call in invalid_tool_calls:
|
||||
if (
|
||||
"id" in invalid_tool_call
|
||||
and invalid_tool_call["id"] in content_tool_calls
|
||||
):
|
||||
continue
|
||||
self.content.append(invalid_tool_call)
|
||||
self._tool_calls: list[types.ToolCall] = [
|
||||
block for block in self.content if is_tool_call_block(block)
|
||||
block for block in self.content if types.is_tool_call_block(block)
|
||||
]
|
||||
self._invalid_tool_calls = [
|
||||
block for block in self.content if types.is_invalid_tool_call_block(block)
|
||||
]
|
||||
self.invalid_tool_calls = invalid_tool_calls or []
|
||||
|
||||
@property
|
||||
def text(self) -> Optional[str]:
|
||||
def text(self) -> str:
|
||||
"""Extract all text content from the AI message as a string."""
|
||||
text_blocks = [block for block in self.content if is_text_block(block)]
|
||||
if text_blocks:
|
||||
return "".join(block["text"] for block in text_blocks)
|
||||
return None
|
||||
return "".join(
|
||||
block["text"] for block in self.content if types.is_text_block(block)
|
||||
)
|
||||
|
||||
@property
|
||||
def tool_calls(self) -> list[types.ToolCall]: # update once we fix branch
|
||||
def tool_calls(self) -> list[types.ToolCall]:
|
||||
"""Get the tool calls made by the AI."""
|
||||
if self._tool_calls:
|
||||
return self._tool_calls
|
||||
tool_calls = [block for block in self.content if is_tool_call_block(block)]
|
||||
if tool_calls:
|
||||
self._tool_calls = tool_calls
|
||||
if not self._tool_calls:
|
||||
self._tool_calls = [
|
||||
block for block in self.content if types.is_tool_call_block(block)
|
||||
]
|
||||
return self._tool_calls
|
||||
|
||||
@tool_calls.setter
|
||||
@ -222,6 +214,17 @@ class AIMessage:
|
||||
"""Set the tool calls for the AI message."""
|
||||
self._tool_calls = value
|
||||
|
||||
@property
|
||||
def invalid_tool_calls(self) -> list[types.InvalidToolCall]:
|
||||
"""Get the invalid tool calls made by the AI."""
|
||||
if not self._invalid_tool_calls:
|
||||
self._invalid_tool_calls = [
|
||||
block
|
||||
for block in self.content
|
||||
if types.is_invalid_tool_call_block(block)
|
||||
]
|
||||
return self._invalid_tool_calls
|
||||
|
||||
|
||||
@dataclass
|
||||
class AIMessageChunk(AIMessage):
|
||||
@ -246,17 +249,10 @@ class AIMessageChunk(AIMessage):
|
||||
when deserializing messages.
|
||||
"""
|
||||
|
||||
tool_call_chunks: list[types.ToolCallChunk] = field(init=False)
|
||||
"""List of partial tool call data.
|
||||
|
||||
Emitted by the model during streaming, this field contains
|
||||
tool call chunks that may not yet be complete. It is used to reconstruct
|
||||
tool calls from the streamed content.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
content: Union[str, list[types.ContentBlock]],
|
||||
*,
|
||||
id: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
lc_version: str = "v1",
|
||||
@ -264,6 +260,7 @@ class AIMessageChunk(AIMessage):
|
||||
usage_metadata: Optional[UsageMetadata] = None,
|
||||
tool_call_chunks: Optional[list[types.ToolCallChunk]] = None,
|
||||
parsed: Optional[Union[dict[str, Any], BaseModel]] = None,
|
||||
chunk_position: Optional[Literal["last"]] = None,
|
||||
):
|
||||
"""Initialize an AI message.
|
||||
|
||||
@ -276,6 +273,8 @@ class AIMessageChunk(AIMessage):
|
||||
usage_metadata: Optional metadata about token usage.
|
||||
tool_call_chunks: Optional list of partial tool call data.
|
||||
parsed: Optional auto-parsed message contents, if applicable.
|
||||
chunk_position: Optional position of the chunk in the stream. If "last",
|
||||
tool calls will be parsed when aggregated into a stream.
|
||||
"""
|
||||
if isinstance(content, str):
|
||||
self.content = [{"type": "text", "text": content, "index": 0}]
|
||||
@ -287,115 +286,51 @@ class AIMessageChunk(AIMessage):
|
||||
self.lc_version = lc_version
|
||||
self.usage_metadata = usage_metadata
|
||||
self.parsed = parsed
|
||||
self.chunk_position = chunk_position
|
||||
if response_metadata is None:
|
||||
self.response_metadata = {}
|
||||
else:
|
||||
self.response_metadata = response_metadata
|
||||
if tool_call_chunks is None:
|
||||
self.tool_call_chunks: list[types.ToolCallChunk] = []
|
||||
else:
|
||||
self.tool_call_chunks = tool_call_chunks
|
||||
|
||||
if tool_call_chunks:
|
||||
content_tool_call_chunks = {
|
||||
block["id"]
|
||||
for block in self.content
|
||||
if types.is_tool_call_chunk(block) and "id" in block
|
||||
}
|
||||
for chunk in tool_call_chunks:
|
||||
if "id" in chunk and chunk["id"] in content_tool_call_chunks:
|
||||
continue
|
||||
self.content.append(chunk)
|
||||
self._tool_call_chunks = [
|
||||
block for block in self.content if types.is_tool_call_chunk(block)
|
||||
]
|
||||
|
||||
self._tool_calls: list[types.ToolCall] = []
|
||||
self.invalid_tool_calls: list[types.InvalidToolCall] = []
|
||||
self._init_tool_calls()
|
||||
|
||||
def _init_tool_calls(self) -> None:
|
||||
"""Initialize tool calls from tool call chunks.
|
||||
|
||||
Args:
|
||||
values: The values to validate.
|
||||
|
||||
Raises:
|
||||
ValueError: If the tool call chunks are malformed.
|
||||
"""
|
||||
self._tool_calls = []
|
||||
self.invalid_tool_calls = []
|
||||
if not self.tool_call_chunks:
|
||||
if self._tool_calls:
|
||||
self.tool_call_chunks = [
|
||||
create_tool_call_chunk(
|
||||
name=tc["name"],
|
||||
args=json.dumps(tc["args"]),
|
||||
id=tc["id"],
|
||||
index=None,
|
||||
)
|
||||
for tc in self._tool_calls
|
||||
]
|
||||
if self.invalid_tool_calls:
|
||||
tool_call_chunks = self.tool_call_chunks
|
||||
tool_call_chunks.extend(
|
||||
[
|
||||
create_tool_call_chunk(
|
||||
name=tc["name"], args=tc["args"], id=tc["id"], index=None
|
||||
)
|
||||
for tc in self.invalid_tool_calls
|
||||
]
|
||||
)
|
||||
self.tool_call_chunks = tool_call_chunks
|
||||
|
||||
tool_calls = []
|
||||
invalid_tool_calls = []
|
||||
|
||||
def add_chunk_to_invalid_tool_calls(chunk: ToolCallChunk) -> None:
|
||||
invalid_tool_calls.append(
|
||||
create_invalid_tool_call(
|
||||
name=chunk.get("name", ""),
|
||||
args=chunk.get("args", ""),
|
||||
id=chunk.get("id", ""),
|
||||
error=None,
|
||||
)
|
||||
)
|
||||
|
||||
for chunk in self.tool_call_chunks:
|
||||
try:
|
||||
args_ = parse_partial_json(chunk["args"]) if chunk["args"] != "" else {} # type: ignore[arg-type]
|
||||
if isinstance(args_, dict):
|
||||
tool_calls.append(
|
||||
create_tool_call(
|
||||
name=chunk.get("name") or "",
|
||||
args=args_,
|
||||
id=chunk.get("id", ""),
|
||||
)
|
||||
)
|
||||
else:
|
||||
add_chunk_to_invalid_tool_calls(chunk)
|
||||
except Exception:
|
||||
add_chunk_to_invalid_tool_calls(chunk)
|
||||
self._tool_calls = tool_calls
|
||||
self.invalid_tool_calls = invalid_tool_calls
|
||||
self._invalid_tool_calls: list[types.InvalidToolCall] = []
|
||||
|
||||
@property
|
||||
def text(self) -> Optional[str]:
|
||||
"""Extract all text content from the AI message as a string."""
|
||||
text_blocks = [block for block in self.content if is_text_block(block)]
|
||||
if text_blocks:
|
||||
return "".join(block["text"] for block in text_blocks)
|
||||
return None
|
||||
|
||||
@property
|
||||
def reasoning(self) -> Optional[str]:
|
||||
"""Extract all reasoning text from the AI message as a string."""
|
||||
text_blocks = [
|
||||
block
|
||||
for block in self.content
|
||||
if block.get("type") == "reasoning" and "reasoning" in block
|
||||
def tool_call_chunks(self) -> list[types.ToolCallChunk]:
|
||||
"""Get the tool calls made by the AI."""
|
||||
if not self._tool_call_chunks:
|
||||
self._tool_call_chunks = [
|
||||
block for block in self.content if types.is_tool_call_chunk(block)
|
||||
]
|
||||
if text_blocks:
|
||||
return "".join(
|
||||
cast("types.ReasoningContentBlock", block).get("reasoning", "")
|
||||
for block in text_blocks
|
||||
)
|
||||
return None
|
||||
return cast("list[types.ToolCallChunk]", self._tool_call_chunks)
|
||||
|
||||
@property
|
||||
def tool_calls(self) -> list[types.ToolCall]:
|
||||
"""Get the tool calls made by the AI."""
|
||||
if self._tool_calls:
|
||||
return self._tool_calls
|
||||
tool_calls = [block for block in self.content if is_tool_call_block(block)]
|
||||
if tool_calls:
|
||||
self._tool_calls = tool_calls
|
||||
if not self._tool_calls:
|
||||
parsed_content = _init_tool_calls(self.content)
|
||||
self._tool_calls = [
|
||||
block for block in parsed_content if types.is_tool_call_block(block)
|
||||
]
|
||||
self._invalid_tool_calls = [
|
||||
block
|
||||
for block in parsed_content
|
||||
if types.is_invalid_tool_call_block(block)
|
||||
]
|
||||
return self._tool_calls
|
||||
|
||||
@tool_calls.setter
|
||||
@ -403,6 +338,21 @@ class AIMessageChunk(AIMessage):
|
||||
"""Set the tool calls for the AI message."""
|
||||
self._tool_calls = value
|
||||
|
||||
@property
|
||||
def invalid_tool_calls(self) -> list[types.InvalidToolCall]:
|
||||
"""Get the invalid tool calls made by the AI."""
|
||||
if not self._invalid_tool_calls:
|
||||
parsed_content = _init_tool_calls(self.content)
|
||||
self._tool_calls = [
|
||||
block for block in parsed_content if types.is_tool_call_block(block)
|
||||
]
|
||||
self._invalid_tool_calls = [
|
||||
block
|
||||
for block in parsed_content
|
||||
if types.is_invalid_tool_call_block(block)
|
||||
]
|
||||
return self._invalid_tool_calls
|
||||
|
||||
def __add__(self, other: Any) -> "AIMessageChunk":
|
||||
"""Add ``AIMessageChunk`` to this one."""
|
||||
if isinstance(other, AIMessageChunk):
|
||||
@ -417,49 +367,76 @@ class AIMessageChunk(AIMessage):
|
||||
def to_message(self) -> "AIMessage":
|
||||
"""Convert this ``AIMessageChunk`` to an AIMessage."""
|
||||
return AIMessage(
|
||||
content=self.content,
|
||||
content=_init_tool_calls(self.content),
|
||||
id=self.id,
|
||||
name=self.name,
|
||||
lc_version=self.lc_version,
|
||||
response_metadata=self.response_metadata,
|
||||
usage_metadata=self.usage_metadata,
|
||||
tool_calls=self.tool_calls,
|
||||
invalid_tool_calls=self.invalid_tool_calls,
|
||||
parsed=self.parsed,
|
||||
)
|
||||
|
||||
|
||||
def _init_tool_calls(content: list[types.ContentBlock]) -> list[types.ContentBlock]:
|
||||
"""Parse tool call chunks in content into tool calls."""
|
||||
new_content = []
|
||||
for block in content:
|
||||
if not types.is_tool_call_chunk(block):
|
||||
new_content.append(block)
|
||||
continue
|
||||
try:
|
||||
args_ = (
|
||||
parse_partial_json(cast("str", block.get("args") or ""))
|
||||
if block.get("args")
|
||||
else {}
|
||||
)
|
||||
if isinstance(args_, dict):
|
||||
new_content.append(
|
||||
create_tool_call(
|
||||
name=cast("str", block.get("name") or ""),
|
||||
args=args_,
|
||||
id=cast("str", block.get("id", "")),
|
||||
)
|
||||
)
|
||||
else:
|
||||
new_content.append(
|
||||
create_invalid_tool_call(
|
||||
name=cast("str", block.get("name", "")),
|
||||
args=cast("str", block.get("args", "")),
|
||||
id=cast("str", block.get("id", "")),
|
||||
error=None,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
new_content.append(
|
||||
create_invalid_tool_call(
|
||||
name=cast("str", block.get("name", "")),
|
||||
args=cast("str", block.get("args", "")),
|
||||
id=cast("str", block.get("id", "")),
|
||||
error=None,
|
||||
)
|
||||
)
|
||||
return new_content
|
||||
|
||||
|
||||
def add_ai_message_chunks(
|
||||
left: AIMessageChunk, *others: AIMessageChunk
|
||||
) -> AIMessageChunk:
|
||||
"""Add multiple ``AIMessageChunks`` together."""
|
||||
if not others:
|
||||
return left
|
||||
content = merge_content(
|
||||
content = cast(
|
||||
"list[types.ContentBlock]",
|
||||
merge_content(
|
||||
cast("list[str | dict[Any, Any]]", left.content),
|
||||
*(cast("list[str | dict[Any, Any]]", o.content) for o in others),
|
||||
),
|
||||
)
|
||||
response_metadata = merge_dicts(
|
||||
cast("dict", left.response_metadata),
|
||||
*(cast("dict", o.response_metadata) for o in others),
|
||||
)
|
||||
|
||||
# Merge tool call chunks
|
||||
if raw_tool_calls := merge_lists(
|
||||
left.tool_call_chunks, *(o.tool_call_chunks for o in others)
|
||||
):
|
||||
tool_call_chunks = [
|
||||
create_tool_call_chunk(
|
||||
name=rtc.get("name"),
|
||||
args=rtc.get("args"),
|
||||
index=rtc.get("index"),
|
||||
id=rtc.get("id"),
|
||||
)
|
||||
for rtc in raw_tool_calls
|
||||
]
|
||||
else:
|
||||
tool_call_chunks = []
|
||||
|
||||
# Token usage
|
||||
if left.usage_metadata or any(o.usage_metadata is not None for o in others):
|
||||
usage_metadata: Optional[UsageMetadata] = left.usage_metadata
|
||||
@ -501,13 +478,19 @@ def add_ai_message_chunks(
|
||||
chunk_id = id_
|
||||
break
|
||||
|
||||
chunk_position: Optional[Literal["last"]] = (
|
||||
"last" if any(x.chunk_position == "last" for x in [left, *others]) else None
|
||||
)
|
||||
if chunk_position == "last":
|
||||
content = _init_tool_calls(content)
|
||||
|
||||
return left.__class__(
|
||||
content=cast("list[types.ContentBlock]", content),
|
||||
tool_call_chunks=tool_call_chunks,
|
||||
content=content,
|
||||
response_metadata=cast("ResponseMetadata", response_metadata),
|
||||
usage_metadata=usage_metadata,
|
||||
parsed=parsed,
|
||||
id=chunk_id,
|
||||
chunk_position=chunk_position,
|
||||
)
|
||||
|
||||
|
||||
@ -579,9 +562,7 @@ class HumanMessage:
|
||||
Concatenated string of all text blocks in the message.
|
||||
"""
|
||||
return "".join(
|
||||
cast("types.TextContentBlock", block)["text"]
|
||||
for block in self.content
|
||||
if block.get("type") == "text"
|
||||
block["text"] for block in self.content if types.is_text_block(block)
|
||||
)
|
||||
|
||||
|
||||
@ -660,9 +641,7 @@ class SystemMessage:
|
||||
def text(self) -> str:
|
||||
"""Extract all text content from the system message."""
|
||||
return "".join(
|
||||
cast("types.TextContentBlock", block)["text"]
|
||||
for block in self.content
|
||||
if block.get("type") == "text"
|
||||
block["text"] for block in self.content if types.is_text_block(block)
|
||||
)
|
||||
|
||||
|
||||
@ -754,9 +733,7 @@ class ToolMessage:
|
||||
def text(self) -> str:
|
||||
"""Extract all text content from the tool message."""
|
||||
return "".join(
|
||||
cast("types.TextContentBlock", block)["text"]
|
||||
for block in self.content
|
||||
if block.get("type") == "text"
|
||||
block["text"] for block in self.content if types.is_text_block(block)
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
|
@ -283,7 +283,7 @@ class BaseOutputParser(
|
||||
Structured output.
|
||||
"""
|
||||
if isinstance(result, AIMessage):
|
||||
return self.parse(result.text or "")
|
||||
return self.parse(result.text)
|
||||
return self.parse(result[0].text)
|
||||
|
||||
@abstractmethod
|
||||
|
@ -73,7 +73,7 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
|
||||
Raises:
|
||||
OutputParserException: If the output is not valid JSON.
|
||||
"""
|
||||
text = result.text or "" if isinstance(result, AIMessage) else result[0].text
|
||||
text = result.text if isinstance(result, AIMessage) else result[0].text
|
||||
text = text.strip()
|
||||
if partial:
|
||||
try:
|
||||
|
@ -83,7 +83,7 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]):
|
||||
continue
|
||||
buffer += chunk_content
|
||||
elif isinstance(chunk, AIMessage):
|
||||
buffer += chunk.text or ""
|
||||
buffer += chunk.text
|
||||
else:
|
||||
# add current chunk to buffer
|
||||
buffer += chunk
|
||||
@ -119,7 +119,7 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]):
|
||||
continue
|
||||
buffer += chunk_content
|
||||
elif isinstance(chunk, AIMessage):
|
||||
buffer += chunk.text or ""
|
||||
buffer += chunk.text
|
||||
else:
|
||||
# add current chunk to buffer
|
||||
buffer += chunk
|
||||
|
@ -14,7 +14,10 @@ from langchain_core.language_models import (
|
||||
ParrotFakeChatModel,
|
||||
)
|
||||
from langchain_core.language_models._utils import _normalize_messages
|
||||
from langchain_core.language_models.fake_chat_models import FakeListChatModelError
|
||||
from langchain_core.language_models.fake_chat_models import (
|
||||
FakeListChatModelError,
|
||||
GenericFakeChatModelV1,
|
||||
)
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
@ -22,6 +25,7 @@ from langchain_core.messages import (
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.outputs.llm_result import LLMResult
|
||||
from langchain_core.tracers import LogStreamCallbackHandler
|
||||
@ -654,3 +658,93 @@ def test_normalize_messages_edge_cases() -> None:
|
||||
)
|
||||
]
|
||||
assert messages == _normalize_messages(messages)
|
||||
|
||||
|
||||
def test_streaming_v1() -> None:
|
||||
chunks = [
|
||||
AIMessageChunkV1(
|
||||
[
|
||||
{
|
||||
"type": "reasoning",
|
||||
"reasoning": "Let's call a tool.",
|
||||
"index": 0,
|
||||
}
|
||||
]
|
||||
),
|
||||
AIMessageChunkV1(
|
||||
[],
|
||||
tool_call_chunks=[
|
||||
{
|
||||
"type": "tool_call_chunk",
|
||||
"args": "",
|
||||
"name": "tool_name",
|
||||
"id": "call_123",
|
||||
"index": 1,
|
||||
},
|
||||
],
|
||||
),
|
||||
AIMessageChunkV1(
|
||||
[],
|
||||
tool_call_chunks=[
|
||||
{
|
||||
"type": "tool_call_chunk",
|
||||
"args": '{"a',
|
||||
"name": "",
|
||||
"id": "",
|
||||
"index": 1,
|
||||
},
|
||||
],
|
||||
),
|
||||
AIMessageChunkV1(
|
||||
[],
|
||||
tool_call_chunks=[
|
||||
{
|
||||
"type": "tool_call_chunk",
|
||||
"args": '": 1}',
|
||||
"name": "",
|
||||
"id": "",
|
||||
"index": 1,
|
||||
},
|
||||
],
|
||||
),
|
||||
]
|
||||
full: Optional[AIMessageChunkV1] = None
|
||||
for chunk in chunks:
|
||||
full = chunk if full is None else full + chunk
|
||||
|
||||
assert isinstance(full, AIMessageChunkV1)
|
||||
assert full.content == [
|
||||
{
|
||||
"type": "reasoning",
|
||||
"reasoning": "Let's call a tool.",
|
||||
"index": 0,
|
||||
},
|
||||
{
|
||||
"type": "tool_call_chunk",
|
||||
"args": '{"a": 1}',
|
||||
"name": "tool_name",
|
||||
"id": "call_123",
|
||||
"index": 1,
|
||||
},
|
||||
]
|
||||
|
||||
llm = GenericFakeChatModelV1(message_chunks=chunks)
|
||||
|
||||
full = None
|
||||
for chunk in llm.stream("anything"):
|
||||
full = chunk if full is None else full + chunk
|
||||
|
||||
assert isinstance(full, AIMessageChunkV1)
|
||||
assert full.content == [
|
||||
{
|
||||
"type": "reasoning",
|
||||
"reasoning": "Let's call a tool.",
|
||||
"index": 0,
|
||||
},
|
||||
{
|
||||
"type": "tool_call",
|
||||
"args": {"a": 1},
|
||||
"name": "tool_name",
|
||||
"id": "call_123",
|
||||
},
|
||||
]
|
||||
|
@ -37,7 +37,7 @@ def test_base_generation_parser() -> None:
|
||||
that support streaming
|
||||
"""
|
||||
if isinstance(result, AIMessageV1):
|
||||
content = result.text or ""
|
||||
content = result.text
|
||||
else:
|
||||
if len(result) != 1:
|
||||
msg = (
|
||||
@ -89,7 +89,7 @@ def test_base_transform_output_parser() -> None:
|
||||
that support streaming
|
||||
"""
|
||||
if isinstance(result, AIMessageV1):
|
||||
content = result.text or ""
|
||||
content = result.text
|
||||
else:
|
||||
if len(result) != 1:
|
||||
msg = (
|
||||
@ -116,4 +116,4 @@ def test_base_transform_output_parser() -> None:
|
||||
model_v1 = GenericFakeChatModelV1(message_chunks=["hello", " ", "world"])
|
||||
chain_v1 = model_v1 | StrInvertCase()
|
||||
chunks = list(chain_v1.stream(""))
|
||||
assert chunks == ["HELLO", " ", "WORLD"]
|
||||
assert chunks == ["HELLO", " ", "WORLD", ""]
|
||||
|
@ -2543,12 +2543,6 @@
|
||||
dict({
|
||||
'$ref': '#/$defs/ToolCallChunk',
|
||||
}),
|
||||
dict({
|
||||
'$ref': '#/$defs/Citation',
|
||||
}),
|
||||
dict({
|
||||
'$ref': '#/$defs/NonStandardAnnotation',
|
||||
}),
|
||||
dict({
|
||||
'$ref': '#/$defs/InvalidToolCall',
|
||||
}),
|
||||
@ -2675,12 +2669,6 @@
|
||||
dict({
|
||||
'$ref': '#/$defs/ToolCallChunk',
|
||||
}),
|
||||
dict({
|
||||
'$ref': '#/$defs/Citation',
|
||||
}),
|
||||
dict({
|
||||
'$ref': '#/$defs/NonStandardAnnotation',
|
||||
}),
|
||||
dict({
|
||||
'$ref': '#/$defs/InvalidToolCall',
|
||||
}),
|
||||
@ -2807,12 +2795,6 @@
|
||||
dict({
|
||||
'$ref': '#/$defs/ToolCallChunk',
|
||||
}),
|
||||
dict({
|
||||
'$ref': '#/$defs/Citation',
|
||||
}),
|
||||
dict({
|
||||
'$ref': '#/$defs/NonStandardAnnotation',
|
||||
}),
|
||||
dict({
|
||||
'$ref': '#/$defs/InvalidToolCall',
|
||||
}),
|
||||
@ -2901,12 +2883,6 @@
|
||||
dict({
|
||||
'$ref': '#/$defs/ToolCallChunk',
|
||||
}),
|
||||
dict({
|
||||
'$ref': '#/$defs/Citation',
|
||||
}),
|
||||
dict({
|
||||
'$ref': '#/$defs/NonStandardAnnotation',
|
||||
}),
|
||||
dict({
|
||||
'$ref': '#/$defs/InvalidToolCall',
|
||||
}),
|
||||
@ -3018,12 +2994,6 @@
|
||||
dict({
|
||||
'$ref': '#/$defs/ToolCallChunk',
|
||||
}),
|
||||
dict({
|
||||
'$ref': '#/$defs/Citation',
|
||||
}),
|
||||
dict({
|
||||
'$ref': '#/$defs/NonStandardAnnotation',
|
||||
}),
|
||||
dict({
|
||||
'$ref': '#/$defs/InvalidToolCall',
|
||||
}),
|
||||
|
@ -11535,6 +11535,9 @@
|
||||
dict({
|
||||
'$ref': '#/definitions/InvalidToolCall',
|
||||
}),
|
||||
dict({
|
||||
'$ref': '#/definitions/ToolCallChunk',
|
||||
}),
|
||||
dict({
|
||||
'$ref': '#/definitions/ReasoningContentBlock',
|
||||
}),
|
||||
@ -11657,6 +11660,9 @@
|
||||
dict({
|
||||
'$ref': '#/definitions/InvalidToolCall',
|
||||
}),
|
||||
dict({
|
||||
'$ref': '#/definitions/ToolCallChunk',
|
||||
}),
|
||||
dict({
|
||||
'$ref': '#/definitions/ReasoningContentBlock',
|
||||
}),
|
||||
@ -11779,6 +11785,9 @@
|
||||
dict({
|
||||
'$ref': '#/definitions/InvalidToolCall',
|
||||
}),
|
||||
dict({
|
||||
'$ref': '#/definitions/ToolCallChunk',
|
||||
}),
|
||||
dict({
|
||||
'$ref': '#/definitions/ReasoningContentBlock',
|
||||
}),
|
||||
@ -11863,6 +11872,9 @@
|
||||
dict({
|
||||
'$ref': '#/definitions/InvalidToolCall',
|
||||
}),
|
||||
dict({
|
||||
'$ref': '#/definitions/ToolCallChunk',
|
||||
}),
|
||||
dict({
|
||||
'$ref': '#/definitions/ReasoningContentBlock',
|
||||
}),
|
||||
@ -11970,6 +11982,9 @@
|
||||
dict({
|
||||
'$ref': '#/definitions/InvalidToolCall',
|
||||
}),
|
||||
dict({
|
||||
'$ref': '#/definitions/ToolCallChunk',
|
||||
}),
|
||||
dict({
|
||||
'$ref': '#/definitions/ReasoningContentBlock',
|
||||
}),
|
||||
|
@ -3,6 +3,7 @@ import uuid
|
||||
from typing import Optional, Union
|
||||
|
||||
import pytest
|
||||
from typing_extensions import get_args
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.load import dumpd, load
|
||||
@ -30,7 +31,7 @@ from langchain_core.messages import (
|
||||
messages_from_dict,
|
||||
messages_to_dict,
|
||||
)
|
||||
from langchain_core.messages.content_blocks import KNOWN_BLOCK_TYPES
|
||||
from langchain_core.messages.content_blocks import KNOWN_BLOCK_TYPES, ContentBlock
|
||||
from langchain_core.messages.tool import invalid_tool_call as create_invalid_tool_call
|
||||
from langchain_core.messages.tool import tool_call as create_tool_call
|
||||
from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
|
||||
@ -1363,23 +1364,19 @@ def test_convert_to_openai_image_block() -> None:
|
||||
|
||||
|
||||
def test_known_block_types() -> None:
|
||||
assert {
|
||||
"audio",
|
||||
"citation",
|
||||
"code_interpreter_call",
|
||||
"code_interpreter_output",
|
||||
"code_interpreter_result",
|
||||
"file",
|
||||
"image",
|
||||
"invalid_tool_call",
|
||||
"non_standard",
|
||||
"non_standard_annotation",
|
||||
"reasoning",
|
||||
"text",
|
||||
"text-plain",
|
||||
"tool_call",
|
||||
"tool_call_chunk",
|
||||
"video",
|
||||
"web_search_call",
|
||||
"web_search_result",
|
||||
} == KNOWN_BLOCK_TYPES
|
||||
expected = {
|
||||
bt
|
||||
for bt in get_args(ContentBlock)
|
||||
for bt in get_args(bt.__annotations__["type"])
|
||||
}
|
||||
# Normalize any Literal[...] types in block types to their string values.
|
||||
# This ensures all entries are plain strings, not Literal objects.
|
||||
expected = {
|
||||
t
|
||||
if isinstance(t, str)
|
||||
else t.__args__[0]
|
||||
if hasattr(t, "__args__") and len(t.__args__) == 1
|
||||
else t
|
||||
for t in expected
|
||||
}
|
||||
assert expected == KNOWN_BLOCK_TYPES
|
||||
|
Loading…
Reference in New Issue
Block a user