mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-16 16:11:02 +00:00
Merge branch 'standard_outputs_copy' into mdrxy/ollama_v1
This commit is contained in:
commit
588fe46601
@ -3,7 +3,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from collections.abc import AsyncIterator, Iterator
|
from collections.abc import AsyncIterator, Iterable, Iterator
|
||||||
from typing import Any, Optional, Union, cast
|
from typing import Any, Optional, Union, cast
|
||||||
|
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
@ -13,7 +13,11 @@ from langchain_core.callbacks import (
|
|||||||
CallbackManagerForLLMRun,
|
CallbackManagerForLLMRun,
|
||||||
)
|
)
|
||||||
from langchain_core.language_models.chat_models import BaseChatModel, SimpleChatModel
|
from langchain_core.language_models.chat_models import BaseChatModel, SimpleChatModel
|
||||||
|
from langchain_core.language_models.v1.chat_models import BaseChatModelV1
|
||||||
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
|
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
|
||||||
|
from langchain_core.messages.v1 import AIMessage as AIMessageV1
|
||||||
|
from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1
|
||||||
|
from langchain_core.messages.v1 import MessageV1
|
||||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
from langchain_core.runnables import RunnableConfig
|
from langchain_core.runnables import RunnableConfig
|
||||||
|
|
||||||
@ -367,3 +371,69 @@ class ParrotFakeChatModel(BaseChatModel):
|
|||||||
@property
|
@property
|
||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
return "parrot-fake-chat-model"
|
return "parrot-fake-chat-model"
|
||||||
|
|
||||||
|
|
||||||
|
class GenericFakeChatModelV1(BaseChatModelV1):
|
||||||
|
"""Generic fake chat model that can be used to test the chat model interface."""
|
||||||
|
|
||||||
|
messages: Optional[Iterator[Union[AIMessageV1, str]]] = None
|
||||||
|
message_chunks: Optional[Iterable[Union[AIMessageChunkV1, str]]] = None
|
||||||
|
|
||||||
|
@override
|
||||||
|
def _invoke(
|
||||||
|
self,
|
||||||
|
messages: list[MessageV1],
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AIMessageV1:
|
||||||
|
"""Top Level call."""
|
||||||
|
if self.messages is None:
|
||||||
|
error_msg = "Messages iterator is not set."
|
||||||
|
raise ValueError(error_msg)
|
||||||
|
message = next(self.messages)
|
||||||
|
return AIMessageV1(content=message) if isinstance(message, str) else message
|
||||||
|
|
||||||
|
@override
|
||||||
|
def _stream(
|
||||||
|
self,
|
||||||
|
messages: list[MessageV1],
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[AIMessageChunkV1]:
|
||||||
|
"""Top Level call."""
|
||||||
|
if self.message_chunks is None:
|
||||||
|
error_msg = "Message chunks iterator is not set."
|
||||||
|
raise ValueError(error_msg)
|
||||||
|
for chunk in self.message_chunks:
|
||||||
|
if isinstance(chunk, str):
|
||||||
|
yield AIMessageChunkV1(chunk)
|
||||||
|
else:
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "generic-fake-chat-model"
|
||||||
|
|
||||||
|
|
||||||
|
class ParrotFakeChatModelV1(BaseChatModelV1):
|
||||||
|
"""Generic fake chat model that can be used to test the chat model interface.
|
||||||
|
|
||||||
|
* Chat model should be usable in both sync and async tests
|
||||||
|
"""
|
||||||
|
|
||||||
|
@override
|
||||||
|
def _invoke(
|
||||||
|
self,
|
||||||
|
messages: list[MessageV1],
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AIMessageV1:
|
||||||
|
"""Top Level call."""
|
||||||
|
if isinstance(messages[-1], AIMessageV1):
|
||||||
|
return messages[-1]
|
||||||
|
return AIMessageV1(content=messages[-1].content)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "parrot-fake-chat-model"
|
||||||
|
@ -3,9 +3,10 @@
|
|||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Literal, Optional, TypedDict, Union, cast, get_args
|
from typing import Any, Literal, Optional, Union, cast, get_args
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
import langchain_core.messages.content_blocks as types
|
import langchain_core.messages.content_blocks as types
|
||||||
from langchain_core.messages.ai import _LC_ID_PREFIX, UsageMetadata, add_usage
|
from langchain_core.messages.ai import _LC_ID_PREFIX, UsageMetadata, add_usage
|
||||||
@ -329,7 +330,11 @@ class AIMessageChunk(AIMessage):
|
|||||||
@property
|
@property
|
||||||
def reasoning(self) -> Optional[str]:
|
def reasoning(self) -> Optional[str]:
|
||||||
"""Extract all reasoning text from the AI message as a string."""
|
"""Extract all reasoning text from the AI message as a string."""
|
||||||
text_blocks = [block for block in self.content if block["type"] == "reasoning"]
|
text_blocks = [
|
||||||
|
block
|
||||||
|
for block in self.content
|
||||||
|
if block["type"] == "reasoning" and "reasoning" in block
|
||||||
|
]
|
||||||
if text_blocks:
|
if text_blocks:
|
||||||
return "".join(block["reasoning"] for block in text_blocks)
|
return "".join(block["reasoning"] for block in text_blocks)
|
||||||
return None
|
return None
|
||||||
|
@ -11,12 +11,14 @@ from typing import (
|
|||||||
Optional,
|
Optional,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_core.language_models import LanguageModelOutput
|
from langchain_core.language_models import LanguageModelOutput
|
||||||
from langchain_core.messages import AnyMessage, BaseMessage
|
from langchain_core.messages import AnyMessage, BaseMessage
|
||||||
|
from langchain_core.messages.v1 import AIMessage, MessageV1, MessageV1Types
|
||||||
from langchain_core.outputs import ChatGeneration, Generation
|
from langchain_core.outputs import ChatGeneration, Generation
|
||||||
from langchain_core.runnables import Runnable, RunnableConfig, RunnableSerializable
|
from langchain_core.runnables import Runnable, RunnableConfig, RunnableSerializable
|
||||||
from langchain_core.runnables.config import run_in_executor
|
from langchain_core.runnables.config import run_in_executor
|
||||||
@ -25,14 +27,16 @@ if TYPE_CHECKING:
|
|||||||
from langchain_core.prompt_values import PromptValue
|
from langchain_core.prompt_values import PromptValue
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
OutputParserLike = Runnable[LanguageModelOutput, T]
|
OutputParserLike = Runnable[Union[LanguageModelOutput, AIMessage], T]
|
||||||
|
|
||||||
|
|
||||||
class BaseLLMOutputParser(ABC, Generic[T]):
|
class BaseLLMOutputParser(ABC, Generic[T]):
|
||||||
"""Abstract base class for parsing the outputs of a model."""
|
"""Abstract base class for parsing the outputs of a model."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def parse_result(self, result: list[Generation], *, partial: bool = False) -> T:
|
def parse_result(
|
||||||
|
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
|
||||||
|
) -> T:
|
||||||
"""Parse a list of candidate model Generations into a specific format.
|
"""Parse a list of candidate model Generations into a specific format.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -46,7 +50,7 @@ class BaseLLMOutputParser(ABC, Generic[T]):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
async def aparse_result(
|
async def aparse_result(
|
||||||
self, result: list[Generation], *, partial: bool = False
|
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
|
||||||
) -> T:
|
) -> T:
|
||||||
"""Async parse a list of candidate model Generations into a specific format.
|
"""Async parse a list of candidate model Generations into a specific format.
|
||||||
|
|
||||||
@ -71,7 +75,7 @@ class BaseGenerationOutputParser(
|
|||||||
@override
|
@override
|
||||||
def InputType(self) -> Any:
|
def InputType(self) -> Any:
|
||||||
"""Return the input type for the parser."""
|
"""Return the input type for the parser."""
|
||||||
return Union[str, AnyMessage]
|
return Union[str, AnyMessage, MessageV1]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@override
|
@override
|
||||||
@ -84,7 +88,7 @@ class BaseGenerationOutputParser(
|
|||||||
@override
|
@override
|
||||||
def invoke(
|
def invoke(
|
||||||
self,
|
self,
|
||||||
input: Union[str, BaseMessage],
|
input: Union[str, BaseMessage, MessageV1],
|
||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> T:
|
) -> T:
|
||||||
@ -97,9 +101,16 @@ class BaseGenerationOutputParser(
|
|||||||
config,
|
config,
|
||||||
run_type="parser",
|
run_type="parser",
|
||||||
)
|
)
|
||||||
|
if isinstance(input, MessageV1Types):
|
||||||
|
return self._call_with_config(
|
||||||
|
lambda inner_input: self.parse_result(inner_input),
|
||||||
|
input,
|
||||||
|
config,
|
||||||
|
run_type="parser",
|
||||||
|
)
|
||||||
return self._call_with_config(
|
return self._call_with_config(
|
||||||
lambda inner_input: self.parse_result([Generation(text=inner_input)]),
|
lambda inner_input: self.parse_result([Generation(text=inner_input)]),
|
||||||
input,
|
cast("str", input),
|
||||||
config,
|
config,
|
||||||
run_type="parser",
|
run_type="parser",
|
||||||
)
|
)
|
||||||
@ -120,6 +131,13 @@ class BaseGenerationOutputParser(
|
|||||||
config,
|
config,
|
||||||
run_type="parser",
|
run_type="parser",
|
||||||
)
|
)
|
||||||
|
if isinstance(input, MessageV1Types):
|
||||||
|
return await self._acall_with_config(
|
||||||
|
lambda inner_input: self.aparse_result(inner_input),
|
||||||
|
input,
|
||||||
|
config,
|
||||||
|
run_type="parser",
|
||||||
|
)
|
||||||
return await self._acall_with_config(
|
return await self._acall_with_config(
|
||||||
lambda inner_input: self.aparse_result([Generation(text=inner_input)]),
|
lambda inner_input: self.aparse_result([Generation(text=inner_input)]),
|
||||||
input,
|
input,
|
||||||
@ -129,7 +147,7 @@ class BaseGenerationOutputParser(
|
|||||||
|
|
||||||
|
|
||||||
class BaseOutputParser(
|
class BaseOutputParser(
|
||||||
BaseLLMOutputParser, RunnableSerializable[LanguageModelOutput, T]
|
BaseLLMOutputParser, RunnableSerializable[Union[LanguageModelOutput, AIMessage], T]
|
||||||
):
|
):
|
||||||
"""Base class to parse the output of an LLM call.
|
"""Base class to parse the output of an LLM call.
|
||||||
|
|
||||||
@ -162,7 +180,7 @@ class BaseOutputParser(
|
|||||||
@override
|
@override
|
||||||
def InputType(self) -> Any:
|
def InputType(self) -> Any:
|
||||||
"""Return the input type for the parser."""
|
"""Return the input type for the parser."""
|
||||||
return Union[str, AnyMessage]
|
return Union[str, AnyMessage, MessageV1]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@override
|
@override
|
||||||
@ -189,7 +207,7 @@ class BaseOutputParser(
|
|||||||
@override
|
@override
|
||||||
def invoke(
|
def invoke(
|
||||||
self,
|
self,
|
||||||
input: Union[str, BaseMessage],
|
input: Union[str, BaseMessage, MessageV1],
|
||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> T:
|
) -> T:
|
||||||
@ -202,9 +220,16 @@ class BaseOutputParser(
|
|||||||
config,
|
config,
|
||||||
run_type="parser",
|
run_type="parser",
|
||||||
)
|
)
|
||||||
|
if isinstance(input, MessageV1Types):
|
||||||
|
return self._call_with_config(
|
||||||
|
lambda inner_input: self.parse_result(inner_input),
|
||||||
|
input,
|
||||||
|
config,
|
||||||
|
run_type="parser",
|
||||||
|
)
|
||||||
return self._call_with_config(
|
return self._call_with_config(
|
||||||
lambda inner_input: self.parse_result([Generation(text=inner_input)]),
|
lambda inner_input: self.parse_result([Generation(text=inner_input)]),
|
||||||
input,
|
cast("str", input),
|
||||||
config,
|
config,
|
||||||
run_type="parser",
|
run_type="parser",
|
||||||
)
|
)
|
||||||
@ -212,7 +237,7 @@ class BaseOutputParser(
|
|||||||
@override
|
@override
|
||||||
async def ainvoke(
|
async def ainvoke(
|
||||||
self,
|
self,
|
||||||
input: Union[str, BaseMessage],
|
input: Union[str, BaseMessage, MessageV1],
|
||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> T:
|
) -> T:
|
||||||
@ -225,15 +250,24 @@ class BaseOutputParser(
|
|||||||
config,
|
config,
|
||||||
run_type="parser",
|
run_type="parser",
|
||||||
)
|
)
|
||||||
|
if isinstance(input, MessageV1Types):
|
||||||
|
return await self._acall_with_config(
|
||||||
|
lambda inner_input: self.aparse_result(inner_input),
|
||||||
|
input,
|
||||||
|
config,
|
||||||
|
run_type="parser",
|
||||||
|
)
|
||||||
return await self._acall_with_config(
|
return await self._acall_with_config(
|
||||||
lambda inner_input: self.aparse_result([Generation(text=inner_input)]),
|
lambda inner_input: self.aparse_result([Generation(text=inner_input)]),
|
||||||
input,
|
cast("str", input),
|
||||||
config,
|
config,
|
||||||
run_type="parser",
|
run_type="parser",
|
||||||
)
|
)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def parse_result(self, result: list[Generation], *, partial: bool = False) -> T:
|
def parse_result(
|
||||||
|
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
|
||||||
|
) -> T:
|
||||||
"""Parse a list of candidate model Generations into a specific format.
|
"""Parse a list of candidate model Generations into a specific format.
|
||||||
|
|
||||||
The return value is parsed from only the first Generation in the result, which
|
The return value is parsed from only the first Generation in the result, which
|
||||||
@ -248,6 +282,8 @@ class BaseOutputParser(
|
|||||||
Returns:
|
Returns:
|
||||||
Structured output.
|
Structured output.
|
||||||
"""
|
"""
|
||||||
|
if isinstance(result, AIMessage):
|
||||||
|
return self.parse(result.text or "")
|
||||||
return self.parse(result[0].text)
|
return self.parse(result[0].text)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -262,7 +298,7 @@ class BaseOutputParser(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
async def aparse_result(
|
async def aparse_result(
|
||||||
self, result: list[Generation], *, partial: bool = False
|
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
|
||||||
) -> T:
|
) -> T:
|
||||||
"""Async parse a list of candidate model Generations into a specific format.
|
"""Async parse a list of candidate model Generations into a specific format.
|
||||||
|
|
||||||
|
@ -13,6 +13,7 @@ from pydantic.v1 import BaseModel
|
|||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_core.exceptions import OutputParserException
|
from langchain_core.exceptions import OutputParserException
|
||||||
|
from langchain_core.messages.v1 import AIMessage
|
||||||
from langchain_core.output_parsers.format_instructions import JSON_FORMAT_INSTRUCTIONS
|
from langchain_core.output_parsers.format_instructions import JSON_FORMAT_INSTRUCTIONS
|
||||||
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser
|
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser
|
||||||
from langchain_core.outputs import Generation
|
from langchain_core.outputs import Generation
|
||||||
@ -53,7 +54,9 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
|
|||||||
return pydantic_object.schema()
|
return pydantic_object.schema()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
def parse_result(
|
||||||
|
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
|
||||||
|
) -> Any:
|
||||||
"""Parse the result of an LLM call to a JSON object.
|
"""Parse the result of an LLM call to a JSON object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -70,7 +73,7 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
|
|||||||
Raises:
|
Raises:
|
||||||
OutputParserException: If the output is not valid JSON.
|
OutputParserException: If the output is not valid JSON.
|
||||||
"""
|
"""
|
||||||
text = result[0].text
|
text = result.text or "" if isinstance(result, AIMessage) else result[0].text
|
||||||
text = text.strip()
|
text = text.strip()
|
||||||
if partial:
|
if partial:
|
||||||
try:
|
try:
|
||||||
|
@ -12,6 +12,7 @@ from typing import TYPE_CHECKING, TypeVar, Union
|
|||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_core.messages import BaseMessage
|
from langchain_core.messages import BaseMessage
|
||||||
|
from langchain_core.messages.v1 import AIMessage
|
||||||
from langchain_core.output_parsers.transform import BaseTransformOutputParser
|
from langchain_core.output_parsers.transform import BaseTransformOutputParser
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -71,7 +72,7 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]):
|
|||||||
|
|
||||||
@override
|
@override
|
||||||
def _transform(
|
def _transform(
|
||||||
self, input: Iterator[Union[str, BaseMessage]]
|
self, input: Iterator[Union[str, BaseMessage, AIMessage]]
|
||||||
) -> Iterator[list[str]]:
|
) -> Iterator[list[str]]:
|
||||||
buffer = ""
|
buffer = ""
|
||||||
for chunk in input:
|
for chunk in input:
|
||||||
@ -81,6 +82,8 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]):
|
|||||||
if not isinstance(chunk_content, str):
|
if not isinstance(chunk_content, str):
|
||||||
continue
|
continue
|
||||||
buffer += chunk_content
|
buffer += chunk_content
|
||||||
|
elif isinstance(chunk, AIMessage):
|
||||||
|
buffer += chunk.text or ""
|
||||||
else:
|
else:
|
||||||
# add current chunk to buffer
|
# add current chunk to buffer
|
||||||
buffer += chunk
|
buffer += chunk
|
||||||
@ -105,7 +108,7 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]):
|
|||||||
|
|
||||||
@override
|
@override
|
||||||
async def _atransform(
|
async def _atransform(
|
||||||
self, input: AsyncIterator[Union[str, BaseMessage]]
|
self, input: AsyncIterator[Union[str, BaseMessage, AIMessage]]
|
||||||
) -> AsyncIterator[list[str]]:
|
) -> AsyncIterator[list[str]]:
|
||||||
buffer = ""
|
buffer = ""
|
||||||
async for chunk in input:
|
async for chunk in input:
|
||||||
@ -115,6 +118,8 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]):
|
|||||||
if not isinstance(chunk_content, str):
|
if not isinstance(chunk_content, str):
|
||||||
continue
|
continue
|
||||||
buffer += chunk_content
|
buffer += chunk_content
|
||||||
|
elif isinstance(chunk, AIMessage):
|
||||||
|
buffer += chunk.text or ""
|
||||||
else:
|
else:
|
||||||
# add current chunk to buffer
|
# add current chunk to buffer
|
||||||
buffer += chunk
|
buffer += chunk
|
||||||
|
@ -11,6 +11,7 @@ from pydantic.v1 import BaseModel as BaseModelV1
|
|||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_core.exceptions import OutputParserException
|
from langchain_core.exceptions import OutputParserException
|
||||||
|
from langchain_core.messages.v1 import AIMessage
|
||||||
from langchain_core.output_parsers import (
|
from langchain_core.output_parsers import (
|
||||||
BaseCumulativeTransformOutputParser,
|
BaseCumulativeTransformOutputParser,
|
||||||
BaseGenerationOutputParser,
|
BaseGenerationOutputParser,
|
||||||
@ -26,7 +27,9 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
|
|||||||
"""Whether to only return the arguments to the function call."""
|
"""Whether to only return the arguments to the function call."""
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
def parse_result(
|
||||||
|
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
|
||||||
|
) -> Any:
|
||||||
"""Parse the result of an LLM call to a JSON object.
|
"""Parse the result of an LLM call to a JSON object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -39,6 +42,12 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
|
|||||||
Raises:
|
Raises:
|
||||||
OutputParserException: If the output is not valid JSON.
|
OutputParserException: If the output is not valid JSON.
|
||||||
"""
|
"""
|
||||||
|
if isinstance(result, AIMessage):
|
||||||
|
msg = (
|
||||||
|
"This output parser does not support v1 AIMessages. Use "
|
||||||
|
"JsonOutputToolsParser instead."
|
||||||
|
)
|
||||||
|
raise TypeError(msg)
|
||||||
generation = result[0]
|
generation = result[0]
|
||||||
if not isinstance(generation, ChatGeneration):
|
if not isinstance(generation, ChatGeneration):
|
||||||
msg = "This output parser can only be used with a chat generation."
|
msg = "This output parser can only be used with a chat generation."
|
||||||
@ -77,7 +86,9 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
|
|||||||
def _diff(self, prev: Optional[Any], next: Any) -> Any:
|
def _diff(self, prev: Optional[Any], next: Any) -> Any:
|
||||||
return jsonpatch.make_patch(prev, next).patch
|
return jsonpatch.make_patch(prev, next).patch
|
||||||
|
|
||||||
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
def parse_result(
|
||||||
|
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
|
||||||
|
) -> Any:
|
||||||
"""Parse the result of an LLM call to a JSON object.
|
"""Parse the result of an LLM call to a JSON object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -90,6 +101,12 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
|
|||||||
Raises:
|
Raises:
|
||||||
OutputParserException: If the output is not valid JSON.
|
OutputParserException: If the output is not valid JSON.
|
||||||
"""
|
"""
|
||||||
|
if isinstance(result, AIMessage):
|
||||||
|
msg = (
|
||||||
|
"This output parser does not support v1 AIMessages. Use "
|
||||||
|
"JsonOutputToolsParser instead."
|
||||||
|
)
|
||||||
|
raise TypeError(msg)
|
||||||
if len(result) != 1:
|
if len(result) != 1:
|
||||||
msg = f"Expected exactly one result, but got {len(result)}"
|
msg = f"Expected exactly one result, but got {len(result)}"
|
||||||
raise OutputParserException(msg)
|
raise OutputParserException(msg)
|
||||||
@ -160,7 +177,9 @@ class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):
|
|||||||
key_name: str
|
key_name: str
|
||||||
"""The name of the key to return."""
|
"""The name of the key to return."""
|
||||||
|
|
||||||
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
def parse_result(
|
||||||
|
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
|
||||||
|
) -> Any:
|
||||||
"""Parse the result of an LLM call to a JSON object.
|
"""Parse the result of an LLM call to a JSON object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -254,7 +273,9 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
|
|||||||
return values
|
return values
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
def parse_result(
|
||||||
|
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
|
||||||
|
) -> Any:
|
||||||
"""Parse the result of an LLM call to a JSON object.
|
"""Parse the result of an LLM call to a JSON object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -294,7 +315,9 @@ class PydanticAttrOutputFunctionsParser(PydanticOutputFunctionsParser):
|
|||||||
"""The name of the attribute to return."""
|
"""The name of the attribute to return."""
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
def parse_result(
|
||||||
|
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
|
||||||
|
) -> Any:
|
||||||
"""Parse the result of an LLM call to a JSON object.
|
"""Parse the result of an LLM call to a JSON object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -4,7 +4,7 @@ import copy
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from json import JSONDecodeError
|
from json import JSONDecodeError
|
||||||
from typing import Annotated, Any, Optional
|
from typing import Annotated, Any, Optional, Union
|
||||||
|
|
||||||
from pydantic import SkipValidation, ValidationError
|
from pydantic import SkipValidation, ValidationError
|
||||||
|
|
||||||
@ -12,6 +12,7 @@ from langchain_core.exceptions import OutputParserException
|
|||||||
from langchain_core.messages import AIMessage, InvalidToolCall
|
from langchain_core.messages import AIMessage, InvalidToolCall
|
||||||
from langchain_core.messages.tool import invalid_tool_call
|
from langchain_core.messages.tool import invalid_tool_call
|
||||||
from langchain_core.messages.tool import tool_call as create_tool_call
|
from langchain_core.messages.tool import tool_call as create_tool_call
|
||||||
|
from langchain_core.messages.v1 import AIMessage as AIMessageV1
|
||||||
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser
|
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser
|
||||||
from langchain_core.outputs import ChatGeneration, Generation
|
from langchain_core.outputs import ChatGeneration, Generation
|
||||||
from langchain_core.utils.json import parse_partial_json
|
from langchain_core.utils.json import parse_partial_json
|
||||||
@ -156,7 +157,9 @@ class JsonOutputToolsParser(BaseCumulativeTransformOutputParser[Any]):
|
|||||||
If no tool calls are found, None will be returned.
|
If no tool calls are found, None will be returned.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
def parse_result(
|
||||||
|
self, result: Union[list[Generation], AIMessageV1], *, partial: bool = False
|
||||||
|
) -> Any:
|
||||||
"""Parse the result of an LLM call to a list of tool calls.
|
"""Parse the result of an LLM call to a list of tool calls.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -173,9 +176,13 @@ class JsonOutputToolsParser(BaseCumulativeTransformOutputParser[Any]):
|
|||||||
Raises:
|
Raises:
|
||||||
OutputParserException: If the output is not valid JSON.
|
OutputParserException: If the output is not valid JSON.
|
||||||
"""
|
"""
|
||||||
|
if isinstance(result, list):
|
||||||
generation = result[0]
|
generation = result[0]
|
||||||
if not isinstance(generation, ChatGeneration):
|
if not isinstance(generation, ChatGeneration):
|
||||||
msg = "This output parser can only be used with a chat generation."
|
msg = (
|
||||||
|
"This output parser can only be used with a chat generation or "
|
||||||
|
"v1 AIMessage."
|
||||||
|
)
|
||||||
raise OutputParserException(msg)
|
raise OutputParserException(msg)
|
||||||
message = generation.message
|
message = generation.message
|
||||||
if isinstance(message, AIMessage) and message.tool_calls:
|
if isinstance(message, AIMessage) and message.tool_calls:
|
||||||
@ -185,7 +192,9 @@ class JsonOutputToolsParser(BaseCumulativeTransformOutputParser[Any]):
|
|||||||
_ = tool_call.pop("id")
|
_ = tool_call.pop("id")
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
raw_tool_calls = copy.deepcopy(message.additional_kwargs["tool_calls"])
|
raw_tool_calls = copy.deepcopy(
|
||||||
|
message.additional_kwargs["tool_calls"]
|
||||||
|
)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
return []
|
return []
|
||||||
tool_calls = parse_tool_calls(
|
tool_calls = parse_tool_calls(
|
||||||
@ -194,10 +203,18 @@ class JsonOutputToolsParser(BaseCumulativeTransformOutputParser[Any]):
|
|||||||
strict=self.strict,
|
strict=self.strict,
|
||||||
return_id=self.return_id,
|
return_id=self.return_id,
|
||||||
)
|
)
|
||||||
|
elif result.tool_calls:
|
||||||
|
# v1 message
|
||||||
|
tool_calls = [dict(tc) for tc in result.tool_calls]
|
||||||
|
for tool_call in tool_calls:
|
||||||
|
if not self.return_id:
|
||||||
|
_ = tool_call.pop("id")
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
||||||
# for backwards compatibility
|
# for backwards compatibility
|
||||||
for tc in tool_calls:
|
for tc in tool_calls:
|
||||||
tc["type"] = tc.pop("name")
|
tc["type"] = tc.pop("name")
|
||||||
|
|
||||||
if self.first_tool_only:
|
if self.first_tool_only:
|
||||||
return tool_calls[0] if tool_calls else None
|
return tool_calls[0] if tool_calls else None
|
||||||
return tool_calls
|
return tool_calls
|
||||||
@ -220,7 +237,9 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
|
|||||||
key_name: str
|
key_name: str
|
||||||
"""The type of tools to return."""
|
"""The type of tools to return."""
|
||||||
|
|
||||||
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
def parse_result(
|
||||||
|
self, result: Union[list[Generation], AIMessageV1], *, partial: bool = False
|
||||||
|
) -> Any:
|
||||||
"""Parse the result of an LLM call to a list of tool calls.
|
"""Parse the result of an LLM call to a list of tool calls.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -234,6 +253,7 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
|
|||||||
Returns:
|
Returns:
|
||||||
The parsed tool calls.
|
The parsed tool calls.
|
||||||
"""
|
"""
|
||||||
|
if isinstance(result, list):
|
||||||
generation = result[0]
|
generation = result[0]
|
||||||
if not isinstance(generation, ChatGeneration):
|
if not isinstance(generation, ChatGeneration):
|
||||||
msg = "This output parser can only be used with a chat generation."
|
msg = "This output parser can only be used with a chat generation."
|
||||||
@ -246,7 +266,9 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
|
|||||||
_ = tool_call.pop("id")
|
_ = tool_call.pop("id")
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
raw_tool_calls = copy.deepcopy(message.additional_kwargs["tool_calls"])
|
raw_tool_calls = copy.deepcopy(
|
||||||
|
message.additional_kwargs["tool_calls"]
|
||||||
|
)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
if self.first_tool_only:
|
if self.first_tool_only:
|
||||||
return None
|
return None
|
||||||
@ -257,9 +279,21 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
|
|||||||
strict=self.strict,
|
strict=self.strict,
|
||||||
return_id=self.return_id,
|
return_id=self.return_id,
|
||||||
)
|
)
|
||||||
|
elif result.tool_calls:
|
||||||
|
# v1 message
|
||||||
|
parsed_tool_calls = [dict(tc) for tc in result.tool_calls]
|
||||||
|
for tool_call in parsed_tool_calls:
|
||||||
|
if not self.return_id:
|
||||||
|
_ = tool_call.pop("id")
|
||||||
|
else:
|
||||||
|
if self.first_tool_only:
|
||||||
|
return None
|
||||||
|
return []
|
||||||
|
|
||||||
# For backwards compatibility
|
# For backwards compatibility
|
||||||
for tc in parsed_tool_calls:
|
for tc in parsed_tool_calls:
|
||||||
tc["type"] = tc.pop("name")
|
tc["type"] = tc.pop("name")
|
||||||
|
|
||||||
if self.first_tool_only:
|
if self.first_tool_only:
|
||||||
parsed_result = list(
|
parsed_result = list(
|
||||||
filter(lambda x: x["type"] == self.key_name, parsed_tool_calls)
|
filter(lambda x: x["type"] == self.key_name, parsed_tool_calls)
|
||||||
@ -299,7 +333,9 @@ class PydanticToolsParser(JsonOutputToolsParser):
|
|||||||
|
|
||||||
# TODO: Support more granular streaming of objects. Currently only streams once all
|
# TODO: Support more granular streaming of objects. Currently only streams once all
|
||||||
# Pydantic object fields are present.
|
# Pydantic object fields are present.
|
||||||
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
def parse_result(
|
||||||
|
self, result: Union[list[Generation], AIMessageV1], *, partial: bool = False
|
||||||
|
) -> Any:
|
||||||
"""Parse the result of an LLM call to a list of Pydantic objects.
|
"""Parse the result of an LLM call to a list of Pydantic objects.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -337,12 +373,19 @@ class PydanticToolsParser(JsonOutputToolsParser):
|
|||||||
except (ValidationError, ValueError):
|
except (ValidationError, ValueError):
|
||||||
if partial:
|
if partial:
|
||||||
continue
|
continue
|
||||||
|
has_max_tokens_stop_reason = False
|
||||||
|
if isinstance(result, list):
|
||||||
has_max_tokens_stop_reason = any(
|
has_max_tokens_stop_reason = any(
|
||||||
generation.message.response_metadata.get("stop_reason")
|
generation.message.response_metadata.get("stop_reason")
|
||||||
== "max_tokens"
|
== "max_tokens"
|
||||||
for generation in result
|
for generation in result
|
||||||
if isinstance(generation, ChatGeneration)
|
if isinstance(generation, ChatGeneration)
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
# v1 message
|
||||||
|
has_max_tokens_stop_reason = (
|
||||||
|
result.response_metadata.get("stop_reason") == "max_tokens"
|
||||||
|
)
|
||||||
if has_max_tokens_stop_reason:
|
if has_max_tokens_stop_reason:
|
||||||
logger.exception(_MAX_TOKENS_ERROR)
|
logger.exception(_MAX_TOKENS_ERROR)
|
||||||
raise
|
raise
|
||||||
|
@ -1,13 +1,14 @@
|
|||||||
"""Output parsers using Pydantic."""
|
"""Output parsers using Pydantic."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from typing import Annotated, Generic, Optional
|
from typing import Annotated, Generic, Optional, Union
|
||||||
|
|
||||||
import pydantic
|
import pydantic
|
||||||
from pydantic import SkipValidation
|
from pydantic import SkipValidation
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_core.exceptions import OutputParserException
|
from langchain_core.exceptions import OutputParserException
|
||||||
|
from langchain_core.messages.v1 import AIMessage
|
||||||
from langchain_core.output_parsers import JsonOutputParser
|
from langchain_core.output_parsers import JsonOutputParser
|
||||||
from langchain_core.outputs import Generation
|
from langchain_core.outputs import Generation
|
||||||
from langchain_core.utils.pydantic import (
|
from langchain_core.utils.pydantic import (
|
||||||
@ -43,7 +44,7 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
|
|||||||
return OutputParserException(msg, llm_output=json_string)
|
return OutputParserException(msg, llm_output=json_string)
|
||||||
|
|
||||||
def parse_result(
|
def parse_result(
|
||||||
self, result: list[Generation], *, partial: bool = False
|
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
|
||||||
) -> Optional[TBaseModel]:
|
) -> Optional[TBaseModel]:
|
||||||
"""Parse the result of an LLM call to a pydantic object.
|
"""Parse the result of an LLM call to a pydantic object.
|
||||||
|
|
||||||
|
@ -12,6 +12,7 @@ from typing import (
|
|||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_core.messages import BaseMessage, BaseMessageChunk
|
from langchain_core.messages import BaseMessage, BaseMessageChunk
|
||||||
|
from langchain_core.messages.v1 import AIMessage, AIMessageChunk
|
||||||
from langchain_core.output_parsers.base import BaseOutputParser, T
|
from langchain_core.output_parsers.base import BaseOutputParser, T
|
||||||
from langchain_core.outputs import (
|
from langchain_core.outputs import (
|
||||||
ChatGeneration,
|
ChatGeneration,
|
||||||
@ -32,23 +33,27 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
|
|||||||
|
|
||||||
def _transform(
|
def _transform(
|
||||||
self,
|
self,
|
||||||
input: Iterator[Union[str, BaseMessage]],
|
input: Iterator[Union[str, BaseMessage, AIMessage]],
|
||||||
) -> Iterator[T]:
|
) -> Iterator[T]:
|
||||||
for chunk in input:
|
for chunk in input:
|
||||||
if isinstance(chunk, BaseMessage):
|
if isinstance(chunk, BaseMessage):
|
||||||
yield self.parse_result([ChatGeneration(message=chunk)])
|
yield self.parse_result([ChatGeneration(message=chunk)])
|
||||||
|
elif isinstance(chunk, AIMessage):
|
||||||
|
yield self.parse_result(chunk)
|
||||||
else:
|
else:
|
||||||
yield self.parse_result([Generation(text=chunk)])
|
yield self.parse_result([Generation(text=chunk)])
|
||||||
|
|
||||||
async def _atransform(
|
async def _atransform(
|
||||||
self,
|
self,
|
||||||
input: AsyncIterator[Union[str, BaseMessage]],
|
input: AsyncIterator[Union[str, BaseMessage, AIMessage]],
|
||||||
) -> AsyncIterator[T]:
|
) -> AsyncIterator[T]:
|
||||||
async for chunk in input:
|
async for chunk in input:
|
||||||
if isinstance(chunk, BaseMessage):
|
if isinstance(chunk, BaseMessage):
|
||||||
yield await run_in_executor(
|
yield await run_in_executor(
|
||||||
None, self.parse_result, [ChatGeneration(message=chunk)]
|
None, self.parse_result, [ChatGeneration(message=chunk)]
|
||||||
)
|
)
|
||||||
|
elif isinstance(chunk, AIMessage):
|
||||||
|
yield await run_in_executor(None, self.parse_result, chunk)
|
||||||
else:
|
else:
|
||||||
yield await run_in_executor(
|
yield await run_in_executor(
|
||||||
None, self.parse_result, [Generation(text=chunk)]
|
None, self.parse_result, [Generation(text=chunk)]
|
||||||
@ -57,7 +62,7 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
|
|||||||
@override
|
@override
|
||||||
def transform(
|
def transform(
|
||||||
self,
|
self,
|
||||||
input: Iterator[Union[str, BaseMessage]],
|
input: Iterator[Union[str, BaseMessage, AIMessage]],
|
||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Iterator[T]:
|
) -> Iterator[T]:
|
||||||
@ -78,7 +83,7 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
|
|||||||
@override
|
@override
|
||||||
async def atransform(
|
async def atransform(
|
||||||
self,
|
self,
|
||||||
input: AsyncIterator[Union[str, BaseMessage]],
|
input: AsyncIterator[Union[str, BaseMessage, AIMessage]],
|
||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AsyncIterator[T]:
|
) -> AsyncIterator[T]:
|
||||||
@ -125,22 +130,41 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]:
|
def _transform(
|
||||||
|
self, input: Iterator[Union[str, BaseMessage, AIMessage]]
|
||||||
|
) -> Iterator[Any]:
|
||||||
prev_parsed = None
|
prev_parsed = None
|
||||||
acc_gen: Union[GenerationChunk, ChatGenerationChunk, None] = None
|
acc_gen: Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk, None] = (
|
||||||
|
None
|
||||||
|
)
|
||||||
for chunk in input:
|
for chunk in input:
|
||||||
chunk_gen: Union[GenerationChunk, ChatGenerationChunk]
|
chunk_gen: Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
|
||||||
if isinstance(chunk, BaseMessageChunk):
|
if isinstance(chunk, BaseMessageChunk):
|
||||||
chunk_gen = ChatGenerationChunk(message=chunk)
|
chunk_gen = ChatGenerationChunk(message=chunk)
|
||||||
elif isinstance(chunk, BaseMessage):
|
elif isinstance(chunk, BaseMessage):
|
||||||
chunk_gen = ChatGenerationChunk(
|
chunk_gen = ChatGenerationChunk(
|
||||||
message=BaseMessageChunk(**chunk.model_dump())
|
message=BaseMessageChunk(**chunk.model_dump())
|
||||||
)
|
)
|
||||||
|
elif isinstance(chunk, AIMessageChunk):
|
||||||
|
chunk_gen = chunk
|
||||||
|
elif isinstance(chunk, AIMessage):
|
||||||
|
chunk_gen = AIMessageChunk(
|
||||||
|
content=chunk.content,
|
||||||
|
id=chunk.id,
|
||||||
|
name=chunk.name,
|
||||||
|
lc_version=chunk.lc_version,
|
||||||
|
response_metadata=chunk.response_metadata,
|
||||||
|
usage_metadata=chunk.usage_metadata,
|
||||||
|
parsed=chunk.parsed,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
chunk_gen = GenerationChunk(text=chunk)
|
chunk_gen = GenerationChunk(text=chunk)
|
||||||
|
|
||||||
acc_gen = chunk_gen if acc_gen is None else acc_gen + chunk_gen # type: ignore[operator]
|
acc_gen = chunk_gen if acc_gen is None else acc_gen + chunk_gen # type: ignore[operator]
|
||||||
|
|
||||||
|
if isinstance(acc_gen, AIMessageChunk):
|
||||||
|
parsed = self.parse_result(acc_gen, partial=True)
|
||||||
|
else:
|
||||||
parsed = self.parse_result([acc_gen], partial=True)
|
parsed = self.parse_result([acc_gen], partial=True)
|
||||||
if parsed is not None and parsed != prev_parsed:
|
if parsed is not None and parsed != prev_parsed:
|
||||||
if self.diff:
|
if self.diff:
|
||||||
@ -151,23 +175,40 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
|
|||||||
|
|
||||||
@override
|
@override
|
||||||
async def _atransform(
|
async def _atransform(
|
||||||
self, input: AsyncIterator[Union[str, BaseMessage]]
|
self, input: AsyncIterator[Union[str, BaseMessage, AIMessage]]
|
||||||
) -> AsyncIterator[T]:
|
) -> AsyncIterator[T]:
|
||||||
prev_parsed = None
|
prev_parsed = None
|
||||||
acc_gen: Union[GenerationChunk, ChatGenerationChunk, None] = None
|
acc_gen: Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk, None] = (
|
||||||
|
None
|
||||||
|
)
|
||||||
async for chunk in input:
|
async for chunk in input:
|
||||||
chunk_gen: Union[GenerationChunk, ChatGenerationChunk]
|
chunk_gen: Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
|
||||||
if isinstance(chunk, BaseMessageChunk):
|
if isinstance(chunk, BaseMessageChunk):
|
||||||
chunk_gen = ChatGenerationChunk(message=chunk)
|
chunk_gen = ChatGenerationChunk(message=chunk)
|
||||||
elif isinstance(chunk, BaseMessage):
|
elif isinstance(chunk, BaseMessage):
|
||||||
chunk_gen = ChatGenerationChunk(
|
chunk_gen = ChatGenerationChunk(
|
||||||
message=BaseMessageChunk(**chunk.model_dump())
|
message=BaseMessageChunk(**chunk.model_dump())
|
||||||
)
|
)
|
||||||
|
elif isinstance(chunk, AIMessageChunk):
|
||||||
|
chunk_gen = chunk
|
||||||
|
elif isinstance(chunk, AIMessage):
|
||||||
|
chunk_gen = AIMessageChunk(
|
||||||
|
content=chunk.content,
|
||||||
|
id=chunk.id,
|
||||||
|
name=chunk.name,
|
||||||
|
lc_version=chunk.lc_version,
|
||||||
|
response_metadata=chunk.response_metadata,
|
||||||
|
usage_metadata=chunk.usage_metadata,
|
||||||
|
parsed=chunk.parsed,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
chunk_gen = GenerationChunk(text=chunk)
|
chunk_gen = GenerationChunk(text=chunk)
|
||||||
|
|
||||||
acc_gen = chunk_gen if acc_gen is None else acc_gen + chunk_gen # type: ignore[operator]
|
acc_gen = chunk_gen if acc_gen is None else acc_gen + chunk_gen # type: ignore[operator]
|
||||||
|
|
||||||
|
if isinstance(acc_gen, AIMessageChunk):
|
||||||
|
parsed = await self.aparse_result(acc_gen, partial=True)
|
||||||
|
else:
|
||||||
parsed = await self.aparse_result([acc_gen], partial=True)
|
parsed = await self.aparse_result([acc_gen], partial=True)
|
||||||
if parsed is not None and parsed != prev_parsed:
|
if parsed is not None and parsed != prev_parsed:
|
||||||
if self.diff:
|
if self.diff:
|
||||||
|
@ -12,6 +12,8 @@ from typing_extensions import override
|
|||||||
|
|
||||||
from langchain_core.exceptions import OutputParserException
|
from langchain_core.exceptions import OutputParserException
|
||||||
from langchain_core.messages import BaseMessage
|
from langchain_core.messages import BaseMessage
|
||||||
|
from langchain_core.messages.utils import convert_from_v1_message
|
||||||
|
from langchain_core.messages.v1 import AIMessage
|
||||||
from langchain_core.output_parsers.transform import BaseTransformOutputParser
|
from langchain_core.output_parsers.transform import BaseTransformOutputParser
|
||||||
from langchain_core.runnables.utils import AddableDict
|
from langchain_core.runnables.utils import AddableDict
|
||||||
|
|
||||||
@ -240,19 +242,26 @@ class XMLOutputParser(BaseTransformOutputParser):
|
|||||||
|
|
||||||
@override
|
@override
|
||||||
def _transform(
|
def _transform(
|
||||||
self, input: Iterator[Union[str, BaseMessage]]
|
self, input: Iterator[Union[str, BaseMessage, AIMessage]]
|
||||||
) -> Iterator[AddableDict]:
|
) -> Iterator[AddableDict]:
|
||||||
streaming_parser = _StreamingParser(self.parser)
|
streaming_parser = _StreamingParser(self.parser)
|
||||||
for chunk in input:
|
for chunk in input:
|
||||||
|
if isinstance(chunk, AIMessage):
|
||||||
|
yield from streaming_parser.parse(convert_from_v1_message(chunk))
|
||||||
|
else:
|
||||||
yield from streaming_parser.parse(chunk)
|
yield from streaming_parser.parse(chunk)
|
||||||
streaming_parser.close()
|
streaming_parser.close()
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def _atransform(
|
async def _atransform(
|
||||||
self, input: AsyncIterator[Union[str, BaseMessage]]
|
self, input: AsyncIterator[Union[str, BaseMessage, AIMessage]]
|
||||||
) -> AsyncIterator[AddableDict]:
|
) -> AsyncIterator[AddableDict]:
|
||||||
streaming_parser = _StreamingParser(self.parser)
|
streaming_parser = _StreamingParser(self.parser)
|
||||||
async for chunk in input:
|
async for chunk in input:
|
||||||
|
if isinstance(chunk, AIMessage):
|
||||||
|
for output in streaming_parser.parse(convert_from_v1_message(chunk)):
|
||||||
|
yield output
|
||||||
|
else:
|
||||||
for output in streaming_parser.parse(chunk):
|
for output in streaming_parser.parse(chunk):
|
||||||
yield output
|
yield output
|
||||||
streaming_parser.close()
|
streaming_parser.close()
|
||||||
|
@ -1,10 +1,14 @@
|
|||||||
"""Module to test base parser implementations."""
|
"""Module to test base parser implementations."""
|
||||||
|
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_core.exceptions import OutputParserException
|
from langchain_core.exceptions import OutputParserException
|
||||||
from langchain_core.language_models import GenericFakeChatModel
|
from langchain_core.language_models import GenericFakeChatModel
|
||||||
|
from langchain_core.language_models.fake_chat_models import GenericFakeChatModelV1
|
||||||
from langchain_core.messages import AIMessage
|
from langchain_core.messages import AIMessage
|
||||||
|
from langchain_core.messages.v1 import AIMessage as AIMessageV1
|
||||||
from langchain_core.output_parsers import (
|
from langchain_core.output_parsers import (
|
||||||
BaseGenerationOutputParser,
|
BaseGenerationOutputParser,
|
||||||
BaseTransformOutputParser,
|
BaseTransformOutputParser,
|
||||||
@ -20,7 +24,7 @@ def test_base_generation_parser() -> None:
|
|||||||
|
|
||||||
@override
|
@override
|
||||||
def parse_result(
|
def parse_result(
|
||||||
self, result: list[Generation], *, partial: bool = False
|
self, result: Union[list[Generation], AIMessageV1], *, partial: bool = False
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Parse a list of model Generations into a specific format.
|
"""Parse a list of model Generations into a specific format.
|
||||||
|
|
||||||
@ -32,16 +36,22 @@ def test_base_generation_parser() -> None:
|
|||||||
partial: Whether to allow partial results. This is used for parsers
|
partial: Whether to allow partial results. This is used for parsers
|
||||||
that support streaming
|
that support streaming
|
||||||
"""
|
"""
|
||||||
|
if isinstance(result, AIMessageV1):
|
||||||
|
content = result.text or ""
|
||||||
|
else:
|
||||||
if len(result) != 1:
|
if len(result) != 1:
|
||||||
msg = "This output parser can only be used with a single generation."
|
msg = (
|
||||||
|
"This output parser can only be used with a single generation."
|
||||||
|
)
|
||||||
raise NotImplementedError(msg)
|
raise NotImplementedError(msg)
|
||||||
generation = result[0]
|
generation = result[0]
|
||||||
if not isinstance(generation, ChatGeneration):
|
if not isinstance(generation, ChatGeneration):
|
||||||
# Say that this one only works with chat generations
|
# Say that this one only works with chat generations
|
||||||
msg = "This output parser can only be used with a chat generation."
|
msg = "This output parser can only be used with a chat generation."
|
||||||
raise OutputParserException(msg)
|
raise OutputParserException(msg)
|
||||||
|
assert isinstance(generation.message.content, str)
|
||||||
content = generation.message.content
|
content = generation.message.content
|
||||||
|
|
||||||
assert isinstance(content, str)
|
assert isinstance(content, str)
|
||||||
return content.swapcase()
|
return content.swapcase()
|
||||||
|
|
||||||
@ -49,6 +59,10 @@ def test_base_generation_parser() -> None:
|
|||||||
chain = model | StrInvertCase()
|
chain = model | StrInvertCase()
|
||||||
assert chain.invoke("") == "HeLLO"
|
assert chain.invoke("") == "HeLLO"
|
||||||
|
|
||||||
|
model_v1 = GenericFakeChatModelV1(messages=iter([AIMessageV1("hEllo")]))
|
||||||
|
chain_v1 = model_v1 | StrInvertCase()
|
||||||
|
assert chain_v1.invoke("") == "HeLLO"
|
||||||
|
|
||||||
|
|
||||||
def test_base_transform_output_parser() -> None:
|
def test_base_transform_output_parser() -> None:
|
||||||
"""Test base transform output parser."""
|
"""Test base transform output parser."""
|
||||||
@ -62,7 +76,7 @@ def test_base_transform_output_parser() -> None:
|
|||||||
|
|
||||||
@override
|
@override
|
||||||
def parse_result(
|
def parse_result(
|
||||||
self, result: list[Generation], *, partial: bool = False
|
self, result: Union[list[Generation], AIMessageV1], *, partial: bool = False
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Parse a list of model Generations into a specific format.
|
"""Parse a list of model Generations into a specific format.
|
||||||
|
|
||||||
@ -74,15 +88,22 @@ def test_base_transform_output_parser() -> None:
|
|||||||
partial: Whether to allow partial results. This is used for parsers
|
partial: Whether to allow partial results. This is used for parsers
|
||||||
that support streaming
|
that support streaming
|
||||||
"""
|
"""
|
||||||
|
if isinstance(result, AIMessageV1):
|
||||||
|
content = result.text or ""
|
||||||
|
else:
|
||||||
if len(result) != 1:
|
if len(result) != 1:
|
||||||
msg = "This output parser can only be used with a single generation."
|
msg = (
|
||||||
|
"This output parser can only be used with a single generation."
|
||||||
|
)
|
||||||
raise NotImplementedError(msg)
|
raise NotImplementedError(msg)
|
||||||
generation = result[0]
|
generation = result[0]
|
||||||
if not isinstance(generation, ChatGeneration):
|
if not isinstance(generation, ChatGeneration):
|
||||||
# Say that this one only works with chat generations
|
# Say that this one only works with chat generations
|
||||||
msg = "This output parser can only be used with a chat generation."
|
msg = "This output parser can only be used with a chat generation."
|
||||||
raise OutputParserException(msg)
|
raise OutputParserException(msg)
|
||||||
|
assert isinstance(generation.message.content, str)
|
||||||
content = generation.message.content
|
content = generation.message.content
|
||||||
|
|
||||||
assert isinstance(content, str)
|
assert isinstance(content, str)
|
||||||
return content.swapcase()
|
return content.swapcase()
|
||||||
|
|
||||||
@ -91,3 +112,8 @@ def test_base_transform_output_parser() -> None:
|
|||||||
# inputs to models are ignored, response is hard-coded in model definition
|
# inputs to models are ignored, response is hard-coded in model definition
|
||||||
chunks = list(chain.stream(""))
|
chunks = list(chain.stream(""))
|
||||||
assert chunks == ["HELLO", " ", "WORLD"]
|
assert chunks == ["HELLO", " ", "WORLD"]
|
||||||
|
|
||||||
|
model_v1 = GenericFakeChatModelV1(message_chunks=["hello", " ", "world"])
|
||||||
|
chain_v1 = model_v1 | StrInvertCase()
|
||||||
|
chunks = list(chain_v1.stream(""))
|
||||||
|
assert chunks == ["HELLO", " ", "WORLD"]
|
||||||
|
@ -10,6 +10,8 @@ from langchain_core.messages import (
|
|||||||
BaseMessage,
|
BaseMessage,
|
||||||
ToolCallChunk,
|
ToolCallChunk,
|
||||||
)
|
)
|
||||||
|
from langchain_core.messages.v1 import AIMessage as AIMessageV1
|
||||||
|
from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1
|
||||||
from langchain_core.output_parsers.openai_tools import (
|
from langchain_core.output_parsers.openai_tools import (
|
||||||
JsonOutputKeyToolsParser,
|
JsonOutputKeyToolsParser,
|
||||||
JsonOutputToolsParser,
|
JsonOutputToolsParser,
|
||||||
@ -331,6 +333,14 @@ for message in STREAMED_MESSAGES:
|
|||||||
STREAMED_MESSAGES_WITH_TOOL_CALLS.append(message)
|
STREAMED_MESSAGES_WITH_TOOL_CALLS.append(message)
|
||||||
|
|
||||||
|
|
||||||
|
STREAMED_MESSAGES_V1 = [
|
||||||
|
AIMessageChunkV1(
|
||||||
|
content=[],
|
||||||
|
tool_call_chunks=chunk.tool_call_chunks,
|
||||||
|
)
|
||||||
|
for chunk in STREAMED_MESSAGES_WITH_TOOL_CALLS
|
||||||
|
]
|
||||||
|
|
||||||
EXPECTED_STREAMED_JSON = [
|
EXPECTED_STREAMED_JSON = [
|
||||||
{},
|
{},
|
||||||
{"names": ["suz"]},
|
{"names": ["suz"]},
|
||||||
@ -398,6 +408,19 @@ def test_partial_json_output_parser(*, use_tool_calls: bool) -> None:
|
|||||||
assert actual == expected
|
assert actual == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_partial_json_output_parser_v1() -> None:
|
||||||
|
def input_iter(_: Any) -> Iterator[AIMessageChunkV1]:
|
||||||
|
yield from STREAMED_MESSAGES_V1
|
||||||
|
|
||||||
|
chain = input_iter | JsonOutputToolsParser()
|
||||||
|
|
||||||
|
actual = list(chain.stream(None))
|
||||||
|
expected: list = [[]] + [
|
||||||
|
[{"type": "NameCollector", "args": chunk}] for chunk in EXPECTED_STREAMED_JSON
|
||||||
|
]
|
||||||
|
assert actual == expected
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
||||||
async def test_partial_json_output_parser_async(*, use_tool_calls: bool) -> None:
|
async def test_partial_json_output_parser_async(*, use_tool_calls: bool) -> None:
|
||||||
input_iter = _get_aiter(use_tool_calls=use_tool_calls)
|
input_iter = _get_aiter(use_tool_calls=use_tool_calls)
|
||||||
@ -410,6 +433,20 @@ async def test_partial_json_output_parser_async(*, use_tool_calls: bool) -> None
|
|||||||
assert actual == expected
|
assert actual == expected
|
||||||
|
|
||||||
|
|
||||||
|
async def test_partial_json_output_parser_async_v1() -> None:
|
||||||
|
async def input_iter(_: Any) -> AsyncIterator[AIMessageChunkV1]:
|
||||||
|
for msg in STREAMED_MESSAGES_V1:
|
||||||
|
yield msg
|
||||||
|
|
||||||
|
chain = input_iter | JsonOutputToolsParser()
|
||||||
|
|
||||||
|
actual = [p async for p in chain.astream(None)]
|
||||||
|
expected: list = [[]] + [
|
||||||
|
[{"type": "NameCollector", "args": chunk}] for chunk in EXPECTED_STREAMED_JSON
|
||||||
|
]
|
||||||
|
assert actual == expected
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
||||||
def test_partial_json_output_parser_return_id(*, use_tool_calls: bool) -> None:
|
def test_partial_json_output_parser_return_id(*, use_tool_calls: bool) -> None:
|
||||||
input_iter = _get_iter(use_tool_calls=use_tool_calls)
|
input_iter = _get_iter(use_tool_calls=use_tool_calls)
|
||||||
@ -429,6 +466,26 @@ def test_partial_json_output_parser_return_id(*, use_tool_calls: bool) -> None:
|
|||||||
assert actual == expected
|
assert actual == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_partial_json_output_parser_return_id_v1() -> None:
|
||||||
|
def input_iter(_: Any) -> Iterator[AIMessageChunkV1]:
|
||||||
|
yield from STREAMED_MESSAGES_V1
|
||||||
|
|
||||||
|
chain = input_iter | JsonOutputToolsParser(return_id=True)
|
||||||
|
|
||||||
|
actual = list(chain.stream(None))
|
||||||
|
expected: list = [[]] + [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"type": "NameCollector",
|
||||||
|
"args": chunk,
|
||||||
|
"id": "call_OwL7f5PEPJTYzw9sQlNJtCZl",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
for chunk in EXPECTED_STREAMED_JSON
|
||||||
|
]
|
||||||
|
assert actual == expected
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
||||||
def test_partial_json_output_key_parser(*, use_tool_calls: bool) -> None:
|
def test_partial_json_output_key_parser(*, use_tool_calls: bool) -> None:
|
||||||
input_iter = _get_iter(use_tool_calls=use_tool_calls)
|
input_iter = _get_iter(use_tool_calls=use_tool_calls)
|
||||||
@ -439,6 +496,17 @@ def test_partial_json_output_key_parser(*, use_tool_calls: bool) -> None:
|
|||||||
assert actual == expected
|
assert actual == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_partial_json_output_key_parser_v1() -> None:
|
||||||
|
def input_iter(_: Any) -> Iterator[AIMessageChunkV1]:
|
||||||
|
yield from STREAMED_MESSAGES_V1
|
||||||
|
|
||||||
|
chain = input_iter | JsonOutputKeyToolsParser(key_name="NameCollector")
|
||||||
|
|
||||||
|
actual = list(chain.stream(None))
|
||||||
|
expected: list = [[]] + [[chunk] for chunk in EXPECTED_STREAMED_JSON]
|
||||||
|
assert actual == expected
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
||||||
async def test_partial_json_output_parser_key_async(*, use_tool_calls: bool) -> None:
|
async def test_partial_json_output_parser_key_async(*, use_tool_calls: bool) -> None:
|
||||||
input_iter = _get_aiter(use_tool_calls=use_tool_calls)
|
input_iter = _get_aiter(use_tool_calls=use_tool_calls)
|
||||||
@ -450,6 +518,18 @@ async def test_partial_json_output_parser_key_async(*, use_tool_calls: bool) ->
|
|||||||
assert actual == expected
|
assert actual == expected
|
||||||
|
|
||||||
|
|
||||||
|
async def test_partial_json_output_parser_key_async_v1() -> None:
|
||||||
|
async def input_iter(_: Any) -> AsyncIterator[AIMessageChunkV1]:
|
||||||
|
for msg in STREAMED_MESSAGES_V1:
|
||||||
|
yield msg
|
||||||
|
|
||||||
|
chain = input_iter | JsonOutputKeyToolsParser(key_name="NameCollector")
|
||||||
|
|
||||||
|
actual = [p async for p in chain.astream(None)]
|
||||||
|
expected: list = [[]] + [[chunk] for chunk in EXPECTED_STREAMED_JSON]
|
||||||
|
assert actual == expected
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
||||||
def test_partial_json_output_key_parser_first_only(*, use_tool_calls: bool) -> None:
|
def test_partial_json_output_key_parser_first_only(*, use_tool_calls: bool) -> None:
|
||||||
input_iter = _get_iter(use_tool_calls=use_tool_calls)
|
input_iter = _get_iter(use_tool_calls=use_tool_calls)
|
||||||
@ -461,6 +541,17 @@ def test_partial_json_output_key_parser_first_only(*, use_tool_calls: bool) -> N
|
|||||||
assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON
|
assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON
|
||||||
|
|
||||||
|
|
||||||
|
def test_partial_json_output_key_parser_first_only_v1() -> None:
|
||||||
|
def input_iter(_: Any) -> Iterator[AIMessageChunkV1]:
|
||||||
|
yield from STREAMED_MESSAGES_V1
|
||||||
|
|
||||||
|
chain = input_iter | JsonOutputKeyToolsParser(
|
||||||
|
key_name="NameCollector", first_tool_only=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
||||||
async def test_partial_json_output_parser_key_async_first_only(
|
async def test_partial_json_output_parser_key_async_first_only(
|
||||||
*,
|
*,
|
||||||
@ -475,6 +566,18 @@ async def test_partial_json_output_parser_key_async_first_only(
|
|||||||
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON
|
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON
|
||||||
|
|
||||||
|
|
||||||
|
async def test_partial_json_output_parser_key_async_first_only_v1() -> None:
|
||||||
|
async def input_iter(_: Any) -> AsyncIterator[AIMessageChunkV1]:
|
||||||
|
for msg in STREAMED_MESSAGES_V1:
|
||||||
|
yield msg
|
||||||
|
|
||||||
|
chain = input_iter | JsonOutputKeyToolsParser(
|
||||||
|
key_name="NameCollector", first_tool_only=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
||||||
def test_json_output_key_tools_parser_multiple_tools_first_only(
|
def test_json_output_key_tools_parser_multiple_tools_first_only(
|
||||||
*, use_tool_calls: bool
|
*, use_tool_calls: bool
|
||||||
@ -531,6 +634,42 @@ def test_json_output_key_tools_parser_multiple_tools_first_only(
|
|||||||
assert output_no_id == {"a": 1}
|
assert output_no_id == {"a": 1}
|
||||||
|
|
||||||
|
|
||||||
|
def test_json_output_key_tools_parser_multiple_tools_first_only_v1() -> None:
|
||||||
|
message = AIMessageV1(
|
||||||
|
content=[],
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"type": "tool_call",
|
||||||
|
"id": "call_other",
|
||||||
|
"name": "other",
|
||||||
|
"args": {"b": 2},
|
||||||
|
},
|
||||||
|
{"type": "tool_call", "id": "call_func", "name": "func", "args": {"a": 1}},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test with return_id=True
|
||||||
|
parser = JsonOutputKeyToolsParser(
|
||||||
|
key_name="func", first_tool_only=True, return_id=True
|
||||||
|
)
|
||||||
|
output = parser.parse_result(message)
|
||||||
|
|
||||||
|
# Should return the func tool call, not None
|
||||||
|
assert output is not None
|
||||||
|
assert output["type"] == "func"
|
||||||
|
assert output["args"] == {"a": 1}
|
||||||
|
assert "id" in output
|
||||||
|
|
||||||
|
# Test with return_id=False
|
||||||
|
parser_no_id = JsonOutputKeyToolsParser(
|
||||||
|
key_name="func", first_tool_only=True, return_id=False
|
||||||
|
)
|
||||||
|
output_no_id = parser_no_id.parse_result(message)
|
||||||
|
|
||||||
|
# Should return just the args
|
||||||
|
assert output_no_id == {"a": 1}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
||||||
def test_json_output_key_tools_parser_multiple_tools_no_match(
|
def test_json_output_key_tools_parser_multiple_tools_no_match(
|
||||||
*, use_tool_calls: bool
|
*, use_tool_calls: bool
|
||||||
@ -583,6 +722,44 @@ def test_json_output_key_tools_parser_multiple_tools_no_match(
|
|||||||
assert output_no_id is None
|
assert output_no_id is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_json_output_key_tools_parser_multiple_tools_no_match_v1() -> None:
|
||||||
|
message = AIMessageV1(
|
||||||
|
content=[],
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"type": "tool_call",
|
||||||
|
"id": "call_other",
|
||||||
|
"name": "other",
|
||||||
|
"args": {"b": 2},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "tool_call",
|
||||||
|
"id": "call_another",
|
||||||
|
"name": "another",
|
||||||
|
"args": {"c": 3},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test with return_id=True, first_tool_only=True
|
||||||
|
parser = JsonOutputKeyToolsParser(
|
||||||
|
key_name="nonexistent", first_tool_only=True, return_id=True
|
||||||
|
)
|
||||||
|
output = parser.parse_result(message)
|
||||||
|
|
||||||
|
# Should return None when no matches
|
||||||
|
assert output is None
|
||||||
|
|
||||||
|
# Test with return_id=False, first_tool_only=True
|
||||||
|
parser_no_id = JsonOutputKeyToolsParser(
|
||||||
|
key_name="nonexistent", first_tool_only=True, return_id=False
|
||||||
|
)
|
||||||
|
output_no_id = parser_no_id.parse_result(message)
|
||||||
|
|
||||||
|
# Should return None when no matches
|
||||||
|
assert output_no_id is None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
||||||
def test_json_output_key_tools_parser_multiple_matching_tools(
|
def test_json_output_key_tools_parser_multiple_matching_tools(
|
||||||
*, use_tool_calls: bool
|
*, use_tool_calls: bool
|
||||||
@ -643,6 +820,42 @@ def test_json_output_key_tools_parser_multiple_matching_tools(
|
|||||||
assert output_all[1]["args"] == {"a": 3}
|
assert output_all[1]["args"] == {"a": 3}
|
||||||
|
|
||||||
|
|
||||||
|
def test_json_output_key_tools_parser_multiple_matching_tools_v1() -> None:
|
||||||
|
message = AIMessageV1(
|
||||||
|
content=[],
|
||||||
|
tool_calls=[
|
||||||
|
{"type": "tool_call", "id": "call_func1", "name": "func", "args": {"a": 1}},
|
||||||
|
{
|
||||||
|
"type": "tool_call",
|
||||||
|
"id": "call_other",
|
||||||
|
"name": "other",
|
||||||
|
"args": {"b": 2},
|
||||||
|
},
|
||||||
|
{"type": "tool_call", "id": "call_func2", "name": "func", "args": {"a": 3}},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test with first_tool_only=True - should return first matching
|
||||||
|
parser = JsonOutputKeyToolsParser(
|
||||||
|
key_name="func", first_tool_only=True, return_id=True
|
||||||
|
)
|
||||||
|
output = parser.parse_result(message)
|
||||||
|
|
||||||
|
assert output is not None
|
||||||
|
assert output["type"] == "func"
|
||||||
|
assert output["args"] == {"a": 1} # First matching tool call
|
||||||
|
|
||||||
|
# Test with first_tool_only=False - should return all matching
|
||||||
|
parser_all = JsonOutputKeyToolsParser(
|
||||||
|
key_name="func", first_tool_only=False, return_id=True
|
||||||
|
)
|
||||||
|
output_all = parser_all.parse_result(message)
|
||||||
|
|
||||||
|
assert len(output_all) == 2
|
||||||
|
assert output_all[0]["args"] == {"a": 1}
|
||||||
|
assert output_all[1]["args"] == {"a": 3}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
||||||
def test_json_output_key_tools_parser_empty_results(*, use_tool_calls: bool) -> None:
|
def test_json_output_key_tools_parser_empty_results(*, use_tool_calls: bool) -> None:
|
||||||
def create_message() -> AIMessage:
|
def create_message() -> AIMessage:
|
||||||
@ -671,6 +884,35 @@ def test_json_output_key_tools_parser_empty_results(*, use_tool_calls: bool) ->
|
|||||||
assert output_all == []
|
assert output_all == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"empty_message",
|
||||||
|
[
|
||||||
|
AIMessageV1(content=[], tool_calls=[]),
|
||||||
|
AIMessageV1(content="", tool_calls=[]),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_json_output_key_tools_parser_empty_results_v1(
|
||||||
|
empty_message: AIMessageV1,
|
||||||
|
) -> None:
|
||||||
|
# Test with first_tool_only=True
|
||||||
|
parser = JsonOutputKeyToolsParser(
|
||||||
|
key_name="func", first_tool_only=True, return_id=True
|
||||||
|
)
|
||||||
|
output = parser.parse_result(empty_message)
|
||||||
|
|
||||||
|
# Should return None for empty results
|
||||||
|
assert output is None
|
||||||
|
|
||||||
|
# Test with first_tool_only=False
|
||||||
|
parser_all = JsonOutputKeyToolsParser(
|
||||||
|
key_name="func", first_tool_only=False, return_id=True
|
||||||
|
)
|
||||||
|
output_all = parser_all.parse_result(empty_message)
|
||||||
|
|
||||||
|
# Should return empty list for empty results
|
||||||
|
assert output_all == []
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
||||||
def test_json_output_key_tools_parser_parameter_combinations(
|
def test_json_output_key_tools_parser_parameter_combinations(
|
||||||
*, use_tool_calls: bool
|
*, use_tool_calls: bool
|
||||||
@ -746,6 +988,56 @@ def test_json_output_key_tools_parser_parameter_combinations(
|
|||||||
assert output4 == [{"a": 1}, {"a": 3}]
|
assert output4 == [{"a": 1}, {"a": 3}]
|
||||||
|
|
||||||
|
|
||||||
|
def test_json_output_key_tools_parser_parameter_combinations_v1() -> None:
|
||||||
|
"""Test all parameter combinations of JsonOutputKeyToolsParser."""
|
||||||
|
result = AIMessageV1(
|
||||||
|
content=[],
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"type": "tool_call",
|
||||||
|
"id": "call_other",
|
||||||
|
"name": "other",
|
||||||
|
"args": {"b": 2},
|
||||||
|
},
|
||||||
|
{"type": "tool_call", "id": "call_func1", "name": "func", "args": {"a": 1}},
|
||||||
|
{"type": "tool_call", "id": "call_func2", "name": "func", "args": {"a": 3}},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test: first_tool_only=True, return_id=True
|
||||||
|
parser1 = JsonOutputKeyToolsParser(
|
||||||
|
key_name="func", first_tool_only=True, return_id=True
|
||||||
|
)
|
||||||
|
output1 = parser1.parse_result(result)
|
||||||
|
assert output1["type"] == "func"
|
||||||
|
assert output1["args"] == {"a": 1}
|
||||||
|
assert "id" in output1
|
||||||
|
|
||||||
|
# Test: first_tool_only=True, return_id=False
|
||||||
|
parser2 = JsonOutputKeyToolsParser(
|
||||||
|
key_name="func", first_tool_only=True, return_id=False
|
||||||
|
)
|
||||||
|
output2 = parser2.parse_result(result)
|
||||||
|
assert output2 == {"a": 1}
|
||||||
|
|
||||||
|
# Test: first_tool_only=False, return_id=True
|
||||||
|
parser3 = JsonOutputKeyToolsParser(
|
||||||
|
key_name="func", first_tool_only=False, return_id=True
|
||||||
|
)
|
||||||
|
output3 = parser3.parse_result(result)
|
||||||
|
assert len(output3) == 2
|
||||||
|
assert all("id" in item for item in output3)
|
||||||
|
assert output3[0]["args"] == {"a": 1}
|
||||||
|
assert output3[1]["args"] == {"a": 3}
|
||||||
|
|
||||||
|
# Test: first_tool_only=False, return_id=False
|
||||||
|
parser4 = JsonOutputKeyToolsParser(
|
||||||
|
key_name="func", first_tool_only=False, return_id=False
|
||||||
|
)
|
||||||
|
output4 = parser4.parse_result(result)
|
||||||
|
assert output4 == [{"a": 1}, {"a": 3}]
|
||||||
|
|
||||||
|
|
||||||
class Person(BaseModel):
|
class Person(BaseModel):
|
||||||
age: int
|
age: int
|
||||||
hair_color: str
|
hair_color: str
|
||||||
@ -788,6 +1080,18 @@ def test_partial_pydantic_output_parser() -> None:
|
|||||||
assert actual == EXPECTED_STREAMED_PYDANTIC
|
assert actual == EXPECTED_STREAMED_PYDANTIC
|
||||||
|
|
||||||
|
|
||||||
|
def test_partial_pydantic_output_parser_v1() -> None:
|
||||||
|
def input_iter(_: Any) -> Iterator[AIMessageChunkV1]:
|
||||||
|
yield from STREAMED_MESSAGES_V1
|
||||||
|
|
||||||
|
chain = input_iter | PydanticToolsParser(
|
||||||
|
tools=[NameCollector], first_tool_only=True
|
||||||
|
)
|
||||||
|
|
||||||
|
actual = list(chain.stream(None))
|
||||||
|
assert actual == EXPECTED_STREAMED_PYDANTIC
|
||||||
|
|
||||||
|
|
||||||
async def test_partial_pydantic_output_parser_async() -> None:
|
async def test_partial_pydantic_output_parser_async() -> None:
|
||||||
for use_tool_calls in [False, True]:
|
for use_tool_calls in [False, True]:
|
||||||
input_iter = _get_aiter(use_tool_calls=use_tool_calls)
|
input_iter = _get_aiter(use_tool_calls=use_tool_calls)
|
||||||
@ -800,6 +1104,19 @@ async def test_partial_pydantic_output_parser_async() -> None:
|
|||||||
assert actual == EXPECTED_STREAMED_PYDANTIC
|
assert actual == EXPECTED_STREAMED_PYDANTIC
|
||||||
|
|
||||||
|
|
||||||
|
async def test_partial_pydantic_output_parser_async_v1() -> None:
|
||||||
|
async def input_iter(_: Any) -> AsyncIterator[AIMessageChunkV1]:
|
||||||
|
for msg in STREAMED_MESSAGES_V1:
|
||||||
|
yield msg
|
||||||
|
|
||||||
|
chain = input_iter | PydanticToolsParser(
|
||||||
|
tools=[NameCollector], first_tool_only=True
|
||||||
|
)
|
||||||
|
|
||||||
|
actual = [p async for p in chain.astream(None)]
|
||||||
|
assert actual == EXPECTED_STREAMED_PYDANTIC
|
||||||
|
|
||||||
|
|
||||||
def test_parse_with_different_pydantic_2_v1() -> None:
|
def test_parse_with_different_pydantic_2_v1() -> None:
|
||||||
"""Test with pydantic.v1.BaseModel from pydantic 2."""
|
"""Test with pydantic.v1.BaseModel from pydantic 2."""
|
||||||
import pydantic
|
import pydantic
|
||||||
@ -870,10 +1187,12 @@ def test_parse_with_different_pydantic_2_proper() -> None:
|
|||||||
|
|
||||||
def test_max_tokens_error(caplog: Any) -> None:
|
def test_max_tokens_error(caplog: Any) -> None:
|
||||||
parser = PydanticToolsParser(tools=[NameCollector], first_tool_only=True)
|
parser = PydanticToolsParser(tools=[NameCollector], first_tool_only=True)
|
||||||
message = AIMessage(
|
for msg_class in [AIMessage, AIMessageV1]:
|
||||||
|
message = msg_class(
|
||||||
content="",
|
content="",
|
||||||
tool_calls=[
|
tool_calls=[
|
||||||
{
|
{
|
||||||
|
"type": "tool_call",
|
||||||
"id": "call_OwL7f5PE",
|
"id": "call_OwL7f5PE",
|
||||||
"name": "NameCollector",
|
"name": "NameCollector",
|
||||||
"args": {"names": ["suz", "jerm"]},
|
"args": {"names": ["suz", "jerm"]},
|
||||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -8,6 +8,8 @@ from langchain_core.messages import (
|
|||||||
AIMessage,
|
AIMessage,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
)
|
)
|
||||||
|
from langchain_core.messages.utils import convert_from_v1_message
|
||||||
|
from langchain_core.messages.v1 import AIMessage as AIMessageV1
|
||||||
from langchain_core.outputs import ChatGeneration, Generation
|
from langchain_core.outputs import ChatGeneration, Generation
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
@ -83,10 +85,12 @@ class OpenAIFunctionsAgentOutputParser(AgentOutputParser):
|
|||||||
@override
|
@override
|
||||||
def parse_result(
|
def parse_result(
|
||||||
self,
|
self,
|
||||||
result: list[Generation],
|
result: Union[list[Generation], AIMessageV1],
|
||||||
*,
|
*,
|
||||||
partial: bool = False,
|
partial: bool = False,
|
||||||
) -> Union[AgentAction, AgentFinish]:
|
) -> Union[AgentAction, AgentFinish]:
|
||||||
|
if isinstance(result, AIMessageV1):
|
||||||
|
result = [ChatGeneration(message=convert_from_v1_message(result))]
|
||||||
if not isinstance(result[0], ChatGeneration):
|
if not isinstance(result[0], ChatGeneration):
|
||||||
msg = "This output parser only works on ChatGeneration output"
|
msg = "This output parser only works on ChatGeneration output"
|
||||||
raise ValueError(msg) # noqa: TRY004
|
raise ValueError(msg) # noqa: TRY004
|
||||||
|
@ -2,6 +2,8 @@ from typing import Union
|
|||||||
|
|
||||||
from langchain_core.agents import AgentAction, AgentFinish
|
from langchain_core.agents import AgentAction, AgentFinish
|
||||||
from langchain_core.messages import BaseMessage
|
from langchain_core.messages import BaseMessage
|
||||||
|
from langchain_core.messages.utils import convert_from_v1_message
|
||||||
|
from langchain_core.messages.v1 import AIMessage as AIMessageV1
|
||||||
from langchain_core.outputs import ChatGeneration, Generation
|
from langchain_core.outputs import ChatGeneration, Generation
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
@ -57,10 +59,12 @@ class OpenAIToolsAgentOutputParser(MultiActionAgentOutputParser):
|
|||||||
@override
|
@override
|
||||||
def parse_result(
|
def parse_result(
|
||||||
self,
|
self,
|
||||||
result: list[Generation],
|
result: Union[list[Generation], AIMessageV1],
|
||||||
*,
|
*,
|
||||||
partial: bool = False,
|
partial: bool = False,
|
||||||
) -> Union[list[AgentAction], AgentFinish]:
|
) -> Union[list[AgentAction], AgentFinish]:
|
||||||
|
if isinstance(result, AIMessageV1):
|
||||||
|
result = [ChatGeneration(message=convert_from_v1_message(result))]
|
||||||
if not isinstance(result[0], ChatGeneration):
|
if not isinstance(result[0], ChatGeneration):
|
||||||
msg = "This output parser only works on ChatGeneration output"
|
msg = "This output parser only works on ChatGeneration output"
|
||||||
raise ValueError(msg) # noqa: TRY004
|
raise ValueError(msg) # noqa: TRY004
|
||||||
|
@ -9,6 +9,8 @@ from langchain_core.messages import (
|
|||||||
BaseMessage,
|
BaseMessage,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
)
|
)
|
||||||
|
from langchain_core.messages.utils import convert_from_v1_message
|
||||||
|
from langchain_core.messages.v1 import AIMessage as AIMessageV1
|
||||||
from langchain_core.outputs import ChatGeneration, Generation
|
from langchain_core.outputs import ChatGeneration, Generation
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
@ -101,10 +103,12 @@ class ToolsAgentOutputParser(MultiActionAgentOutputParser):
|
|||||||
@override
|
@override
|
||||||
def parse_result(
|
def parse_result(
|
||||||
self,
|
self,
|
||||||
result: list[Generation],
|
result: Union[list[Generation], AIMessageV1],
|
||||||
*,
|
*,
|
||||||
partial: bool = False,
|
partial: bool = False,
|
||||||
) -> Union[list[AgentAction], AgentFinish]:
|
) -> Union[list[AgentAction], AgentFinish]:
|
||||||
|
if isinstance(result, AIMessageV1):
|
||||||
|
result = [ChatGeneration(message=convert_from_v1_message(result))]
|
||||||
if not isinstance(result[0], ChatGeneration):
|
if not isinstance(result[0], ChatGeneration):
|
||||||
msg = "This output parser only works on ChatGeneration output"
|
msg = "This output parser only works on ChatGeneration output"
|
||||||
raise ValueError(msg) # noqa: TRY004
|
raise ValueError(msg) # noqa: TRY004
|
||||||
|
Loading…
Reference in New Issue
Block a user