mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-14 15:16:21 +00:00
Merge branch 'standard_outputs_copy' into mdrxy/ollama_v1 + updates
This commit is contained in:
commit
cc56b8dbd3
@ -25,7 +25,7 @@ from pydantic import (
|
|||||||
Field,
|
Field,
|
||||||
field_validator,
|
field_validator,
|
||||||
)
|
)
|
||||||
from typing_extensions import TypeAlias, override
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_core.caches import BaseCache
|
from langchain_core.caches import BaseCache
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
@ -79,8 +79,8 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
|
|
||||||
def _generate_response_from_error(error: BaseException) -> list[AIMessageV1]:
|
def _generate_response_from_error(error: BaseException) -> list[AIMessageV1]:
|
||||||
if hasattr(error, "response"):
|
response = getattr(error, "response", None)
|
||||||
response = error.response
|
if response is not None:
|
||||||
metadata: dict = {}
|
metadata: dict = {}
|
||||||
if hasattr(response, "headers"):
|
if hasattr(response, "headers"):
|
||||||
try:
|
try:
|
||||||
@ -90,7 +90,7 @@ def _generate_response_from_error(error: BaseException) -> list[AIMessageV1]:
|
|||||||
if hasattr(response, "status_code"):
|
if hasattr(response, "status_code"):
|
||||||
metadata["status_code"] = response.status_code
|
metadata["status_code"] = response.status_code
|
||||||
if hasattr(error, "request_id"):
|
if hasattr(error, "request_id"):
|
||||||
metadata["request_id"] = error.request_id
|
metadata["request_id"] = error.request_id # type: ignore[arg-type]
|
||||||
# Permit response_metadata without model_name, model_provider fields
|
# Permit response_metadata without model_name, model_provider fields
|
||||||
generations = [AIMessageV1(content=[], response_metadata=metadata)] # type: ignore[arg-type]
|
generations = [AIMessageV1(content=[], response_metadata=metadata)] # type: ignore[arg-type]
|
||||||
else:
|
else:
|
||||||
@ -118,7 +118,7 @@ def _format_for_tracing(messages: Sequence[MessageV1]) -> list[MessageV1]:
|
|||||||
for idx, block in enumerate(message.content):
|
for idx, block in enumerate(message.content):
|
||||||
# Update image content blocks to OpenAI # Chat Completions format.
|
# Update image content blocks to OpenAI # Chat Completions format.
|
||||||
if (
|
if (
|
||||||
block["type"] == "image"
|
block.get("type") == "image"
|
||||||
and is_data_content_block(block) # type: ignore[arg-type] # permit unnecessary runtime check
|
and is_data_content_block(block) # type: ignore[arg-type] # permit unnecessary runtime check
|
||||||
and block.get("source_type") != "id"
|
and block.get("source_type") != "id"
|
||||||
):
|
):
|
||||||
@ -338,7 +338,7 @@ class BaseChatModelV1(RunnableSerializable[LanguageModelInput, AIMessageV1], ABC
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
@override
|
@override
|
||||||
def InputType(self) -> TypeAlias:
|
def InputType(self) -> Any:
|
||||||
"""Get the input type for this runnable."""
|
"""Get the input type for this runnable."""
|
||||||
from langchain_core.prompt_values import (
|
from langchain_core.prompt_values import (
|
||||||
ChatPromptValueConcrete,
|
ChatPromptValueConcrete,
|
||||||
@ -458,7 +458,7 @@ class BaseChatModelV1(RunnableSerializable[LanguageModelInput, AIMessageV1], ABC
|
|||||||
chunks: list[AIMessageChunkV1] = []
|
chunks: list[AIMessageChunkV1] = []
|
||||||
try:
|
try:
|
||||||
for msg in self._stream(input_messages, **kwargs):
|
for msg in self._stream(input_messages, **kwargs):
|
||||||
run_manager.on_llm_new_token(msg.text or "")
|
run_manager.on_llm_new_token(msg.text)
|
||||||
chunks.append(msg)
|
chunks.append(msg)
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
run_manager.on_llm_error(e, response=_generate_response_from_error(e))
|
run_manager.on_llm_error(e, response=_generate_response_from_error(e))
|
||||||
@ -525,7 +525,7 @@ class BaseChatModelV1(RunnableSerializable[LanguageModelInput, AIMessageV1], ABC
|
|||||||
chunks: list[AIMessageChunkV1] = []
|
chunks: list[AIMessageChunkV1] = []
|
||||||
try:
|
try:
|
||||||
async for msg in self._astream(input_messages, **kwargs):
|
async for msg in self._astream(input_messages, **kwargs):
|
||||||
await run_manager.on_llm_new_token(msg.text or "")
|
await run_manager.on_llm_new_token(msg.text)
|
||||||
chunks.append(msg)
|
chunks.append(msg)
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
await run_manager.on_llm_error(
|
await run_manager.on_llm_error(
|
||||||
@ -602,9 +602,12 @@ class BaseChatModelV1(RunnableSerializable[LanguageModelInput, AIMessageV1], ABC
|
|||||||
# TODO: replace this with something for new messages
|
# TODO: replace this with something for new messages
|
||||||
input_messages = _normalize_messages_v1(messages)
|
input_messages = _normalize_messages_v1(messages)
|
||||||
for msg in self._stream(input_messages, **kwargs):
|
for msg in self._stream(input_messages, **kwargs):
|
||||||
run_manager.on_llm_new_token(msg.text or "")
|
run_manager.on_llm_new_token(msg.text)
|
||||||
chunks.append(msg)
|
chunks.append(msg)
|
||||||
yield msg
|
yield msg
|
||||||
|
|
||||||
|
if msg.chunk_position != "last":
|
||||||
|
yield (AIMessageChunkV1([], chunk_position="last"))
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
run_manager.on_llm_error(e, response=_generate_response_from_error(e))
|
run_manager.on_llm_error(e, response=_generate_response_from_error(e))
|
||||||
raise
|
raise
|
||||||
@ -673,9 +676,11 @@ class BaseChatModelV1(RunnableSerializable[LanguageModelInput, AIMessageV1], ABC
|
|||||||
input_messages,
|
input_messages,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
await run_manager.on_llm_new_token(msg.text or "")
|
await run_manager.on_llm_new_token(msg.text)
|
||||||
chunks.append(msg)
|
chunks.append(msg)
|
||||||
yield msg
|
yield msg
|
||||||
|
if msg.chunk_position != "last":
|
||||||
|
yield (AIMessageChunkV1([], chunk_position="last"))
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
await run_manager.on_llm_error(e, response=_generate_response_from_error(e))
|
await run_manager.on_llm_error(e, response=_generate_response_from_error(e))
|
||||||
raise
|
raise
|
||||||
@ -716,22 +721,23 @@ class BaseChatModelV1(RunnableSerializable[LanguageModelInput, AIMessageV1], ABC
|
|||||||
ls_params["ls_stop"] = stop
|
ls_params["ls_stop"] = stop
|
||||||
|
|
||||||
# model
|
# model
|
||||||
if hasattr(self, "model") and isinstance(self.model, str):
|
model = (
|
||||||
ls_params["ls_model_name"] = self.model
|
kwargs.get("model")
|
||||||
elif hasattr(self, "model_name") and isinstance(self.model_name, str):
|
or getattr(self, "model", None)
|
||||||
ls_params["ls_model_name"] = self.model_name
|
or getattr(self, "model_name", None)
|
||||||
|
)
|
||||||
|
if isinstance(model, str):
|
||||||
|
ls_params["ls_model_name"] = model
|
||||||
|
|
||||||
# temperature
|
# temperature
|
||||||
if "temperature" in kwargs and isinstance(kwargs["temperature"], float):
|
temperature = kwargs.get("temperature") or getattr(self, "temperature", None)
|
||||||
ls_params["ls_temperature"] = kwargs["temperature"]
|
if isinstance(temperature, (int, float)):
|
||||||
elif hasattr(self, "temperature") and isinstance(self.temperature, float):
|
ls_params["ls_temperature"] = temperature
|
||||||
ls_params["ls_temperature"] = self.temperature
|
|
||||||
|
|
||||||
# max_tokens
|
# max_tokens
|
||||||
if "max_tokens" in kwargs and isinstance(kwargs["max_tokens"], int):
|
max_tokens = kwargs.get("max_tokens") or getattr(self, "max_tokens", None)
|
||||||
ls_params["ls_max_tokens"] = kwargs["max_tokens"]
|
if isinstance(max_tokens, int):
|
||||||
elif hasattr(self, "max_tokens") and isinstance(self.max_tokens, int):
|
ls_params["ls_max_tokens"] = max_tokens
|
||||||
ls_params["ls_max_tokens"] = self.max_tokens
|
|
||||||
|
|
||||||
return ls_params
|
return ls_params
|
||||||
|
|
||||||
@ -806,7 +812,7 @@ class BaseChatModelV1(RunnableSerializable[LanguageModelInput, AIMessageV1], ABC
|
|||||||
Union[typing.Dict[str, Any], type, Callable, BaseTool] # noqa: UP006
|
Union[typing.Dict[str, Any], type, Callable, BaseTool] # noqa: UP006
|
||||||
],
|
],
|
||||||
*,
|
*,
|
||||||
tool_choice: Optional[Union[str]] = None,
|
tool_choice: Optional[str] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Runnable[LanguageModelInput, AIMessageV1]:
|
) -> Runnable[LanguageModelInput, AIMessageV1]:
|
||||||
"""Bind tools to the model.
|
"""Bind tools to the model.
|
||||||
|
@ -103,7 +103,7 @@ The module defines several types of content blocks, including:
|
|||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Literal, Optional, Union
|
from typing import Any, Literal, Optional, TypeGuard, Union
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from typing_extensions import NotRequired, TypedDict, get_args, get_origin
|
from typing_extensions import NotRequired, TypedDict, get_args, get_origin
|
||||||
@ -844,8 +844,6 @@ ContentBlock = Union[
|
|||||||
TextContentBlock,
|
TextContentBlock,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
ToolCallChunk,
|
ToolCallChunk,
|
||||||
Citation,
|
|
||||||
NonStandardAnnotation,
|
|
||||||
InvalidToolCall,
|
InvalidToolCall,
|
||||||
ReasoningContentBlock,
|
ReasoningContentBlock,
|
||||||
NonStandardContentBlock,
|
NonStandardContentBlock,
|
||||||
@ -884,7 +882,24 @@ def _extract_typedict_type_values(union_type: Any) -> set[str]:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
KNOWN_BLOCK_TYPES = _extract_typedict_type_values(ContentBlock)
|
KNOWN_BLOCK_TYPES = {
|
||||||
|
"text",
|
||||||
|
"text-plain",
|
||||||
|
"tool_call",
|
||||||
|
"invalid_tool_call",
|
||||||
|
"tool_call_chunk",
|
||||||
|
"reasoning",
|
||||||
|
"non_standard",
|
||||||
|
"image",
|
||||||
|
"audio",
|
||||||
|
"file",
|
||||||
|
"video",
|
||||||
|
"code_interpreter_call",
|
||||||
|
"code_interpreter_output",
|
||||||
|
"code_interpreter_result",
|
||||||
|
"web_search_call",
|
||||||
|
"web_search_result",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def is_data_content_block(block: dict) -> bool:
|
def is_data_content_block(block: dict) -> bool:
|
||||||
@ -914,6 +929,28 @@ def is_data_content_block(block: dict) -> bool:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_tool_call_block(block: ContentBlock) -> TypeGuard[ToolCall]:
|
||||||
|
"""Type guard to check if a content block is a tool call."""
|
||||||
|
return block.get("type") == "tool_call"
|
||||||
|
|
||||||
|
|
||||||
|
def is_tool_call_chunk(block: ContentBlock) -> TypeGuard[ToolCallChunk]:
|
||||||
|
"""Type guard to check if a content block is a tool call chunk."""
|
||||||
|
return block.get("type") == "tool_call_chunk"
|
||||||
|
|
||||||
|
|
||||||
|
def is_text_block(block: ContentBlock) -> TypeGuard[TextContentBlock]:
|
||||||
|
"""Type guard to check if a content block is a text block."""
|
||||||
|
return block.get("type") == "text"
|
||||||
|
|
||||||
|
|
||||||
|
def is_invalid_tool_call_block(
|
||||||
|
block: ContentBlock,
|
||||||
|
) -> TypeGuard[InvalidToolCall]:
|
||||||
|
"""Type guard to check if a content block is an invalid tool call."""
|
||||||
|
return block.get("type") == "invalid_tool_call"
|
||||||
|
|
||||||
|
|
||||||
def convert_to_openai_image_block(block: dict[str, Any]) -> dict:
|
def convert_to_openai_image_block(block: dict[str, Any]) -> dict:
|
||||||
"""Convert image content block to format expected by OpenAI Chat Completions API."""
|
"""Convert image content block to format expected by OpenAI Chat Completions API."""
|
||||||
if "url" in block:
|
if "url" in block:
|
||||||
|
@ -4,10 +4,9 @@ Each message has content that may be comprised of content blocks, defined under
|
|||||||
``langchain_core.messages.content_blocks``.
|
``langchain_core.messages.content_blocks``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Literal, Optional, TypeGuard, Union, cast, get_args
|
from typing import Any, Literal, Optional, Union, cast, get_args
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
@ -20,32 +19,12 @@ from langchain_core.messages.ai import (
|
|||||||
add_usage,
|
add_usage,
|
||||||
)
|
)
|
||||||
from langchain_core.messages.base import merge_content
|
from langchain_core.messages.base import merge_content
|
||||||
from langchain_core.messages.content_blocks import create_text_block
|
|
||||||
from langchain_core.messages.tool import ToolCallChunk
|
|
||||||
from langchain_core.messages.tool import invalid_tool_call as create_invalid_tool_call
|
from langchain_core.messages.tool import invalid_tool_call as create_invalid_tool_call
|
||||||
from langchain_core.messages.tool import tool_call as create_tool_call
|
from langchain_core.messages.tool import tool_call as create_tool_call
|
||||||
from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
|
from langchain_core.utils._merge import merge_dicts
|
||||||
from langchain_core.utils._merge import merge_dicts, merge_lists
|
|
||||||
from langchain_core.utils.json import parse_partial_json
|
from langchain_core.utils.json import parse_partial_json
|
||||||
|
|
||||||
|
|
||||||
def is_tool_call_block(block: types.ContentBlock) -> TypeGuard[types.ToolCall]:
|
|
||||||
"""Type guard to check if a content block is a tool call."""
|
|
||||||
return block.get("type") == "tool_call"
|
|
||||||
|
|
||||||
|
|
||||||
def is_text_block(block: types.ContentBlock) -> TypeGuard[types.TextContentBlock]:
|
|
||||||
"""Type guard to check if a content block is a text block."""
|
|
||||||
return block.get("type") == "text"
|
|
||||||
|
|
||||||
|
|
||||||
def is_invalid_tool_call_block(
|
|
||||||
block: types.ContentBlock,
|
|
||||||
) -> TypeGuard[types.InvalidToolCall]:
|
|
||||||
"""Type guard to check if a content block is an invalid tool call."""
|
|
||||||
return block.get("type") == "invalid_tool_call"
|
|
||||||
|
|
||||||
|
|
||||||
def _ensure_id(id_val: Optional[str]) -> str:
|
def _ensure_id(id_val: Optional[str]) -> str:
|
||||||
"""Ensure the ID is a valid string, generating a new UUID if not provided.
|
"""Ensure the ID is a valid string, generating a new UUID if not provided.
|
||||||
|
|
||||||
@ -169,7 +148,7 @@ class AIMessage:
|
|||||||
parsed: Optional auto-parsed message contents, if applicable.
|
parsed: Optional auto-parsed message contents, if applicable.
|
||||||
"""
|
"""
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
self.content = [create_text_block(content)]
|
self.content = [types.create_text_block(content)]
|
||||||
else:
|
else:
|
||||||
self.content = content
|
self.content = content
|
||||||
|
|
||||||
@ -188,33 +167,46 @@ class AIMessage:
|
|||||||
content_tool_calls = {
|
content_tool_calls = {
|
||||||
block["id"]
|
block["id"]
|
||||||
for block in self.content
|
for block in self.content
|
||||||
if block.get("type") == "tool_call" and "id" in block
|
if types.is_tool_call_block(block) and "id" in block
|
||||||
}
|
}
|
||||||
for tool_call in tool_calls:
|
for tool_call in tool_calls:
|
||||||
if "id" in tool_call and tool_call["id"] in content_tool_calls:
|
if "id" in tool_call and tool_call["id"] in content_tool_calls:
|
||||||
continue
|
continue
|
||||||
self.content.append(tool_call)
|
self.content.append(tool_call)
|
||||||
|
if invalid_tool_calls:
|
||||||
|
content_tool_calls = {
|
||||||
|
block["id"]
|
||||||
|
for block in self.content
|
||||||
|
if types.is_invalid_tool_call_block(block) and "id" in block
|
||||||
|
}
|
||||||
|
for invalid_tool_call in invalid_tool_calls:
|
||||||
|
if (
|
||||||
|
"id" in invalid_tool_call
|
||||||
|
and invalid_tool_call["id"] in content_tool_calls
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
self.content.append(invalid_tool_call)
|
||||||
self._tool_calls: list[types.ToolCall] = [
|
self._tool_calls: list[types.ToolCall] = [
|
||||||
block for block in self.content if is_tool_call_block(block)
|
block for block in self.content if types.is_tool_call_block(block)
|
||||||
|
]
|
||||||
|
self._invalid_tool_calls = [
|
||||||
|
block for block in self.content if types.is_invalid_tool_call_block(block)
|
||||||
]
|
]
|
||||||
self.invalid_tool_calls = invalid_tool_calls or []
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def text(self) -> Optional[str]:
|
def text(self) -> str:
|
||||||
"""Extract all text content from the AI message as a string."""
|
"""Extract all text content from the AI message as a string."""
|
||||||
text_blocks = [block for block in self.content if is_text_block(block)]
|
return "".join(
|
||||||
if text_blocks:
|
block["text"] for block in self.content if types.is_text_block(block)
|
||||||
return "".join(block["text"] for block in text_blocks)
|
)
|
||||||
return None
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tool_calls(self) -> list[types.ToolCall]: # update once we fix branch
|
def tool_calls(self) -> list[types.ToolCall]:
|
||||||
"""Get the tool calls made by the AI."""
|
"""Get the tool calls made by the AI."""
|
||||||
if self._tool_calls:
|
if not self._tool_calls:
|
||||||
return self._tool_calls
|
self._tool_calls = [
|
||||||
tool_calls = [block for block in self.content if is_tool_call_block(block)]
|
block for block in self.content if types.is_tool_call_block(block)
|
||||||
if tool_calls:
|
]
|
||||||
self._tool_calls = tool_calls
|
|
||||||
return self._tool_calls
|
return self._tool_calls
|
||||||
|
|
||||||
@tool_calls.setter
|
@tool_calls.setter
|
||||||
@ -222,6 +214,17 @@ class AIMessage:
|
|||||||
"""Set the tool calls for the AI message."""
|
"""Set the tool calls for the AI message."""
|
||||||
self._tool_calls = value
|
self._tool_calls = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def invalid_tool_calls(self) -> list[types.InvalidToolCall]:
|
||||||
|
"""Get the invalid tool calls made by the AI."""
|
||||||
|
if not self._invalid_tool_calls:
|
||||||
|
self._invalid_tool_calls = [
|
||||||
|
block
|
||||||
|
for block in self.content
|
||||||
|
if types.is_invalid_tool_call_block(block)
|
||||||
|
]
|
||||||
|
return self._invalid_tool_calls
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AIMessageChunk(AIMessage):
|
class AIMessageChunk(AIMessage):
|
||||||
@ -246,17 +249,10 @@ class AIMessageChunk(AIMessage):
|
|||||||
when deserializing messages.
|
when deserializing messages.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tool_call_chunks: list[types.ToolCallChunk] = field(init=False)
|
|
||||||
"""List of partial tool call data.
|
|
||||||
|
|
||||||
Emitted by the model during streaming, this field contains
|
|
||||||
tool call chunks that may not yet be complete. It is used to reconstruct
|
|
||||||
tool calls from the streamed content.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
content: Union[str, list[types.ContentBlock]],
|
content: Union[str, list[types.ContentBlock]],
|
||||||
|
*,
|
||||||
id: Optional[str] = None,
|
id: Optional[str] = None,
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
lc_version: str = "v1",
|
lc_version: str = "v1",
|
||||||
@ -264,6 +260,7 @@ class AIMessageChunk(AIMessage):
|
|||||||
usage_metadata: Optional[UsageMetadata] = None,
|
usage_metadata: Optional[UsageMetadata] = None,
|
||||||
tool_call_chunks: Optional[list[types.ToolCallChunk]] = None,
|
tool_call_chunks: Optional[list[types.ToolCallChunk]] = None,
|
||||||
parsed: Optional[Union[dict[str, Any], BaseModel]] = None,
|
parsed: Optional[Union[dict[str, Any], BaseModel]] = None,
|
||||||
|
chunk_position: Optional[Literal["last"]] = None,
|
||||||
):
|
):
|
||||||
"""Initialize an AI message.
|
"""Initialize an AI message.
|
||||||
|
|
||||||
@ -276,6 +273,8 @@ class AIMessageChunk(AIMessage):
|
|||||||
usage_metadata: Optional metadata about token usage.
|
usage_metadata: Optional metadata about token usage.
|
||||||
tool_call_chunks: Optional list of partial tool call data.
|
tool_call_chunks: Optional list of partial tool call data.
|
||||||
parsed: Optional auto-parsed message contents, if applicable.
|
parsed: Optional auto-parsed message contents, if applicable.
|
||||||
|
chunk_position: Optional position of the chunk in the stream. If "last",
|
||||||
|
tool calls will be parsed when aggregated into a stream.
|
||||||
"""
|
"""
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
self.content = [{"type": "text", "text": content, "index": 0}]
|
self.content = [{"type": "text", "text": content, "index": 0}]
|
||||||
@ -287,115 +286,51 @@ class AIMessageChunk(AIMessage):
|
|||||||
self.lc_version = lc_version
|
self.lc_version = lc_version
|
||||||
self.usage_metadata = usage_metadata
|
self.usage_metadata = usage_metadata
|
||||||
self.parsed = parsed
|
self.parsed = parsed
|
||||||
|
self.chunk_position = chunk_position
|
||||||
if response_metadata is None:
|
if response_metadata is None:
|
||||||
self.response_metadata = {}
|
self.response_metadata = {}
|
||||||
else:
|
else:
|
||||||
self.response_metadata = response_metadata
|
self.response_metadata = response_metadata
|
||||||
if tool_call_chunks is None:
|
|
||||||
self.tool_call_chunks: list[types.ToolCallChunk] = []
|
if tool_call_chunks:
|
||||||
else:
|
content_tool_call_chunks = {
|
||||||
self.tool_call_chunks = tool_call_chunks
|
block["id"]
|
||||||
|
for block in self.content
|
||||||
|
if types.is_tool_call_chunk(block) and "id" in block
|
||||||
|
}
|
||||||
|
for chunk in tool_call_chunks:
|
||||||
|
if "id" in chunk and chunk["id"] in content_tool_call_chunks:
|
||||||
|
continue
|
||||||
|
self.content.append(chunk)
|
||||||
|
self._tool_call_chunks = [
|
||||||
|
block for block in self.content if types.is_tool_call_chunk(block)
|
||||||
|
]
|
||||||
|
|
||||||
self._tool_calls: list[types.ToolCall] = []
|
self._tool_calls: list[types.ToolCall] = []
|
||||||
self.invalid_tool_calls: list[types.InvalidToolCall] = []
|
self._invalid_tool_calls: list[types.InvalidToolCall] = []
|
||||||
self._init_tool_calls()
|
|
||||||
|
|
||||||
def _init_tool_calls(self) -> None:
|
|
||||||
"""Initialize tool calls from tool call chunks.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
values: The values to validate.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the tool call chunks are malformed.
|
|
||||||
"""
|
|
||||||
self._tool_calls = []
|
|
||||||
self.invalid_tool_calls = []
|
|
||||||
if not self.tool_call_chunks:
|
|
||||||
if self._tool_calls:
|
|
||||||
self.tool_call_chunks = [
|
|
||||||
create_tool_call_chunk(
|
|
||||||
name=tc["name"],
|
|
||||||
args=json.dumps(tc["args"]),
|
|
||||||
id=tc["id"],
|
|
||||||
index=None,
|
|
||||||
)
|
|
||||||
for tc in self._tool_calls
|
|
||||||
]
|
|
||||||
if self.invalid_tool_calls:
|
|
||||||
tool_call_chunks = self.tool_call_chunks
|
|
||||||
tool_call_chunks.extend(
|
|
||||||
[
|
|
||||||
create_tool_call_chunk(
|
|
||||||
name=tc["name"], args=tc["args"], id=tc["id"], index=None
|
|
||||||
)
|
|
||||||
for tc in self.invalid_tool_calls
|
|
||||||
]
|
|
||||||
)
|
|
||||||
self.tool_call_chunks = tool_call_chunks
|
|
||||||
|
|
||||||
tool_calls = []
|
|
||||||
invalid_tool_calls = []
|
|
||||||
|
|
||||||
def add_chunk_to_invalid_tool_calls(chunk: ToolCallChunk) -> None:
|
|
||||||
invalid_tool_calls.append(
|
|
||||||
create_invalid_tool_call(
|
|
||||||
name=chunk.get("name", ""),
|
|
||||||
args=chunk.get("args", ""),
|
|
||||||
id=chunk.get("id", ""),
|
|
||||||
error=None,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
for chunk in self.tool_call_chunks:
|
|
||||||
try:
|
|
||||||
args_ = parse_partial_json(chunk["args"]) if chunk["args"] != "" else {} # type: ignore[arg-type]
|
|
||||||
if isinstance(args_, dict):
|
|
||||||
tool_calls.append(
|
|
||||||
create_tool_call(
|
|
||||||
name=chunk.get("name") or "",
|
|
||||||
args=args_,
|
|
||||||
id=chunk.get("id", ""),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
add_chunk_to_invalid_tool_calls(chunk)
|
|
||||||
except Exception:
|
|
||||||
add_chunk_to_invalid_tool_calls(chunk)
|
|
||||||
self._tool_calls = tool_calls
|
|
||||||
self.invalid_tool_calls = invalid_tool_calls
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def text(self) -> Optional[str]:
|
def tool_call_chunks(self) -> list[types.ToolCallChunk]:
|
||||||
"""Extract all text content from the AI message as a string."""
|
"""Get the tool calls made by the AI."""
|
||||||
text_blocks = [block for block in self.content if is_text_block(block)]
|
if not self._tool_call_chunks:
|
||||||
if text_blocks:
|
self._tool_call_chunks = [
|
||||||
return "".join(block["text"] for block in text_blocks)
|
block for block in self.content if types.is_tool_call_chunk(block)
|
||||||
return None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def reasoning(self) -> Optional[str]:
|
|
||||||
"""Extract all reasoning text from the AI message as a string."""
|
|
||||||
text_blocks = [
|
|
||||||
block
|
|
||||||
for block in self.content
|
|
||||||
if block.get("type") == "reasoning" and "reasoning" in block
|
|
||||||
]
|
]
|
||||||
if text_blocks:
|
return cast("list[types.ToolCallChunk]", self._tool_call_chunks)
|
||||||
return "".join(
|
|
||||||
cast("types.ReasoningContentBlock", block).get("reasoning", "")
|
|
||||||
for block in text_blocks
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tool_calls(self) -> list[types.ToolCall]:
|
def tool_calls(self) -> list[types.ToolCall]:
|
||||||
"""Get the tool calls made by the AI."""
|
"""Get the tool calls made by the AI."""
|
||||||
if self._tool_calls:
|
if not self._tool_calls:
|
||||||
return self._tool_calls
|
parsed_content = _init_tool_calls(self.content)
|
||||||
tool_calls = [block for block in self.content if is_tool_call_block(block)]
|
self._tool_calls = [
|
||||||
if tool_calls:
|
block for block in parsed_content if types.is_tool_call_block(block)
|
||||||
self._tool_calls = tool_calls
|
]
|
||||||
|
self._invalid_tool_calls = [
|
||||||
|
block
|
||||||
|
for block in parsed_content
|
||||||
|
if types.is_invalid_tool_call_block(block)
|
||||||
|
]
|
||||||
return self._tool_calls
|
return self._tool_calls
|
||||||
|
|
||||||
@tool_calls.setter
|
@tool_calls.setter
|
||||||
@ -403,6 +338,21 @@ class AIMessageChunk(AIMessage):
|
|||||||
"""Set the tool calls for the AI message."""
|
"""Set the tool calls for the AI message."""
|
||||||
self._tool_calls = value
|
self._tool_calls = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def invalid_tool_calls(self) -> list[types.InvalidToolCall]:
|
||||||
|
"""Get the invalid tool calls made by the AI."""
|
||||||
|
if not self._invalid_tool_calls:
|
||||||
|
parsed_content = _init_tool_calls(self.content)
|
||||||
|
self._tool_calls = [
|
||||||
|
block for block in parsed_content if types.is_tool_call_block(block)
|
||||||
|
]
|
||||||
|
self._invalid_tool_calls = [
|
||||||
|
block
|
||||||
|
for block in parsed_content
|
||||||
|
if types.is_invalid_tool_call_block(block)
|
||||||
|
]
|
||||||
|
return self._invalid_tool_calls
|
||||||
|
|
||||||
def __add__(self, other: Any) -> "AIMessageChunk":
|
def __add__(self, other: Any) -> "AIMessageChunk":
|
||||||
"""Add ``AIMessageChunk`` to this one."""
|
"""Add ``AIMessageChunk`` to this one."""
|
||||||
if isinstance(other, AIMessageChunk):
|
if isinstance(other, AIMessageChunk):
|
||||||
@ -417,49 +367,76 @@ class AIMessageChunk(AIMessage):
|
|||||||
def to_message(self) -> "AIMessage":
|
def to_message(self) -> "AIMessage":
|
||||||
"""Convert this ``AIMessageChunk`` to an AIMessage."""
|
"""Convert this ``AIMessageChunk`` to an AIMessage."""
|
||||||
return AIMessage(
|
return AIMessage(
|
||||||
content=self.content,
|
content=_init_tool_calls(self.content),
|
||||||
id=self.id,
|
id=self.id,
|
||||||
name=self.name,
|
name=self.name,
|
||||||
lc_version=self.lc_version,
|
lc_version=self.lc_version,
|
||||||
response_metadata=self.response_metadata,
|
response_metadata=self.response_metadata,
|
||||||
usage_metadata=self.usage_metadata,
|
usage_metadata=self.usage_metadata,
|
||||||
tool_calls=self.tool_calls,
|
|
||||||
invalid_tool_calls=self.invalid_tool_calls,
|
|
||||||
parsed=self.parsed,
|
parsed=self.parsed,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _init_tool_calls(content: list[types.ContentBlock]) -> list[types.ContentBlock]:
|
||||||
|
"""Parse tool call chunks in content into tool calls."""
|
||||||
|
new_content = []
|
||||||
|
for block in content:
|
||||||
|
if not types.is_tool_call_chunk(block):
|
||||||
|
new_content.append(block)
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
args_ = (
|
||||||
|
parse_partial_json(cast("str", block.get("args") or ""))
|
||||||
|
if block.get("args")
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
if isinstance(args_, dict):
|
||||||
|
new_content.append(
|
||||||
|
create_tool_call(
|
||||||
|
name=cast("str", block.get("name") or ""),
|
||||||
|
args=args_,
|
||||||
|
id=cast("str", block.get("id", "")),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
new_content.append(
|
||||||
|
create_invalid_tool_call(
|
||||||
|
name=cast("str", block.get("name", "")),
|
||||||
|
args=cast("str", block.get("args", "")),
|
||||||
|
id=cast("str", block.get("id", "")),
|
||||||
|
error=None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
new_content.append(
|
||||||
|
create_invalid_tool_call(
|
||||||
|
name=cast("str", block.get("name", "")),
|
||||||
|
args=cast("str", block.get("args", "")),
|
||||||
|
id=cast("str", block.get("id", "")),
|
||||||
|
error=None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return new_content
|
||||||
|
|
||||||
|
|
||||||
def add_ai_message_chunks(
|
def add_ai_message_chunks(
|
||||||
left: AIMessageChunk, *others: AIMessageChunk
|
left: AIMessageChunk, *others: AIMessageChunk
|
||||||
) -> AIMessageChunk:
|
) -> AIMessageChunk:
|
||||||
"""Add multiple ``AIMessageChunks`` together."""
|
"""Add multiple ``AIMessageChunks`` together."""
|
||||||
if not others:
|
if not others:
|
||||||
return left
|
return left
|
||||||
content = merge_content(
|
content = cast(
|
||||||
|
"list[types.ContentBlock]",
|
||||||
|
merge_content(
|
||||||
cast("list[str | dict[Any, Any]]", left.content),
|
cast("list[str | dict[Any, Any]]", left.content),
|
||||||
*(cast("list[str | dict[Any, Any]]", o.content) for o in others),
|
*(cast("list[str | dict[Any, Any]]", o.content) for o in others),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
response_metadata = merge_dicts(
|
response_metadata = merge_dicts(
|
||||||
cast("dict", left.response_metadata),
|
cast("dict", left.response_metadata),
|
||||||
*(cast("dict", o.response_metadata) for o in others),
|
*(cast("dict", o.response_metadata) for o in others),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Merge tool call chunks
|
|
||||||
if raw_tool_calls := merge_lists(
|
|
||||||
left.tool_call_chunks, *(o.tool_call_chunks for o in others)
|
|
||||||
):
|
|
||||||
tool_call_chunks = [
|
|
||||||
create_tool_call_chunk(
|
|
||||||
name=rtc.get("name"),
|
|
||||||
args=rtc.get("args"),
|
|
||||||
index=rtc.get("index"),
|
|
||||||
id=rtc.get("id"),
|
|
||||||
)
|
|
||||||
for rtc in raw_tool_calls
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
tool_call_chunks = []
|
|
||||||
|
|
||||||
# Token usage
|
# Token usage
|
||||||
if left.usage_metadata or any(o.usage_metadata is not None for o in others):
|
if left.usage_metadata or any(o.usage_metadata is not None for o in others):
|
||||||
usage_metadata: Optional[UsageMetadata] = left.usage_metadata
|
usage_metadata: Optional[UsageMetadata] = left.usage_metadata
|
||||||
@ -501,13 +478,19 @@ def add_ai_message_chunks(
|
|||||||
chunk_id = id_
|
chunk_id = id_
|
||||||
break
|
break
|
||||||
|
|
||||||
|
chunk_position: Optional[Literal["last"]] = (
|
||||||
|
"last" if any(x.chunk_position == "last" for x in [left, *others]) else None
|
||||||
|
)
|
||||||
|
if chunk_position == "last":
|
||||||
|
content = _init_tool_calls(content)
|
||||||
|
|
||||||
return left.__class__(
|
return left.__class__(
|
||||||
content=cast("list[types.ContentBlock]", content),
|
content=content,
|
||||||
tool_call_chunks=tool_call_chunks,
|
|
||||||
response_metadata=cast("ResponseMetadata", response_metadata),
|
response_metadata=cast("ResponseMetadata", response_metadata),
|
||||||
usage_metadata=usage_metadata,
|
usage_metadata=usage_metadata,
|
||||||
parsed=parsed,
|
parsed=parsed,
|
||||||
id=chunk_id,
|
id=chunk_id,
|
||||||
|
chunk_position=chunk_position,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -579,9 +562,7 @@ class HumanMessage:
|
|||||||
Concatenated string of all text blocks in the message.
|
Concatenated string of all text blocks in the message.
|
||||||
"""
|
"""
|
||||||
return "".join(
|
return "".join(
|
||||||
cast("types.TextContentBlock", block)["text"]
|
block["text"] for block in self.content if types.is_text_block(block)
|
||||||
for block in self.content
|
|
||||||
if block.get("type") == "text"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -660,9 +641,7 @@ class SystemMessage:
|
|||||||
def text(self) -> str:
|
def text(self) -> str:
|
||||||
"""Extract all text content from the system message."""
|
"""Extract all text content from the system message."""
|
||||||
return "".join(
|
return "".join(
|
||||||
cast("types.TextContentBlock", block)["text"]
|
block["text"] for block in self.content if types.is_text_block(block)
|
||||||
for block in self.content
|
|
||||||
if block.get("type") == "text"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -754,9 +733,7 @@ class ToolMessage:
|
|||||||
def text(self) -> str:
|
def text(self) -> str:
|
||||||
"""Extract all text content from the tool message."""
|
"""Extract all text content from the tool message."""
|
||||||
return "".join(
|
return "".join(
|
||||||
cast("types.TextContentBlock", block)["text"]
|
block["text"] for block in self.content if types.is_text_block(block)
|
||||||
for block in self.content
|
|
||||||
if block.get("type") == "text"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
|
@ -283,7 +283,7 @@ class BaseOutputParser(
|
|||||||
Structured output.
|
Structured output.
|
||||||
"""
|
"""
|
||||||
if isinstance(result, AIMessage):
|
if isinstance(result, AIMessage):
|
||||||
return self.parse(result.text or "")
|
return self.parse(result.text)
|
||||||
return self.parse(result[0].text)
|
return self.parse(result[0].text)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -73,7 +73,7 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
|
|||||||
Raises:
|
Raises:
|
||||||
OutputParserException: If the output is not valid JSON.
|
OutputParserException: If the output is not valid JSON.
|
||||||
"""
|
"""
|
||||||
text = result.text or "" if isinstance(result, AIMessage) else result[0].text
|
text = result.text if isinstance(result, AIMessage) else result[0].text
|
||||||
text = text.strip()
|
text = text.strip()
|
||||||
if partial:
|
if partial:
|
||||||
try:
|
try:
|
||||||
|
@ -83,7 +83,7 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]):
|
|||||||
continue
|
continue
|
||||||
buffer += chunk_content
|
buffer += chunk_content
|
||||||
elif isinstance(chunk, AIMessage):
|
elif isinstance(chunk, AIMessage):
|
||||||
buffer += chunk.text or ""
|
buffer += chunk.text
|
||||||
else:
|
else:
|
||||||
# add current chunk to buffer
|
# add current chunk to buffer
|
||||||
buffer += chunk
|
buffer += chunk
|
||||||
@ -119,7 +119,7 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]):
|
|||||||
continue
|
continue
|
||||||
buffer += chunk_content
|
buffer += chunk_content
|
||||||
elif isinstance(chunk, AIMessage):
|
elif isinstance(chunk, AIMessage):
|
||||||
buffer += chunk.text or ""
|
buffer += chunk.text
|
||||||
else:
|
else:
|
||||||
# add current chunk to buffer
|
# add current chunk to buffer
|
||||||
buffer += chunk
|
buffer += chunk
|
||||||
|
@ -14,7 +14,10 @@ from langchain_core.language_models import (
|
|||||||
ParrotFakeChatModel,
|
ParrotFakeChatModel,
|
||||||
)
|
)
|
||||||
from langchain_core.language_models._utils import _normalize_messages
|
from langchain_core.language_models._utils import _normalize_messages
|
||||||
from langchain_core.language_models.fake_chat_models import FakeListChatModelError
|
from langchain_core.language_models.fake_chat_models import (
|
||||||
|
FakeListChatModelError,
|
||||||
|
GenericFakeChatModelV1,
|
||||||
|
)
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
AIMessageChunk,
|
AIMessageChunk,
|
||||||
@ -22,6 +25,7 @@ from langchain_core.messages import (
|
|||||||
HumanMessage,
|
HumanMessage,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
)
|
)
|
||||||
|
from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1
|
||||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
from langchain_core.outputs.llm_result import LLMResult
|
from langchain_core.outputs.llm_result import LLMResult
|
||||||
from langchain_core.tracers import LogStreamCallbackHandler
|
from langchain_core.tracers import LogStreamCallbackHandler
|
||||||
@ -654,3 +658,93 @@ def test_normalize_messages_edge_cases() -> None:
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
assert messages == _normalize_messages(messages)
|
assert messages == _normalize_messages(messages)
|
||||||
|
|
||||||
|
|
||||||
|
def test_streaming_v1() -> None:
|
||||||
|
chunks = [
|
||||||
|
AIMessageChunkV1(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"type": "reasoning",
|
||||||
|
"reasoning": "Let's call a tool.",
|
||||||
|
"index": 0,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
),
|
||||||
|
AIMessageChunkV1(
|
||||||
|
[],
|
||||||
|
tool_call_chunks=[
|
||||||
|
{
|
||||||
|
"type": "tool_call_chunk",
|
||||||
|
"args": "",
|
||||||
|
"name": "tool_name",
|
||||||
|
"id": "call_123",
|
||||||
|
"index": 1,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
),
|
||||||
|
AIMessageChunkV1(
|
||||||
|
[],
|
||||||
|
tool_call_chunks=[
|
||||||
|
{
|
||||||
|
"type": "tool_call_chunk",
|
||||||
|
"args": '{"a',
|
||||||
|
"name": "",
|
||||||
|
"id": "",
|
||||||
|
"index": 1,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
),
|
||||||
|
AIMessageChunkV1(
|
||||||
|
[],
|
||||||
|
tool_call_chunks=[
|
||||||
|
{
|
||||||
|
"type": "tool_call_chunk",
|
||||||
|
"args": '": 1}',
|
||||||
|
"name": "",
|
||||||
|
"id": "",
|
||||||
|
"index": 1,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
full: Optional[AIMessageChunkV1] = None
|
||||||
|
for chunk in chunks:
|
||||||
|
full = chunk if full is None else full + chunk
|
||||||
|
|
||||||
|
assert isinstance(full, AIMessageChunkV1)
|
||||||
|
assert full.content == [
|
||||||
|
{
|
||||||
|
"type": "reasoning",
|
||||||
|
"reasoning": "Let's call a tool.",
|
||||||
|
"index": 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "tool_call_chunk",
|
||||||
|
"args": '{"a": 1}',
|
||||||
|
"name": "tool_name",
|
||||||
|
"id": "call_123",
|
||||||
|
"index": 1,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
llm = GenericFakeChatModelV1(message_chunks=chunks)
|
||||||
|
|
||||||
|
full = None
|
||||||
|
for chunk in llm.stream("anything"):
|
||||||
|
full = chunk if full is None else full + chunk
|
||||||
|
|
||||||
|
assert isinstance(full, AIMessageChunkV1)
|
||||||
|
assert full.content == [
|
||||||
|
{
|
||||||
|
"type": "reasoning",
|
||||||
|
"reasoning": "Let's call a tool.",
|
||||||
|
"index": 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "tool_call",
|
||||||
|
"args": {"a": 1},
|
||||||
|
"name": "tool_name",
|
||||||
|
"id": "call_123",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
@ -37,7 +37,7 @@ def test_base_generation_parser() -> None:
|
|||||||
that support streaming
|
that support streaming
|
||||||
"""
|
"""
|
||||||
if isinstance(result, AIMessageV1):
|
if isinstance(result, AIMessageV1):
|
||||||
content = result.text or ""
|
content = result.text
|
||||||
else:
|
else:
|
||||||
if len(result) != 1:
|
if len(result) != 1:
|
||||||
msg = (
|
msg = (
|
||||||
@ -89,7 +89,7 @@ def test_base_transform_output_parser() -> None:
|
|||||||
that support streaming
|
that support streaming
|
||||||
"""
|
"""
|
||||||
if isinstance(result, AIMessageV1):
|
if isinstance(result, AIMessageV1):
|
||||||
content = result.text or ""
|
content = result.text
|
||||||
else:
|
else:
|
||||||
if len(result) != 1:
|
if len(result) != 1:
|
||||||
msg = (
|
msg = (
|
||||||
@ -116,4 +116,4 @@ def test_base_transform_output_parser() -> None:
|
|||||||
model_v1 = GenericFakeChatModelV1(message_chunks=["hello", " ", "world"])
|
model_v1 = GenericFakeChatModelV1(message_chunks=["hello", " ", "world"])
|
||||||
chain_v1 = model_v1 | StrInvertCase()
|
chain_v1 = model_v1 | StrInvertCase()
|
||||||
chunks = list(chain_v1.stream(""))
|
chunks = list(chain_v1.stream(""))
|
||||||
assert chunks == ["HELLO", " ", "WORLD"]
|
assert chunks == ["HELLO", " ", "WORLD", ""]
|
||||||
|
@ -2543,12 +2543,6 @@
|
|||||||
dict({
|
dict({
|
||||||
'$ref': '#/$defs/ToolCallChunk',
|
'$ref': '#/$defs/ToolCallChunk',
|
||||||
}),
|
}),
|
||||||
dict({
|
|
||||||
'$ref': '#/$defs/Citation',
|
|
||||||
}),
|
|
||||||
dict({
|
|
||||||
'$ref': '#/$defs/NonStandardAnnotation',
|
|
||||||
}),
|
|
||||||
dict({
|
dict({
|
||||||
'$ref': '#/$defs/InvalidToolCall',
|
'$ref': '#/$defs/InvalidToolCall',
|
||||||
}),
|
}),
|
||||||
@ -2675,12 +2669,6 @@
|
|||||||
dict({
|
dict({
|
||||||
'$ref': '#/$defs/ToolCallChunk',
|
'$ref': '#/$defs/ToolCallChunk',
|
||||||
}),
|
}),
|
||||||
dict({
|
|
||||||
'$ref': '#/$defs/Citation',
|
|
||||||
}),
|
|
||||||
dict({
|
|
||||||
'$ref': '#/$defs/NonStandardAnnotation',
|
|
||||||
}),
|
|
||||||
dict({
|
dict({
|
||||||
'$ref': '#/$defs/InvalidToolCall',
|
'$ref': '#/$defs/InvalidToolCall',
|
||||||
}),
|
}),
|
||||||
@ -2807,12 +2795,6 @@
|
|||||||
dict({
|
dict({
|
||||||
'$ref': '#/$defs/ToolCallChunk',
|
'$ref': '#/$defs/ToolCallChunk',
|
||||||
}),
|
}),
|
||||||
dict({
|
|
||||||
'$ref': '#/$defs/Citation',
|
|
||||||
}),
|
|
||||||
dict({
|
|
||||||
'$ref': '#/$defs/NonStandardAnnotation',
|
|
||||||
}),
|
|
||||||
dict({
|
dict({
|
||||||
'$ref': '#/$defs/InvalidToolCall',
|
'$ref': '#/$defs/InvalidToolCall',
|
||||||
}),
|
}),
|
||||||
@ -2901,12 +2883,6 @@
|
|||||||
dict({
|
dict({
|
||||||
'$ref': '#/$defs/ToolCallChunk',
|
'$ref': '#/$defs/ToolCallChunk',
|
||||||
}),
|
}),
|
||||||
dict({
|
|
||||||
'$ref': '#/$defs/Citation',
|
|
||||||
}),
|
|
||||||
dict({
|
|
||||||
'$ref': '#/$defs/NonStandardAnnotation',
|
|
||||||
}),
|
|
||||||
dict({
|
dict({
|
||||||
'$ref': '#/$defs/InvalidToolCall',
|
'$ref': '#/$defs/InvalidToolCall',
|
||||||
}),
|
}),
|
||||||
@ -3018,12 +2994,6 @@
|
|||||||
dict({
|
dict({
|
||||||
'$ref': '#/$defs/ToolCallChunk',
|
'$ref': '#/$defs/ToolCallChunk',
|
||||||
}),
|
}),
|
||||||
dict({
|
|
||||||
'$ref': '#/$defs/Citation',
|
|
||||||
}),
|
|
||||||
dict({
|
|
||||||
'$ref': '#/$defs/NonStandardAnnotation',
|
|
||||||
}),
|
|
||||||
dict({
|
dict({
|
||||||
'$ref': '#/$defs/InvalidToolCall',
|
'$ref': '#/$defs/InvalidToolCall',
|
||||||
}),
|
}),
|
||||||
|
@ -11535,6 +11535,9 @@
|
|||||||
dict({
|
dict({
|
||||||
'$ref': '#/definitions/InvalidToolCall',
|
'$ref': '#/definitions/InvalidToolCall',
|
||||||
}),
|
}),
|
||||||
|
dict({
|
||||||
|
'$ref': '#/definitions/ToolCallChunk',
|
||||||
|
}),
|
||||||
dict({
|
dict({
|
||||||
'$ref': '#/definitions/ReasoningContentBlock',
|
'$ref': '#/definitions/ReasoningContentBlock',
|
||||||
}),
|
}),
|
||||||
@ -11657,6 +11660,9 @@
|
|||||||
dict({
|
dict({
|
||||||
'$ref': '#/definitions/InvalidToolCall',
|
'$ref': '#/definitions/InvalidToolCall',
|
||||||
}),
|
}),
|
||||||
|
dict({
|
||||||
|
'$ref': '#/definitions/ToolCallChunk',
|
||||||
|
}),
|
||||||
dict({
|
dict({
|
||||||
'$ref': '#/definitions/ReasoningContentBlock',
|
'$ref': '#/definitions/ReasoningContentBlock',
|
||||||
}),
|
}),
|
||||||
@ -11779,6 +11785,9 @@
|
|||||||
dict({
|
dict({
|
||||||
'$ref': '#/definitions/InvalidToolCall',
|
'$ref': '#/definitions/InvalidToolCall',
|
||||||
}),
|
}),
|
||||||
|
dict({
|
||||||
|
'$ref': '#/definitions/ToolCallChunk',
|
||||||
|
}),
|
||||||
dict({
|
dict({
|
||||||
'$ref': '#/definitions/ReasoningContentBlock',
|
'$ref': '#/definitions/ReasoningContentBlock',
|
||||||
}),
|
}),
|
||||||
@ -11863,6 +11872,9 @@
|
|||||||
dict({
|
dict({
|
||||||
'$ref': '#/definitions/InvalidToolCall',
|
'$ref': '#/definitions/InvalidToolCall',
|
||||||
}),
|
}),
|
||||||
|
dict({
|
||||||
|
'$ref': '#/definitions/ToolCallChunk',
|
||||||
|
}),
|
||||||
dict({
|
dict({
|
||||||
'$ref': '#/definitions/ReasoningContentBlock',
|
'$ref': '#/definitions/ReasoningContentBlock',
|
||||||
}),
|
}),
|
||||||
@ -11970,6 +11982,9 @@
|
|||||||
dict({
|
dict({
|
||||||
'$ref': '#/definitions/InvalidToolCall',
|
'$ref': '#/definitions/InvalidToolCall',
|
||||||
}),
|
}),
|
||||||
|
dict({
|
||||||
|
'$ref': '#/definitions/ToolCallChunk',
|
||||||
|
}),
|
||||||
dict({
|
dict({
|
||||||
'$ref': '#/definitions/ReasoningContentBlock',
|
'$ref': '#/definitions/ReasoningContentBlock',
|
||||||
}),
|
}),
|
||||||
|
@ -3,6 +3,7 @@ import uuid
|
|||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from typing_extensions import get_args
|
||||||
|
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.load import dumpd, load
|
from langchain_core.load import dumpd, load
|
||||||
@ -30,7 +31,7 @@ from langchain_core.messages import (
|
|||||||
messages_from_dict,
|
messages_from_dict,
|
||||||
messages_to_dict,
|
messages_to_dict,
|
||||||
)
|
)
|
||||||
from langchain_core.messages.content_blocks import KNOWN_BLOCK_TYPES
|
from langchain_core.messages.content_blocks import KNOWN_BLOCK_TYPES, ContentBlock
|
||||||
from langchain_core.messages.tool import invalid_tool_call as create_invalid_tool_call
|
from langchain_core.messages.tool import invalid_tool_call as create_invalid_tool_call
|
||||||
from langchain_core.messages.tool import tool_call as create_tool_call
|
from langchain_core.messages.tool import tool_call as create_tool_call
|
||||||
from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
|
from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
|
||||||
@ -1363,23 +1364,19 @@ def test_convert_to_openai_image_block() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_known_block_types() -> None:
|
def test_known_block_types() -> None:
|
||||||
assert {
|
expected = {
|
||||||
"audio",
|
bt
|
||||||
"citation",
|
for bt in get_args(ContentBlock)
|
||||||
"code_interpreter_call",
|
for bt in get_args(bt.__annotations__["type"])
|
||||||
"code_interpreter_output",
|
}
|
||||||
"code_interpreter_result",
|
# Normalize any Literal[...] types in block types to their string values.
|
||||||
"file",
|
# This ensures all entries are plain strings, not Literal objects.
|
||||||
"image",
|
expected = {
|
||||||
"invalid_tool_call",
|
t
|
||||||
"non_standard",
|
if isinstance(t, str)
|
||||||
"non_standard_annotation",
|
else t.__args__[0]
|
||||||
"reasoning",
|
if hasattr(t, "__args__") and len(t.__args__) == 1
|
||||||
"text",
|
else t
|
||||||
"text-plain",
|
for t in expected
|
||||||
"tool_call",
|
}
|
||||||
"tool_call_chunk",
|
assert expected == KNOWN_BLOCK_TYPES
|
||||||
"video",
|
|
||||||
"web_search_call",
|
|
||||||
"web_search_result",
|
|
||||||
} == KNOWN_BLOCK_TYPES
|
|
||||||
|
Loading…
Reference in New Issue
Block a user