This commit is contained in:
Bagatur 2025-01-06 18:07:09 -05:00
parent 47b386d28f
commit 87d8012ef6
2 changed files with 48 additions and 25 deletions

View File

@ -9,6 +9,7 @@ from typing import (
Optional, Optional,
TypeVar, TypeVar,
Union, Union,
cast,
) )
from typing_extensions import override from typing_extensions import override
@ -65,6 +66,8 @@ class BaseGenerationOutputParser(
): ):
"""Base class to parse the output of an LLM call.""" """Base class to parse the output of an LLM call."""
return_message: bool = False
@property @property
@override @override
def InputType(self) -> Any: def InputType(self) -> Any:
@ -75,9 +78,12 @@ class BaseGenerationOutputParser(
@override @override
def OutputType(self) -> type[T]: def OutputType(self) -> type[T]:
"""Return the output type for the parser.""" """Return the output type for the parser."""
# even though mypy complains this isn't valid, if self.return_message:
# it is good enough for pydantic to build the schema from return cast(type[T], AnyMessage)
return T # type: ignore[misc] else:
# even though mypy complains this isn't valid,
# it is good enough for pydantic to build the schema from
return T # type: ignore[misc]
def invoke( def invoke(
self, self,
@ -86,7 +92,7 @@ class BaseGenerationOutputParser(
**kwargs: Any, **kwargs: Any,
) -> T: ) -> T:
if isinstance(input, BaseMessage): if isinstance(input, BaseMessage):
return self._call_with_config( parsed = self._call_with_config(
lambda inner_input: self.parse_result( lambda inner_input: self.parse_result(
[ChatGeneration(message=inner_input)] [ChatGeneration(message=inner_input)]
), ),
@ -94,6 +100,10 @@ class BaseGenerationOutputParser(
config, config,
run_type="parser", run_type="parser",
) )
if self.return_message:
return cast(T, input.model_copy(update={"parsed": parsed}))
else:
return parsed
else: else:
return self._call_with_config( return self._call_with_config(
lambda inner_input: self.parse_result([Generation(text=inner_input)]), lambda inner_input: self.parse_result([Generation(text=inner_input)]),
@ -109,7 +119,7 @@ class BaseGenerationOutputParser(
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> T: ) -> T:
if isinstance(input, BaseMessage): if isinstance(input, BaseMessage):
return await self._acall_with_config( parsed = await self._acall_with_config(
lambda inner_input: self.aparse_result( lambda inner_input: self.aparse_result(
[ChatGeneration(message=inner_input)] [ChatGeneration(message=inner_input)]
), ),
@ -117,6 +127,10 @@ class BaseGenerationOutputParser(
config, config,
run_type="parser", run_type="parser",
) )
if self.return_message:
return cast(T, input.model_copy(update={"parsed": parsed}))
else:
return parsed
else: else:
return await self._acall_with_config( return await self._acall_with_config(
lambda inner_input: self.aparse_result([Generation(text=inner_input)]), lambda inner_input: self.aparse_result([Generation(text=inner_input)]),
@ -155,6 +169,8 @@ class BaseOutputParser(
return "boolean_output_parser" return "boolean_output_parser"
""" # noqa: E501 """ # noqa: E501
return_message: bool = False
@property @property
@override @override
def InputType(self) -> Any: def InputType(self) -> Any:
@ -171,6 +187,9 @@ class BaseOutputParser(
Raises: Raises:
TypeError: If the class doesn't have an inferable OutputType. TypeError: If the class doesn't have an inferable OutputType.
""" """
if self.return_message:
return cast(type[T], AnyMessage)
for base in self.__class__.mro(): for base in self.__class__.mro():
if hasattr(base, "__pydantic_generic_metadata__"): if hasattr(base, "__pydantic_generic_metadata__"):
metadata = base.__pydantic_generic_metadata__ metadata = base.__pydantic_generic_metadata__
@ -190,7 +209,7 @@ class BaseOutputParser(
**kwargs: Any, **kwargs: Any,
) -> T: ) -> T:
if isinstance(input, BaseMessage): if isinstance(input, BaseMessage):
return self._call_with_config( parsed = self._call_with_config(
lambda inner_input: self.parse_result( lambda inner_input: self.parse_result(
[ChatGeneration(message=inner_input)] [ChatGeneration(message=inner_input)]
), ),
@ -198,6 +217,10 @@ class BaseOutputParser(
config, config,
run_type="parser", run_type="parser",
) )
if self.return_message:
return cast(T, input.model_copy(update={"parsed": parsed}))
else:
return parsed
else: else:
return self._call_with_config( return self._call_with_config(
lambda inner_input: self.parse_result([Generation(text=inner_input)]), lambda inner_input: self.parse_result([Generation(text=inner_input)]),
@ -213,7 +236,7 @@ class BaseOutputParser(
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> T: ) -> T:
if isinstance(input, BaseMessage): if isinstance(input, BaseMessage):
return await self._acall_with_config( parsed = await self._acall_with_config(
lambda inner_input: self.aparse_result( lambda inner_input: self.aparse_result(
[ChatGeneration(message=inner_input)] [ChatGeneration(message=inner_input)]
), ),
@ -221,6 +244,10 @@ class BaseOutputParser(
config, config,
run_type="parser", run_type="parser",
) )
if self.return_message:
return cast(T, input.model_copy(update={"parsed": parsed}))
else:
return parsed
else: else:
return await self._acall_with_config( return await self._acall_with_config(
lambda inner_input: self.aparse_result([Generation(text=inner_input)]), lambda inner_input: self.aparse_result([Generation(text=inner_input)]),

View File

@ -10,7 +10,6 @@ import sys
import warnings import warnings
from io import BytesIO from io import BytesIO
from math import ceil from math import ceil
from operator import itemgetter
from typing import ( from typing import (
Any, Any,
AsyncIterator, AsyncIterator,
@ -85,11 +84,7 @@ from langchain_core.utils.function_calling import (
convert_to_openai_function, convert_to_openai_function,
convert_to_openai_tool, convert_to_openai_tool,
) )
from langchain_core.utils.pydantic import ( from langchain_core.utils.pydantic import TypeBaseModel, is_basemodel_subclass
PydanticBaseModel,
TypeBaseModel,
is_basemodel_subclass,
)
from langchain_core.utils.utils import _build_model_kwargs, from_env, secret_from_env from langchain_core.utils.utils import _build_model_kwargs, from_env, secret_from_env
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
from typing_extensions import Self from typing_extensions import Self
@ -777,7 +772,7 @@ class BaseChatOpenAI(BaseChatModel):
): ):
message = response.choices[0].message # type: ignore[attr-defined] message = response.choices[0].message # type: ignore[attr-defined]
if hasattr(message, "parsed"): if hasattr(message, "parsed"):
generations[0].message.parsed = message.parsed cast(AIMessage, generations[0].message).parsed = message.parsed
# For backwards compatibility. # For backwards compatibility.
generations[0].message.additional_kwargs["parsed"] = message.parsed generations[0].message.additional_kwargs["parsed"] = message.parsed
if hasattr(message, "refusal"): if hasattr(message, "refusal"):
@ -1474,17 +1469,18 @@ class BaseChatOpenAI(BaseChatModel):
output_parser: Runnable = PydanticToolsParser( output_parser: Runnable = PydanticToolsParser(
tools=[schema], # type: ignore[list-item] tools=[schema], # type: ignore[list-item]
first_tool_only=True, # type: ignore[list-item] first_tool_only=True, # type: ignore[list-item]
return_message=True,
) )
else: else:
output_parser = JsonOutputKeyToolsParser( output_parser = JsonOutputKeyToolsParser(
key_name=tool_name, first_tool_only=True key_name=tool_name, first_tool_only=True, return_message=True
) )
elif method == "json_mode": elif method == "json_mode":
llm = self.bind(response_format={"type": "json_object"}) llm = self.bind(response_format={"type": "json_object"})
output_parser = ( output_parser = (
PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type] PydanticOutputParser(pydantic_object=schema, return_message=True) # type: ignore[arg-type]
if is_pydantic_schema if is_pydantic_schema
else JsonOutputParser() else JsonOutputParser(return_message=True)
) )
elif method == "json_schema": elif method == "json_schema":
if schema is None: if schema is None:
@ -1496,10 +1492,10 @@ class BaseChatOpenAI(BaseChatModel):
llm = self.bind(response_format=response_format) llm = self.bind(response_format=response_format)
if is_pydantic_schema: if is_pydantic_schema:
output_parser = _oai_structured_outputs_parser.with_types( output_parser = _oai_structured_outputs_parser.with_types(
output_type=cast(type, schema) output_type=AIMessage
) )
else: else:
output_parser = JsonOutputParser() output_parser = JsonOutputParser(return_message=True)
else: else:
raise ValueError( raise ValueError(
f"Unrecognized method argument. Expected one of 'function_calling' or " f"Unrecognized method argument. Expected one of 'function_calling' or "
@ -1507,8 +1503,8 @@ class BaseChatOpenAI(BaseChatModel):
) )
if include_raw: if include_raw:
parser_assign = RunnablePassthrough.assign( parser_assign = RunnablePassthrough.assign(raw=output_parser).assign(
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None parsed=lambda x: x["raw"].parsed, parsing_error=lambda _: None
) )
parser_none = RunnablePassthrough.assign(parsed=lambda _: None) parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
parser_with_fallback = parser_assign.with_fallbacks( parser_with_fallback = parser_assign.with_fallbacks(
@ -2231,15 +2227,15 @@ def _convert_to_openai_response_format(
@chain @chain
def _oai_structured_outputs_parser(ai_msg: AIMessage) -> PydanticBaseModel: def _oai_structured_outputs_parser(ai_msg: AIMessage) -> AIMessage:
if ai_msg.additional_kwargs.get("parsed"): if ai_msg.parsed is not None:
return ai_msg.additional_kwargs["parsed"] return ai_msg
elif ai_msg.additional_kwargs.get("refusal"): elif ai_msg.additional_kwargs.get("refusal"):
raise OpenAIRefusalError(ai_msg.additional_kwargs["refusal"]) raise OpenAIRefusalError(ai_msg.additional_kwargs["refusal"])
else: else:
raise ValueError( raise ValueError(
"Structured Output response does not have a 'parsed' field nor a 'refusal' " "Structured Output response does not have a 'parsed' field nor a 'refusal' "
"field. Received message:\n\n{ai_msg}" f"field. Received message:\n\n{ai_msg}"
) )