Merge branch 'standard_outputs_copy' into mdrxy/ollama_v1

This commit is contained in:
Mason Daugherty 2025-07-30 17:41:16 -04:00 committed by GitHub
commit 588fe46601
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 4011 additions and 1129 deletions

View File

@ -3,7 +3,7 @@
import asyncio
import re
import time
from collections.abc import AsyncIterator, Iterator
from collections.abc import AsyncIterator, Iterable, Iterator
from typing import Any, Optional, Union, cast
from typing_extensions import override
@ -13,7 +13,11 @@ from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel, SimpleChatModel
from langchain_core.language_models.v1.chat_models import BaseChatModelV1
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.messages.v1 import AIMessage as AIMessageV1
from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1
from langchain_core.messages.v1 import MessageV1
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import RunnableConfig
@ -367,3 +371,69 @@ class ParrotFakeChatModel(BaseChatModel):
@property
def _llm_type(self) -> str:
return "parrot-fake-chat-model"
class GenericFakeChatModelV1(BaseChatModelV1):
"""Generic fake chat model that can be used to test the chat model interface."""
messages: Optional[Iterator[Union[AIMessageV1, str]]] = None
message_chunks: Optional[Iterable[Union[AIMessageChunkV1, str]]] = None
@override
def _invoke(
self,
messages: list[MessageV1],
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AIMessageV1:
"""Top Level call."""
if self.messages is None:
error_msg = "Messages iterator is not set."
raise ValueError(error_msg)
message = next(self.messages)
return AIMessageV1(content=message) if isinstance(message, str) else message
@override
def _stream(
self,
messages: list[MessageV1],
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[AIMessageChunkV1]:
"""Top Level call."""
if self.message_chunks is None:
error_msg = "Message chunks iterator is not set."
raise ValueError(error_msg)
for chunk in self.message_chunks:
if isinstance(chunk, str):
yield AIMessageChunkV1(chunk)
else:
yield chunk
@property
def _llm_type(self) -> str:
return "generic-fake-chat-model"
class ParrotFakeChatModelV1(BaseChatModelV1):
"""Generic fake chat model that can be used to test the chat model interface.
* Chat model should be usable in both sync and async tests
"""
@override
def _invoke(
self,
messages: list[MessageV1],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AIMessageV1:
"""Top Level call."""
if isinstance(messages[-1], AIMessageV1):
return messages[-1]
return AIMessageV1(content=messages[-1].content)
@property
def _llm_type(self) -> str:
return "parrot-fake-chat-model"

View File

@ -3,9 +3,10 @@
import json
import uuid
from dataclasses import dataclass, field
from typing import Any, Literal, Optional, TypedDict, Union, cast, get_args
from typing import Any, Literal, Optional, Union, cast, get_args
from pydantic import BaseModel
from typing_extensions import TypedDict
import langchain_core.messages.content_blocks as types
from langchain_core.messages.ai import _LC_ID_PREFIX, UsageMetadata, add_usage
@ -329,7 +330,11 @@ class AIMessageChunk(AIMessage):
@property
def reasoning(self) -> Optional[str]:
"""Extract all reasoning text from the AI message as a string."""
text_blocks = [block for block in self.content if block["type"] == "reasoning"]
text_blocks = [
block
for block in self.content
if block["type"] == "reasoning" and "reasoning" in block
]
if text_blocks:
return "".join(block["reasoning"] for block in text_blocks)
return None

View File

@ -11,12 +11,14 @@ from typing import (
Optional,
TypeVar,
Union,
cast,
)
from typing_extensions import override
from langchain_core.language_models import LanguageModelOutput
from langchain_core.messages import AnyMessage, BaseMessage
from langchain_core.messages.v1 import AIMessage, MessageV1, MessageV1Types
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.runnables import Runnable, RunnableConfig, RunnableSerializable
from langchain_core.runnables.config import run_in_executor
@ -25,14 +27,16 @@ if TYPE_CHECKING:
from langchain_core.prompt_values import PromptValue
T = TypeVar("T")
OutputParserLike = Runnable[LanguageModelOutput, T]
OutputParserLike = Runnable[Union[LanguageModelOutput, AIMessage], T]
class BaseLLMOutputParser(ABC, Generic[T]):
"""Abstract base class for parsing the outputs of a model."""
@abstractmethod
def parse_result(self, result: list[Generation], *, partial: bool = False) -> T:
def parse_result(
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
) -> T:
"""Parse a list of candidate model Generations into a specific format.
Args:
@ -46,7 +50,7 @@ class BaseLLMOutputParser(ABC, Generic[T]):
"""
async def aparse_result(
self, result: list[Generation], *, partial: bool = False
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
) -> T:
"""Async parse a list of candidate model Generations into a specific format.
@ -71,7 +75,7 @@ class BaseGenerationOutputParser(
@override
def InputType(self) -> Any:
"""Return the input type for the parser."""
return Union[str, AnyMessage]
return Union[str, AnyMessage, MessageV1]
@property
@override
@ -84,7 +88,7 @@ class BaseGenerationOutputParser(
@override
def invoke(
self,
input: Union[str, BaseMessage],
input: Union[str, BaseMessage, MessageV1],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> T:
@ -97,9 +101,16 @@ class BaseGenerationOutputParser(
config,
run_type="parser",
)
if isinstance(input, MessageV1Types):
return self._call_with_config(
lambda inner_input: self.parse_result(inner_input),
input,
config,
run_type="parser",
)
return self._call_with_config(
lambda inner_input: self.parse_result([Generation(text=inner_input)]),
input,
cast("str", input),
config,
run_type="parser",
)
@ -120,6 +131,13 @@ class BaseGenerationOutputParser(
config,
run_type="parser",
)
if isinstance(input, MessageV1Types):
return await self._acall_with_config(
lambda inner_input: self.aparse_result(inner_input),
input,
config,
run_type="parser",
)
return await self._acall_with_config(
lambda inner_input: self.aparse_result([Generation(text=inner_input)]),
input,
@ -129,7 +147,7 @@ class BaseGenerationOutputParser(
class BaseOutputParser(
BaseLLMOutputParser, RunnableSerializable[LanguageModelOutput, T]
BaseLLMOutputParser, RunnableSerializable[Union[LanguageModelOutput, AIMessage], T]
):
"""Base class to parse the output of an LLM call.
@ -162,7 +180,7 @@ class BaseOutputParser(
@override
def InputType(self) -> Any:
"""Return the input type for the parser."""
return Union[str, AnyMessage]
return Union[str, AnyMessage, MessageV1]
@property
@override
@ -189,7 +207,7 @@ class BaseOutputParser(
@override
def invoke(
self,
input: Union[str, BaseMessage],
input: Union[str, BaseMessage, MessageV1],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> T:
@ -202,9 +220,16 @@ class BaseOutputParser(
config,
run_type="parser",
)
if isinstance(input, MessageV1Types):
return self._call_with_config(
lambda inner_input: self.parse_result(inner_input),
input,
config,
run_type="parser",
)
return self._call_with_config(
lambda inner_input: self.parse_result([Generation(text=inner_input)]),
input,
cast("str", input),
config,
run_type="parser",
)
@ -212,7 +237,7 @@ class BaseOutputParser(
@override
async def ainvoke(
self,
input: Union[str, BaseMessage],
input: Union[str, BaseMessage, MessageV1],
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> T:
@ -225,15 +250,24 @@ class BaseOutputParser(
config,
run_type="parser",
)
if isinstance(input, MessageV1Types):
return await self._acall_with_config(
lambda inner_input: self.aparse_result(inner_input),
input,
config,
run_type="parser",
)
return await self._acall_with_config(
lambda inner_input: self.aparse_result([Generation(text=inner_input)]),
input,
cast("str", input),
config,
run_type="parser",
)
@override
def parse_result(self, result: list[Generation], *, partial: bool = False) -> T:
def parse_result(
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
) -> T:
"""Parse a list of candidate model Generations into a specific format.
The return value is parsed from only the first Generation in the result, which
@ -248,6 +282,8 @@ class BaseOutputParser(
Returns:
Structured output.
"""
if isinstance(result, AIMessage):
return self.parse(result.text or "")
return self.parse(result[0].text)
@abstractmethod
@ -262,7 +298,7 @@ class BaseOutputParser(
"""
async def aparse_result(
self, result: list[Generation], *, partial: bool = False
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
) -> T:
"""Async parse a list of candidate model Generations into a specific format.

View File

@ -13,6 +13,7 @@ from pydantic.v1 import BaseModel
from typing_extensions import override
from langchain_core.exceptions import OutputParserException
from langchain_core.messages.v1 import AIMessage
from langchain_core.output_parsers.format_instructions import JSON_FORMAT_INSTRUCTIONS
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser
from langchain_core.outputs import Generation
@ -53,7 +54,9 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
return pydantic_object.schema()
return None
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
def parse_result(
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
) -> Any:
"""Parse the result of an LLM call to a JSON object.
Args:
@ -70,7 +73,7 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
Raises:
OutputParserException: If the output is not valid JSON.
"""
text = result[0].text
text = result.text or "" if isinstance(result, AIMessage) else result[0].text
text = text.strip()
if partial:
try:

View File

@ -12,6 +12,7 @@ from typing import TYPE_CHECKING, TypeVar, Union
from typing_extensions import override
from langchain_core.messages import BaseMessage
from langchain_core.messages.v1 import AIMessage
from langchain_core.output_parsers.transform import BaseTransformOutputParser
if TYPE_CHECKING:
@ -71,7 +72,7 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]):
@override
def _transform(
self, input: Iterator[Union[str, BaseMessage]]
self, input: Iterator[Union[str, BaseMessage, AIMessage]]
) -> Iterator[list[str]]:
buffer = ""
for chunk in input:
@ -81,6 +82,8 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]):
if not isinstance(chunk_content, str):
continue
buffer += chunk_content
elif isinstance(chunk, AIMessage):
buffer += chunk.text or ""
else:
# add current chunk to buffer
buffer += chunk
@ -105,7 +108,7 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]):
@override
async def _atransform(
self, input: AsyncIterator[Union[str, BaseMessage]]
self, input: AsyncIterator[Union[str, BaseMessage, AIMessage]]
) -> AsyncIterator[list[str]]:
buffer = ""
async for chunk in input:
@ -115,6 +118,8 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]):
if not isinstance(chunk_content, str):
continue
buffer += chunk_content
elif isinstance(chunk, AIMessage):
buffer += chunk.text or ""
else:
# add current chunk to buffer
buffer += chunk

View File

@ -11,6 +11,7 @@ from pydantic.v1 import BaseModel as BaseModelV1
from typing_extensions import override
from langchain_core.exceptions import OutputParserException
from langchain_core.messages.v1 import AIMessage
from langchain_core.output_parsers import (
BaseCumulativeTransformOutputParser,
BaseGenerationOutputParser,
@ -26,7 +27,9 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
"""Whether to only return the arguments to the function call."""
@override
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
def parse_result(
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
) -> Any:
"""Parse the result of an LLM call to a JSON object.
Args:
@ -39,6 +42,12 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
Raises:
OutputParserException: If the output is not valid JSON.
"""
if isinstance(result, AIMessage):
msg = (
"This output parser does not support v1 AIMessages. Use "
"JsonOutputToolsParser instead."
)
raise TypeError(msg)
generation = result[0]
if not isinstance(generation, ChatGeneration):
msg = "This output parser can only be used with a chat generation."
@ -77,7 +86,9 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
def _diff(self, prev: Optional[Any], next: Any) -> Any:
return jsonpatch.make_patch(prev, next).patch
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
def parse_result(
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
) -> Any:
"""Parse the result of an LLM call to a JSON object.
Args:
@ -90,6 +101,12 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
Raises:
OutputParserException: If the output is not valid JSON.
"""
if isinstance(result, AIMessage):
msg = (
"This output parser does not support v1 AIMessages. Use "
"JsonOutputToolsParser instead."
)
raise TypeError(msg)
if len(result) != 1:
msg = f"Expected exactly one result, but got {len(result)}"
raise OutputParserException(msg)
@ -160,7 +177,9 @@ class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):
key_name: str
"""The name of the key to return."""
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
def parse_result(
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
) -> Any:
"""Parse the result of an LLM call to a JSON object.
Args:
@ -254,7 +273,9 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
return values
@override
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
def parse_result(
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
) -> Any:
"""Parse the result of an LLM call to a JSON object.
Args:
@ -294,7 +315,9 @@ class PydanticAttrOutputFunctionsParser(PydanticOutputFunctionsParser):
"""The name of the attribute to return."""
@override
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
def parse_result(
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
) -> Any:
"""Parse the result of an LLM call to a JSON object.
Args:

View File

@ -4,7 +4,7 @@ import copy
import json
import logging
from json import JSONDecodeError
from typing import Annotated, Any, Optional
from typing import Annotated, Any, Optional, Union
from pydantic import SkipValidation, ValidationError
@ -12,6 +12,7 @@ from langchain_core.exceptions import OutputParserException
from langchain_core.messages import AIMessage, InvalidToolCall
from langchain_core.messages.tool import invalid_tool_call
from langchain_core.messages.tool import tool_call as create_tool_call
from langchain_core.messages.v1 import AIMessage as AIMessageV1
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.utils.json import parse_partial_json
@ -156,7 +157,9 @@ class JsonOutputToolsParser(BaseCumulativeTransformOutputParser[Any]):
If no tool calls are found, None will be returned.
"""
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
def parse_result(
self, result: Union[list[Generation], AIMessageV1], *, partial: bool = False
) -> Any:
"""Parse the result of an LLM call to a list of tool calls.
Args:
@ -173,31 +176,45 @@ class JsonOutputToolsParser(BaseCumulativeTransformOutputParser[Any]):
Raises:
OutputParserException: If the output is not valid JSON.
"""
generation = result[0]
if not isinstance(generation, ChatGeneration):
msg = "This output parser can only be used with a chat generation."
raise OutputParserException(msg)
message = generation.message
if isinstance(message, AIMessage) and message.tool_calls:
tool_calls = [dict(tc) for tc in message.tool_calls]
if isinstance(result, list):
generation = result[0]
if not isinstance(generation, ChatGeneration):
msg = (
"This output parser can only be used with a chat generation or "
"v1 AIMessage."
)
raise OutputParserException(msg)
message = generation.message
if isinstance(message, AIMessage) and message.tool_calls:
tool_calls = [dict(tc) for tc in message.tool_calls]
for tool_call in tool_calls:
if not self.return_id:
_ = tool_call.pop("id")
else:
try:
raw_tool_calls = copy.deepcopy(
message.additional_kwargs["tool_calls"]
)
except KeyError:
return []
tool_calls = parse_tool_calls(
raw_tool_calls,
partial=partial,
strict=self.strict,
return_id=self.return_id,
)
elif result.tool_calls:
# v1 message
tool_calls = [dict(tc) for tc in result.tool_calls]
for tool_call in tool_calls:
if not self.return_id:
_ = tool_call.pop("id")
else:
try:
raw_tool_calls = copy.deepcopy(message.additional_kwargs["tool_calls"])
except KeyError:
return []
tool_calls = parse_tool_calls(
raw_tool_calls,
partial=partial,
strict=self.strict,
return_id=self.return_id,
)
return []
# for backwards compatibility
for tc in tool_calls:
tc["type"] = tc.pop("name")
if self.first_tool_only:
return tool_calls[0] if tool_calls else None
return tool_calls
@ -220,7 +237,9 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
key_name: str
"""The type of tools to return."""
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
def parse_result(
self, result: Union[list[Generation], AIMessageV1], *, partial: bool = False
) -> Any:
"""Parse the result of an LLM call to a list of tool calls.
Args:
@ -234,32 +253,47 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
Returns:
The parsed tool calls.
"""
generation = result[0]
if not isinstance(generation, ChatGeneration):
msg = "This output parser can only be used with a chat generation."
raise OutputParserException(msg)
message = generation.message
if isinstance(message, AIMessage) and message.tool_calls:
parsed_tool_calls = [dict(tc) for tc in message.tool_calls]
if isinstance(result, list):
generation = result[0]
if not isinstance(generation, ChatGeneration):
msg = "This output parser can only be used with a chat generation."
raise OutputParserException(msg)
message = generation.message
if isinstance(message, AIMessage) and message.tool_calls:
parsed_tool_calls = [dict(tc) for tc in message.tool_calls]
for tool_call in parsed_tool_calls:
if not self.return_id:
_ = tool_call.pop("id")
else:
try:
raw_tool_calls = copy.deepcopy(
message.additional_kwargs["tool_calls"]
)
except KeyError:
if self.first_tool_only:
return None
return []
parsed_tool_calls = parse_tool_calls(
raw_tool_calls,
partial=partial,
strict=self.strict,
return_id=self.return_id,
)
elif result.tool_calls:
# v1 message
parsed_tool_calls = [dict(tc) for tc in result.tool_calls]
for tool_call in parsed_tool_calls:
if not self.return_id:
_ = tool_call.pop("id")
else:
try:
raw_tool_calls = copy.deepcopy(message.additional_kwargs["tool_calls"])
except KeyError:
if self.first_tool_only:
return None
return []
parsed_tool_calls = parse_tool_calls(
raw_tool_calls,
partial=partial,
strict=self.strict,
return_id=self.return_id,
)
if self.first_tool_only:
return None
return []
# For backwards compatibility
for tc in parsed_tool_calls:
tc["type"] = tc.pop("name")
if self.first_tool_only:
parsed_result = list(
filter(lambda x: x["type"] == self.key_name, parsed_tool_calls)
@ -299,7 +333,9 @@ class PydanticToolsParser(JsonOutputToolsParser):
# TODO: Support more granular streaming of objects. Currently only streams once all
# Pydantic object fields are present.
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
def parse_result(
self, result: Union[list[Generation], AIMessageV1], *, partial: bool = False
) -> Any:
"""Parse the result of an LLM call to a list of Pydantic objects.
Args:
@ -337,12 +373,19 @@ class PydanticToolsParser(JsonOutputToolsParser):
except (ValidationError, ValueError):
if partial:
continue
has_max_tokens_stop_reason = any(
generation.message.response_metadata.get("stop_reason")
== "max_tokens"
for generation in result
if isinstance(generation, ChatGeneration)
)
has_max_tokens_stop_reason = False
if isinstance(result, list):
has_max_tokens_stop_reason = any(
generation.message.response_metadata.get("stop_reason")
== "max_tokens"
for generation in result
if isinstance(generation, ChatGeneration)
)
else:
# v1 message
has_max_tokens_stop_reason = (
result.response_metadata.get("stop_reason") == "max_tokens"
)
if has_max_tokens_stop_reason:
logger.exception(_MAX_TOKENS_ERROR)
raise

View File

@ -1,13 +1,14 @@
"""Output parsers using Pydantic."""
import json
from typing import Annotated, Generic, Optional
from typing import Annotated, Generic, Optional, Union
import pydantic
from pydantic import SkipValidation
from typing_extensions import override
from langchain_core.exceptions import OutputParserException
from langchain_core.messages.v1 import AIMessage
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.outputs import Generation
from langchain_core.utils.pydantic import (
@ -43,7 +44,7 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
return OutputParserException(msg, llm_output=json_string)
def parse_result(
self, result: list[Generation], *, partial: bool = False
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
) -> Optional[TBaseModel]:
"""Parse the result of an LLM call to a pydantic object.

View File

@ -12,6 +12,7 @@ from typing import (
from typing_extensions import override
from langchain_core.messages import BaseMessage, BaseMessageChunk
from langchain_core.messages.v1 import AIMessage, AIMessageChunk
from langchain_core.output_parsers.base import BaseOutputParser, T
from langchain_core.outputs import (
ChatGeneration,
@ -32,23 +33,27 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
def _transform(
self,
input: Iterator[Union[str, BaseMessage]],
input: Iterator[Union[str, BaseMessage, AIMessage]],
) -> Iterator[T]:
for chunk in input:
if isinstance(chunk, BaseMessage):
yield self.parse_result([ChatGeneration(message=chunk)])
elif isinstance(chunk, AIMessage):
yield self.parse_result(chunk)
else:
yield self.parse_result([Generation(text=chunk)])
async def _atransform(
self,
input: AsyncIterator[Union[str, BaseMessage]],
input: AsyncIterator[Union[str, BaseMessage, AIMessage]],
) -> AsyncIterator[T]:
async for chunk in input:
if isinstance(chunk, BaseMessage):
yield await run_in_executor(
None, self.parse_result, [ChatGeneration(message=chunk)]
)
elif isinstance(chunk, AIMessage):
yield await run_in_executor(None, self.parse_result, chunk)
else:
yield await run_in_executor(
None, self.parse_result, [Generation(text=chunk)]
@ -57,7 +62,7 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
@override
def transform(
self,
input: Iterator[Union[str, BaseMessage]],
input: Iterator[Union[str, BaseMessage, AIMessage]],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Iterator[T]:
@ -78,7 +83,7 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
@override
async def atransform(
self,
input: AsyncIterator[Union[str, BaseMessage]],
input: AsyncIterator[Union[str, BaseMessage, AIMessage]],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> AsyncIterator[T]:
@ -125,23 +130,42 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
raise NotImplementedError
@override
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]:
def _transform(
self, input: Iterator[Union[str, BaseMessage, AIMessage]]
) -> Iterator[Any]:
prev_parsed = None
acc_gen: Union[GenerationChunk, ChatGenerationChunk, None] = None
acc_gen: Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk, None] = (
None
)
for chunk in input:
chunk_gen: Union[GenerationChunk, ChatGenerationChunk]
chunk_gen: Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
if isinstance(chunk, BaseMessageChunk):
chunk_gen = ChatGenerationChunk(message=chunk)
elif isinstance(chunk, BaseMessage):
chunk_gen = ChatGenerationChunk(
message=BaseMessageChunk(**chunk.model_dump())
)
elif isinstance(chunk, AIMessageChunk):
chunk_gen = chunk
elif isinstance(chunk, AIMessage):
chunk_gen = AIMessageChunk(
content=chunk.content,
id=chunk.id,
name=chunk.name,
lc_version=chunk.lc_version,
response_metadata=chunk.response_metadata,
usage_metadata=chunk.usage_metadata,
parsed=chunk.parsed,
)
else:
chunk_gen = GenerationChunk(text=chunk)
acc_gen = chunk_gen if acc_gen is None else acc_gen + chunk_gen # type: ignore[operator]
parsed = self.parse_result([acc_gen], partial=True)
if isinstance(acc_gen, AIMessageChunk):
parsed = self.parse_result(acc_gen, partial=True)
else:
parsed = self.parse_result([acc_gen], partial=True)
if parsed is not None and parsed != prev_parsed:
if self.diff:
yield self._diff(prev_parsed, parsed)
@ -151,24 +175,41 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
@override
async def _atransform(
self, input: AsyncIterator[Union[str, BaseMessage]]
self, input: AsyncIterator[Union[str, BaseMessage, AIMessage]]
) -> AsyncIterator[T]:
prev_parsed = None
acc_gen: Union[GenerationChunk, ChatGenerationChunk, None] = None
acc_gen: Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk, None] = (
None
)
async for chunk in input:
chunk_gen: Union[GenerationChunk, ChatGenerationChunk]
chunk_gen: Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
if isinstance(chunk, BaseMessageChunk):
chunk_gen = ChatGenerationChunk(message=chunk)
elif isinstance(chunk, BaseMessage):
chunk_gen = ChatGenerationChunk(
message=BaseMessageChunk(**chunk.model_dump())
)
elif isinstance(chunk, AIMessageChunk):
chunk_gen = chunk
elif isinstance(chunk, AIMessage):
chunk_gen = AIMessageChunk(
content=chunk.content,
id=chunk.id,
name=chunk.name,
lc_version=chunk.lc_version,
response_metadata=chunk.response_metadata,
usage_metadata=chunk.usage_metadata,
parsed=chunk.parsed,
)
else:
chunk_gen = GenerationChunk(text=chunk)
acc_gen = chunk_gen if acc_gen is None else acc_gen + chunk_gen # type: ignore[operator]
parsed = await self.aparse_result([acc_gen], partial=True)
if isinstance(acc_gen, AIMessageChunk):
parsed = await self.aparse_result(acc_gen, partial=True)
else:
parsed = await self.aparse_result([acc_gen], partial=True)
if parsed is not None and parsed != prev_parsed:
if self.diff:
yield await run_in_executor(None, self._diff, prev_parsed, parsed)

View File

@ -12,6 +12,8 @@ from typing_extensions import override
from langchain_core.exceptions import OutputParserException
from langchain_core.messages import BaseMessage
from langchain_core.messages.utils import convert_from_v1_message
from langchain_core.messages.v1 import AIMessage
from langchain_core.output_parsers.transform import BaseTransformOutputParser
from langchain_core.runnables.utils import AddableDict
@ -240,21 +242,28 @@ class XMLOutputParser(BaseTransformOutputParser):
@override
def _transform(
self, input: Iterator[Union[str, BaseMessage]]
self, input: Iterator[Union[str, BaseMessage, AIMessage]]
) -> Iterator[AddableDict]:
streaming_parser = _StreamingParser(self.parser)
for chunk in input:
yield from streaming_parser.parse(chunk)
if isinstance(chunk, AIMessage):
yield from streaming_parser.parse(convert_from_v1_message(chunk))
else:
yield from streaming_parser.parse(chunk)
streaming_parser.close()
@override
async def _atransform(
self, input: AsyncIterator[Union[str, BaseMessage]]
self, input: AsyncIterator[Union[str, BaseMessage, AIMessage]]
) -> AsyncIterator[AddableDict]:
streaming_parser = _StreamingParser(self.parser)
async for chunk in input:
for output in streaming_parser.parse(chunk):
yield output
if isinstance(chunk, AIMessage):
for output in streaming_parser.parse(convert_from_v1_message(chunk)):
yield output
else:
for output in streaming_parser.parse(chunk):
yield output
streaming_parser.close()
def _root_to_dict(self, root: ET.Element) -> dict[str, Union[str, list[Any]]]:

View File

@ -1,10 +1,14 @@
"""Module to test base parser implementations."""
from typing import Union
from typing_extensions import override
from langchain_core.exceptions import OutputParserException
from langchain_core.language_models import GenericFakeChatModel
from langchain_core.language_models.fake_chat_models import GenericFakeChatModelV1
from langchain_core.messages import AIMessage
from langchain_core.messages.v1 import AIMessage as AIMessageV1
from langchain_core.output_parsers import (
BaseGenerationOutputParser,
BaseTransformOutputParser,
@ -20,7 +24,7 @@ def test_base_generation_parser() -> None:
@override
def parse_result(
self, result: list[Generation], *, partial: bool = False
self, result: Union[list[Generation], AIMessageV1], *, partial: bool = False
) -> str:
"""Parse a list of model Generations into a specific format.
@ -32,16 +36,22 @@ def test_base_generation_parser() -> None:
partial: Whether to allow partial results. This is used for parsers
that support streaming
"""
if len(result) != 1:
msg = "This output parser can only be used with a single generation."
raise NotImplementedError(msg)
generation = result[0]
if not isinstance(generation, ChatGeneration):
# Say that this one only works with chat generations
msg = "This output parser can only be used with a chat generation."
raise OutputParserException(msg)
if isinstance(result, AIMessageV1):
content = result.text or ""
else:
if len(result) != 1:
msg = (
"This output parser can only be used with a single generation."
)
raise NotImplementedError(msg)
generation = result[0]
if not isinstance(generation, ChatGeneration):
# Say that this one only works with chat generations
msg = "This output parser can only be used with a chat generation."
raise OutputParserException(msg)
assert isinstance(generation.message.content, str)
content = generation.message.content
content = generation.message.content
assert isinstance(content, str)
return content.swapcase()
@ -49,6 +59,10 @@ def test_base_generation_parser() -> None:
chain = model | StrInvertCase()
assert chain.invoke("") == "HeLLO"
model_v1 = GenericFakeChatModelV1(messages=iter([AIMessageV1("hEllo")]))
chain_v1 = model_v1 | StrInvertCase()
assert chain_v1.invoke("") == "HeLLO"
def test_base_transform_output_parser() -> None:
"""Test base transform output parser."""
@ -62,7 +76,7 @@ def test_base_transform_output_parser() -> None:
@override
def parse_result(
self, result: list[Generation], *, partial: bool = False
self, result: Union[list[Generation], AIMessageV1], *, partial: bool = False
) -> str:
"""Parse a list of model Generations into a specific format.
@ -74,15 +88,22 @@ def test_base_transform_output_parser() -> None:
partial: Whether to allow partial results. This is used for parsers
that support streaming
"""
if len(result) != 1:
msg = "This output parser can only be used with a single generation."
raise NotImplementedError(msg)
generation = result[0]
if not isinstance(generation, ChatGeneration):
# Say that this one only works with chat generations
msg = "This output parser can only be used with a chat generation."
raise OutputParserException(msg)
content = generation.message.content
if isinstance(result, AIMessageV1):
content = result.text or ""
else:
if len(result) != 1:
msg = (
"This output parser can only be used with a single generation."
)
raise NotImplementedError(msg)
generation = result[0]
if not isinstance(generation, ChatGeneration):
# Say that this one only works with chat generations
msg = "This output parser can only be used with a chat generation."
raise OutputParserException(msg)
assert isinstance(generation.message.content, str)
content = generation.message.content
assert isinstance(content, str)
return content.swapcase()
@ -91,3 +112,8 @@ def test_base_transform_output_parser() -> None:
# inputs to models are ignored, response is hard-coded in model definition
chunks = list(chain.stream(""))
assert chunks == ["HELLO", " ", "WORLD"]
model_v1 = GenericFakeChatModelV1(message_chunks=["hello", " ", "world"])
chain_v1 = model_v1 | StrInvertCase()
chunks = list(chain_v1.stream(""))
assert chunks == ["HELLO", " ", "WORLD"]

View File

@ -10,6 +10,8 @@ from langchain_core.messages import (
BaseMessage,
ToolCallChunk,
)
from langchain_core.messages.v1 import AIMessage as AIMessageV1
from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1
from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser,
JsonOutputToolsParser,
@ -331,6 +333,14 @@ for message in STREAMED_MESSAGES:
STREAMED_MESSAGES_WITH_TOOL_CALLS.append(message)
STREAMED_MESSAGES_V1 = [
AIMessageChunkV1(
content=[],
tool_call_chunks=chunk.tool_call_chunks,
)
for chunk in STREAMED_MESSAGES_WITH_TOOL_CALLS
]
EXPECTED_STREAMED_JSON = [
{},
{"names": ["suz"]},
@ -398,6 +408,19 @@ def test_partial_json_output_parser(*, use_tool_calls: bool) -> None:
assert actual == expected
def test_partial_json_output_parser_v1() -> None:
def input_iter(_: Any) -> Iterator[AIMessageChunkV1]:
yield from STREAMED_MESSAGES_V1
chain = input_iter | JsonOutputToolsParser()
actual = list(chain.stream(None))
expected: list = [[]] + [
[{"type": "NameCollector", "args": chunk}] for chunk in EXPECTED_STREAMED_JSON
]
assert actual == expected
@pytest.mark.parametrize("use_tool_calls", [False, True])
async def test_partial_json_output_parser_async(*, use_tool_calls: bool) -> None:
input_iter = _get_aiter(use_tool_calls=use_tool_calls)
@ -410,6 +433,20 @@ async def test_partial_json_output_parser_async(*, use_tool_calls: bool) -> None
assert actual == expected
async def test_partial_json_output_parser_async_v1() -> None:
async def input_iter(_: Any) -> AsyncIterator[AIMessageChunkV1]:
for msg in STREAMED_MESSAGES_V1:
yield msg
chain = input_iter | JsonOutputToolsParser()
actual = [p async for p in chain.astream(None)]
expected: list = [[]] + [
[{"type": "NameCollector", "args": chunk}] for chunk in EXPECTED_STREAMED_JSON
]
assert actual == expected
@pytest.mark.parametrize("use_tool_calls", [False, True])
def test_partial_json_output_parser_return_id(*, use_tool_calls: bool) -> None:
input_iter = _get_iter(use_tool_calls=use_tool_calls)
@ -429,6 +466,26 @@ def test_partial_json_output_parser_return_id(*, use_tool_calls: bool) -> None:
assert actual == expected
def test_partial_json_output_parser_return_id_v1() -> None:
def input_iter(_: Any) -> Iterator[AIMessageChunkV1]:
yield from STREAMED_MESSAGES_V1
chain = input_iter | JsonOutputToolsParser(return_id=True)
actual = list(chain.stream(None))
expected: list = [[]] + [
[
{
"type": "NameCollector",
"args": chunk,
"id": "call_OwL7f5PEPJTYzw9sQlNJtCZl",
}
]
for chunk in EXPECTED_STREAMED_JSON
]
assert actual == expected
@pytest.mark.parametrize("use_tool_calls", [False, True])
def test_partial_json_output_key_parser(*, use_tool_calls: bool) -> None:
input_iter = _get_iter(use_tool_calls=use_tool_calls)
@ -439,6 +496,17 @@ def test_partial_json_output_key_parser(*, use_tool_calls: bool) -> None:
assert actual == expected
def test_partial_json_output_key_parser_v1() -> None:
def input_iter(_: Any) -> Iterator[AIMessageChunkV1]:
yield from STREAMED_MESSAGES_V1
chain = input_iter | JsonOutputKeyToolsParser(key_name="NameCollector")
actual = list(chain.stream(None))
expected: list = [[]] + [[chunk] for chunk in EXPECTED_STREAMED_JSON]
assert actual == expected
@pytest.mark.parametrize("use_tool_calls", [False, True])
async def test_partial_json_output_parser_key_async(*, use_tool_calls: bool) -> None:
input_iter = _get_aiter(use_tool_calls=use_tool_calls)
@ -450,6 +518,18 @@ async def test_partial_json_output_parser_key_async(*, use_tool_calls: bool) ->
assert actual == expected
async def test_partial_json_output_parser_key_async_v1() -> None:
async def input_iter(_: Any) -> AsyncIterator[AIMessageChunkV1]:
for msg in STREAMED_MESSAGES_V1:
yield msg
chain = input_iter | JsonOutputKeyToolsParser(key_name="NameCollector")
actual = [p async for p in chain.astream(None)]
expected: list = [[]] + [[chunk] for chunk in EXPECTED_STREAMED_JSON]
assert actual == expected
@pytest.mark.parametrize("use_tool_calls", [False, True])
def test_partial_json_output_key_parser_first_only(*, use_tool_calls: bool) -> None:
input_iter = _get_iter(use_tool_calls=use_tool_calls)
@ -461,6 +541,17 @@ def test_partial_json_output_key_parser_first_only(*, use_tool_calls: bool) -> N
assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON
def test_partial_json_output_key_parser_first_only_v1() -> None:
def input_iter(_: Any) -> Iterator[AIMessageChunkV1]:
yield from STREAMED_MESSAGES_V1
chain = input_iter | JsonOutputKeyToolsParser(
key_name="NameCollector", first_tool_only=True
)
assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON
@pytest.mark.parametrize("use_tool_calls", [False, True])
async def test_partial_json_output_parser_key_async_first_only(
*,
@ -475,6 +566,18 @@ async def test_partial_json_output_parser_key_async_first_only(
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON
async def test_partial_json_output_parser_key_async_first_only_v1() -> None:
async def input_iter(_: Any) -> AsyncIterator[AIMessageChunkV1]:
for msg in STREAMED_MESSAGES_V1:
yield msg
chain = input_iter | JsonOutputKeyToolsParser(
key_name="NameCollector", first_tool_only=True
)
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON
@pytest.mark.parametrize("use_tool_calls", [False, True])
def test_json_output_key_tools_parser_multiple_tools_first_only(
*, use_tool_calls: bool
@ -531,6 +634,42 @@ def test_json_output_key_tools_parser_multiple_tools_first_only(
assert output_no_id == {"a": 1}
def test_json_output_key_tools_parser_multiple_tools_first_only_v1() -> None:
message = AIMessageV1(
content=[],
tool_calls=[
{
"type": "tool_call",
"id": "call_other",
"name": "other",
"args": {"b": 2},
},
{"type": "tool_call", "id": "call_func", "name": "func", "args": {"a": 1}},
],
)
# Test with return_id=True
parser = JsonOutputKeyToolsParser(
key_name="func", first_tool_only=True, return_id=True
)
output = parser.parse_result(message)
# Should return the func tool call, not None
assert output is not None
assert output["type"] == "func"
assert output["args"] == {"a": 1}
assert "id" in output
# Test with return_id=False
parser_no_id = JsonOutputKeyToolsParser(
key_name="func", first_tool_only=True, return_id=False
)
output_no_id = parser_no_id.parse_result(message)
# Should return just the args
assert output_no_id == {"a": 1}
@pytest.mark.parametrize("use_tool_calls", [False, True])
def test_json_output_key_tools_parser_multiple_tools_no_match(
*, use_tool_calls: bool
@ -583,6 +722,44 @@ def test_json_output_key_tools_parser_multiple_tools_no_match(
assert output_no_id is None
def test_json_output_key_tools_parser_multiple_tools_no_match_v1() -> None:
message = AIMessageV1(
content=[],
tool_calls=[
{
"type": "tool_call",
"id": "call_other",
"name": "other",
"args": {"b": 2},
},
{
"type": "tool_call",
"id": "call_another",
"name": "another",
"args": {"c": 3},
},
],
)
# Test with return_id=True, first_tool_only=True
parser = JsonOutputKeyToolsParser(
key_name="nonexistent", first_tool_only=True, return_id=True
)
output = parser.parse_result(message)
# Should return None when no matches
assert output is None
# Test with return_id=False, first_tool_only=True
parser_no_id = JsonOutputKeyToolsParser(
key_name="nonexistent", first_tool_only=True, return_id=False
)
output_no_id = parser_no_id.parse_result(message)
# Should return None when no matches
assert output_no_id is None
@pytest.mark.parametrize("use_tool_calls", [False, True])
def test_json_output_key_tools_parser_multiple_matching_tools(
*, use_tool_calls: bool
@ -643,6 +820,42 @@ def test_json_output_key_tools_parser_multiple_matching_tools(
assert output_all[1]["args"] == {"a": 3}
def test_json_output_key_tools_parser_multiple_matching_tools_v1() -> None:
message = AIMessageV1(
content=[],
tool_calls=[
{"type": "tool_call", "id": "call_func1", "name": "func", "args": {"a": 1}},
{
"type": "tool_call",
"id": "call_other",
"name": "other",
"args": {"b": 2},
},
{"type": "tool_call", "id": "call_func2", "name": "func", "args": {"a": 3}},
],
)
# Test with first_tool_only=True - should return first matching
parser = JsonOutputKeyToolsParser(
key_name="func", first_tool_only=True, return_id=True
)
output = parser.parse_result(message)
assert output is not None
assert output["type"] == "func"
assert output["args"] == {"a": 1} # First matching tool call
# Test with first_tool_only=False - should return all matching
parser_all = JsonOutputKeyToolsParser(
key_name="func", first_tool_only=False, return_id=True
)
output_all = parser_all.parse_result(message)
assert len(output_all) == 2
assert output_all[0]["args"] == {"a": 1}
assert output_all[1]["args"] == {"a": 3}
@pytest.mark.parametrize("use_tool_calls", [False, True])
def test_json_output_key_tools_parser_empty_results(*, use_tool_calls: bool) -> None:
def create_message() -> AIMessage:
@ -671,6 +884,35 @@ def test_json_output_key_tools_parser_empty_results(*, use_tool_calls: bool) ->
assert output_all == []
@pytest.mark.parametrize(
"empty_message",
[
AIMessageV1(content=[], tool_calls=[]),
AIMessageV1(content="", tool_calls=[]),
],
)
def test_json_output_key_tools_parser_empty_results_v1(
empty_message: AIMessageV1,
) -> None:
# Test with first_tool_only=True
parser = JsonOutputKeyToolsParser(
key_name="func", first_tool_only=True, return_id=True
)
output = parser.parse_result(empty_message)
# Should return None for empty results
assert output is None
# Test with first_tool_only=False
parser_all = JsonOutputKeyToolsParser(
key_name="func", first_tool_only=False, return_id=True
)
output_all = parser_all.parse_result(empty_message)
# Should return empty list for empty results
assert output_all == []
@pytest.mark.parametrize("use_tool_calls", [False, True])
def test_json_output_key_tools_parser_parameter_combinations(
*, use_tool_calls: bool
@ -746,6 +988,56 @@ def test_json_output_key_tools_parser_parameter_combinations(
assert output4 == [{"a": 1}, {"a": 3}]
def test_json_output_key_tools_parser_parameter_combinations_v1() -> None:
"""Test all parameter combinations of JsonOutputKeyToolsParser."""
result = AIMessageV1(
content=[],
tool_calls=[
{
"type": "tool_call",
"id": "call_other",
"name": "other",
"args": {"b": 2},
},
{"type": "tool_call", "id": "call_func1", "name": "func", "args": {"a": 1}},
{"type": "tool_call", "id": "call_func2", "name": "func", "args": {"a": 3}},
],
)
# Test: first_tool_only=True, return_id=True
parser1 = JsonOutputKeyToolsParser(
key_name="func", first_tool_only=True, return_id=True
)
output1 = parser1.parse_result(result)
assert output1["type"] == "func"
assert output1["args"] == {"a": 1}
assert "id" in output1
# Test: first_tool_only=True, return_id=False
parser2 = JsonOutputKeyToolsParser(
key_name="func", first_tool_only=True, return_id=False
)
output2 = parser2.parse_result(result)
assert output2 == {"a": 1}
# Test: first_tool_only=False, return_id=True
parser3 = JsonOutputKeyToolsParser(
key_name="func", first_tool_only=False, return_id=True
)
output3 = parser3.parse_result(result)
assert len(output3) == 2
assert all("id" in item for item in output3)
assert output3[0]["args"] == {"a": 1}
assert output3[1]["args"] == {"a": 3}
# Test: first_tool_only=False, return_id=False
parser4 = JsonOutputKeyToolsParser(
key_name="func", first_tool_only=False, return_id=False
)
output4 = parser4.parse_result(result)
assert output4 == [{"a": 1}, {"a": 3}]
class Person(BaseModel):
age: int
hair_color: str
@ -788,6 +1080,18 @@ def test_partial_pydantic_output_parser() -> None:
assert actual == EXPECTED_STREAMED_PYDANTIC
def test_partial_pydantic_output_parser_v1() -> None:
def input_iter(_: Any) -> Iterator[AIMessageChunkV1]:
yield from STREAMED_MESSAGES_V1
chain = input_iter | PydanticToolsParser(
tools=[NameCollector], first_tool_only=True
)
actual = list(chain.stream(None))
assert actual == EXPECTED_STREAMED_PYDANTIC
async def test_partial_pydantic_output_parser_async() -> None:
for use_tool_calls in [False, True]:
input_iter = _get_aiter(use_tool_calls=use_tool_calls)
@ -800,6 +1104,19 @@ async def test_partial_pydantic_output_parser_async() -> None:
assert actual == EXPECTED_STREAMED_PYDANTIC
async def test_partial_pydantic_output_parser_async_v1() -> None:
async def input_iter(_: Any) -> AsyncIterator[AIMessageChunkV1]:
for msg in STREAMED_MESSAGES_V1:
yield msg
chain = input_iter | PydanticToolsParser(
tools=[NameCollector], first_tool_only=True
)
actual = [p async for p in chain.astream(None)]
assert actual == EXPECTED_STREAMED_PYDANTIC
def test_parse_with_different_pydantic_2_v1() -> None:
"""Test with pydantic.v1.BaseModel from pydantic 2."""
import pydantic
@ -870,20 +1187,22 @@ def test_parse_with_different_pydantic_2_proper() -> None:
def test_max_tokens_error(caplog: Any) -> None:
parser = PydanticToolsParser(tools=[NameCollector], first_tool_only=True)
message = AIMessage(
content="",
tool_calls=[
{
"id": "call_OwL7f5PE",
"name": "NameCollector",
"args": {"names": ["suz", "jerm"]},
}
],
response_metadata={"stop_reason": "max_tokens"},
)
with pytest.raises(ValidationError):
_ = parser.invoke(message)
assert any(
"`max_tokens` stop reason" in msg and record.levelname == "ERROR"
for record, msg in zip(caplog.records, caplog.messages)
)
for msg_class in [AIMessage, AIMessageV1]:
message = msg_class(
content="",
tool_calls=[
{
"type": "tool_call",
"id": "call_OwL7f5PE",
"name": "NameCollector",
"args": {"names": ["suz", "jerm"]},
}
],
response_metadata={"stop_reason": "max_tokens"},
)
with pytest.raises(ValidationError):
_ = parser.invoke(message)
assert any(
"`max_tokens` stop reason" in msg and record.levelname == "ERROR"
for record, msg in zip(caplog.records, caplog.messages)
)

View File

@ -8,6 +8,8 @@ from langchain_core.messages import (
AIMessage,
BaseMessage,
)
from langchain_core.messages.utils import convert_from_v1_message
from langchain_core.messages.v1 import AIMessage as AIMessageV1
from langchain_core.outputs import ChatGeneration, Generation
from typing_extensions import override
@ -83,10 +85,12 @@ class OpenAIFunctionsAgentOutputParser(AgentOutputParser):
@override
def parse_result(
self,
result: list[Generation],
result: Union[list[Generation], AIMessageV1],
*,
partial: bool = False,
) -> Union[AgentAction, AgentFinish]:
if isinstance(result, AIMessageV1):
result = [ChatGeneration(message=convert_from_v1_message(result))]
if not isinstance(result[0], ChatGeneration):
msg = "This output parser only works on ChatGeneration output"
raise ValueError(msg) # noqa: TRY004

View File

@ -2,6 +2,8 @@ from typing import Union
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.messages import BaseMessage
from langchain_core.messages.utils import convert_from_v1_message
from langchain_core.messages.v1 import AIMessage as AIMessageV1
from langchain_core.outputs import ChatGeneration, Generation
from typing_extensions import override
@ -57,10 +59,12 @@ class OpenAIToolsAgentOutputParser(MultiActionAgentOutputParser):
@override
def parse_result(
self,
result: list[Generation],
result: Union[list[Generation], AIMessageV1],
*,
partial: bool = False,
) -> Union[list[AgentAction], AgentFinish]:
if isinstance(result, AIMessageV1):
result = [ChatGeneration(message=convert_from_v1_message(result))]
if not isinstance(result[0], ChatGeneration):
msg = "This output parser only works on ChatGeneration output"
raise ValueError(msg) # noqa: TRY004

View File

@ -9,6 +9,8 @@ from langchain_core.messages import (
BaseMessage,
ToolCall,
)
from langchain_core.messages.utils import convert_from_v1_message
from langchain_core.messages.v1 import AIMessage as AIMessageV1
from langchain_core.outputs import ChatGeneration, Generation
from typing_extensions import override
@ -101,10 +103,12 @@ class ToolsAgentOutputParser(MultiActionAgentOutputParser):
@override
def parse_result(
self,
result: list[Generation],
result: Union[list[Generation], AIMessageV1],
*,
partial: bool = False,
) -> Union[list[AgentAction], AgentFinish]:
if isinstance(result, AIMessageV1):
result = [ChatGeneration(message=convert_from_v1_message(result))]
if not isinstance(result[0], ChatGeneration):
msg = "This output parser only works on ChatGeneration output"
raise ValueError(msg) # noqa: TRY004