Merge branch 'standard_outputs_copy' into mdrxy/ollama_v1 + updates

This commit is contained in:
Mason Daugherty 2025-08-04 12:57:38 -04:00
commit cc56b8dbd3
No known key found for this signature in database
11 changed files with 358 additions and 262 deletions

View File

@ -25,7 +25,7 @@ from pydantic import (
Field,
field_validator,
)
from typing_extensions import TypeAlias, override
from typing_extensions import override
from langchain_core.caches import BaseCache
from langchain_core.callbacks import (
@ -79,8 +79,8 @@ if TYPE_CHECKING:
def _generate_response_from_error(error: BaseException) -> list[AIMessageV1]:
if hasattr(error, "response"):
response = error.response
response = getattr(error, "response", None)
if response is not None:
metadata: dict = {}
if hasattr(response, "headers"):
try:
@ -90,7 +90,7 @@ def _generate_response_from_error(error: BaseException) -> list[AIMessageV1]:
if hasattr(response, "status_code"):
metadata["status_code"] = response.status_code
if hasattr(error, "request_id"):
metadata["request_id"] = error.request_id
metadata["request_id"] = error.request_id # type: ignore[arg-type]
# Permit response_metadata without model_name, model_provider fields
generations = [AIMessageV1(content=[], response_metadata=metadata)] # type: ignore[arg-type]
else:
@ -118,7 +118,7 @@ def _format_for_tracing(messages: Sequence[MessageV1]) -> list[MessageV1]:
for idx, block in enumerate(message.content):
# Update image content blocks to OpenAI # Chat Completions format.
if (
block["type"] == "image"
block.get("type") == "image"
and is_data_content_block(block) # type: ignore[arg-type] # permit unnecessary runtime check
and block.get("source_type") != "id"
):
@ -338,7 +338,7 @@ class BaseChatModelV1(RunnableSerializable[LanguageModelInput, AIMessageV1], ABC
@property
@override
def InputType(self) -> TypeAlias:
def InputType(self) -> Any:
"""Get the input type for this runnable."""
from langchain_core.prompt_values import (
ChatPromptValueConcrete,
@ -458,7 +458,7 @@ class BaseChatModelV1(RunnableSerializable[LanguageModelInput, AIMessageV1], ABC
chunks: list[AIMessageChunkV1] = []
try:
for msg in self._stream(input_messages, **kwargs):
run_manager.on_llm_new_token(msg.text or "")
run_manager.on_llm_new_token(msg.text)
chunks.append(msg)
except BaseException as e:
run_manager.on_llm_error(e, response=_generate_response_from_error(e))
@ -525,7 +525,7 @@ class BaseChatModelV1(RunnableSerializable[LanguageModelInput, AIMessageV1], ABC
chunks: list[AIMessageChunkV1] = []
try:
async for msg in self._astream(input_messages, **kwargs):
await run_manager.on_llm_new_token(msg.text or "")
await run_manager.on_llm_new_token(msg.text)
chunks.append(msg)
except BaseException as e:
await run_manager.on_llm_error(
@ -602,9 +602,12 @@ class BaseChatModelV1(RunnableSerializable[LanguageModelInput, AIMessageV1], ABC
# TODO: replace this with something for new messages
input_messages = _normalize_messages_v1(messages)
for msg in self._stream(input_messages, **kwargs):
run_manager.on_llm_new_token(msg.text or "")
run_manager.on_llm_new_token(msg.text)
chunks.append(msg)
yield msg
if msg.chunk_position != "last":
yield (AIMessageChunkV1([], chunk_position="last"))
except BaseException as e:
run_manager.on_llm_error(e, response=_generate_response_from_error(e))
raise
@ -673,9 +676,11 @@ class BaseChatModelV1(RunnableSerializable[LanguageModelInput, AIMessageV1], ABC
input_messages,
**kwargs,
):
await run_manager.on_llm_new_token(msg.text or "")
await run_manager.on_llm_new_token(msg.text)
chunks.append(msg)
yield msg
if msg.chunk_position != "last":
yield (AIMessageChunkV1([], chunk_position="last"))
except BaseException as e:
await run_manager.on_llm_error(e, response=_generate_response_from_error(e))
raise
@ -716,22 +721,23 @@ class BaseChatModelV1(RunnableSerializable[LanguageModelInput, AIMessageV1], ABC
ls_params["ls_stop"] = stop
# model
if hasattr(self, "model") and isinstance(self.model, str):
ls_params["ls_model_name"] = self.model
elif hasattr(self, "model_name") and isinstance(self.model_name, str):
ls_params["ls_model_name"] = self.model_name
model = (
kwargs.get("model")
or getattr(self, "model", None)
or getattr(self, "model_name", None)
)
if isinstance(model, str):
ls_params["ls_model_name"] = model
# temperature
if "temperature" in kwargs and isinstance(kwargs["temperature"], float):
ls_params["ls_temperature"] = kwargs["temperature"]
elif hasattr(self, "temperature") and isinstance(self.temperature, float):
ls_params["ls_temperature"] = self.temperature
temperature = kwargs.get("temperature") or getattr(self, "temperature", None)
if isinstance(temperature, (int, float)):
ls_params["ls_temperature"] = temperature
# max_tokens
if "max_tokens" in kwargs and isinstance(kwargs["max_tokens"], int):
ls_params["ls_max_tokens"] = kwargs["max_tokens"]
elif hasattr(self, "max_tokens") and isinstance(self.max_tokens, int):
ls_params["ls_max_tokens"] = self.max_tokens
max_tokens = kwargs.get("max_tokens") or getattr(self, "max_tokens", None)
if isinstance(max_tokens, int):
ls_params["ls_max_tokens"] = max_tokens
return ls_params
@ -806,7 +812,7 @@ class BaseChatModelV1(RunnableSerializable[LanguageModelInput, AIMessageV1], ABC
Union[typing.Dict[str, Any], type, Callable, BaseTool] # noqa: UP006
],
*,
tool_choice: Optional[Union[str]] = None,
tool_choice: Optional[str] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, AIMessageV1]:
"""Bind tools to the model.

View File

@ -103,7 +103,7 @@ The module defines several types of content blocks, including:
""" # noqa: E501
import warnings
from typing import Any, Literal, Optional, Union
from typing import Any, Literal, Optional, TypeGuard, Union
from uuid import uuid4
from typing_extensions import NotRequired, TypedDict, get_args, get_origin
@ -844,8 +844,6 @@ ContentBlock = Union[
TextContentBlock,
ToolCall,
ToolCallChunk,
Citation,
NonStandardAnnotation,
InvalidToolCall,
ReasoningContentBlock,
NonStandardContentBlock,
@ -884,7 +882,24 @@ def _extract_typedict_type_values(union_type: Any) -> set[str]:
return result
KNOWN_BLOCK_TYPES = _extract_typedict_type_values(ContentBlock)
KNOWN_BLOCK_TYPES = {
"text",
"text-plain",
"tool_call",
"invalid_tool_call",
"tool_call_chunk",
"reasoning",
"non_standard",
"image",
"audio",
"file",
"video",
"code_interpreter_call",
"code_interpreter_output",
"code_interpreter_result",
"web_search_call",
"web_search_result",
}
def is_data_content_block(block: dict) -> bool:
@ -914,6 +929,28 @@ def is_data_content_block(block: dict) -> bool:
)
def is_tool_call_block(block: ContentBlock) -> TypeGuard[ToolCall]:
"""Type guard to check if a content block is a tool call."""
return block.get("type") == "tool_call"
def is_tool_call_chunk(block: ContentBlock) -> TypeGuard[ToolCallChunk]:
"""Type guard to check if a content block is a tool call chunk."""
return block.get("type") == "tool_call_chunk"
def is_text_block(block: ContentBlock) -> TypeGuard[TextContentBlock]:
"""Type guard to check if a content block is a text block."""
return block.get("type") == "text"
def is_invalid_tool_call_block(
block: ContentBlock,
) -> TypeGuard[InvalidToolCall]:
"""Type guard to check if a content block is an invalid tool call."""
return block.get("type") == "invalid_tool_call"
def convert_to_openai_image_block(block: dict[str, Any]) -> dict:
"""Convert image content block to format expected by OpenAI Chat Completions API."""
if "url" in block:

View File

@ -4,10 +4,9 @@ Each message has content that may be comprised of content blocks, defined under
``langchain_core.messages.content_blocks``.
"""
import json
import uuid
from dataclasses import dataclass, field
from typing import Any, Literal, Optional, TypeGuard, Union, cast, get_args
from typing import Any, Literal, Optional, Union, cast, get_args
from pydantic import BaseModel
from typing_extensions import TypedDict
@ -20,32 +19,12 @@ from langchain_core.messages.ai import (
add_usage,
)
from langchain_core.messages.base import merge_content
from langchain_core.messages.content_blocks import create_text_block
from langchain_core.messages.tool import ToolCallChunk
from langchain_core.messages.tool import invalid_tool_call as create_invalid_tool_call
from langchain_core.messages.tool import tool_call as create_tool_call
from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
from langchain_core.utils._merge import merge_dicts, merge_lists
from langchain_core.utils._merge import merge_dicts
from langchain_core.utils.json import parse_partial_json
def is_tool_call_block(block: types.ContentBlock) -> TypeGuard[types.ToolCall]:
"""Type guard to check if a content block is a tool call."""
return block.get("type") == "tool_call"
def is_text_block(block: types.ContentBlock) -> TypeGuard[types.TextContentBlock]:
"""Type guard to check if a content block is a text block."""
return block.get("type") == "text"
def is_invalid_tool_call_block(
block: types.ContentBlock,
) -> TypeGuard[types.InvalidToolCall]:
"""Type guard to check if a content block is an invalid tool call."""
return block.get("type") == "invalid_tool_call"
def _ensure_id(id_val: Optional[str]) -> str:
"""Ensure the ID is a valid string, generating a new UUID if not provided.
@ -169,7 +148,7 @@ class AIMessage:
parsed: Optional auto-parsed message contents, if applicable.
"""
if isinstance(content, str):
self.content = [create_text_block(content)]
self.content = [types.create_text_block(content)]
else:
self.content = content
@ -188,33 +167,46 @@ class AIMessage:
content_tool_calls = {
block["id"]
for block in self.content
if block.get("type") == "tool_call" and "id" in block
if types.is_tool_call_block(block) and "id" in block
}
for tool_call in tool_calls:
if "id" in tool_call and tool_call["id"] in content_tool_calls:
continue
self.content.append(tool_call)
if invalid_tool_calls:
content_tool_calls = {
block["id"]
for block in self.content
if types.is_invalid_tool_call_block(block) and "id" in block
}
for invalid_tool_call in invalid_tool_calls:
if (
"id" in invalid_tool_call
and invalid_tool_call["id"] in content_tool_calls
):
continue
self.content.append(invalid_tool_call)
self._tool_calls: list[types.ToolCall] = [
block for block in self.content if is_tool_call_block(block)
block for block in self.content if types.is_tool_call_block(block)
]
self._invalid_tool_calls = [
block for block in self.content if types.is_invalid_tool_call_block(block)
]
self.invalid_tool_calls = invalid_tool_calls or []
@property
def text(self) -> Optional[str]:
def text(self) -> str:
"""Extract all text content from the AI message as a string."""
text_blocks = [block for block in self.content if is_text_block(block)]
if text_blocks:
return "".join(block["text"] for block in text_blocks)
return None
return "".join(
block["text"] for block in self.content if types.is_text_block(block)
)
@property
def tool_calls(self) -> list[types.ToolCall]: # update once we fix branch
def tool_calls(self) -> list[types.ToolCall]:
"""Get the tool calls made by the AI."""
if self._tool_calls:
return self._tool_calls
tool_calls = [block for block in self.content if is_tool_call_block(block)]
if tool_calls:
self._tool_calls = tool_calls
if not self._tool_calls:
self._tool_calls = [
block for block in self.content if types.is_tool_call_block(block)
]
return self._tool_calls
@tool_calls.setter
@ -222,6 +214,17 @@ class AIMessage:
"""Set the tool calls for the AI message."""
self._tool_calls = value
@property
def invalid_tool_calls(self) -> list[types.InvalidToolCall]:
"""Get the invalid tool calls made by the AI."""
if not self._invalid_tool_calls:
self._invalid_tool_calls = [
block
for block in self.content
if types.is_invalid_tool_call_block(block)
]
return self._invalid_tool_calls
@dataclass
class AIMessageChunk(AIMessage):
@ -246,17 +249,10 @@ class AIMessageChunk(AIMessage):
when deserializing messages.
"""
tool_call_chunks: list[types.ToolCallChunk] = field(init=False)
"""List of partial tool call data.
Emitted by the model during streaming, this field contains
tool call chunks that may not yet be complete. It is used to reconstruct
tool calls from the streamed content.
"""
def __init__(
self,
content: Union[str, list[types.ContentBlock]],
*,
id: Optional[str] = None,
name: Optional[str] = None,
lc_version: str = "v1",
@ -264,6 +260,7 @@ class AIMessageChunk(AIMessage):
usage_metadata: Optional[UsageMetadata] = None,
tool_call_chunks: Optional[list[types.ToolCallChunk]] = None,
parsed: Optional[Union[dict[str, Any], BaseModel]] = None,
chunk_position: Optional[Literal["last"]] = None,
):
"""Initialize an AI message.
@ -276,6 +273,8 @@ class AIMessageChunk(AIMessage):
usage_metadata: Optional metadata about token usage.
tool_call_chunks: Optional list of partial tool call data.
parsed: Optional auto-parsed message contents, if applicable.
chunk_position: Optional position of the chunk in the stream. If "last",
tool calls will be parsed when aggregated into a stream.
"""
if isinstance(content, str):
self.content = [{"type": "text", "text": content, "index": 0}]
@ -287,115 +286,51 @@ class AIMessageChunk(AIMessage):
self.lc_version = lc_version
self.usage_metadata = usage_metadata
self.parsed = parsed
self.chunk_position = chunk_position
if response_metadata is None:
self.response_metadata = {}
else:
self.response_metadata = response_metadata
if tool_call_chunks is None:
self.tool_call_chunks: list[types.ToolCallChunk] = []
else:
self.tool_call_chunks = tool_call_chunks
if tool_call_chunks:
content_tool_call_chunks = {
block["id"]
for block in self.content
if types.is_tool_call_chunk(block) and "id" in block
}
for chunk in tool_call_chunks:
if "id" in chunk and chunk["id"] in content_tool_call_chunks:
continue
self.content.append(chunk)
self._tool_call_chunks = [
block for block in self.content if types.is_tool_call_chunk(block)
]
self._tool_calls: list[types.ToolCall] = []
self.invalid_tool_calls: list[types.InvalidToolCall] = []
self._init_tool_calls()
def _init_tool_calls(self) -> None:
"""Initialize tool calls from tool call chunks.
Args:
values: The values to validate.
Raises:
ValueError: If the tool call chunks are malformed.
"""
self._tool_calls = []
self.invalid_tool_calls = []
if not self.tool_call_chunks:
if self._tool_calls:
self.tool_call_chunks = [
create_tool_call_chunk(
name=tc["name"],
args=json.dumps(tc["args"]),
id=tc["id"],
index=None,
)
for tc in self._tool_calls
]
if self.invalid_tool_calls:
tool_call_chunks = self.tool_call_chunks
tool_call_chunks.extend(
[
create_tool_call_chunk(
name=tc["name"], args=tc["args"], id=tc["id"], index=None
)
for tc in self.invalid_tool_calls
]
)
self.tool_call_chunks = tool_call_chunks
tool_calls = []
invalid_tool_calls = []
def add_chunk_to_invalid_tool_calls(chunk: ToolCallChunk) -> None:
invalid_tool_calls.append(
create_invalid_tool_call(
name=chunk.get("name", ""),
args=chunk.get("args", ""),
id=chunk.get("id", ""),
error=None,
)
)
for chunk in self.tool_call_chunks:
try:
args_ = parse_partial_json(chunk["args"]) if chunk["args"] != "" else {} # type: ignore[arg-type]
if isinstance(args_, dict):
tool_calls.append(
create_tool_call(
name=chunk.get("name") or "",
args=args_,
id=chunk.get("id", ""),
)
)
else:
add_chunk_to_invalid_tool_calls(chunk)
except Exception:
add_chunk_to_invalid_tool_calls(chunk)
self._tool_calls = tool_calls
self.invalid_tool_calls = invalid_tool_calls
self._invalid_tool_calls: list[types.InvalidToolCall] = []
@property
def text(self) -> Optional[str]:
"""Extract all text content from the AI message as a string."""
text_blocks = [block for block in self.content if is_text_block(block)]
if text_blocks:
return "".join(block["text"] for block in text_blocks)
return None
@property
def reasoning(self) -> Optional[str]:
"""Extract all reasoning text from the AI message as a string."""
text_blocks = [
block
for block in self.content
if block.get("type") == "reasoning" and "reasoning" in block
]
if text_blocks:
return "".join(
cast("types.ReasoningContentBlock", block).get("reasoning", "")
for block in text_blocks
)
return None
def tool_call_chunks(self) -> list[types.ToolCallChunk]:
"""Get the tool calls made by the AI."""
if not self._tool_call_chunks:
self._tool_call_chunks = [
block for block in self.content if types.is_tool_call_chunk(block)
]
return cast("list[types.ToolCallChunk]", self._tool_call_chunks)
@property
def tool_calls(self) -> list[types.ToolCall]:
"""Get the tool calls made by the AI."""
if self._tool_calls:
return self._tool_calls
tool_calls = [block for block in self.content if is_tool_call_block(block)]
if tool_calls:
self._tool_calls = tool_calls
if not self._tool_calls:
parsed_content = _init_tool_calls(self.content)
self._tool_calls = [
block for block in parsed_content if types.is_tool_call_block(block)
]
self._invalid_tool_calls = [
block
for block in parsed_content
if types.is_invalid_tool_call_block(block)
]
return self._tool_calls
@tool_calls.setter
@ -403,6 +338,21 @@ class AIMessageChunk(AIMessage):
"""Set the tool calls for the AI message."""
self._tool_calls = value
@property
def invalid_tool_calls(self) -> list[types.InvalidToolCall]:
"""Get the invalid tool calls made by the AI."""
if not self._invalid_tool_calls:
parsed_content = _init_tool_calls(self.content)
self._tool_calls = [
block for block in parsed_content if types.is_tool_call_block(block)
]
self._invalid_tool_calls = [
block
for block in parsed_content
if types.is_invalid_tool_call_block(block)
]
return self._invalid_tool_calls
def __add__(self, other: Any) -> "AIMessageChunk":
"""Add ``AIMessageChunk`` to this one."""
if isinstance(other, AIMessageChunk):
@ -417,49 +367,76 @@ class AIMessageChunk(AIMessage):
def to_message(self) -> "AIMessage":
"""Convert this ``AIMessageChunk`` to an AIMessage."""
return AIMessage(
content=self.content,
content=_init_tool_calls(self.content),
id=self.id,
name=self.name,
lc_version=self.lc_version,
response_metadata=self.response_metadata,
usage_metadata=self.usage_metadata,
tool_calls=self.tool_calls,
invalid_tool_calls=self.invalid_tool_calls,
parsed=self.parsed,
)
def _init_tool_calls(content: list[types.ContentBlock]) -> list[types.ContentBlock]:
"""Parse tool call chunks in content into tool calls."""
new_content = []
for block in content:
if not types.is_tool_call_chunk(block):
new_content.append(block)
continue
try:
args_ = (
parse_partial_json(cast("str", block.get("args") or ""))
if block.get("args")
else {}
)
if isinstance(args_, dict):
new_content.append(
create_tool_call(
name=cast("str", block.get("name") or ""),
args=args_,
id=cast("str", block.get("id", "")),
)
)
else:
new_content.append(
create_invalid_tool_call(
name=cast("str", block.get("name", "")),
args=cast("str", block.get("args", "")),
id=cast("str", block.get("id", "")),
error=None,
)
)
except Exception:
new_content.append(
create_invalid_tool_call(
name=cast("str", block.get("name", "")),
args=cast("str", block.get("args", "")),
id=cast("str", block.get("id", "")),
error=None,
)
)
return new_content
def add_ai_message_chunks(
left: AIMessageChunk, *others: AIMessageChunk
) -> AIMessageChunk:
"""Add multiple ``AIMessageChunks`` together."""
if not others:
return left
content = merge_content(
cast("list[str | dict[Any, Any]]", left.content),
*(cast("list[str | dict[Any, Any]]", o.content) for o in others),
content = cast(
"list[types.ContentBlock]",
merge_content(
cast("list[str | dict[Any, Any]]", left.content),
*(cast("list[str | dict[Any, Any]]", o.content) for o in others),
),
)
response_metadata = merge_dicts(
cast("dict", left.response_metadata),
*(cast("dict", o.response_metadata) for o in others),
)
# Merge tool call chunks
if raw_tool_calls := merge_lists(
left.tool_call_chunks, *(o.tool_call_chunks for o in others)
):
tool_call_chunks = [
create_tool_call_chunk(
name=rtc.get("name"),
args=rtc.get("args"),
index=rtc.get("index"),
id=rtc.get("id"),
)
for rtc in raw_tool_calls
]
else:
tool_call_chunks = []
# Token usage
if left.usage_metadata or any(o.usage_metadata is not None for o in others):
usage_metadata: Optional[UsageMetadata] = left.usage_metadata
@ -501,13 +478,19 @@ def add_ai_message_chunks(
chunk_id = id_
break
chunk_position: Optional[Literal["last"]] = (
"last" if any(x.chunk_position == "last" for x in [left, *others]) else None
)
if chunk_position == "last":
content = _init_tool_calls(content)
return left.__class__(
content=cast("list[types.ContentBlock]", content),
tool_call_chunks=tool_call_chunks,
content=content,
response_metadata=cast("ResponseMetadata", response_metadata),
usage_metadata=usage_metadata,
parsed=parsed,
id=chunk_id,
chunk_position=chunk_position,
)
@ -579,9 +562,7 @@ class HumanMessage:
Concatenated string of all text blocks in the message.
"""
return "".join(
cast("types.TextContentBlock", block)["text"]
for block in self.content
if block.get("type") == "text"
block["text"] for block in self.content if types.is_text_block(block)
)
@ -660,9 +641,7 @@ class SystemMessage:
def text(self) -> str:
"""Extract all text content from the system message."""
return "".join(
cast("types.TextContentBlock", block)["text"]
for block in self.content
if block.get("type") == "text"
block["text"] for block in self.content if types.is_text_block(block)
)
@ -754,9 +733,7 @@ class ToolMessage:
def text(self) -> str:
"""Extract all text content from the tool message."""
return "".join(
cast("types.TextContentBlock", block)["text"]
for block in self.content
if block.get("type") == "text"
block["text"] for block in self.content if types.is_text_block(block)
)
def __post_init__(self) -> None:

View File

@ -283,7 +283,7 @@ class BaseOutputParser(
Structured output.
"""
if isinstance(result, AIMessage):
return self.parse(result.text or "")
return self.parse(result.text)
return self.parse(result[0].text)
@abstractmethod

View File

@ -73,7 +73,7 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
Raises:
OutputParserException: If the output is not valid JSON.
"""
text = result.text or "" if isinstance(result, AIMessage) else result[0].text
text = result.text if isinstance(result, AIMessage) else result[0].text
text = text.strip()
if partial:
try:

View File

@ -83,7 +83,7 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]):
continue
buffer += chunk_content
elif isinstance(chunk, AIMessage):
buffer += chunk.text or ""
buffer += chunk.text
else:
# add current chunk to buffer
buffer += chunk
@ -119,7 +119,7 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]):
continue
buffer += chunk_content
elif isinstance(chunk, AIMessage):
buffer += chunk.text or ""
buffer += chunk.text
else:
# add current chunk to buffer
buffer += chunk

View File

@ -14,7 +14,10 @@ from langchain_core.language_models import (
ParrotFakeChatModel,
)
from langchain_core.language_models._utils import _normalize_messages
from langchain_core.language_models.fake_chat_models import FakeListChatModelError
from langchain_core.language_models.fake_chat_models import (
FakeListChatModelError,
GenericFakeChatModelV1,
)
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
@ -22,6 +25,7 @@ from langchain_core.messages import (
HumanMessage,
SystemMessage,
)
from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.outputs.llm_result import LLMResult
from langchain_core.tracers import LogStreamCallbackHandler
@ -654,3 +658,93 @@ def test_normalize_messages_edge_cases() -> None:
)
]
assert messages == _normalize_messages(messages)
def test_streaming_v1() -> None:
chunks = [
AIMessageChunkV1(
[
{
"type": "reasoning",
"reasoning": "Let's call a tool.",
"index": 0,
}
]
),
AIMessageChunkV1(
[],
tool_call_chunks=[
{
"type": "tool_call_chunk",
"args": "",
"name": "tool_name",
"id": "call_123",
"index": 1,
},
],
),
AIMessageChunkV1(
[],
tool_call_chunks=[
{
"type": "tool_call_chunk",
"args": '{"a',
"name": "",
"id": "",
"index": 1,
},
],
),
AIMessageChunkV1(
[],
tool_call_chunks=[
{
"type": "tool_call_chunk",
"args": '": 1}',
"name": "",
"id": "",
"index": 1,
},
],
),
]
full: Optional[AIMessageChunkV1] = None
for chunk in chunks:
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunkV1)
assert full.content == [
{
"type": "reasoning",
"reasoning": "Let's call a tool.",
"index": 0,
},
{
"type": "tool_call_chunk",
"args": '{"a": 1}',
"name": "tool_name",
"id": "call_123",
"index": 1,
},
]
llm = GenericFakeChatModelV1(message_chunks=chunks)
full = None
for chunk in llm.stream("anything"):
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunkV1)
assert full.content == [
{
"type": "reasoning",
"reasoning": "Let's call a tool.",
"index": 0,
},
{
"type": "tool_call",
"args": {"a": 1},
"name": "tool_name",
"id": "call_123",
},
]

View File

@ -37,7 +37,7 @@ def test_base_generation_parser() -> None:
that support streaming
"""
if isinstance(result, AIMessageV1):
content = result.text or ""
content = result.text
else:
if len(result) != 1:
msg = (
@ -89,7 +89,7 @@ def test_base_transform_output_parser() -> None:
that support streaming
"""
if isinstance(result, AIMessageV1):
content = result.text or ""
content = result.text
else:
if len(result) != 1:
msg = (
@ -116,4 +116,4 @@ def test_base_transform_output_parser() -> None:
model_v1 = GenericFakeChatModelV1(message_chunks=["hello", " ", "world"])
chain_v1 = model_v1 | StrInvertCase()
chunks = list(chain_v1.stream(""))
assert chunks == ["HELLO", " ", "WORLD"]
assert chunks == ["HELLO", " ", "WORLD", ""]

View File

@ -2543,12 +2543,6 @@
dict({
'$ref': '#/$defs/ToolCallChunk',
}),
dict({
'$ref': '#/$defs/Citation',
}),
dict({
'$ref': '#/$defs/NonStandardAnnotation',
}),
dict({
'$ref': '#/$defs/InvalidToolCall',
}),
@ -2675,12 +2669,6 @@
dict({
'$ref': '#/$defs/ToolCallChunk',
}),
dict({
'$ref': '#/$defs/Citation',
}),
dict({
'$ref': '#/$defs/NonStandardAnnotation',
}),
dict({
'$ref': '#/$defs/InvalidToolCall',
}),
@ -2807,12 +2795,6 @@
dict({
'$ref': '#/$defs/ToolCallChunk',
}),
dict({
'$ref': '#/$defs/Citation',
}),
dict({
'$ref': '#/$defs/NonStandardAnnotation',
}),
dict({
'$ref': '#/$defs/InvalidToolCall',
}),
@ -2901,12 +2883,6 @@
dict({
'$ref': '#/$defs/ToolCallChunk',
}),
dict({
'$ref': '#/$defs/Citation',
}),
dict({
'$ref': '#/$defs/NonStandardAnnotation',
}),
dict({
'$ref': '#/$defs/InvalidToolCall',
}),
@ -3018,12 +2994,6 @@
dict({
'$ref': '#/$defs/ToolCallChunk',
}),
dict({
'$ref': '#/$defs/Citation',
}),
dict({
'$ref': '#/$defs/NonStandardAnnotation',
}),
dict({
'$ref': '#/$defs/InvalidToolCall',
}),

View File

@ -11535,6 +11535,9 @@
dict({
'$ref': '#/definitions/InvalidToolCall',
}),
dict({
'$ref': '#/definitions/ToolCallChunk',
}),
dict({
'$ref': '#/definitions/ReasoningContentBlock',
}),
@ -11657,6 +11660,9 @@
dict({
'$ref': '#/definitions/InvalidToolCall',
}),
dict({
'$ref': '#/definitions/ToolCallChunk',
}),
dict({
'$ref': '#/definitions/ReasoningContentBlock',
}),
@ -11779,6 +11785,9 @@
dict({
'$ref': '#/definitions/InvalidToolCall',
}),
dict({
'$ref': '#/definitions/ToolCallChunk',
}),
dict({
'$ref': '#/definitions/ReasoningContentBlock',
}),
@ -11863,6 +11872,9 @@
dict({
'$ref': '#/definitions/InvalidToolCall',
}),
dict({
'$ref': '#/definitions/ToolCallChunk',
}),
dict({
'$ref': '#/definitions/ReasoningContentBlock',
}),
@ -11970,6 +11982,9 @@
dict({
'$ref': '#/definitions/InvalidToolCall',
}),
dict({
'$ref': '#/definitions/ToolCallChunk',
}),
dict({
'$ref': '#/definitions/ReasoningContentBlock',
}),

View File

@ -3,6 +3,7 @@ import uuid
from typing import Optional, Union
import pytest
from typing_extensions import get_args
from langchain_core.documents import Document
from langchain_core.load import dumpd, load
@ -30,7 +31,7 @@ from langchain_core.messages import (
messages_from_dict,
messages_to_dict,
)
from langchain_core.messages.content_blocks import KNOWN_BLOCK_TYPES
from langchain_core.messages.content_blocks import KNOWN_BLOCK_TYPES, ContentBlock
from langchain_core.messages.tool import invalid_tool_call as create_invalid_tool_call
from langchain_core.messages.tool import tool_call as create_tool_call
from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
@ -1363,23 +1364,19 @@ def test_convert_to_openai_image_block() -> None:
def test_known_block_types() -> None:
assert {
"audio",
"citation",
"code_interpreter_call",
"code_interpreter_output",
"code_interpreter_result",
"file",
"image",
"invalid_tool_call",
"non_standard",
"non_standard_annotation",
"reasoning",
"text",
"text-plain",
"tool_call",
"tool_call_chunk",
"video",
"web_search_call",
"web_search_result",
} == KNOWN_BLOCK_TYPES
expected = {
bt
for bt in get_args(ContentBlock)
for bt in get_args(bt.__annotations__["type"])
}
# Normalize any Literal[...] types in block types to their string values.
# This ensures all entries are plain strings, not Literal objects.
expected = {
t
if isinstance(t, str)
else t.__args__[0]
if hasattr(t, "__args__") and len(t.__args__) == 1
else t
for t in expected
}
assert expected == KNOWN_BLOCK_TYPES