diff --git a/libs/core/langchain_core/output_parsers/base.py b/libs/core/langchain_core/output_parsers/base.py index 9d080cef300..0760c1cf7da 100644 --- a/libs/core/langchain_core/output_parsers/base.py +++ b/libs/core/langchain_core/output_parsers/base.py @@ -9,6 +9,7 @@ from typing import ( Optional, TypeVar, Union, + cast, ) from typing_extensions import override @@ -65,6 +66,8 @@ class BaseGenerationOutputParser( ): """Base class to parse the output of an LLM call.""" + return_message: bool = False + @property @override def InputType(self) -> Any: @@ -75,9 +78,12 @@ class BaseGenerationOutputParser( @override def OutputType(self) -> type[T]: """Return the output type for the parser.""" - # even though mypy complains this isn't valid, - # it is good enough for pydantic to build the schema from - return T # type: ignore[misc] + if self.return_message: + return cast(type[T], AnyMessage) + else: + # even though mypy complains this isn't valid, + # it is good enough for pydantic to build the schema from + return T # type: ignore[misc] def invoke( self, @@ -86,7 +92,7 @@ class BaseGenerationOutputParser( **kwargs: Any, ) -> T: if isinstance(input, BaseMessage): - return self._call_with_config( + parsed = self._call_with_config( lambda inner_input: self.parse_result( [ChatGeneration(message=inner_input)] ), @@ -94,6 +100,10 @@ class BaseGenerationOutputParser( config, run_type="parser", ) + if self.return_message: + return cast(T, input.model_copy(update={"parsed": parsed})) + else: + return parsed else: return self._call_with_config( lambda inner_input: self.parse_result([Generation(text=inner_input)]), @@ -109,7 +119,7 @@ class BaseGenerationOutputParser( **kwargs: Optional[Any], ) -> T: if isinstance(input, BaseMessage): - return await self._acall_with_config( + parsed = await self._acall_with_config( lambda inner_input: self.aparse_result( [ChatGeneration(message=inner_input)] ), @@ -117,6 +127,10 @@ class BaseGenerationOutputParser( config, run_type="parser", ) + if self.return_message: + return cast(T, input.model_copy(update={"parsed": parsed})) + else: + return parsed else: return await self._acall_with_config( lambda inner_input: self.aparse_result([Generation(text=inner_input)]), @@ -155,6 +169,8 @@ class BaseOutputParser( return "boolean_output_parser" """ # noqa: E501 + return_message: bool = False + @property @override def InputType(self) -> Any: @@ -171,6 +187,9 @@ class BaseOutputParser( Raises: TypeError: If the class doesn't have an inferable OutputType. """ + if self.return_message: + return cast(type[T], AnyMessage) + for base in self.__class__.mro(): if hasattr(base, "__pydantic_generic_metadata__"): metadata = base.__pydantic_generic_metadata__ @@ -190,7 +209,7 @@ class BaseOutputParser( **kwargs: Any, ) -> T: if isinstance(input, BaseMessage): - return self._call_with_config( + parsed = self._call_with_config( lambda inner_input: self.parse_result( [ChatGeneration(message=inner_input)] ), @@ -198,6 +217,10 @@ class BaseOutputParser( config, run_type="parser", ) + if self.return_message: + return cast(T, input.model_copy(update={"parsed": parsed})) + else: + return parsed else: return self._call_with_config( lambda inner_input: self.parse_result([Generation(text=inner_input)]), @@ -213,7 +236,7 @@ class BaseOutputParser( **kwargs: Optional[Any], ) -> T: if isinstance(input, BaseMessage): - return await self._acall_with_config( + parsed = await self._acall_with_config( lambda inner_input: self.aparse_result( [ChatGeneration(message=inner_input)] ), @@ -221,6 +244,10 @@ class BaseOutputParser( config, run_type="parser", ) + if self.return_message: + return cast(T, input.model_copy(update={"parsed": parsed})) + else: + return parsed else: return await self._acall_with_config( lambda inner_input: self.aparse_result([Generation(text=inner_input)]), diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index c600b5c3fd8..119ec24e32e 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -10,7 +10,6 @@ import sys import warnings from io import BytesIO from math import ceil -from operator import itemgetter from typing import ( Any, AsyncIterator, @@ -85,11 +84,7 @@ from langchain_core.utils.function_calling import ( convert_to_openai_function, convert_to_openai_tool, ) -from langchain_core.utils.pydantic import ( - PydanticBaseModel, - TypeBaseModel, - is_basemodel_subclass, -) +from langchain_core.utils.pydantic import TypeBaseModel, is_basemodel_subclass from langchain_core.utils.utils import _build_model_kwargs, from_env, secret_from_env from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator from typing_extensions import Self @@ -777,7 +772,7 @@ class BaseChatOpenAI(BaseChatModel): ): message = response.choices[0].message # type: ignore[attr-defined] if hasattr(message, "parsed"): - generations[0].message.parsed = message.parsed + cast(AIMessage, generations[0].message).parsed = message.parsed # For backwards compatibility. generations[0].message.additional_kwargs["parsed"] = message.parsed if hasattr(message, "refusal"): @@ -1474,17 +1469,18 @@ class BaseChatOpenAI(BaseChatModel): output_parser: Runnable = PydanticToolsParser( tools=[schema], # type: ignore[list-item] first_tool_only=True, # type: ignore[list-item] + return_message=True, ) else: output_parser = JsonOutputKeyToolsParser( - key_name=tool_name, first_tool_only=True + key_name=tool_name, first_tool_only=True, return_message=True ) elif method == "json_mode": llm = self.bind(response_format={"type": "json_object"}) output_parser = ( - PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type] + PydanticOutputParser(pydantic_object=schema, return_message=True) # type: ignore[arg-type] if is_pydantic_schema - else JsonOutputParser() + else JsonOutputParser(return_message=True) ) elif method == "json_schema": if schema is None: @@ -1496,10 +1492,10 @@ class BaseChatOpenAI(BaseChatModel): llm = self.bind(response_format=response_format) if is_pydantic_schema: output_parser = _oai_structured_outputs_parser.with_types( - output_type=cast(type, schema) + output_type=AIMessage ) else: - output_parser = JsonOutputParser() + output_parser = JsonOutputParser(return_message=True) else: raise ValueError( f"Unrecognized method argument. Expected one of 'function_calling' or " @@ -1507,8 +1503,8 @@ class BaseChatOpenAI(BaseChatModel): ) if include_raw: - parser_assign = RunnablePassthrough.assign( - parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None + parser_assign = RunnablePassthrough.assign(raw=output_parser).assign( + parsed=lambda x: x["raw"].parsed, parsing_error=lambda _: None ) parser_none = RunnablePassthrough.assign(parsed=lambda _: None) parser_with_fallback = parser_assign.with_fallbacks( @@ -2231,15 +2227,15 @@ def _convert_to_openai_response_format( @chain -def _oai_structured_outputs_parser(ai_msg: AIMessage) -> PydanticBaseModel: - if ai_msg.additional_kwargs.get("parsed"): - return ai_msg.additional_kwargs["parsed"] +def _oai_structured_outputs_parser(ai_msg: AIMessage) -> AIMessage: + if ai_msg.parsed is not None: + return ai_msg elif ai_msg.additional_kwargs.get("refusal"): raise OpenAIRefusalError(ai_msg.additional_kwargs["refusal"]) else: raise ValueError( "Structured Output response does not have a 'parsed' field nor a 'refusal' " - "field. Received message:\n\n{ai_msg}" + f"field. Received message:\n\n{ai_msg}" )