mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-15 15:46:47 +00:00
Merge branch 'standard_outputs_copy' into mdrxy/ollama_v1
This commit is contained in:
commit
588fe46601
@ -3,7 +3,7 @@
|
||||
import asyncio
|
||||
import re
|
||||
import time
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from collections.abc import AsyncIterator, Iterable, Iterator
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from typing_extensions import override
|
||||
@ -13,7 +13,11 @@ from langchain_core.callbacks import (
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import BaseChatModel, SimpleChatModel
|
||||
from langchain_core.language_models.v1.chat_models import BaseChatModelV1
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
|
||||
from langchain_core.messages.v1 import AIMessage as AIMessageV1
|
||||
from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1
|
||||
from langchain_core.messages.v1 import MessageV1
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
@ -367,3 +371,69 @@ class ParrotFakeChatModel(BaseChatModel):
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "parrot-fake-chat-model"
|
||||
|
||||
|
||||
class GenericFakeChatModelV1(BaseChatModelV1):
|
||||
"""Generic fake chat model that can be used to test the chat model interface."""
|
||||
|
||||
messages: Optional[Iterator[Union[AIMessageV1, str]]] = None
|
||||
message_chunks: Optional[Iterable[Union[AIMessageChunkV1, str]]] = None
|
||||
|
||||
@override
|
||||
def _invoke(
|
||||
self,
|
||||
messages: list[MessageV1],
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AIMessageV1:
|
||||
"""Top Level call."""
|
||||
if self.messages is None:
|
||||
error_msg = "Messages iterator is not set."
|
||||
raise ValueError(error_msg)
|
||||
message = next(self.messages)
|
||||
return AIMessageV1(content=message) if isinstance(message, str) else message
|
||||
|
||||
@override
|
||||
def _stream(
|
||||
self,
|
||||
messages: list[MessageV1],
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[AIMessageChunkV1]:
|
||||
"""Top Level call."""
|
||||
if self.message_chunks is None:
|
||||
error_msg = "Message chunks iterator is not set."
|
||||
raise ValueError(error_msg)
|
||||
for chunk in self.message_chunks:
|
||||
if isinstance(chunk, str):
|
||||
yield AIMessageChunkV1(chunk)
|
||||
else:
|
||||
yield chunk
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "generic-fake-chat-model"
|
||||
|
||||
|
||||
class ParrotFakeChatModelV1(BaseChatModelV1):
|
||||
"""Generic fake chat model that can be used to test the chat model interface.
|
||||
|
||||
* Chat model should be usable in both sync and async tests
|
||||
"""
|
||||
|
||||
@override
|
||||
def _invoke(
|
||||
self,
|
||||
messages: list[MessageV1],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AIMessageV1:
|
||||
"""Top Level call."""
|
||||
if isinstance(messages[-1], AIMessageV1):
|
||||
return messages[-1]
|
||||
return AIMessageV1(content=messages[-1].content)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "parrot-fake-chat-model"
|
||||
|
@ -3,9 +3,10 @@
|
||||
import json
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Literal, Optional, TypedDict, Union, cast, get_args
|
||||
from typing import Any, Literal, Optional, Union, cast, get_args
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
import langchain_core.messages.content_blocks as types
|
||||
from langchain_core.messages.ai import _LC_ID_PREFIX, UsageMetadata, add_usage
|
||||
@ -329,7 +330,11 @@ class AIMessageChunk(AIMessage):
|
||||
@property
|
||||
def reasoning(self) -> Optional[str]:
|
||||
"""Extract all reasoning text from the AI message as a string."""
|
||||
text_blocks = [block for block in self.content if block["type"] == "reasoning"]
|
||||
text_blocks = [
|
||||
block
|
||||
for block in self.content
|
||||
if block["type"] == "reasoning" and "reasoning" in block
|
||||
]
|
||||
if text_blocks:
|
||||
return "".join(block["reasoning"] for block in text_blocks)
|
||||
return None
|
||||
|
@ -11,12 +11,14 @@ from typing import (
|
||||
Optional,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core.language_models import LanguageModelOutput
|
||||
from langchain_core.messages import AnyMessage, BaseMessage
|
||||
from langchain_core.messages.v1 import AIMessage, MessageV1, MessageV1Types
|
||||
from langchain_core.outputs import ChatGeneration, Generation
|
||||
from langchain_core.runnables import Runnable, RunnableConfig, RunnableSerializable
|
||||
from langchain_core.runnables.config import run_in_executor
|
||||
@ -25,14 +27,16 @@ if TYPE_CHECKING:
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
|
||||
T = TypeVar("T")
|
||||
OutputParserLike = Runnable[LanguageModelOutput, T]
|
||||
OutputParserLike = Runnable[Union[LanguageModelOutput, AIMessage], T]
|
||||
|
||||
|
||||
class BaseLLMOutputParser(ABC, Generic[T]):
|
||||
"""Abstract base class for parsing the outputs of a model."""
|
||||
|
||||
@abstractmethod
|
||||
def parse_result(self, result: list[Generation], *, partial: bool = False) -> T:
|
||||
def parse_result(
|
||||
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
|
||||
) -> T:
|
||||
"""Parse a list of candidate model Generations into a specific format.
|
||||
|
||||
Args:
|
||||
@ -46,7 +50,7 @@ class BaseLLMOutputParser(ABC, Generic[T]):
|
||||
"""
|
||||
|
||||
async def aparse_result(
|
||||
self, result: list[Generation], *, partial: bool = False
|
||||
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
|
||||
) -> T:
|
||||
"""Async parse a list of candidate model Generations into a specific format.
|
||||
|
||||
@ -71,7 +75,7 @@ class BaseGenerationOutputParser(
|
||||
@override
|
||||
def InputType(self) -> Any:
|
||||
"""Return the input type for the parser."""
|
||||
return Union[str, AnyMessage]
|
||||
return Union[str, AnyMessage, MessageV1]
|
||||
|
||||
@property
|
||||
@override
|
||||
@ -84,7 +88,7 @@ class BaseGenerationOutputParser(
|
||||
@override
|
||||
def invoke(
|
||||
self,
|
||||
input: Union[str, BaseMessage],
|
||||
input: Union[str, BaseMessage, MessageV1],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> T:
|
||||
@ -97,9 +101,16 @@ class BaseGenerationOutputParser(
|
||||
config,
|
||||
run_type="parser",
|
||||
)
|
||||
if isinstance(input, MessageV1Types):
|
||||
return self._call_with_config(
|
||||
lambda inner_input: self.parse_result(inner_input),
|
||||
input,
|
||||
config,
|
||||
run_type="parser",
|
||||
)
|
||||
return self._call_with_config(
|
||||
lambda inner_input: self.parse_result([Generation(text=inner_input)]),
|
||||
input,
|
||||
cast("str", input),
|
||||
config,
|
||||
run_type="parser",
|
||||
)
|
||||
@ -120,6 +131,13 @@ class BaseGenerationOutputParser(
|
||||
config,
|
||||
run_type="parser",
|
||||
)
|
||||
if isinstance(input, MessageV1Types):
|
||||
return await self._acall_with_config(
|
||||
lambda inner_input: self.aparse_result(inner_input),
|
||||
input,
|
||||
config,
|
||||
run_type="parser",
|
||||
)
|
||||
return await self._acall_with_config(
|
||||
lambda inner_input: self.aparse_result([Generation(text=inner_input)]),
|
||||
input,
|
||||
@ -129,7 +147,7 @@ class BaseGenerationOutputParser(
|
||||
|
||||
|
||||
class BaseOutputParser(
|
||||
BaseLLMOutputParser, RunnableSerializable[LanguageModelOutput, T]
|
||||
BaseLLMOutputParser, RunnableSerializable[Union[LanguageModelOutput, AIMessage], T]
|
||||
):
|
||||
"""Base class to parse the output of an LLM call.
|
||||
|
||||
@ -162,7 +180,7 @@ class BaseOutputParser(
|
||||
@override
|
||||
def InputType(self) -> Any:
|
||||
"""Return the input type for the parser."""
|
||||
return Union[str, AnyMessage]
|
||||
return Union[str, AnyMessage, MessageV1]
|
||||
|
||||
@property
|
||||
@override
|
||||
@ -189,7 +207,7 @@ class BaseOutputParser(
|
||||
@override
|
||||
def invoke(
|
||||
self,
|
||||
input: Union[str, BaseMessage],
|
||||
input: Union[str, BaseMessage, MessageV1],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> T:
|
||||
@ -202,9 +220,16 @@ class BaseOutputParser(
|
||||
config,
|
||||
run_type="parser",
|
||||
)
|
||||
if isinstance(input, MessageV1Types):
|
||||
return self._call_with_config(
|
||||
lambda inner_input: self.parse_result(inner_input),
|
||||
input,
|
||||
config,
|
||||
run_type="parser",
|
||||
)
|
||||
return self._call_with_config(
|
||||
lambda inner_input: self.parse_result([Generation(text=inner_input)]),
|
||||
input,
|
||||
cast("str", input),
|
||||
config,
|
||||
run_type="parser",
|
||||
)
|
||||
@ -212,7 +237,7 @@ class BaseOutputParser(
|
||||
@override
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: Union[str, BaseMessage],
|
||||
input: Union[str, BaseMessage, MessageV1],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> T:
|
||||
@ -225,15 +250,24 @@ class BaseOutputParser(
|
||||
config,
|
||||
run_type="parser",
|
||||
)
|
||||
if isinstance(input, MessageV1Types):
|
||||
return await self._acall_with_config(
|
||||
lambda inner_input: self.aparse_result(inner_input),
|
||||
input,
|
||||
config,
|
||||
run_type="parser",
|
||||
)
|
||||
return await self._acall_with_config(
|
||||
lambda inner_input: self.aparse_result([Generation(text=inner_input)]),
|
||||
input,
|
||||
cast("str", input),
|
||||
config,
|
||||
run_type="parser",
|
||||
)
|
||||
|
||||
@override
|
||||
def parse_result(self, result: list[Generation], *, partial: bool = False) -> T:
|
||||
def parse_result(
|
||||
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
|
||||
) -> T:
|
||||
"""Parse a list of candidate model Generations into a specific format.
|
||||
|
||||
The return value is parsed from only the first Generation in the result, which
|
||||
@ -248,6 +282,8 @@ class BaseOutputParser(
|
||||
Returns:
|
||||
Structured output.
|
||||
"""
|
||||
if isinstance(result, AIMessage):
|
||||
return self.parse(result.text or "")
|
||||
return self.parse(result[0].text)
|
||||
|
||||
@abstractmethod
|
||||
@ -262,7 +298,7 @@ class BaseOutputParser(
|
||||
"""
|
||||
|
||||
async def aparse_result(
|
||||
self, result: list[Generation], *, partial: bool = False
|
||||
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
|
||||
) -> T:
|
||||
"""Async parse a list of candidate model Generations into a specific format.
|
||||
|
||||
|
@ -13,6 +13,7 @@ from pydantic.v1 import BaseModel
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.messages.v1 import AIMessage
|
||||
from langchain_core.output_parsers.format_instructions import JSON_FORMAT_INSTRUCTIONS
|
||||
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser
|
||||
from langchain_core.outputs import Generation
|
||||
@ -53,7 +54,9 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
|
||||
return pydantic_object.schema()
|
||||
return None
|
||||
|
||||
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
||||
def parse_result(
|
||||
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
|
||||
) -> Any:
|
||||
"""Parse the result of an LLM call to a JSON object.
|
||||
|
||||
Args:
|
||||
@ -70,7 +73,7 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
|
||||
Raises:
|
||||
OutputParserException: If the output is not valid JSON.
|
||||
"""
|
||||
text = result[0].text
|
||||
text = result.text or "" if isinstance(result, AIMessage) else result[0].text
|
||||
text = text.strip()
|
||||
if partial:
|
||||
try:
|
||||
|
@ -12,6 +12,7 @@ from typing import TYPE_CHECKING, TypeVar, Union
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages.v1 import AIMessage
|
||||
from langchain_core.output_parsers.transform import BaseTransformOutputParser
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -71,7 +72,7 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]):
|
||||
|
||||
@override
|
||||
def _transform(
|
||||
self, input: Iterator[Union[str, BaseMessage]]
|
||||
self, input: Iterator[Union[str, BaseMessage, AIMessage]]
|
||||
) -> Iterator[list[str]]:
|
||||
buffer = ""
|
||||
for chunk in input:
|
||||
@ -81,6 +82,8 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]):
|
||||
if not isinstance(chunk_content, str):
|
||||
continue
|
||||
buffer += chunk_content
|
||||
elif isinstance(chunk, AIMessage):
|
||||
buffer += chunk.text or ""
|
||||
else:
|
||||
# add current chunk to buffer
|
||||
buffer += chunk
|
||||
@ -105,7 +108,7 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]):
|
||||
|
||||
@override
|
||||
async def _atransform(
|
||||
self, input: AsyncIterator[Union[str, BaseMessage]]
|
||||
self, input: AsyncIterator[Union[str, BaseMessage, AIMessage]]
|
||||
) -> AsyncIterator[list[str]]:
|
||||
buffer = ""
|
||||
async for chunk in input:
|
||||
@ -115,6 +118,8 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]):
|
||||
if not isinstance(chunk_content, str):
|
||||
continue
|
||||
buffer += chunk_content
|
||||
elif isinstance(chunk, AIMessage):
|
||||
buffer += chunk.text or ""
|
||||
else:
|
||||
# add current chunk to buffer
|
||||
buffer += chunk
|
||||
|
@ -11,6 +11,7 @@ from pydantic.v1 import BaseModel as BaseModelV1
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.messages.v1 import AIMessage
|
||||
from langchain_core.output_parsers import (
|
||||
BaseCumulativeTransformOutputParser,
|
||||
BaseGenerationOutputParser,
|
||||
@ -26,7 +27,9 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
|
||||
"""Whether to only return the arguments to the function call."""
|
||||
|
||||
@override
|
||||
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
||||
def parse_result(
|
||||
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
|
||||
) -> Any:
|
||||
"""Parse the result of an LLM call to a JSON object.
|
||||
|
||||
Args:
|
||||
@ -39,6 +42,12 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
|
||||
Raises:
|
||||
OutputParserException: If the output is not valid JSON.
|
||||
"""
|
||||
if isinstance(result, AIMessage):
|
||||
msg = (
|
||||
"This output parser does not support v1 AIMessages. Use "
|
||||
"JsonOutputToolsParser instead."
|
||||
)
|
||||
raise TypeError(msg)
|
||||
generation = result[0]
|
||||
if not isinstance(generation, ChatGeneration):
|
||||
msg = "This output parser can only be used with a chat generation."
|
||||
@ -77,7 +86,9 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
|
||||
def _diff(self, prev: Optional[Any], next: Any) -> Any:
|
||||
return jsonpatch.make_patch(prev, next).patch
|
||||
|
||||
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
||||
def parse_result(
|
||||
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
|
||||
) -> Any:
|
||||
"""Parse the result of an LLM call to a JSON object.
|
||||
|
||||
Args:
|
||||
@ -90,6 +101,12 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
|
||||
Raises:
|
||||
OutputParserException: If the output is not valid JSON.
|
||||
"""
|
||||
if isinstance(result, AIMessage):
|
||||
msg = (
|
||||
"This output parser does not support v1 AIMessages. Use "
|
||||
"JsonOutputToolsParser instead."
|
||||
)
|
||||
raise TypeError(msg)
|
||||
if len(result) != 1:
|
||||
msg = f"Expected exactly one result, but got {len(result)}"
|
||||
raise OutputParserException(msg)
|
||||
@ -160,7 +177,9 @@ class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):
|
||||
key_name: str
|
||||
"""The name of the key to return."""
|
||||
|
||||
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
||||
def parse_result(
|
||||
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
|
||||
) -> Any:
|
||||
"""Parse the result of an LLM call to a JSON object.
|
||||
|
||||
Args:
|
||||
@ -254,7 +273,9 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
|
||||
return values
|
||||
|
||||
@override
|
||||
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
||||
def parse_result(
|
||||
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
|
||||
) -> Any:
|
||||
"""Parse the result of an LLM call to a JSON object.
|
||||
|
||||
Args:
|
||||
@ -294,7 +315,9 @@ class PydanticAttrOutputFunctionsParser(PydanticOutputFunctionsParser):
|
||||
"""The name of the attribute to return."""
|
||||
|
||||
@override
|
||||
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
||||
def parse_result(
|
||||
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
|
||||
) -> Any:
|
||||
"""Parse the result of an LLM call to a JSON object.
|
||||
|
||||
Args:
|
||||
|
@ -4,7 +4,7 @@ import copy
|
||||
import json
|
||||
import logging
|
||||
from json import JSONDecodeError
|
||||
from typing import Annotated, Any, Optional
|
||||
from typing import Annotated, Any, Optional, Union
|
||||
|
||||
from pydantic import SkipValidation, ValidationError
|
||||
|
||||
@ -12,6 +12,7 @@ from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.messages import AIMessage, InvalidToolCall
|
||||
from langchain_core.messages.tool import invalid_tool_call
|
||||
from langchain_core.messages.tool import tool_call as create_tool_call
|
||||
from langchain_core.messages.v1 import AIMessage as AIMessageV1
|
||||
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser
|
||||
from langchain_core.outputs import ChatGeneration, Generation
|
||||
from langchain_core.utils.json import parse_partial_json
|
||||
@ -156,7 +157,9 @@ class JsonOutputToolsParser(BaseCumulativeTransformOutputParser[Any]):
|
||||
If no tool calls are found, None will be returned.
|
||||
"""
|
||||
|
||||
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
||||
def parse_result(
|
||||
self, result: Union[list[Generation], AIMessageV1], *, partial: bool = False
|
||||
) -> Any:
|
||||
"""Parse the result of an LLM call to a list of tool calls.
|
||||
|
||||
Args:
|
||||
@ -173,31 +176,45 @@ class JsonOutputToolsParser(BaseCumulativeTransformOutputParser[Any]):
|
||||
Raises:
|
||||
OutputParserException: If the output is not valid JSON.
|
||||
"""
|
||||
generation = result[0]
|
||||
if not isinstance(generation, ChatGeneration):
|
||||
msg = "This output parser can only be used with a chat generation."
|
||||
raise OutputParserException(msg)
|
||||
message = generation.message
|
||||
if isinstance(message, AIMessage) and message.tool_calls:
|
||||
tool_calls = [dict(tc) for tc in message.tool_calls]
|
||||
if isinstance(result, list):
|
||||
generation = result[0]
|
||||
if not isinstance(generation, ChatGeneration):
|
||||
msg = (
|
||||
"This output parser can only be used with a chat generation or "
|
||||
"v1 AIMessage."
|
||||
)
|
||||
raise OutputParserException(msg)
|
||||
message = generation.message
|
||||
if isinstance(message, AIMessage) and message.tool_calls:
|
||||
tool_calls = [dict(tc) for tc in message.tool_calls]
|
||||
for tool_call in tool_calls:
|
||||
if not self.return_id:
|
||||
_ = tool_call.pop("id")
|
||||
else:
|
||||
try:
|
||||
raw_tool_calls = copy.deepcopy(
|
||||
message.additional_kwargs["tool_calls"]
|
||||
)
|
||||
except KeyError:
|
||||
return []
|
||||
tool_calls = parse_tool_calls(
|
||||
raw_tool_calls,
|
||||
partial=partial,
|
||||
strict=self.strict,
|
||||
return_id=self.return_id,
|
||||
)
|
||||
elif result.tool_calls:
|
||||
# v1 message
|
||||
tool_calls = [dict(tc) for tc in result.tool_calls]
|
||||
for tool_call in tool_calls:
|
||||
if not self.return_id:
|
||||
_ = tool_call.pop("id")
|
||||
else:
|
||||
try:
|
||||
raw_tool_calls = copy.deepcopy(message.additional_kwargs["tool_calls"])
|
||||
except KeyError:
|
||||
return []
|
||||
tool_calls = parse_tool_calls(
|
||||
raw_tool_calls,
|
||||
partial=partial,
|
||||
strict=self.strict,
|
||||
return_id=self.return_id,
|
||||
)
|
||||
return []
|
||||
|
||||
# for backwards compatibility
|
||||
for tc in tool_calls:
|
||||
tc["type"] = tc.pop("name")
|
||||
|
||||
if self.first_tool_only:
|
||||
return tool_calls[0] if tool_calls else None
|
||||
return tool_calls
|
||||
@ -220,7 +237,9 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
|
||||
key_name: str
|
||||
"""The type of tools to return."""
|
||||
|
||||
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
||||
def parse_result(
|
||||
self, result: Union[list[Generation], AIMessageV1], *, partial: bool = False
|
||||
) -> Any:
|
||||
"""Parse the result of an LLM call to a list of tool calls.
|
||||
|
||||
Args:
|
||||
@ -234,32 +253,47 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
|
||||
Returns:
|
||||
The parsed tool calls.
|
||||
"""
|
||||
generation = result[0]
|
||||
if not isinstance(generation, ChatGeneration):
|
||||
msg = "This output parser can only be used with a chat generation."
|
||||
raise OutputParserException(msg)
|
||||
message = generation.message
|
||||
if isinstance(message, AIMessage) and message.tool_calls:
|
||||
parsed_tool_calls = [dict(tc) for tc in message.tool_calls]
|
||||
if isinstance(result, list):
|
||||
generation = result[0]
|
||||
if not isinstance(generation, ChatGeneration):
|
||||
msg = "This output parser can only be used with a chat generation."
|
||||
raise OutputParserException(msg)
|
||||
message = generation.message
|
||||
if isinstance(message, AIMessage) and message.tool_calls:
|
||||
parsed_tool_calls = [dict(tc) for tc in message.tool_calls]
|
||||
for tool_call in parsed_tool_calls:
|
||||
if not self.return_id:
|
||||
_ = tool_call.pop("id")
|
||||
else:
|
||||
try:
|
||||
raw_tool_calls = copy.deepcopy(
|
||||
message.additional_kwargs["tool_calls"]
|
||||
)
|
||||
except KeyError:
|
||||
if self.first_tool_only:
|
||||
return None
|
||||
return []
|
||||
parsed_tool_calls = parse_tool_calls(
|
||||
raw_tool_calls,
|
||||
partial=partial,
|
||||
strict=self.strict,
|
||||
return_id=self.return_id,
|
||||
)
|
||||
elif result.tool_calls:
|
||||
# v1 message
|
||||
parsed_tool_calls = [dict(tc) for tc in result.tool_calls]
|
||||
for tool_call in parsed_tool_calls:
|
||||
if not self.return_id:
|
||||
_ = tool_call.pop("id")
|
||||
else:
|
||||
try:
|
||||
raw_tool_calls = copy.deepcopy(message.additional_kwargs["tool_calls"])
|
||||
except KeyError:
|
||||
if self.first_tool_only:
|
||||
return None
|
||||
return []
|
||||
parsed_tool_calls = parse_tool_calls(
|
||||
raw_tool_calls,
|
||||
partial=partial,
|
||||
strict=self.strict,
|
||||
return_id=self.return_id,
|
||||
)
|
||||
if self.first_tool_only:
|
||||
return None
|
||||
return []
|
||||
|
||||
# For backwards compatibility
|
||||
for tc in parsed_tool_calls:
|
||||
tc["type"] = tc.pop("name")
|
||||
|
||||
if self.first_tool_only:
|
||||
parsed_result = list(
|
||||
filter(lambda x: x["type"] == self.key_name, parsed_tool_calls)
|
||||
@ -299,7 +333,9 @@ class PydanticToolsParser(JsonOutputToolsParser):
|
||||
|
||||
# TODO: Support more granular streaming of objects. Currently only streams once all
|
||||
# Pydantic object fields are present.
|
||||
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
||||
def parse_result(
|
||||
self, result: Union[list[Generation], AIMessageV1], *, partial: bool = False
|
||||
) -> Any:
|
||||
"""Parse the result of an LLM call to a list of Pydantic objects.
|
||||
|
||||
Args:
|
||||
@ -337,12 +373,19 @@ class PydanticToolsParser(JsonOutputToolsParser):
|
||||
except (ValidationError, ValueError):
|
||||
if partial:
|
||||
continue
|
||||
has_max_tokens_stop_reason = any(
|
||||
generation.message.response_metadata.get("stop_reason")
|
||||
== "max_tokens"
|
||||
for generation in result
|
||||
if isinstance(generation, ChatGeneration)
|
||||
)
|
||||
has_max_tokens_stop_reason = False
|
||||
if isinstance(result, list):
|
||||
has_max_tokens_stop_reason = any(
|
||||
generation.message.response_metadata.get("stop_reason")
|
||||
== "max_tokens"
|
||||
for generation in result
|
||||
if isinstance(generation, ChatGeneration)
|
||||
)
|
||||
else:
|
||||
# v1 message
|
||||
has_max_tokens_stop_reason = (
|
||||
result.response_metadata.get("stop_reason") == "max_tokens"
|
||||
)
|
||||
if has_max_tokens_stop_reason:
|
||||
logger.exception(_MAX_TOKENS_ERROR)
|
||||
raise
|
||||
|
@ -1,13 +1,14 @@
|
||||
"""Output parsers using Pydantic."""
|
||||
|
||||
import json
|
||||
from typing import Annotated, Generic, Optional
|
||||
from typing import Annotated, Generic, Optional, Union
|
||||
|
||||
import pydantic
|
||||
from pydantic import SkipValidation
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.messages.v1 import AIMessage
|
||||
from langchain_core.output_parsers import JsonOutputParser
|
||||
from langchain_core.outputs import Generation
|
||||
from langchain_core.utils.pydantic import (
|
||||
@ -43,7 +44,7 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
|
||||
return OutputParserException(msg, llm_output=json_string)
|
||||
|
||||
def parse_result(
|
||||
self, result: list[Generation], *, partial: bool = False
|
||||
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
|
||||
) -> Optional[TBaseModel]:
|
||||
"""Parse the result of an LLM call to a pydantic object.
|
||||
|
||||
|
@ -12,6 +12,7 @@ from typing import (
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core.messages import BaseMessage, BaseMessageChunk
|
||||
from langchain_core.messages.v1 import AIMessage, AIMessageChunk
|
||||
from langchain_core.output_parsers.base import BaseOutputParser, T
|
||||
from langchain_core.outputs import (
|
||||
ChatGeneration,
|
||||
@ -32,23 +33,27 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
|
||||
|
||||
def _transform(
|
||||
self,
|
||||
input: Iterator[Union[str, BaseMessage]],
|
||||
input: Iterator[Union[str, BaseMessage, AIMessage]],
|
||||
) -> Iterator[T]:
|
||||
for chunk in input:
|
||||
if isinstance(chunk, BaseMessage):
|
||||
yield self.parse_result([ChatGeneration(message=chunk)])
|
||||
elif isinstance(chunk, AIMessage):
|
||||
yield self.parse_result(chunk)
|
||||
else:
|
||||
yield self.parse_result([Generation(text=chunk)])
|
||||
|
||||
async def _atransform(
|
||||
self,
|
||||
input: AsyncIterator[Union[str, BaseMessage]],
|
||||
input: AsyncIterator[Union[str, BaseMessage, AIMessage]],
|
||||
) -> AsyncIterator[T]:
|
||||
async for chunk in input:
|
||||
if isinstance(chunk, BaseMessage):
|
||||
yield await run_in_executor(
|
||||
None, self.parse_result, [ChatGeneration(message=chunk)]
|
||||
)
|
||||
elif isinstance(chunk, AIMessage):
|
||||
yield await run_in_executor(None, self.parse_result, chunk)
|
||||
else:
|
||||
yield await run_in_executor(
|
||||
None, self.parse_result, [Generation(text=chunk)]
|
||||
@ -57,7 +62,7 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
|
||||
@override
|
||||
def transform(
|
||||
self,
|
||||
input: Iterator[Union[str, BaseMessage]],
|
||||
input: Iterator[Union[str, BaseMessage, AIMessage]],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[T]:
|
||||
@ -78,7 +83,7 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
|
||||
@override
|
||||
async def atransform(
|
||||
self,
|
||||
input: AsyncIterator[Union[str, BaseMessage]],
|
||||
input: AsyncIterator[Union[str, BaseMessage, AIMessage]],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[T]:
|
||||
@ -125,23 +130,42 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]:
|
||||
def _transform(
|
||||
self, input: Iterator[Union[str, BaseMessage, AIMessage]]
|
||||
) -> Iterator[Any]:
|
||||
prev_parsed = None
|
||||
acc_gen: Union[GenerationChunk, ChatGenerationChunk, None] = None
|
||||
acc_gen: Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk, None] = (
|
||||
None
|
||||
)
|
||||
for chunk in input:
|
||||
chunk_gen: Union[GenerationChunk, ChatGenerationChunk]
|
||||
chunk_gen: Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
|
||||
if isinstance(chunk, BaseMessageChunk):
|
||||
chunk_gen = ChatGenerationChunk(message=chunk)
|
||||
elif isinstance(chunk, BaseMessage):
|
||||
chunk_gen = ChatGenerationChunk(
|
||||
message=BaseMessageChunk(**chunk.model_dump())
|
||||
)
|
||||
elif isinstance(chunk, AIMessageChunk):
|
||||
chunk_gen = chunk
|
||||
elif isinstance(chunk, AIMessage):
|
||||
chunk_gen = AIMessageChunk(
|
||||
content=chunk.content,
|
||||
id=chunk.id,
|
||||
name=chunk.name,
|
||||
lc_version=chunk.lc_version,
|
||||
response_metadata=chunk.response_metadata,
|
||||
usage_metadata=chunk.usage_metadata,
|
||||
parsed=chunk.parsed,
|
||||
)
|
||||
else:
|
||||
chunk_gen = GenerationChunk(text=chunk)
|
||||
|
||||
acc_gen = chunk_gen if acc_gen is None else acc_gen + chunk_gen # type: ignore[operator]
|
||||
|
||||
parsed = self.parse_result([acc_gen], partial=True)
|
||||
if isinstance(acc_gen, AIMessageChunk):
|
||||
parsed = self.parse_result(acc_gen, partial=True)
|
||||
else:
|
||||
parsed = self.parse_result([acc_gen], partial=True)
|
||||
if parsed is not None and parsed != prev_parsed:
|
||||
if self.diff:
|
||||
yield self._diff(prev_parsed, parsed)
|
||||
@ -151,24 +175,41 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
|
||||
|
||||
@override
|
||||
async def _atransform(
|
||||
self, input: AsyncIterator[Union[str, BaseMessage]]
|
||||
self, input: AsyncIterator[Union[str, BaseMessage, AIMessage]]
|
||||
) -> AsyncIterator[T]:
|
||||
prev_parsed = None
|
||||
acc_gen: Union[GenerationChunk, ChatGenerationChunk, None] = None
|
||||
acc_gen: Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk, None] = (
|
||||
None
|
||||
)
|
||||
async for chunk in input:
|
||||
chunk_gen: Union[GenerationChunk, ChatGenerationChunk]
|
||||
chunk_gen: Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
|
||||
if isinstance(chunk, BaseMessageChunk):
|
||||
chunk_gen = ChatGenerationChunk(message=chunk)
|
||||
elif isinstance(chunk, BaseMessage):
|
||||
chunk_gen = ChatGenerationChunk(
|
||||
message=BaseMessageChunk(**chunk.model_dump())
|
||||
)
|
||||
elif isinstance(chunk, AIMessageChunk):
|
||||
chunk_gen = chunk
|
||||
elif isinstance(chunk, AIMessage):
|
||||
chunk_gen = AIMessageChunk(
|
||||
content=chunk.content,
|
||||
id=chunk.id,
|
||||
name=chunk.name,
|
||||
lc_version=chunk.lc_version,
|
||||
response_metadata=chunk.response_metadata,
|
||||
usage_metadata=chunk.usage_metadata,
|
||||
parsed=chunk.parsed,
|
||||
)
|
||||
else:
|
||||
chunk_gen = GenerationChunk(text=chunk)
|
||||
|
||||
acc_gen = chunk_gen if acc_gen is None else acc_gen + chunk_gen # type: ignore[operator]
|
||||
|
||||
parsed = await self.aparse_result([acc_gen], partial=True)
|
||||
if isinstance(acc_gen, AIMessageChunk):
|
||||
parsed = await self.aparse_result(acc_gen, partial=True)
|
||||
else:
|
||||
parsed = await self.aparse_result([acc_gen], partial=True)
|
||||
if parsed is not None and parsed != prev_parsed:
|
||||
if self.diff:
|
||||
yield await run_in_executor(None, self._diff, prev_parsed, parsed)
|
||||
|
@ -12,6 +12,8 @@ from typing_extensions import override
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages.utils import convert_from_v1_message
|
||||
from langchain_core.messages.v1 import AIMessage
|
||||
from langchain_core.output_parsers.transform import BaseTransformOutputParser
|
||||
from langchain_core.runnables.utils import AddableDict
|
||||
|
||||
@ -240,21 +242,28 @@ class XMLOutputParser(BaseTransformOutputParser):
|
||||
|
||||
@override
|
||||
def _transform(
|
||||
self, input: Iterator[Union[str, BaseMessage]]
|
||||
self, input: Iterator[Union[str, BaseMessage, AIMessage]]
|
||||
) -> Iterator[AddableDict]:
|
||||
streaming_parser = _StreamingParser(self.parser)
|
||||
for chunk in input:
|
||||
yield from streaming_parser.parse(chunk)
|
||||
if isinstance(chunk, AIMessage):
|
||||
yield from streaming_parser.parse(convert_from_v1_message(chunk))
|
||||
else:
|
||||
yield from streaming_parser.parse(chunk)
|
||||
streaming_parser.close()
|
||||
|
||||
@override
|
||||
async def _atransform(
|
||||
self, input: AsyncIterator[Union[str, BaseMessage]]
|
||||
self, input: AsyncIterator[Union[str, BaseMessage, AIMessage]]
|
||||
) -> AsyncIterator[AddableDict]:
|
||||
streaming_parser = _StreamingParser(self.parser)
|
||||
async for chunk in input:
|
||||
for output in streaming_parser.parse(chunk):
|
||||
yield output
|
||||
if isinstance(chunk, AIMessage):
|
||||
for output in streaming_parser.parse(convert_from_v1_message(chunk)):
|
||||
yield output
|
||||
else:
|
||||
for output in streaming_parser.parse(chunk):
|
||||
yield output
|
||||
streaming_parser.close()
|
||||
|
||||
def _root_to_dict(self, root: ET.Element) -> dict[str, Union[str, list[Any]]]:
|
||||
|
@ -1,10 +1,14 @@
|
||||
"""Module to test base parser implementations."""
|
||||
|
||||
from typing import Union
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.language_models import GenericFakeChatModel
|
||||
from langchain_core.language_models.fake_chat_models import GenericFakeChatModelV1
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages.v1 import AIMessage as AIMessageV1
|
||||
from langchain_core.output_parsers import (
|
||||
BaseGenerationOutputParser,
|
||||
BaseTransformOutputParser,
|
||||
@ -20,7 +24,7 @@ def test_base_generation_parser() -> None:
|
||||
|
||||
@override
|
||||
def parse_result(
|
||||
self, result: list[Generation], *, partial: bool = False
|
||||
self, result: Union[list[Generation], AIMessageV1], *, partial: bool = False
|
||||
) -> str:
|
||||
"""Parse a list of model Generations into a specific format.
|
||||
|
||||
@ -32,16 +36,22 @@ def test_base_generation_parser() -> None:
|
||||
partial: Whether to allow partial results. This is used for parsers
|
||||
that support streaming
|
||||
"""
|
||||
if len(result) != 1:
|
||||
msg = "This output parser can only be used with a single generation."
|
||||
raise NotImplementedError(msg)
|
||||
generation = result[0]
|
||||
if not isinstance(generation, ChatGeneration):
|
||||
# Say that this one only works with chat generations
|
||||
msg = "This output parser can only be used with a chat generation."
|
||||
raise OutputParserException(msg)
|
||||
if isinstance(result, AIMessageV1):
|
||||
content = result.text or ""
|
||||
else:
|
||||
if len(result) != 1:
|
||||
msg = (
|
||||
"This output parser can only be used with a single generation."
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
generation = result[0]
|
||||
if not isinstance(generation, ChatGeneration):
|
||||
# Say that this one only works with chat generations
|
||||
msg = "This output parser can only be used with a chat generation."
|
||||
raise OutputParserException(msg)
|
||||
assert isinstance(generation.message.content, str)
|
||||
content = generation.message.content
|
||||
|
||||
content = generation.message.content
|
||||
assert isinstance(content, str)
|
||||
return content.swapcase()
|
||||
|
||||
@ -49,6 +59,10 @@ def test_base_generation_parser() -> None:
|
||||
chain = model | StrInvertCase()
|
||||
assert chain.invoke("") == "HeLLO"
|
||||
|
||||
model_v1 = GenericFakeChatModelV1(messages=iter([AIMessageV1("hEllo")]))
|
||||
chain_v1 = model_v1 | StrInvertCase()
|
||||
assert chain_v1.invoke("") == "HeLLO"
|
||||
|
||||
|
||||
def test_base_transform_output_parser() -> None:
|
||||
"""Test base transform output parser."""
|
||||
@ -62,7 +76,7 @@ def test_base_transform_output_parser() -> None:
|
||||
|
||||
@override
|
||||
def parse_result(
|
||||
self, result: list[Generation], *, partial: bool = False
|
||||
self, result: Union[list[Generation], AIMessageV1], *, partial: bool = False
|
||||
) -> str:
|
||||
"""Parse a list of model Generations into a specific format.
|
||||
|
||||
@ -74,15 +88,22 @@ def test_base_transform_output_parser() -> None:
|
||||
partial: Whether to allow partial results. This is used for parsers
|
||||
that support streaming
|
||||
"""
|
||||
if len(result) != 1:
|
||||
msg = "This output parser can only be used with a single generation."
|
||||
raise NotImplementedError(msg)
|
||||
generation = result[0]
|
||||
if not isinstance(generation, ChatGeneration):
|
||||
# Say that this one only works with chat generations
|
||||
msg = "This output parser can only be used with a chat generation."
|
||||
raise OutputParserException(msg)
|
||||
content = generation.message.content
|
||||
if isinstance(result, AIMessageV1):
|
||||
content = result.text or ""
|
||||
else:
|
||||
if len(result) != 1:
|
||||
msg = (
|
||||
"This output parser can only be used with a single generation."
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
generation = result[0]
|
||||
if not isinstance(generation, ChatGeneration):
|
||||
# Say that this one only works with chat generations
|
||||
msg = "This output parser can only be used with a chat generation."
|
||||
raise OutputParserException(msg)
|
||||
assert isinstance(generation.message.content, str)
|
||||
content = generation.message.content
|
||||
|
||||
assert isinstance(content, str)
|
||||
return content.swapcase()
|
||||
|
||||
@ -91,3 +112,8 @@ def test_base_transform_output_parser() -> None:
|
||||
# inputs to models are ignored, response is hard-coded in model definition
|
||||
chunks = list(chain.stream(""))
|
||||
assert chunks == ["HELLO", " ", "WORLD"]
|
||||
|
||||
model_v1 = GenericFakeChatModelV1(message_chunks=["hello", " ", "world"])
|
||||
chain_v1 = model_v1 | StrInvertCase()
|
||||
chunks = list(chain_v1.stream(""))
|
||||
assert chunks == ["HELLO", " ", "WORLD"]
|
||||
|
@ -10,6 +10,8 @@ from langchain_core.messages import (
|
||||
BaseMessage,
|
||||
ToolCallChunk,
|
||||
)
|
||||
from langchain_core.messages.v1 import AIMessage as AIMessageV1
|
||||
from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1
|
||||
from langchain_core.output_parsers.openai_tools import (
|
||||
JsonOutputKeyToolsParser,
|
||||
JsonOutputToolsParser,
|
||||
@ -331,6 +333,14 @@ for message in STREAMED_MESSAGES:
|
||||
STREAMED_MESSAGES_WITH_TOOL_CALLS.append(message)
|
||||
|
||||
|
||||
STREAMED_MESSAGES_V1 = [
|
||||
AIMessageChunkV1(
|
||||
content=[],
|
||||
tool_call_chunks=chunk.tool_call_chunks,
|
||||
)
|
||||
for chunk in STREAMED_MESSAGES_WITH_TOOL_CALLS
|
||||
]
|
||||
|
||||
EXPECTED_STREAMED_JSON = [
|
||||
{},
|
||||
{"names": ["suz"]},
|
||||
@ -398,6 +408,19 @@ def test_partial_json_output_parser(*, use_tool_calls: bool) -> None:
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test_partial_json_output_parser_v1() -> None:
|
||||
def input_iter(_: Any) -> Iterator[AIMessageChunkV1]:
|
||||
yield from STREAMED_MESSAGES_V1
|
||||
|
||||
chain = input_iter | JsonOutputToolsParser()
|
||||
|
||||
actual = list(chain.stream(None))
|
||||
expected: list = [[]] + [
|
||||
[{"type": "NameCollector", "args": chunk}] for chunk in EXPECTED_STREAMED_JSON
|
||||
]
|
||||
assert actual == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
||||
async def test_partial_json_output_parser_async(*, use_tool_calls: bool) -> None:
|
||||
input_iter = _get_aiter(use_tool_calls=use_tool_calls)
|
||||
@ -410,6 +433,20 @@ async def test_partial_json_output_parser_async(*, use_tool_calls: bool) -> None
|
||||
assert actual == expected
|
||||
|
||||
|
||||
async def test_partial_json_output_parser_async_v1() -> None:
|
||||
async def input_iter(_: Any) -> AsyncIterator[AIMessageChunkV1]:
|
||||
for msg in STREAMED_MESSAGES_V1:
|
||||
yield msg
|
||||
|
||||
chain = input_iter | JsonOutputToolsParser()
|
||||
|
||||
actual = [p async for p in chain.astream(None)]
|
||||
expected: list = [[]] + [
|
||||
[{"type": "NameCollector", "args": chunk}] for chunk in EXPECTED_STREAMED_JSON
|
||||
]
|
||||
assert actual == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
||||
def test_partial_json_output_parser_return_id(*, use_tool_calls: bool) -> None:
|
||||
input_iter = _get_iter(use_tool_calls=use_tool_calls)
|
||||
@ -429,6 +466,26 @@ def test_partial_json_output_parser_return_id(*, use_tool_calls: bool) -> None:
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test_partial_json_output_parser_return_id_v1() -> None:
|
||||
def input_iter(_: Any) -> Iterator[AIMessageChunkV1]:
|
||||
yield from STREAMED_MESSAGES_V1
|
||||
|
||||
chain = input_iter | JsonOutputToolsParser(return_id=True)
|
||||
|
||||
actual = list(chain.stream(None))
|
||||
expected: list = [[]] + [
|
||||
[
|
||||
{
|
||||
"type": "NameCollector",
|
||||
"args": chunk,
|
||||
"id": "call_OwL7f5PEPJTYzw9sQlNJtCZl",
|
||||
}
|
||||
]
|
||||
for chunk in EXPECTED_STREAMED_JSON
|
||||
]
|
||||
assert actual == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
||||
def test_partial_json_output_key_parser(*, use_tool_calls: bool) -> None:
|
||||
input_iter = _get_iter(use_tool_calls=use_tool_calls)
|
||||
@ -439,6 +496,17 @@ def test_partial_json_output_key_parser(*, use_tool_calls: bool) -> None:
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test_partial_json_output_key_parser_v1() -> None:
|
||||
def input_iter(_: Any) -> Iterator[AIMessageChunkV1]:
|
||||
yield from STREAMED_MESSAGES_V1
|
||||
|
||||
chain = input_iter | JsonOutputKeyToolsParser(key_name="NameCollector")
|
||||
|
||||
actual = list(chain.stream(None))
|
||||
expected: list = [[]] + [[chunk] for chunk in EXPECTED_STREAMED_JSON]
|
||||
assert actual == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
||||
async def test_partial_json_output_parser_key_async(*, use_tool_calls: bool) -> None:
|
||||
input_iter = _get_aiter(use_tool_calls=use_tool_calls)
|
||||
@ -450,6 +518,18 @@ async def test_partial_json_output_parser_key_async(*, use_tool_calls: bool) ->
|
||||
assert actual == expected
|
||||
|
||||
|
||||
async def test_partial_json_output_parser_key_async_v1() -> None:
|
||||
async def input_iter(_: Any) -> AsyncIterator[AIMessageChunkV1]:
|
||||
for msg in STREAMED_MESSAGES_V1:
|
||||
yield msg
|
||||
|
||||
chain = input_iter | JsonOutputKeyToolsParser(key_name="NameCollector")
|
||||
|
||||
actual = [p async for p in chain.astream(None)]
|
||||
expected: list = [[]] + [[chunk] for chunk in EXPECTED_STREAMED_JSON]
|
||||
assert actual == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
||||
def test_partial_json_output_key_parser_first_only(*, use_tool_calls: bool) -> None:
|
||||
input_iter = _get_iter(use_tool_calls=use_tool_calls)
|
||||
@ -461,6 +541,17 @@ def test_partial_json_output_key_parser_first_only(*, use_tool_calls: bool) -> N
|
||||
assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON
|
||||
|
||||
|
||||
def test_partial_json_output_key_parser_first_only_v1() -> None:
|
||||
def input_iter(_: Any) -> Iterator[AIMessageChunkV1]:
|
||||
yield from STREAMED_MESSAGES_V1
|
||||
|
||||
chain = input_iter | JsonOutputKeyToolsParser(
|
||||
key_name="NameCollector", first_tool_only=True
|
||||
)
|
||||
|
||||
assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
||||
async def test_partial_json_output_parser_key_async_first_only(
|
||||
*,
|
||||
@ -475,6 +566,18 @@ async def test_partial_json_output_parser_key_async_first_only(
|
||||
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON
|
||||
|
||||
|
||||
async def test_partial_json_output_parser_key_async_first_only_v1() -> None:
|
||||
async def input_iter(_: Any) -> AsyncIterator[AIMessageChunkV1]:
|
||||
for msg in STREAMED_MESSAGES_V1:
|
||||
yield msg
|
||||
|
||||
chain = input_iter | JsonOutputKeyToolsParser(
|
||||
key_name="NameCollector", first_tool_only=True
|
||||
)
|
||||
|
||||
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
||||
def test_json_output_key_tools_parser_multiple_tools_first_only(
|
||||
*, use_tool_calls: bool
|
||||
@ -531,6 +634,42 @@ def test_json_output_key_tools_parser_multiple_tools_first_only(
|
||||
assert output_no_id == {"a": 1}
|
||||
|
||||
|
||||
def test_json_output_key_tools_parser_multiple_tools_first_only_v1() -> None:
|
||||
message = AIMessageV1(
|
||||
content=[],
|
||||
tool_calls=[
|
||||
{
|
||||
"type": "tool_call",
|
||||
"id": "call_other",
|
||||
"name": "other",
|
||||
"args": {"b": 2},
|
||||
},
|
||||
{"type": "tool_call", "id": "call_func", "name": "func", "args": {"a": 1}},
|
||||
],
|
||||
)
|
||||
|
||||
# Test with return_id=True
|
||||
parser = JsonOutputKeyToolsParser(
|
||||
key_name="func", first_tool_only=True, return_id=True
|
||||
)
|
||||
output = parser.parse_result(message)
|
||||
|
||||
# Should return the func tool call, not None
|
||||
assert output is not None
|
||||
assert output["type"] == "func"
|
||||
assert output["args"] == {"a": 1}
|
||||
assert "id" in output
|
||||
|
||||
# Test with return_id=False
|
||||
parser_no_id = JsonOutputKeyToolsParser(
|
||||
key_name="func", first_tool_only=True, return_id=False
|
||||
)
|
||||
output_no_id = parser_no_id.parse_result(message)
|
||||
|
||||
# Should return just the args
|
||||
assert output_no_id == {"a": 1}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
||||
def test_json_output_key_tools_parser_multiple_tools_no_match(
|
||||
*, use_tool_calls: bool
|
||||
@ -583,6 +722,44 @@ def test_json_output_key_tools_parser_multiple_tools_no_match(
|
||||
assert output_no_id is None
|
||||
|
||||
|
||||
def test_json_output_key_tools_parser_multiple_tools_no_match_v1() -> None:
|
||||
message = AIMessageV1(
|
||||
content=[],
|
||||
tool_calls=[
|
||||
{
|
||||
"type": "tool_call",
|
||||
"id": "call_other",
|
||||
"name": "other",
|
||||
"args": {"b": 2},
|
||||
},
|
||||
{
|
||||
"type": "tool_call",
|
||||
"id": "call_another",
|
||||
"name": "another",
|
||||
"args": {"c": 3},
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
# Test with return_id=True, first_tool_only=True
|
||||
parser = JsonOutputKeyToolsParser(
|
||||
key_name="nonexistent", first_tool_only=True, return_id=True
|
||||
)
|
||||
output = parser.parse_result(message)
|
||||
|
||||
# Should return None when no matches
|
||||
assert output is None
|
||||
|
||||
# Test with return_id=False, first_tool_only=True
|
||||
parser_no_id = JsonOutputKeyToolsParser(
|
||||
key_name="nonexistent", first_tool_only=True, return_id=False
|
||||
)
|
||||
output_no_id = parser_no_id.parse_result(message)
|
||||
|
||||
# Should return None when no matches
|
||||
assert output_no_id is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
||||
def test_json_output_key_tools_parser_multiple_matching_tools(
|
||||
*, use_tool_calls: bool
|
||||
@ -643,6 +820,42 @@ def test_json_output_key_tools_parser_multiple_matching_tools(
|
||||
assert output_all[1]["args"] == {"a": 3}
|
||||
|
||||
|
||||
def test_json_output_key_tools_parser_multiple_matching_tools_v1() -> None:
|
||||
message = AIMessageV1(
|
||||
content=[],
|
||||
tool_calls=[
|
||||
{"type": "tool_call", "id": "call_func1", "name": "func", "args": {"a": 1}},
|
||||
{
|
||||
"type": "tool_call",
|
||||
"id": "call_other",
|
||||
"name": "other",
|
||||
"args": {"b": 2},
|
||||
},
|
||||
{"type": "tool_call", "id": "call_func2", "name": "func", "args": {"a": 3}},
|
||||
],
|
||||
)
|
||||
|
||||
# Test with first_tool_only=True - should return first matching
|
||||
parser = JsonOutputKeyToolsParser(
|
||||
key_name="func", first_tool_only=True, return_id=True
|
||||
)
|
||||
output = parser.parse_result(message)
|
||||
|
||||
assert output is not None
|
||||
assert output["type"] == "func"
|
||||
assert output["args"] == {"a": 1} # First matching tool call
|
||||
|
||||
# Test with first_tool_only=False - should return all matching
|
||||
parser_all = JsonOutputKeyToolsParser(
|
||||
key_name="func", first_tool_only=False, return_id=True
|
||||
)
|
||||
output_all = parser_all.parse_result(message)
|
||||
|
||||
assert len(output_all) == 2
|
||||
assert output_all[0]["args"] == {"a": 1}
|
||||
assert output_all[1]["args"] == {"a": 3}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
||||
def test_json_output_key_tools_parser_empty_results(*, use_tool_calls: bool) -> None:
|
||||
def create_message() -> AIMessage:
|
||||
@ -671,6 +884,35 @@ def test_json_output_key_tools_parser_empty_results(*, use_tool_calls: bool) ->
|
||||
assert output_all == []
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"empty_message",
|
||||
[
|
||||
AIMessageV1(content=[], tool_calls=[]),
|
||||
AIMessageV1(content="", tool_calls=[]),
|
||||
],
|
||||
)
|
||||
def test_json_output_key_tools_parser_empty_results_v1(
|
||||
empty_message: AIMessageV1,
|
||||
) -> None:
|
||||
# Test with first_tool_only=True
|
||||
parser = JsonOutputKeyToolsParser(
|
||||
key_name="func", first_tool_only=True, return_id=True
|
||||
)
|
||||
output = parser.parse_result(empty_message)
|
||||
|
||||
# Should return None for empty results
|
||||
assert output is None
|
||||
|
||||
# Test with first_tool_only=False
|
||||
parser_all = JsonOutputKeyToolsParser(
|
||||
key_name="func", first_tool_only=False, return_id=True
|
||||
)
|
||||
output_all = parser_all.parse_result(empty_message)
|
||||
|
||||
# Should return empty list for empty results
|
||||
assert output_all == []
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
||||
def test_json_output_key_tools_parser_parameter_combinations(
|
||||
*, use_tool_calls: bool
|
||||
@ -746,6 +988,56 @@ def test_json_output_key_tools_parser_parameter_combinations(
|
||||
assert output4 == [{"a": 1}, {"a": 3}]
|
||||
|
||||
|
||||
def test_json_output_key_tools_parser_parameter_combinations_v1() -> None:
|
||||
"""Test all parameter combinations of JsonOutputKeyToolsParser."""
|
||||
result = AIMessageV1(
|
||||
content=[],
|
||||
tool_calls=[
|
||||
{
|
||||
"type": "tool_call",
|
||||
"id": "call_other",
|
||||
"name": "other",
|
||||
"args": {"b": 2},
|
||||
},
|
||||
{"type": "tool_call", "id": "call_func1", "name": "func", "args": {"a": 1}},
|
||||
{"type": "tool_call", "id": "call_func2", "name": "func", "args": {"a": 3}},
|
||||
],
|
||||
)
|
||||
|
||||
# Test: first_tool_only=True, return_id=True
|
||||
parser1 = JsonOutputKeyToolsParser(
|
||||
key_name="func", first_tool_only=True, return_id=True
|
||||
)
|
||||
output1 = parser1.parse_result(result)
|
||||
assert output1["type"] == "func"
|
||||
assert output1["args"] == {"a": 1}
|
||||
assert "id" in output1
|
||||
|
||||
# Test: first_tool_only=True, return_id=False
|
||||
parser2 = JsonOutputKeyToolsParser(
|
||||
key_name="func", first_tool_only=True, return_id=False
|
||||
)
|
||||
output2 = parser2.parse_result(result)
|
||||
assert output2 == {"a": 1}
|
||||
|
||||
# Test: first_tool_only=False, return_id=True
|
||||
parser3 = JsonOutputKeyToolsParser(
|
||||
key_name="func", first_tool_only=False, return_id=True
|
||||
)
|
||||
output3 = parser3.parse_result(result)
|
||||
assert len(output3) == 2
|
||||
assert all("id" in item for item in output3)
|
||||
assert output3[0]["args"] == {"a": 1}
|
||||
assert output3[1]["args"] == {"a": 3}
|
||||
|
||||
# Test: first_tool_only=False, return_id=False
|
||||
parser4 = JsonOutputKeyToolsParser(
|
||||
key_name="func", first_tool_only=False, return_id=False
|
||||
)
|
||||
output4 = parser4.parse_result(result)
|
||||
assert output4 == [{"a": 1}, {"a": 3}]
|
||||
|
||||
|
||||
class Person(BaseModel):
|
||||
age: int
|
||||
hair_color: str
|
||||
@ -788,6 +1080,18 @@ def test_partial_pydantic_output_parser() -> None:
|
||||
assert actual == EXPECTED_STREAMED_PYDANTIC
|
||||
|
||||
|
||||
def test_partial_pydantic_output_parser_v1() -> None:
|
||||
def input_iter(_: Any) -> Iterator[AIMessageChunkV1]:
|
||||
yield from STREAMED_MESSAGES_V1
|
||||
|
||||
chain = input_iter | PydanticToolsParser(
|
||||
tools=[NameCollector], first_tool_only=True
|
||||
)
|
||||
|
||||
actual = list(chain.stream(None))
|
||||
assert actual == EXPECTED_STREAMED_PYDANTIC
|
||||
|
||||
|
||||
async def test_partial_pydantic_output_parser_async() -> None:
|
||||
for use_tool_calls in [False, True]:
|
||||
input_iter = _get_aiter(use_tool_calls=use_tool_calls)
|
||||
@ -800,6 +1104,19 @@ async def test_partial_pydantic_output_parser_async() -> None:
|
||||
assert actual == EXPECTED_STREAMED_PYDANTIC
|
||||
|
||||
|
||||
async def test_partial_pydantic_output_parser_async_v1() -> None:
|
||||
async def input_iter(_: Any) -> AsyncIterator[AIMessageChunkV1]:
|
||||
for msg in STREAMED_MESSAGES_V1:
|
||||
yield msg
|
||||
|
||||
chain = input_iter | PydanticToolsParser(
|
||||
tools=[NameCollector], first_tool_only=True
|
||||
)
|
||||
|
||||
actual = [p async for p in chain.astream(None)]
|
||||
assert actual == EXPECTED_STREAMED_PYDANTIC
|
||||
|
||||
|
||||
def test_parse_with_different_pydantic_2_v1() -> None:
|
||||
"""Test with pydantic.v1.BaseModel from pydantic 2."""
|
||||
import pydantic
|
||||
@ -870,20 +1187,22 @@ def test_parse_with_different_pydantic_2_proper() -> None:
|
||||
|
||||
def test_max_tokens_error(caplog: Any) -> None:
|
||||
parser = PydanticToolsParser(tools=[NameCollector], first_tool_only=True)
|
||||
message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "call_OwL7f5PE",
|
||||
"name": "NameCollector",
|
||||
"args": {"names": ["suz", "jerm"]},
|
||||
}
|
||||
],
|
||||
response_metadata={"stop_reason": "max_tokens"},
|
||||
)
|
||||
with pytest.raises(ValidationError):
|
||||
_ = parser.invoke(message)
|
||||
assert any(
|
||||
"`max_tokens` stop reason" in msg and record.levelname == "ERROR"
|
||||
for record, msg in zip(caplog.records, caplog.messages)
|
||||
)
|
||||
for msg_class in [AIMessage, AIMessageV1]:
|
||||
message = msg_class(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"type": "tool_call",
|
||||
"id": "call_OwL7f5PE",
|
||||
"name": "NameCollector",
|
||||
"args": {"names": ["suz", "jerm"]},
|
||||
}
|
||||
],
|
||||
response_metadata={"stop_reason": "max_tokens"},
|
||||
)
|
||||
with pytest.raises(ValidationError):
|
||||
_ = parser.invoke(message)
|
||||
assert any(
|
||||
"`max_tokens` stop reason" in msg and record.levelname == "ERROR"
|
||||
for record, msg in zip(caplog.records, caplog.messages)
|
||||
)
|
||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -8,6 +8,8 @@ from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
)
|
||||
from langchain_core.messages.utils import convert_from_v1_message
|
||||
from langchain_core.messages.v1 import AIMessage as AIMessageV1
|
||||
from langchain_core.outputs import ChatGeneration, Generation
|
||||
from typing_extensions import override
|
||||
|
||||
@ -83,10 +85,12 @@ class OpenAIFunctionsAgentOutputParser(AgentOutputParser):
|
||||
@override
|
||||
def parse_result(
|
||||
self,
|
||||
result: list[Generation],
|
||||
result: Union[list[Generation], AIMessageV1],
|
||||
*,
|
||||
partial: bool = False,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
if isinstance(result, AIMessageV1):
|
||||
result = [ChatGeneration(message=convert_from_v1_message(result))]
|
||||
if not isinstance(result[0], ChatGeneration):
|
||||
msg = "This output parser only works on ChatGeneration output"
|
||||
raise ValueError(msg) # noqa: TRY004
|
||||
|
@ -2,6 +2,8 @@ from typing import Union
|
||||
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages.utils import convert_from_v1_message
|
||||
from langchain_core.messages.v1 import AIMessage as AIMessageV1
|
||||
from langchain_core.outputs import ChatGeneration, Generation
|
||||
from typing_extensions import override
|
||||
|
||||
@ -57,10 +59,12 @@ class OpenAIToolsAgentOutputParser(MultiActionAgentOutputParser):
|
||||
@override
|
||||
def parse_result(
|
||||
self,
|
||||
result: list[Generation],
|
||||
result: Union[list[Generation], AIMessageV1],
|
||||
*,
|
||||
partial: bool = False,
|
||||
) -> Union[list[AgentAction], AgentFinish]:
|
||||
if isinstance(result, AIMessageV1):
|
||||
result = [ChatGeneration(message=convert_from_v1_message(result))]
|
||||
if not isinstance(result[0], ChatGeneration):
|
||||
msg = "This output parser only works on ChatGeneration output"
|
||||
raise ValueError(msg) # noqa: TRY004
|
||||
|
@ -9,6 +9,8 @@ from langchain_core.messages import (
|
||||
BaseMessage,
|
||||
ToolCall,
|
||||
)
|
||||
from langchain_core.messages.utils import convert_from_v1_message
|
||||
from langchain_core.messages.v1 import AIMessage as AIMessageV1
|
||||
from langchain_core.outputs import ChatGeneration, Generation
|
||||
from typing_extensions import override
|
||||
|
||||
@ -101,10 +103,12 @@ class ToolsAgentOutputParser(MultiActionAgentOutputParser):
|
||||
@override
|
||||
def parse_result(
|
||||
self,
|
||||
result: list[Generation],
|
||||
result: Union[list[Generation], AIMessageV1],
|
||||
*,
|
||||
partial: bool = False,
|
||||
) -> Union[list[AgentAction], AgentFinish]:
|
||||
if isinstance(result, AIMessageV1):
|
||||
result = [ChatGeneration(message=convert_from_v1_message(result))]
|
||||
if not isinstance(result[0], ChatGeneration):
|
||||
msg = "This output parser only works on ChatGeneration output"
|
||||
raise ValueError(msg) # noqa: TRY004
|
||||
|
Loading…
Reference in New Issue
Block a user