Merge branch 'standard_outputs_copy' into mdrxy/ollama_v1

This commit is contained in:
Mason Daugherty 2025-07-30 17:41:16 -04:00 committed by GitHub
commit 588fe46601
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 4011 additions and 1129 deletions

View File

@ -3,7 +3,7 @@
import asyncio import asyncio
import re import re
import time import time
from collections.abc import AsyncIterator, Iterator from collections.abc import AsyncIterator, Iterable, Iterator
from typing import Any, Optional, Union, cast from typing import Any, Optional, Union, cast
from typing_extensions import override from typing_extensions import override
@ -13,7 +13,11 @@ from langchain_core.callbacks import (
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
) )
from langchain_core.language_models.chat_models import BaseChatModel, SimpleChatModel from langchain_core.language_models.chat_models import BaseChatModel, SimpleChatModel
from langchain_core.language_models.v1.chat_models import BaseChatModelV1
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.messages.v1 import AIMessage as AIMessageV1
from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1
from langchain_core.messages.v1 import MessageV1
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import RunnableConfig from langchain_core.runnables import RunnableConfig
@ -367,3 +371,69 @@ class ParrotFakeChatModel(BaseChatModel):
@property @property
def _llm_type(self) -> str: def _llm_type(self) -> str:
return "parrot-fake-chat-model" return "parrot-fake-chat-model"
class GenericFakeChatModelV1(BaseChatModelV1):
"""Generic fake chat model that can be used to test the chat model interface."""
messages: Optional[Iterator[Union[AIMessageV1, str]]] = None
message_chunks: Optional[Iterable[Union[AIMessageChunkV1, str]]] = None
@override
def _invoke(
self,
messages: list[MessageV1],
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AIMessageV1:
"""Top Level call."""
if self.messages is None:
error_msg = "Messages iterator is not set."
raise ValueError(error_msg)
message = next(self.messages)
return AIMessageV1(content=message) if isinstance(message, str) else message
@override
def _stream(
self,
messages: list[MessageV1],
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[AIMessageChunkV1]:
"""Top Level call."""
if self.message_chunks is None:
error_msg = "Message chunks iterator is not set."
raise ValueError(error_msg)
for chunk in self.message_chunks:
if isinstance(chunk, str):
yield AIMessageChunkV1(chunk)
else:
yield chunk
@property
def _llm_type(self) -> str:
return "generic-fake-chat-model"
class ParrotFakeChatModelV1(BaseChatModelV1):
"""Generic fake chat model that can be used to test the chat model interface.
* Chat model should be usable in both sync and async tests
"""
@override
def _invoke(
self,
messages: list[MessageV1],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AIMessageV1:
"""Top Level call."""
if isinstance(messages[-1], AIMessageV1):
return messages[-1]
return AIMessageV1(content=messages[-1].content)
@property
def _llm_type(self) -> str:
return "parrot-fake-chat-model"

View File

@ -3,9 +3,10 @@
import json import json
import uuid import uuid
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Literal, Optional, TypedDict, Union, cast, get_args from typing import Any, Literal, Optional, Union, cast, get_args
from pydantic import BaseModel from pydantic import BaseModel
from typing_extensions import TypedDict
import langchain_core.messages.content_blocks as types import langchain_core.messages.content_blocks as types
from langchain_core.messages.ai import _LC_ID_PREFIX, UsageMetadata, add_usage from langchain_core.messages.ai import _LC_ID_PREFIX, UsageMetadata, add_usage
@ -329,7 +330,11 @@ class AIMessageChunk(AIMessage):
@property @property
def reasoning(self) -> Optional[str]: def reasoning(self) -> Optional[str]:
"""Extract all reasoning text from the AI message as a string.""" """Extract all reasoning text from the AI message as a string."""
text_blocks = [block for block in self.content if block["type"] == "reasoning"] text_blocks = [
block
for block in self.content
if block["type"] == "reasoning" and "reasoning" in block
]
if text_blocks: if text_blocks:
return "".join(block["reasoning"] for block in text_blocks) return "".join(block["reasoning"] for block in text_blocks)
return None return None

View File

@ -11,12 +11,14 @@ from typing import (
Optional, Optional,
TypeVar, TypeVar,
Union, Union,
cast,
) )
from typing_extensions import override from typing_extensions import override
from langchain_core.language_models import LanguageModelOutput from langchain_core.language_models import LanguageModelOutput
from langchain_core.messages import AnyMessage, BaseMessage from langchain_core.messages import AnyMessage, BaseMessage
from langchain_core.messages.v1 import AIMessage, MessageV1, MessageV1Types
from langchain_core.outputs import ChatGeneration, Generation from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.runnables import Runnable, RunnableConfig, RunnableSerializable from langchain_core.runnables import Runnable, RunnableConfig, RunnableSerializable
from langchain_core.runnables.config import run_in_executor from langchain_core.runnables.config import run_in_executor
@ -25,14 +27,16 @@ if TYPE_CHECKING:
from langchain_core.prompt_values import PromptValue from langchain_core.prompt_values import PromptValue
T = TypeVar("T") T = TypeVar("T")
OutputParserLike = Runnable[LanguageModelOutput, T] OutputParserLike = Runnable[Union[LanguageModelOutput, AIMessage], T]
class BaseLLMOutputParser(ABC, Generic[T]): class BaseLLMOutputParser(ABC, Generic[T]):
"""Abstract base class for parsing the outputs of a model.""" """Abstract base class for parsing the outputs of a model."""
@abstractmethod @abstractmethod
def parse_result(self, result: list[Generation], *, partial: bool = False) -> T: def parse_result(
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
) -> T:
"""Parse a list of candidate model Generations into a specific format. """Parse a list of candidate model Generations into a specific format.
Args: Args:
@ -46,7 +50,7 @@ class BaseLLMOutputParser(ABC, Generic[T]):
""" """
async def aparse_result( async def aparse_result(
self, result: list[Generation], *, partial: bool = False self, result: Union[list[Generation], AIMessage], *, partial: bool = False
) -> T: ) -> T:
"""Async parse a list of candidate model Generations into a specific format. """Async parse a list of candidate model Generations into a specific format.
@ -71,7 +75,7 @@ class BaseGenerationOutputParser(
@override @override
def InputType(self) -> Any: def InputType(self) -> Any:
"""Return the input type for the parser.""" """Return the input type for the parser."""
return Union[str, AnyMessage] return Union[str, AnyMessage, MessageV1]
@property @property
@override @override
@ -84,7 +88,7 @@ class BaseGenerationOutputParser(
@override @override
def invoke( def invoke(
self, self,
input: Union[str, BaseMessage], input: Union[str, BaseMessage, MessageV1],
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Any, **kwargs: Any,
) -> T: ) -> T:
@ -97,9 +101,16 @@ class BaseGenerationOutputParser(
config, config,
run_type="parser", run_type="parser",
) )
if isinstance(input, MessageV1Types):
return self._call_with_config(
lambda inner_input: self.parse_result(inner_input),
input,
config,
run_type="parser",
)
return self._call_with_config( return self._call_with_config(
lambda inner_input: self.parse_result([Generation(text=inner_input)]), lambda inner_input: self.parse_result([Generation(text=inner_input)]),
input, cast("str", input),
config, config,
run_type="parser", run_type="parser",
) )
@ -120,6 +131,13 @@ class BaseGenerationOutputParser(
config, config,
run_type="parser", run_type="parser",
) )
if isinstance(input, MessageV1Types):
return await self._acall_with_config(
lambda inner_input: self.aparse_result(inner_input),
input,
config,
run_type="parser",
)
return await self._acall_with_config( return await self._acall_with_config(
lambda inner_input: self.aparse_result([Generation(text=inner_input)]), lambda inner_input: self.aparse_result([Generation(text=inner_input)]),
input, input,
@ -129,7 +147,7 @@ class BaseGenerationOutputParser(
class BaseOutputParser( class BaseOutputParser(
BaseLLMOutputParser, RunnableSerializable[LanguageModelOutput, T] BaseLLMOutputParser, RunnableSerializable[Union[LanguageModelOutput, AIMessage], T]
): ):
"""Base class to parse the output of an LLM call. """Base class to parse the output of an LLM call.
@ -162,7 +180,7 @@ class BaseOutputParser(
@override @override
def InputType(self) -> Any: def InputType(self) -> Any:
"""Return the input type for the parser.""" """Return the input type for the parser."""
return Union[str, AnyMessage] return Union[str, AnyMessage, MessageV1]
@property @property
@override @override
@ -189,7 +207,7 @@ class BaseOutputParser(
@override @override
def invoke( def invoke(
self, self,
input: Union[str, BaseMessage], input: Union[str, BaseMessage, MessageV1],
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Any, **kwargs: Any,
) -> T: ) -> T:
@ -202,9 +220,16 @@ class BaseOutputParser(
config, config,
run_type="parser", run_type="parser",
) )
if isinstance(input, MessageV1Types):
return self._call_with_config(
lambda inner_input: self.parse_result(inner_input),
input,
config,
run_type="parser",
)
return self._call_with_config( return self._call_with_config(
lambda inner_input: self.parse_result([Generation(text=inner_input)]), lambda inner_input: self.parse_result([Generation(text=inner_input)]),
input, cast("str", input),
config, config,
run_type="parser", run_type="parser",
) )
@ -212,7 +237,7 @@ class BaseOutputParser(
@override @override
async def ainvoke( async def ainvoke(
self, self,
input: Union[str, BaseMessage], input: Union[str, BaseMessage, MessageV1],
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> T: ) -> T:
@ -225,15 +250,24 @@ class BaseOutputParser(
config, config,
run_type="parser", run_type="parser",
) )
if isinstance(input, MessageV1Types):
return await self._acall_with_config(
lambda inner_input: self.aparse_result(inner_input),
input,
config,
run_type="parser",
)
return await self._acall_with_config( return await self._acall_with_config(
lambda inner_input: self.aparse_result([Generation(text=inner_input)]), lambda inner_input: self.aparse_result([Generation(text=inner_input)]),
input, cast("str", input),
config, config,
run_type="parser", run_type="parser",
) )
@override @override
def parse_result(self, result: list[Generation], *, partial: bool = False) -> T: def parse_result(
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
) -> T:
"""Parse a list of candidate model Generations into a specific format. """Parse a list of candidate model Generations into a specific format.
The return value is parsed from only the first Generation in the result, which The return value is parsed from only the first Generation in the result, which
@ -248,6 +282,8 @@ class BaseOutputParser(
Returns: Returns:
Structured output. Structured output.
""" """
if isinstance(result, AIMessage):
return self.parse(result.text or "")
return self.parse(result[0].text) return self.parse(result[0].text)
@abstractmethod @abstractmethod
@ -262,7 +298,7 @@ class BaseOutputParser(
""" """
async def aparse_result( async def aparse_result(
self, result: list[Generation], *, partial: bool = False self, result: Union[list[Generation], AIMessage], *, partial: bool = False
) -> T: ) -> T:
"""Async parse a list of candidate model Generations into a specific format. """Async parse a list of candidate model Generations into a specific format.

View File

@ -13,6 +13,7 @@ from pydantic.v1 import BaseModel
from typing_extensions import override from typing_extensions import override
from langchain_core.exceptions import OutputParserException from langchain_core.exceptions import OutputParserException
from langchain_core.messages.v1 import AIMessage
from langchain_core.output_parsers.format_instructions import JSON_FORMAT_INSTRUCTIONS from langchain_core.output_parsers.format_instructions import JSON_FORMAT_INSTRUCTIONS
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser
from langchain_core.outputs import Generation from langchain_core.outputs import Generation
@ -53,7 +54,9 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
return pydantic_object.schema() return pydantic_object.schema()
return None return None
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any: def parse_result(
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
) -> Any:
"""Parse the result of an LLM call to a JSON object. """Parse the result of an LLM call to a JSON object.
Args: Args:
@ -70,7 +73,7 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
Raises: Raises:
OutputParserException: If the output is not valid JSON. OutputParserException: If the output is not valid JSON.
""" """
text = result[0].text text = result.text or "" if isinstance(result, AIMessage) else result[0].text
text = text.strip() text = text.strip()
if partial: if partial:
try: try:

View File

@ -12,6 +12,7 @@ from typing import TYPE_CHECKING, TypeVar, Union
from typing_extensions import override from typing_extensions import override
from langchain_core.messages import BaseMessage from langchain_core.messages import BaseMessage
from langchain_core.messages.v1 import AIMessage
from langchain_core.output_parsers.transform import BaseTransformOutputParser from langchain_core.output_parsers.transform import BaseTransformOutputParser
if TYPE_CHECKING: if TYPE_CHECKING:
@ -71,7 +72,7 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]):
@override @override
def _transform( def _transform(
self, input: Iterator[Union[str, BaseMessage]] self, input: Iterator[Union[str, BaseMessage, AIMessage]]
) -> Iterator[list[str]]: ) -> Iterator[list[str]]:
buffer = "" buffer = ""
for chunk in input: for chunk in input:
@ -81,6 +82,8 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]):
if not isinstance(chunk_content, str): if not isinstance(chunk_content, str):
continue continue
buffer += chunk_content buffer += chunk_content
elif isinstance(chunk, AIMessage):
buffer += chunk.text or ""
else: else:
# add current chunk to buffer # add current chunk to buffer
buffer += chunk buffer += chunk
@ -105,7 +108,7 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]):
@override @override
async def _atransform( async def _atransform(
self, input: AsyncIterator[Union[str, BaseMessage]] self, input: AsyncIterator[Union[str, BaseMessage, AIMessage]]
) -> AsyncIterator[list[str]]: ) -> AsyncIterator[list[str]]:
buffer = "" buffer = ""
async for chunk in input: async for chunk in input:
@ -115,6 +118,8 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]):
if not isinstance(chunk_content, str): if not isinstance(chunk_content, str):
continue continue
buffer += chunk_content buffer += chunk_content
elif isinstance(chunk, AIMessage):
buffer += chunk.text or ""
else: else:
# add current chunk to buffer # add current chunk to buffer
buffer += chunk buffer += chunk

View File

@ -11,6 +11,7 @@ from pydantic.v1 import BaseModel as BaseModelV1
from typing_extensions import override from typing_extensions import override
from langchain_core.exceptions import OutputParserException from langchain_core.exceptions import OutputParserException
from langchain_core.messages.v1 import AIMessage
from langchain_core.output_parsers import ( from langchain_core.output_parsers import (
BaseCumulativeTransformOutputParser, BaseCumulativeTransformOutputParser,
BaseGenerationOutputParser, BaseGenerationOutputParser,
@ -26,7 +27,9 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
"""Whether to only return the arguments to the function call.""" """Whether to only return the arguments to the function call."""
@override @override
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any: def parse_result(
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
) -> Any:
"""Parse the result of an LLM call to a JSON object. """Parse the result of an LLM call to a JSON object.
Args: Args:
@ -39,6 +42,12 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
Raises: Raises:
OutputParserException: If the output is not valid JSON. OutputParserException: If the output is not valid JSON.
""" """
if isinstance(result, AIMessage):
msg = (
"This output parser does not support v1 AIMessages. Use "
"JsonOutputToolsParser instead."
)
raise TypeError(msg)
generation = result[0] generation = result[0]
if not isinstance(generation, ChatGeneration): if not isinstance(generation, ChatGeneration):
msg = "This output parser can only be used with a chat generation." msg = "This output parser can only be used with a chat generation."
@ -77,7 +86,9 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
def _diff(self, prev: Optional[Any], next: Any) -> Any: def _diff(self, prev: Optional[Any], next: Any) -> Any:
return jsonpatch.make_patch(prev, next).patch return jsonpatch.make_patch(prev, next).patch
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any: def parse_result(
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
) -> Any:
"""Parse the result of an LLM call to a JSON object. """Parse the result of an LLM call to a JSON object.
Args: Args:
@ -90,6 +101,12 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
Raises: Raises:
OutputParserException: If the output is not valid JSON. OutputParserException: If the output is not valid JSON.
""" """
if isinstance(result, AIMessage):
msg = (
"This output parser does not support v1 AIMessages. Use "
"JsonOutputToolsParser instead."
)
raise TypeError(msg)
if len(result) != 1: if len(result) != 1:
msg = f"Expected exactly one result, but got {len(result)}" msg = f"Expected exactly one result, but got {len(result)}"
raise OutputParserException(msg) raise OutputParserException(msg)
@ -160,7 +177,9 @@ class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):
key_name: str key_name: str
"""The name of the key to return.""" """The name of the key to return."""
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any: def parse_result(
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
) -> Any:
"""Parse the result of an LLM call to a JSON object. """Parse the result of an LLM call to a JSON object.
Args: Args:
@ -254,7 +273,9 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
return values return values
@override @override
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any: def parse_result(
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
) -> Any:
"""Parse the result of an LLM call to a JSON object. """Parse the result of an LLM call to a JSON object.
Args: Args:
@ -294,7 +315,9 @@ class PydanticAttrOutputFunctionsParser(PydanticOutputFunctionsParser):
"""The name of the attribute to return.""" """The name of the attribute to return."""
@override @override
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any: def parse_result(
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
) -> Any:
"""Parse the result of an LLM call to a JSON object. """Parse the result of an LLM call to a JSON object.
Args: Args:

View File

@ -4,7 +4,7 @@ import copy
import json import json
import logging import logging
from json import JSONDecodeError from json import JSONDecodeError
from typing import Annotated, Any, Optional from typing import Annotated, Any, Optional, Union
from pydantic import SkipValidation, ValidationError from pydantic import SkipValidation, ValidationError
@ -12,6 +12,7 @@ from langchain_core.exceptions import OutputParserException
from langchain_core.messages import AIMessage, InvalidToolCall from langchain_core.messages import AIMessage, InvalidToolCall
from langchain_core.messages.tool import invalid_tool_call from langchain_core.messages.tool import invalid_tool_call
from langchain_core.messages.tool import tool_call as create_tool_call from langchain_core.messages.tool import tool_call as create_tool_call
from langchain_core.messages.v1 import AIMessage as AIMessageV1
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser
from langchain_core.outputs import ChatGeneration, Generation from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.utils.json import parse_partial_json from langchain_core.utils.json import parse_partial_json
@ -156,7 +157,9 @@ class JsonOutputToolsParser(BaseCumulativeTransformOutputParser[Any]):
If no tool calls are found, None will be returned. If no tool calls are found, None will be returned.
""" """
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any: def parse_result(
self, result: Union[list[Generation], AIMessageV1], *, partial: bool = False
) -> Any:
"""Parse the result of an LLM call to a list of tool calls. """Parse the result of an LLM call to a list of tool calls.
Args: Args:
@ -173,9 +176,13 @@ class JsonOutputToolsParser(BaseCumulativeTransformOutputParser[Any]):
Raises: Raises:
OutputParserException: If the output is not valid JSON. OutputParserException: If the output is not valid JSON.
""" """
if isinstance(result, list):
generation = result[0] generation = result[0]
if not isinstance(generation, ChatGeneration): if not isinstance(generation, ChatGeneration):
msg = "This output parser can only be used with a chat generation." msg = (
"This output parser can only be used with a chat generation or "
"v1 AIMessage."
)
raise OutputParserException(msg) raise OutputParserException(msg)
message = generation.message message = generation.message
if isinstance(message, AIMessage) and message.tool_calls: if isinstance(message, AIMessage) and message.tool_calls:
@ -185,7 +192,9 @@ class JsonOutputToolsParser(BaseCumulativeTransformOutputParser[Any]):
_ = tool_call.pop("id") _ = tool_call.pop("id")
else: else:
try: try:
raw_tool_calls = copy.deepcopy(message.additional_kwargs["tool_calls"]) raw_tool_calls = copy.deepcopy(
message.additional_kwargs["tool_calls"]
)
except KeyError: except KeyError:
return [] return []
tool_calls = parse_tool_calls( tool_calls = parse_tool_calls(
@ -194,10 +203,18 @@ class JsonOutputToolsParser(BaseCumulativeTransformOutputParser[Any]):
strict=self.strict, strict=self.strict,
return_id=self.return_id, return_id=self.return_id,
) )
elif result.tool_calls:
# v1 message
tool_calls = [dict(tc) for tc in result.tool_calls]
for tool_call in tool_calls:
if not self.return_id:
_ = tool_call.pop("id")
else:
return []
# for backwards compatibility # for backwards compatibility
for tc in tool_calls: for tc in tool_calls:
tc["type"] = tc.pop("name") tc["type"] = tc.pop("name")
if self.first_tool_only: if self.first_tool_only:
return tool_calls[0] if tool_calls else None return tool_calls[0] if tool_calls else None
return tool_calls return tool_calls
@ -220,7 +237,9 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
key_name: str key_name: str
"""The type of tools to return.""" """The type of tools to return."""
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any: def parse_result(
self, result: Union[list[Generation], AIMessageV1], *, partial: bool = False
) -> Any:
"""Parse the result of an LLM call to a list of tool calls. """Parse the result of an LLM call to a list of tool calls.
Args: Args:
@ -234,6 +253,7 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
Returns: Returns:
The parsed tool calls. The parsed tool calls.
""" """
if isinstance(result, list):
generation = result[0] generation = result[0]
if not isinstance(generation, ChatGeneration): if not isinstance(generation, ChatGeneration):
msg = "This output parser can only be used with a chat generation." msg = "This output parser can only be used with a chat generation."
@ -246,7 +266,9 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
_ = tool_call.pop("id") _ = tool_call.pop("id")
else: else:
try: try:
raw_tool_calls = copy.deepcopy(message.additional_kwargs["tool_calls"]) raw_tool_calls = copy.deepcopy(
message.additional_kwargs["tool_calls"]
)
except KeyError: except KeyError:
if self.first_tool_only: if self.first_tool_only:
return None return None
@ -257,9 +279,21 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
strict=self.strict, strict=self.strict,
return_id=self.return_id, return_id=self.return_id,
) )
elif result.tool_calls:
# v1 message
parsed_tool_calls = [dict(tc) for tc in result.tool_calls]
for tool_call in parsed_tool_calls:
if not self.return_id:
_ = tool_call.pop("id")
else:
if self.first_tool_only:
return None
return []
# For backwards compatibility # For backwards compatibility
for tc in parsed_tool_calls: for tc in parsed_tool_calls:
tc["type"] = tc.pop("name") tc["type"] = tc.pop("name")
if self.first_tool_only: if self.first_tool_only:
parsed_result = list( parsed_result = list(
filter(lambda x: x["type"] == self.key_name, parsed_tool_calls) filter(lambda x: x["type"] == self.key_name, parsed_tool_calls)
@ -299,7 +333,9 @@ class PydanticToolsParser(JsonOutputToolsParser):
# TODO: Support more granular streaming of objects. Currently only streams once all # TODO: Support more granular streaming of objects. Currently only streams once all
# Pydantic object fields are present. # Pydantic object fields are present.
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any: def parse_result(
self, result: Union[list[Generation], AIMessageV1], *, partial: bool = False
) -> Any:
"""Parse the result of an LLM call to a list of Pydantic objects. """Parse the result of an LLM call to a list of Pydantic objects.
Args: Args:
@ -337,12 +373,19 @@ class PydanticToolsParser(JsonOutputToolsParser):
except (ValidationError, ValueError): except (ValidationError, ValueError):
if partial: if partial:
continue continue
has_max_tokens_stop_reason = False
if isinstance(result, list):
has_max_tokens_stop_reason = any( has_max_tokens_stop_reason = any(
generation.message.response_metadata.get("stop_reason") generation.message.response_metadata.get("stop_reason")
== "max_tokens" == "max_tokens"
for generation in result for generation in result
if isinstance(generation, ChatGeneration) if isinstance(generation, ChatGeneration)
) )
else:
# v1 message
has_max_tokens_stop_reason = (
result.response_metadata.get("stop_reason") == "max_tokens"
)
if has_max_tokens_stop_reason: if has_max_tokens_stop_reason:
logger.exception(_MAX_TOKENS_ERROR) logger.exception(_MAX_TOKENS_ERROR)
raise raise

View File

@ -1,13 +1,14 @@
"""Output parsers using Pydantic.""" """Output parsers using Pydantic."""
import json import json
from typing import Annotated, Generic, Optional from typing import Annotated, Generic, Optional, Union
import pydantic import pydantic
from pydantic import SkipValidation from pydantic import SkipValidation
from typing_extensions import override from typing_extensions import override
from langchain_core.exceptions import OutputParserException from langchain_core.exceptions import OutputParserException
from langchain_core.messages.v1 import AIMessage
from langchain_core.output_parsers import JsonOutputParser from langchain_core.output_parsers import JsonOutputParser
from langchain_core.outputs import Generation from langchain_core.outputs import Generation
from langchain_core.utils.pydantic import ( from langchain_core.utils.pydantic import (
@ -43,7 +44,7 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
return OutputParserException(msg, llm_output=json_string) return OutputParserException(msg, llm_output=json_string)
def parse_result( def parse_result(
self, result: list[Generation], *, partial: bool = False self, result: Union[list[Generation], AIMessage], *, partial: bool = False
) -> Optional[TBaseModel]: ) -> Optional[TBaseModel]:
"""Parse the result of an LLM call to a pydantic object. """Parse the result of an LLM call to a pydantic object.

View File

@ -12,6 +12,7 @@ from typing import (
from typing_extensions import override from typing_extensions import override
from langchain_core.messages import BaseMessage, BaseMessageChunk from langchain_core.messages import BaseMessage, BaseMessageChunk
from langchain_core.messages.v1 import AIMessage, AIMessageChunk
from langchain_core.output_parsers.base import BaseOutputParser, T from langchain_core.output_parsers.base import BaseOutputParser, T
from langchain_core.outputs import ( from langchain_core.outputs import (
ChatGeneration, ChatGeneration,
@ -32,23 +33,27 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
def _transform( def _transform(
self, self,
input: Iterator[Union[str, BaseMessage]], input: Iterator[Union[str, BaseMessage, AIMessage]],
) -> Iterator[T]: ) -> Iterator[T]:
for chunk in input: for chunk in input:
if isinstance(chunk, BaseMessage): if isinstance(chunk, BaseMessage):
yield self.parse_result([ChatGeneration(message=chunk)]) yield self.parse_result([ChatGeneration(message=chunk)])
elif isinstance(chunk, AIMessage):
yield self.parse_result(chunk)
else: else:
yield self.parse_result([Generation(text=chunk)]) yield self.parse_result([Generation(text=chunk)])
async def _atransform( async def _atransform(
self, self,
input: AsyncIterator[Union[str, BaseMessage]], input: AsyncIterator[Union[str, BaseMessage, AIMessage]],
) -> AsyncIterator[T]: ) -> AsyncIterator[T]:
async for chunk in input: async for chunk in input:
if isinstance(chunk, BaseMessage): if isinstance(chunk, BaseMessage):
yield await run_in_executor( yield await run_in_executor(
None, self.parse_result, [ChatGeneration(message=chunk)] None, self.parse_result, [ChatGeneration(message=chunk)]
) )
elif isinstance(chunk, AIMessage):
yield await run_in_executor(None, self.parse_result, chunk)
else: else:
yield await run_in_executor( yield await run_in_executor(
None, self.parse_result, [Generation(text=chunk)] None, self.parse_result, [Generation(text=chunk)]
@ -57,7 +62,7 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
@override @override
def transform( def transform(
self, self,
input: Iterator[Union[str, BaseMessage]], input: Iterator[Union[str, BaseMessage, AIMessage]],
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[T]: ) -> Iterator[T]:
@ -78,7 +83,7 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
@override @override
async def atransform( async def atransform(
self, self,
input: AsyncIterator[Union[str, BaseMessage]], input: AsyncIterator[Union[str, BaseMessage, AIMessage]],
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[T]: ) -> AsyncIterator[T]:
@ -125,22 +130,41 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
raise NotImplementedError raise NotImplementedError
@override @override
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]: def _transform(
self, input: Iterator[Union[str, BaseMessage, AIMessage]]
) -> Iterator[Any]:
prev_parsed = None prev_parsed = None
acc_gen: Union[GenerationChunk, ChatGenerationChunk, None] = None acc_gen: Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk, None] = (
None
)
for chunk in input: for chunk in input:
chunk_gen: Union[GenerationChunk, ChatGenerationChunk] chunk_gen: Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
if isinstance(chunk, BaseMessageChunk): if isinstance(chunk, BaseMessageChunk):
chunk_gen = ChatGenerationChunk(message=chunk) chunk_gen = ChatGenerationChunk(message=chunk)
elif isinstance(chunk, BaseMessage): elif isinstance(chunk, BaseMessage):
chunk_gen = ChatGenerationChunk( chunk_gen = ChatGenerationChunk(
message=BaseMessageChunk(**chunk.model_dump()) message=BaseMessageChunk(**chunk.model_dump())
) )
elif isinstance(chunk, AIMessageChunk):
chunk_gen = chunk
elif isinstance(chunk, AIMessage):
chunk_gen = AIMessageChunk(
content=chunk.content,
id=chunk.id,
name=chunk.name,
lc_version=chunk.lc_version,
response_metadata=chunk.response_metadata,
usage_metadata=chunk.usage_metadata,
parsed=chunk.parsed,
)
else: else:
chunk_gen = GenerationChunk(text=chunk) chunk_gen = GenerationChunk(text=chunk)
acc_gen = chunk_gen if acc_gen is None else acc_gen + chunk_gen # type: ignore[operator] acc_gen = chunk_gen if acc_gen is None else acc_gen + chunk_gen # type: ignore[operator]
if isinstance(acc_gen, AIMessageChunk):
parsed = self.parse_result(acc_gen, partial=True)
else:
parsed = self.parse_result([acc_gen], partial=True) parsed = self.parse_result([acc_gen], partial=True)
if parsed is not None and parsed != prev_parsed: if parsed is not None and parsed != prev_parsed:
if self.diff: if self.diff:
@ -151,23 +175,40 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
@override @override
async def _atransform( async def _atransform(
self, input: AsyncIterator[Union[str, BaseMessage]] self, input: AsyncIterator[Union[str, BaseMessage, AIMessage]]
) -> AsyncIterator[T]: ) -> AsyncIterator[T]:
prev_parsed = None prev_parsed = None
acc_gen: Union[GenerationChunk, ChatGenerationChunk, None] = None acc_gen: Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk, None] = (
None
)
async for chunk in input: async for chunk in input:
chunk_gen: Union[GenerationChunk, ChatGenerationChunk] chunk_gen: Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
if isinstance(chunk, BaseMessageChunk): if isinstance(chunk, BaseMessageChunk):
chunk_gen = ChatGenerationChunk(message=chunk) chunk_gen = ChatGenerationChunk(message=chunk)
elif isinstance(chunk, BaseMessage): elif isinstance(chunk, BaseMessage):
chunk_gen = ChatGenerationChunk( chunk_gen = ChatGenerationChunk(
message=BaseMessageChunk(**chunk.model_dump()) message=BaseMessageChunk(**chunk.model_dump())
) )
elif isinstance(chunk, AIMessageChunk):
chunk_gen = chunk
elif isinstance(chunk, AIMessage):
chunk_gen = AIMessageChunk(
content=chunk.content,
id=chunk.id,
name=chunk.name,
lc_version=chunk.lc_version,
response_metadata=chunk.response_metadata,
usage_metadata=chunk.usage_metadata,
parsed=chunk.parsed,
)
else: else:
chunk_gen = GenerationChunk(text=chunk) chunk_gen = GenerationChunk(text=chunk)
acc_gen = chunk_gen if acc_gen is None else acc_gen + chunk_gen # type: ignore[operator] acc_gen = chunk_gen if acc_gen is None else acc_gen + chunk_gen # type: ignore[operator]
if isinstance(acc_gen, AIMessageChunk):
parsed = await self.aparse_result(acc_gen, partial=True)
else:
parsed = await self.aparse_result([acc_gen], partial=True) parsed = await self.aparse_result([acc_gen], partial=True)
if parsed is not None and parsed != prev_parsed: if parsed is not None and parsed != prev_parsed:
if self.diff: if self.diff:

View File

@ -12,6 +12,8 @@ from typing_extensions import override
from langchain_core.exceptions import OutputParserException from langchain_core.exceptions import OutputParserException
from langchain_core.messages import BaseMessage from langchain_core.messages import BaseMessage
from langchain_core.messages.utils import convert_from_v1_message
from langchain_core.messages.v1 import AIMessage
from langchain_core.output_parsers.transform import BaseTransformOutputParser from langchain_core.output_parsers.transform import BaseTransformOutputParser
from langchain_core.runnables.utils import AddableDict from langchain_core.runnables.utils import AddableDict
@ -240,19 +242,26 @@ class XMLOutputParser(BaseTransformOutputParser):
@override @override
def _transform( def _transform(
self, input: Iterator[Union[str, BaseMessage]] self, input: Iterator[Union[str, BaseMessage, AIMessage]]
) -> Iterator[AddableDict]: ) -> Iterator[AddableDict]:
streaming_parser = _StreamingParser(self.parser) streaming_parser = _StreamingParser(self.parser)
for chunk in input: for chunk in input:
if isinstance(chunk, AIMessage):
yield from streaming_parser.parse(convert_from_v1_message(chunk))
else:
yield from streaming_parser.parse(chunk) yield from streaming_parser.parse(chunk)
streaming_parser.close() streaming_parser.close()
@override @override
async def _atransform( async def _atransform(
self, input: AsyncIterator[Union[str, BaseMessage]] self, input: AsyncIterator[Union[str, BaseMessage, AIMessage]]
) -> AsyncIterator[AddableDict]: ) -> AsyncIterator[AddableDict]:
streaming_parser = _StreamingParser(self.parser) streaming_parser = _StreamingParser(self.parser)
async for chunk in input: async for chunk in input:
if isinstance(chunk, AIMessage):
for output in streaming_parser.parse(convert_from_v1_message(chunk)):
yield output
else:
for output in streaming_parser.parse(chunk): for output in streaming_parser.parse(chunk):
yield output yield output
streaming_parser.close() streaming_parser.close()

View File

@ -1,10 +1,14 @@
"""Module to test base parser implementations.""" """Module to test base parser implementations."""
from typing import Union
from typing_extensions import override from typing_extensions import override
from langchain_core.exceptions import OutputParserException from langchain_core.exceptions import OutputParserException
from langchain_core.language_models import GenericFakeChatModel from langchain_core.language_models import GenericFakeChatModel
from langchain_core.language_models.fake_chat_models import GenericFakeChatModelV1
from langchain_core.messages import AIMessage from langchain_core.messages import AIMessage
from langchain_core.messages.v1 import AIMessage as AIMessageV1
from langchain_core.output_parsers import ( from langchain_core.output_parsers import (
BaseGenerationOutputParser, BaseGenerationOutputParser,
BaseTransformOutputParser, BaseTransformOutputParser,
@ -20,7 +24,7 @@ def test_base_generation_parser() -> None:
@override @override
def parse_result( def parse_result(
self, result: list[Generation], *, partial: bool = False self, result: Union[list[Generation], AIMessageV1], *, partial: bool = False
) -> str: ) -> str:
"""Parse a list of model Generations into a specific format. """Parse a list of model Generations into a specific format.
@ -32,16 +36,22 @@ def test_base_generation_parser() -> None:
partial: Whether to allow partial results. This is used for parsers partial: Whether to allow partial results. This is used for parsers
that support streaming that support streaming
""" """
if isinstance(result, AIMessageV1):
content = result.text or ""
else:
if len(result) != 1: if len(result) != 1:
msg = "This output parser can only be used with a single generation." msg = (
"This output parser can only be used with a single generation."
)
raise NotImplementedError(msg) raise NotImplementedError(msg)
generation = result[0] generation = result[0]
if not isinstance(generation, ChatGeneration): if not isinstance(generation, ChatGeneration):
# Say that this one only works with chat generations # Say that this one only works with chat generations
msg = "This output parser can only be used with a chat generation." msg = "This output parser can only be used with a chat generation."
raise OutputParserException(msg) raise OutputParserException(msg)
assert isinstance(generation.message.content, str)
content = generation.message.content content = generation.message.content
assert isinstance(content, str) assert isinstance(content, str)
return content.swapcase() return content.swapcase()
@ -49,6 +59,10 @@ def test_base_generation_parser() -> None:
chain = model | StrInvertCase() chain = model | StrInvertCase()
assert chain.invoke("") == "HeLLO" assert chain.invoke("") == "HeLLO"
model_v1 = GenericFakeChatModelV1(messages=iter([AIMessageV1("hEllo")]))
chain_v1 = model_v1 | StrInvertCase()
assert chain_v1.invoke("") == "HeLLO"
def test_base_transform_output_parser() -> None: def test_base_transform_output_parser() -> None:
"""Test base transform output parser.""" """Test base transform output parser."""
@ -62,7 +76,7 @@ def test_base_transform_output_parser() -> None:
@override @override
def parse_result( def parse_result(
self, result: list[Generation], *, partial: bool = False self, result: Union[list[Generation], AIMessageV1], *, partial: bool = False
) -> str: ) -> str:
"""Parse a list of model Generations into a specific format. """Parse a list of model Generations into a specific format.
@ -74,15 +88,22 @@ def test_base_transform_output_parser() -> None:
partial: Whether to allow partial results. This is used for parsers partial: Whether to allow partial results. This is used for parsers
that support streaming that support streaming
""" """
if isinstance(result, AIMessageV1):
content = result.text or ""
else:
if len(result) != 1: if len(result) != 1:
msg = "This output parser can only be used with a single generation." msg = (
"This output parser can only be used with a single generation."
)
raise NotImplementedError(msg) raise NotImplementedError(msg)
generation = result[0] generation = result[0]
if not isinstance(generation, ChatGeneration): if not isinstance(generation, ChatGeneration):
# Say that this one only works with chat generations # Say that this one only works with chat generations
msg = "This output parser can only be used with a chat generation." msg = "This output parser can only be used with a chat generation."
raise OutputParserException(msg) raise OutputParserException(msg)
assert isinstance(generation.message.content, str)
content = generation.message.content content = generation.message.content
assert isinstance(content, str) assert isinstance(content, str)
return content.swapcase() return content.swapcase()
@ -91,3 +112,8 @@ def test_base_transform_output_parser() -> None:
# inputs to models are ignored, response is hard-coded in model definition # inputs to models are ignored, response is hard-coded in model definition
chunks = list(chain.stream("")) chunks = list(chain.stream(""))
assert chunks == ["HELLO", " ", "WORLD"] assert chunks == ["HELLO", " ", "WORLD"]
model_v1 = GenericFakeChatModelV1(message_chunks=["hello", " ", "world"])
chain_v1 = model_v1 | StrInvertCase()
chunks = list(chain_v1.stream(""))
assert chunks == ["HELLO", " ", "WORLD"]

View File

@ -10,6 +10,8 @@ from langchain_core.messages import (
BaseMessage, BaseMessage,
ToolCallChunk, ToolCallChunk,
) )
from langchain_core.messages.v1 import AIMessage as AIMessageV1
from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1
from langchain_core.output_parsers.openai_tools import ( from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser, JsonOutputKeyToolsParser,
JsonOutputToolsParser, JsonOutputToolsParser,
@ -331,6 +333,14 @@ for message in STREAMED_MESSAGES:
STREAMED_MESSAGES_WITH_TOOL_CALLS.append(message) STREAMED_MESSAGES_WITH_TOOL_CALLS.append(message)
STREAMED_MESSAGES_V1 = [
AIMessageChunkV1(
content=[],
tool_call_chunks=chunk.tool_call_chunks,
)
for chunk in STREAMED_MESSAGES_WITH_TOOL_CALLS
]
EXPECTED_STREAMED_JSON = [ EXPECTED_STREAMED_JSON = [
{}, {},
{"names": ["suz"]}, {"names": ["suz"]},
@ -398,6 +408,19 @@ def test_partial_json_output_parser(*, use_tool_calls: bool) -> None:
assert actual == expected assert actual == expected
def test_partial_json_output_parser_v1() -> None:
def input_iter(_: Any) -> Iterator[AIMessageChunkV1]:
yield from STREAMED_MESSAGES_V1
chain = input_iter | JsonOutputToolsParser()
actual = list(chain.stream(None))
expected: list = [[]] + [
[{"type": "NameCollector", "args": chunk}] for chunk in EXPECTED_STREAMED_JSON
]
assert actual == expected
@pytest.mark.parametrize("use_tool_calls", [False, True]) @pytest.mark.parametrize("use_tool_calls", [False, True])
async def test_partial_json_output_parser_async(*, use_tool_calls: bool) -> None: async def test_partial_json_output_parser_async(*, use_tool_calls: bool) -> None:
input_iter = _get_aiter(use_tool_calls=use_tool_calls) input_iter = _get_aiter(use_tool_calls=use_tool_calls)
@ -410,6 +433,20 @@ async def test_partial_json_output_parser_async(*, use_tool_calls: bool) -> None
assert actual == expected assert actual == expected
async def test_partial_json_output_parser_async_v1() -> None:
async def input_iter(_: Any) -> AsyncIterator[AIMessageChunkV1]:
for msg in STREAMED_MESSAGES_V1:
yield msg
chain = input_iter | JsonOutputToolsParser()
actual = [p async for p in chain.astream(None)]
expected: list = [[]] + [
[{"type": "NameCollector", "args": chunk}] for chunk in EXPECTED_STREAMED_JSON
]
assert actual == expected
@pytest.mark.parametrize("use_tool_calls", [False, True]) @pytest.mark.parametrize("use_tool_calls", [False, True])
def test_partial_json_output_parser_return_id(*, use_tool_calls: bool) -> None: def test_partial_json_output_parser_return_id(*, use_tool_calls: bool) -> None:
input_iter = _get_iter(use_tool_calls=use_tool_calls) input_iter = _get_iter(use_tool_calls=use_tool_calls)
@ -429,6 +466,26 @@ def test_partial_json_output_parser_return_id(*, use_tool_calls: bool) -> None:
assert actual == expected assert actual == expected
def test_partial_json_output_parser_return_id_v1() -> None:
def input_iter(_: Any) -> Iterator[AIMessageChunkV1]:
yield from STREAMED_MESSAGES_V1
chain = input_iter | JsonOutputToolsParser(return_id=True)
actual = list(chain.stream(None))
expected: list = [[]] + [
[
{
"type": "NameCollector",
"args": chunk,
"id": "call_OwL7f5PEPJTYzw9sQlNJtCZl",
}
]
for chunk in EXPECTED_STREAMED_JSON
]
assert actual == expected
@pytest.mark.parametrize("use_tool_calls", [False, True]) @pytest.mark.parametrize("use_tool_calls", [False, True])
def test_partial_json_output_key_parser(*, use_tool_calls: bool) -> None: def test_partial_json_output_key_parser(*, use_tool_calls: bool) -> None:
input_iter = _get_iter(use_tool_calls=use_tool_calls) input_iter = _get_iter(use_tool_calls=use_tool_calls)
@ -439,6 +496,17 @@ def test_partial_json_output_key_parser(*, use_tool_calls: bool) -> None:
assert actual == expected assert actual == expected
def test_partial_json_output_key_parser_v1() -> None:
def input_iter(_: Any) -> Iterator[AIMessageChunkV1]:
yield from STREAMED_MESSAGES_V1
chain = input_iter | JsonOutputKeyToolsParser(key_name="NameCollector")
actual = list(chain.stream(None))
expected: list = [[]] + [[chunk] for chunk in EXPECTED_STREAMED_JSON]
assert actual == expected
@pytest.mark.parametrize("use_tool_calls", [False, True]) @pytest.mark.parametrize("use_tool_calls", [False, True])
async def test_partial_json_output_parser_key_async(*, use_tool_calls: bool) -> None: async def test_partial_json_output_parser_key_async(*, use_tool_calls: bool) -> None:
input_iter = _get_aiter(use_tool_calls=use_tool_calls) input_iter = _get_aiter(use_tool_calls=use_tool_calls)
@ -450,6 +518,18 @@ async def test_partial_json_output_parser_key_async(*, use_tool_calls: bool) ->
assert actual == expected assert actual == expected
async def test_partial_json_output_parser_key_async_v1() -> None:
async def input_iter(_: Any) -> AsyncIterator[AIMessageChunkV1]:
for msg in STREAMED_MESSAGES_V1:
yield msg
chain = input_iter | JsonOutputKeyToolsParser(key_name="NameCollector")
actual = [p async for p in chain.astream(None)]
expected: list = [[]] + [[chunk] for chunk in EXPECTED_STREAMED_JSON]
assert actual == expected
@pytest.mark.parametrize("use_tool_calls", [False, True]) @pytest.mark.parametrize("use_tool_calls", [False, True])
def test_partial_json_output_key_parser_first_only(*, use_tool_calls: bool) -> None: def test_partial_json_output_key_parser_first_only(*, use_tool_calls: bool) -> None:
input_iter = _get_iter(use_tool_calls=use_tool_calls) input_iter = _get_iter(use_tool_calls=use_tool_calls)
@ -461,6 +541,17 @@ def test_partial_json_output_key_parser_first_only(*, use_tool_calls: bool) -> N
assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON
def test_partial_json_output_key_parser_first_only_v1() -> None:
def input_iter(_: Any) -> Iterator[AIMessageChunkV1]:
yield from STREAMED_MESSAGES_V1
chain = input_iter | JsonOutputKeyToolsParser(
key_name="NameCollector", first_tool_only=True
)
assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON
@pytest.mark.parametrize("use_tool_calls", [False, True]) @pytest.mark.parametrize("use_tool_calls", [False, True])
async def test_partial_json_output_parser_key_async_first_only( async def test_partial_json_output_parser_key_async_first_only(
*, *,
@ -475,6 +566,18 @@ async def test_partial_json_output_parser_key_async_first_only(
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON
async def test_partial_json_output_parser_key_async_first_only_v1() -> None:
async def input_iter(_: Any) -> AsyncIterator[AIMessageChunkV1]:
for msg in STREAMED_MESSAGES_V1:
yield msg
chain = input_iter | JsonOutputKeyToolsParser(
key_name="NameCollector", first_tool_only=True
)
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON
@pytest.mark.parametrize("use_tool_calls", [False, True]) @pytest.mark.parametrize("use_tool_calls", [False, True])
def test_json_output_key_tools_parser_multiple_tools_first_only( def test_json_output_key_tools_parser_multiple_tools_first_only(
*, use_tool_calls: bool *, use_tool_calls: bool
@ -531,6 +634,42 @@ def test_json_output_key_tools_parser_multiple_tools_first_only(
assert output_no_id == {"a": 1} assert output_no_id == {"a": 1}
def test_json_output_key_tools_parser_multiple_tools_first_only_v1() -> None:
message = AIMessageV1(
content=[],
tool_calls=[
{
"type": "tool_call",
"id": "call_other",
"name": "other",
"args": {"b": 2},
},
{"type": "tool_call", "id": "call_func", "name": "func", "args": {"a": 1}},
],
)
# Test with return_id=True
parser = JsonOutputKeyToolsParser(
key_name="func", first_tool_only=True, return_id=True
)
output = parser.parse_result(message)
# Should return the func tool call, not None
assert output is not None
assert output["type"] == "func"
assert output["args"] == {"a": 1}
assert "id" in output
# Test with return_id=False
parser_no_id = JsonOutputKeyToolsParser(
key_name="func", first_tool_only=True, return_id=False
)
output_no_id = parser_no_id.parse_result(message)
# Should return just the args
assert output_no_id == {"a": 1}
@pytest.mark.parametrize("use_tool_calls", [False, True]) @pytest.mark.parametrize("use_tool_calls", [False, True])
def test_json_output_key_tools_parser_multiple_tools_no_match( def test_json_output_key_tools_parser_multiple_tools_no_match(
*, use_tool_calls: bool *, use_tool_calls: bool
@ -583,6 +722,44 @@ def test_json_output_key_tools_parser_multiple_tools_no_match(
assert output_no_id is None assert output_no_id is None
def test_json_output_key_tools_parser_multiple_tools_no_match_v1() -> None:
message = AIMessageV1(
content=[],
tool_calls=[
{
"type": "tool_call",
"id": "call_other",
"name": "other",
"args": {"b": 2},
},
{
"type": "tool_call",
"id": "call_another",
"name": "another",
"args": {"c": 3},
},
],
)
# Test with return_id=True, first_tool_only=True
parser = JsonOutputKeyToolsParser(
key_name="nonexistent", first_tool_only=True, return_id=True
)
output = parser.parse_result(message)
# Should return None when no matches
assert output is None
# Test with return_id=False, first_tool_only=True
parser_no_id = JsonOutputKeyToolsParser(
key_name="nonexistent", first_tool_only=True, return_id=False
)
output_no_id = parser_no_id.parse_result(message)
# Should return None when no matches
assert output_no_id is None
@pytest.mark.parametrize("use_tool_calls", [False, True]) @pytest.mark.parametrize("use_tool_calls", [False, True])
def test_json_output_key_tools_parser_multiple_matching_tools( def test_json_output_key_tools_parser_multiple_matching_tools(
*, use_tool_calls: bool *, use_tool_calls: bool
@ -643,6 +820,42 @@ def test_json_output_key_tools_parser_multiple_matching_tools(
assert output_all[1]["args"] == {"a": 3} assert output_all[1]["args"] == {"a": 3}
def test_json_output_key_tools_parser_multiple_matching_tools_v1() -> None:
message = AIMessageV1(
content=[],
tool_calls=[
{"type": "tool_call", "id": "call_func1", "name": "func", "args": {"a": 1}},
{
"type": "tool_call",
"id": "call_other",
"name": "other",
"args": {"b": 2},
},
{"type": "tool_call", "id": "call_func2", "name": "func", "args": {"a": 3}},
],
)
# Test with first_tool_only=True - should return first matching
parser = JsonOutputKeyToolsParser(
key_name="func", first_tool_only=True, return_id=True
)
output = parser.parse_result(message)
assert output is not None
assert output["type"] == "func"
assert output["args"] == {"a": 1} # First matching tool call
# Test with first_tool_only=False - should return all matching
parser_all = JsonOutputKeyToolsParser(
key_name="func", first_tool_only=False, return_id=True
)
output_all = parser_all.parse_result(message)
assert len(output_all) == 2
assert output_all[0]["args"] == {"a": 1}
assert output_all[1]["args"] == {"a": 3}
@pytest.mark.parametrize("use_tool_calls", [False, True]) @pytest.mark.parametrize("use_tool_calls", [False, True])
def test_json_output_key_tools_parser_empty_results(*, use_tool_calls: bool) -> None: def test_json_output_key_tools_parser_empty_results(*, use_tool_calls: bool) -> None:
def create_message() -> AIMessage: def create_message() -> AIMessage:
@ -671,6 +884,35 @@ def test_json_output_key_tools_parser_empty_results(*, use_tool_calls: bool) ->
assert output_all == [] assert output_all == []
@pytest.mark.parametrize(
"empty_message",
[
AIMessageV1(content=[], tool_calls=[]),
AIMessageV1(content="", tool_calls=[]),
],
)
def test_json_output_key_tools_parser_empty_results_v1(
empty_message: AIMessageV1,
) -> None:
# Test with first_tool_only=True
parser = JsonOutputKeyToolsParser(
key_name="func", first_tool_only=True, return_id=True
)
output = parser.parse_result(empty_message)
# Should return None for empty results
assert output is None
# Test with first_tool_only=False
parser_all = JsonOutputKeyToolsParser(
key_name="func", first_tool_only=False, return_id=True
)
output_all = parser_all.parse_result(empty_message)
# Should return empty list for empty results
assert output_all == []
@pytest.mark.parametrize("use_tool_calls", [False, True]) @pytest.mark.parametrize("use_tool_calls", [False, True])
def test_json_output_key_tools_parser_parameter_combinations( def test_json_output_key_tools_parser_parameter_combinations(
*, use_tool_calls: bool *, use_tool_calls: bool
@ -746,6 +988,56 @@ def test_json_output_key_tools_parser_parameter_combinations(
assert output4 == [{"a": 1}, {"a": 3}] assert output4 == [{"a": 1}, {"a": 3}]
def test_json_output_key_tools_parser_parameter_combinations_v1() -> None:
"""Test all parameter combinations of JsonOutputKeyToolsParser."""
result = AIMessageV1(
content=[],
tool_calls=[
{
"type": "tool_call",
"id": "call_other",
"name": "other",
"args": {"b": 2},
},
{"type": "tool_call", "id": "call_func1", "name": "func", "args": {"a": 1}},
{"type": "tool_call", "id": "call_func2", "name": "func", "args": {"a": 3}},
],
)
# Test: first_tool_only=True, return_id=True
parser1 = JsonOutputKeyToolsParser(
key_name="func", first_tool_only=True, return_id=True
)
output1 = parser1.parse_result(result)
assert output1["type"] == "func"
assert output1["args"] == {"a": 1}
assert "id" in output1
# Test: first_tool_only=True, return_id=False
parser2 = JsonOutputKeyToolsParser(
key_name="func", first_tool_only=True, return_id=False
)
output2 = parser2.parse_result(result)
assert output2 == {"a": 1}
# Test: first_tool_only=False, return_id=True
parser3 = JsonOutputKeyToolsParser(
key_name="func", first_tool_only=False, return_id=True
)
output3 = parser3.parse_result(result)
assert len(output3) == 2
assert all("id" in item for item in output3)
assert output3[0]["args"] == {"a": 1}
assert output3[1]["args"] == {"a": 3}
# Test: first_tool_only=False, return_id=False
parser4 = JsonOutputKeyToolsParser(
key_name="func", first_tool_only=False, return_id=False
)
output4 = parser4.parse_result(result)
assert output4 == [{"a": 1}, {"a": 3}]
class Person(BaseModel): class Person(BaseModel):
age: int age: int
hair_color: str hair_color: str
@ -788,6 +1080,18 @@ def test_partial_pydantic_output_parser() -> None:
assert actual == EXPECTED_STREAMED_PYDANTIC assert actual == EXPECTED_STREAMED_PYDANTIC
def test_partial_pydantic_output_parser_v1() -> None:
def input_iter(_: Any) -> Iterator[AIMessageChunkV1]:
yield from STREAMED_MESSAGES_V1
chain = input_iter | PydanticToolsParser(
tools=[NameCollector], first_tool_only=True
)
actual = list(chain.stream(None))
assert actual == EXPECTED_STREAMED_PYDANTIC
async def test_partial_pydantic_output_parser_async() -> None: async def test_partial_pydantic_output_parser_async() -> None:
for use_tool_calls in [False, True]: for use_tool_calls in [False, True]:
input_iter = _get_aiter(use_tool_calls=use_tool_calls) input_iter = _get_aiter(use_tool_calls=use_tool_calls)
@ -800,6 +1104,19 @@ async def test_partial_pydantic_output_parser_async() -> None:
assert actual == EXPECTED_STREAMED_PYDANTIC assert actual == EXPECTED_STREAMED_PYDANTIC
async def test_partial_pydantic_output_parser_async_v1() -> None:
async def input_iter(_: Any) -> AsyncIterator[AIMessageChunkV1]:
for msg in STREAMED_MESSAGES_V1:
yield msg
chain = input_iter | PydanticToolsParser(
tools=[NameCollector], first_tool_only=True
)
actual = [p async for p in chain.astream(None)]
assert actual == EXPECTED_STREAMED_PYDANTIC
def test_parse_with_different_pydantic_2_v1() -> None: def test_parse_with_different_pydantic_2_v1() -> None:
"""Test with pydantic.v1.BaseModel from pydantic 2.""" """Test with pydantic.v1.BaseModel from pydantic 2."""
import pydantic import pydantic
@ -870,10 +1187,12 @@ def test_parse_with_different_pydantic_2_proper() -> None:
def test_max_tokens_error(caplog: Any) -> None: def test_max_tokens_error(caplog: Any) -> None:
parser = PydanticToolsParser(tools=[NameCollector], first_tool_only=True) parser = PydanticToolsParser(tools=[NameCollector], first_tool_only=True)
message = AIMessage( for msg_class in [AIMessage, AIMessageV1]:
message = msg_class(
content="", content="",
tool_calls=[ tool_calls=[
{ {
"type": "tool_call",
"id": "call_OwL7f5PE", "id": "call_OwL7f5PE",
"name": "NameCollector", "name": "NameCollector",
"args": {"names": ["suz", "jerm"]}, "args": {"names": ["suz", "jerm"]},

View File

@ -8,6 +8,8 @@ from langchain_core.messages import (
AIMessage, AIMessage,
BaseMessage, BaseMessage,
) )
from langchain_core.messages.utils import convert_from_v1_message
from langchain_core.messages.v1 import AIMessage as AIMessageV1
from langchain_core.outputs import ChatGeneration, Generation from langchain_core.outputs import ChatGeneration, Generation
from typing_extensions import override from typing_extensions import override
@ -83,10 +85,12 @@ class OpenAIFunctionsAgentOutputParser(AgentOutputParser):
@override @override
def parse_result( def parse_result(
self, self,
result: list[Generation], result: Union[list[Generation], AIMessageV1],
*, *,
partial: bool = False, partial: bool = False,
) -> Union[AgentAction, AgentFinish]: ) -> Union[AgentAction, AgentFinish]:
if isinstance(result, AIMessageV1):
result = [ChatGeneration(message=convert_from_v1_message(result))]
if not isinstance(result[0], ChatGeneration): if not isinstance(result[0], ChatGeneration):
msg = "This output parser only works on ChatGeneration output" msg = "This output parser only works on ChatGeneration output"
raise ValueError(msg) # noqa: TRY004 raise ValueError(msg) # noqa: TRY004

View File

@ -2,6 +2,8 @@ from typing import Union
from langchain_core.agents import AgentAction, AgentFinish from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.messages import BaseMessage from langchain_core.messages import BaseMessage
from langchain_core.messages.utils import convert_from_v1_message
from langchain_core.messages.v1 import AIMessage as AIMessageV1
from langchain_core.outputs import ChatGeneration, Generation from langchain_core.outputs import ChatGeneration, Generation
from typing_extensions import override from typing_extensions import override
@ -57,10 +59,12 @@ class OpenAIToolsAgentOutputParser(MultiActionAgentOutputParser):
@override @override
def parse_result( def parse_result(
self, self,
result: list[Generation], result: Union[list[Generation], AIMessageV1],
*, *,
partial: bool = False, partial: bool = False,
) -> Union[list[AgentAction], AgentFinish]: ) -> Union[list[AgentAction], AgentFinish]:
if isinstance(result, AIMessageV1):
result = [ChatGeneration(message=convert_from_v1_message(result))]
if not isinstance(result[0], ChatGeneration): if not isinstance(result[0], ChatGeneration):
msg = "This output parser only works on ChatGeneration output" msg = "This output parser only works on ChatGeneration output"
raise ValueError(msg) # noqa: TRY004 raise ValueError(msg) # noqa: TRY004

View File

@ -9,6 +9,8 @@ from langchain_core.messages import (
BaseMessage, BaseMessage,
ToolCall, ToolCall,
) )
from langchain_core.messages.utils import convert_from_v1_message
from langchain_core.messages.v1 import AIMessage as AIMessageV1
from langchain_core.outputs import ChatGeneration, Generation from langchain_core.outputs import ChatGeneration, Generation
from typing_extensions import override from typing_extensions import override
@ -101,10 +103,12 @@ class ToolsAgentOutputParser(MultiActionAgentOutputParser):
@override @override
def parse_result( def parse_result(
self, self,
result: list[Generation], result: Union[list[Generation], AIMessageV1],
*, *,
partial: bool = False, partial: bool = False,
) -> Union[list[AgentAction], AgentFinish]: ) -> Union[list[AgentAction], AgentFinish]:
if isinstance(result, AIMessageV1):
result = [ChatGeneration(message=convert_from_v1_message(result))]
if not isinstance(result[0], ChatGeneration): if not isinstance(result[0], ChatGeneration):
msg = "This output parser only works on ChatGeneration output" msg = "This output parser only works on ChatGeneration output"
raise ValueError(msg) # noqa: TRY004 raise ValueError(msg) # noqa: TRY004