rfc: AIMessage.parsed and with_structured_output(..., tools=[])

This commit is contained in:
Bagatur 2024-10-29 19:35:51 -07:00
parent 5111063af2
commit d85ece9fe3
3 changed files with 94 additions and 23 deletions

View File

@ -2,7 +2,7 @@ import json
import operator import operator
from typing import Any, Literal, Optional, Union, cast from typing import Any, Literal, Optional, Union, cast
from pydantic import model_validator from pydantic import BaseModel, model_validator
from typing_extensions import NotRequired, Self, TypedDict from typing_extensions import NotRequired, Self, TypedDict
from langchain_core.messages.base import ( from langchain_core.messages.base import (
@ -166,6 +166,7 @@ class AIMessage(BaseMessage):
type: Literal["ai"] = "ai" type: Literal["ai"] = "ai"
"""The type of the message (used for deserialization). Defaults to "ai".""" """The type of the message (used for deserialization). Defaults to "ai"."""
parsed: Optional[Union[dict, BaseModel]] = None
def __init__( def __init__(
self, content: Union[str, list[Union[str, dict]]], **kwargs: Any self, content: Union[str, list[Union[str, dict]]], **kwargs: Any
@ -440,6 +441,17 @@ def add_ai_message_chunks(
else: else:
usage_metadata = None usage_metadata = None
has_parsed = [m for m in ([left, *others]) if m.parsed]
if len(has_parsed) >= 2:
msg = (
"Cannot concatenate two AIMessageChunks with non-null 'parsed' attributes."
)
raise ValueError(msg)
elif len(has_parsed) == 1:
parsed = has_parsed[0].parsed
else:
parsed = None
return left.__class__( return left.__class__(
example=left.example, example=left.example,
content=content, content=content,
@ -448,6 +460,7 @@ def add_ai_message_chunks(
response_metadata=response_metadata, response_metadata=response_metadata,
usage_metadata=usage_metadata, usage_metadata=usage_metadata,
id=left.id, id=left.id,
parsed=parsed,
) )

View File

@ -61,10 +61,12 @@ class BaseLLMOutputParser(Generic[T], ABC):
class BaseGenerationOutputParser( class BaseGenerationOutputParser(
BaseLLMOutputParser, RunnableSerializable[LanguageModelOutput, T] BaseLLMOutputParser, RunnableSerializable[LanguageModelOutput, Union[AnyMessage, T]]
): ):
"""Base class to parse the output of an LLM call.""" """Base class to parse the output of an LLM call."""
return_message: bool = False
@property @property
@override @override
def InputType(self) -> Any: def InputType(self) -> Any:
@ -73,11 +75,14 @@ class BaseGenerationOutputParser(
@property @property
@override @override
def OutputType(self) -> type[T]: def OutputType(self) -> Union[type[AnyMessage], type[T]]:
"""Return the output type for the parser.""" """Return the output type for the parser."""
# even though mypy complains this isn't valid, if self.return_message:
# it is good enough for pydantic to build the schema from return AnyMessage
return T # type: ignore[misc] else:
# even though mypy complains this isn't valid,
# it is good enough for pydantic to build the schema from
return T # type: ignore[misc]
def invoke( def invoke(
self, self,
@ -86,7 +91,7 @@ class BaseGenerationOutputParser(
**kwargs: Any, **kwargs: Any,
) -> T: ) -> T:
if isinstance(input, BaseMessage): if isinstance(input, BaseMessage):
return self._call_with_config( parsed = self._call_with_config(
lambda inner_input: self.parse_result( lambda inner_input: self.parse_result(
[ChatGeneration(message=inner_input)] [ChatGeneration(message=inner_input)]
), ),
@ -94,6 +99,8 @@ class BaseGenerationOutputParser(
config, config,
run_type="parser", run_type="parser",
) )
if self.return_message:
return input.model_copy(update={"parsed": parsed})
else: else:
return self._call_with_config( return self._call_with_config(
lambda inner_input: self.parse_result([Generation(text=inner_input)]), lambda inner_input: self.parse_result([Generation(text=inner_input)]),
@ -109,7 +116,7 @@ class BaseGenerationOutputParser(
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> T: ) -> T:
if isinstance(input, BaseMessage): if isinstance(input, BaseMessage):
return await self._acall_with_config( parsed = await self._acall_with_config(
lambda inner_input: self.aparse_result( lambda inner_input: self.aparse_result(
[ChatGeneration(message=inner_input)] [ChatGeneration(message=inner_input)]
), ),
@ -117,6 +124,8 @@ class BaseGenerationOutputParser(
config, config,
run_type="parser", run_type="parser",
) )
if self.return_message:
return input.model_copy(update={"parsed": parsed})
else: else:
return await self._acall_with_config( return await self._acall_with_config(
lambda inner_input: self.aparse_result([Generation(text=inner_input)]), lambda inner_input: self.aparse_result([Generation(text=inner_input)]),
@ -127,7 +136,7 @@ class BaseGenerationOutputParser(
class BaseOutputParser( class BaseOutputParser(
BaseLLMOutputParser, RunnableSerializable[LanguageModelOutput, T] BaseLLMOutputParser, RunnableSerializable[LanguageModelOutput, Union[AnyMessage, T]]
): ):
"""Base class to parse the output of an LLM call. """Base class to parse the output of an LLM call.
@ -155,6 +164,8 @@ class BaseOutputParser(
return "boolean_output_parser" return "boolean_output_parser"
""" # noqa: E501 """ # noqa: E501
return_message: bool = False
@property @property
@override @override
def InputType(self) -> Any: def InputType(self) -> Any:
@ -163,7 +174,7 @@ class BaseOutputParser(
@property @property
@override @override
def OutputType(self) -> type[T]: def OutputType(self) -> Union[type[AnyMessage], type[T]]:
"""Return the output type for the parser. """Return the output type for the parser.
This property is inferred from the first type argument of the class. This property is inferred from the first type argument of the class.
@ -171,6 +182,9 @@ class BaseOutputParser(
Raises: Raises:
TypeError: If the class doesn't have an inferable OutputType. TypeError: If the class doesn't have an inferable OutputType.
""" """
if self.return_message:
return AnyMessage
for base in self.__class__.mro(): for base in self.__class__.mro():
if hasattr(base, "__pydantic_generic_metadata__"): if hasattr(base, "__pydantic_generic_metadata__"):
metadata = base.__pydantic_generic_metadata__ metadata = base.__pydantic_generic_metadata__
@ -190,7 +204,7 @@ class BaseOutputParser(
**kwargs: Any, **kwargs: Any,
) -> T: ) -> T:
if isinstance(input, BaseMessage): if isinstance(input, BaseMessage):
return self._call_with_config( parsed = self._call_with_config(
lambda inner_input: self.parse_result( lambda inner_input: self.parse_result(
[ChatGeneration(message=inner_input)] [ChatGeneration(message=inner_input)]
), ),
@ -198,6 +212,10 @@ class BaseOutputParser(
config, config,
run_type="parser", run_type="parser",
) )
if self.return_message:
return input.model_copy(update={"parsed": parsed})
else:
return parsed
else: else:
return self._call_with_config( return self._call_with_config(
lambda inner_input: self.parse_result([Generation(text=inner_input)]), lambda inner_input: self.parse_result([Generation(text=inner_input)]),
@ -213,7 +231,7 @@ class BaseOutputParser(
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> T: ) -> T:
if isinstance(input, BaseMessage): if isinstance(input, BaseMessage):
return await self._acall_with_config( parsed = await self._acall_with_config(
lambda inner_input: self.aparse_result( lambda inner_input: self.aparse_result(
[ChatGeneration(message=inner_input)] [ChatGeneration(message=inner_input)]
), ),
@ -221,6 +239,10 @@ class BaseOutputParser(
config, config,
run_type="parser", run_type="parser",
) )
if self.return_message:
return input.model_copy(update={"parsed": parsed})
else:
return parsed
else: else:
return await self._acall_with_config( return await self._acall_with_config(
lambda inner_input: self.aparse_result([Generation(text=inner_input)]), lambda inner_input: self.aparse_result([Generation(text=inner_input)]),

View File

@ -767,6 +767,7 @@ class BaseChatOpenAI(BaseChatModel):
message = response.choices[0].message # type: ignore[attr-defined] message = response.choices[0].message # type: ignore[attr-defined]
if hasattr(message, "parsed"): if hasattr(message, "parsed"):
generations[0].message.additional_kwargs["parsed"] = message.parsed generations[0].message.additional_kwargs["parsed"] = message.parsed
cast(AIMessage, generations[0].message).parsed = message.parsed
if hasattr(message, "refusal"): if hasattr(message, "refusal"):
generations[0].message.additional_kwargs["refusal"] = message.refusal generations[0].message.additional_kwargs["refusal"] = message.refusal
@ -1144,10 +1145,18 @@ class BaseChatOpenAI(BaseChatModel):
method: Literal[ method: Literal[
"function_calling", "json_mode", "json_schema" "function_calling", "json_mode", "json_schema"
] = "function_calling", ] = "function_calling",
include_raw: bool = False, include_raw: Union[
bool, Literal["raw_only", "parsed_only", "raw_and_parsed"]
] = False,
strict: Optional[bool] = None, strict: Optional[bool] = None,
tools: Optional[
Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]]
] = None,
tool_choice: Optional[
Union[dict, str, Literal["auto", "none", "required", "any"], bool]
] = None,
**kwargs: Any, **kwargs: Any,
) -> Runnable[LanguageModelInput, _DictOrPydantic]: ) -> Runnable[LanguageModelInput, Union[_DictOrPydantic, BaseMessage]]:
"""Model wrapper that returns outputs formatted to match the given schema. """Model wrapper that returns outputs formatted to match the given schema.
Args: Args:
@ -1432,12 +1441,19 @@ class BaseChatOpenAI(BaseChatModel):
"schema must be specified when method is not 'json_mode'. " "schema must be specified when method is not 'json_mode'. "
"Received None." "Received None."
) )
tool_name = convert_to_openai_tool(schema)["function"]["name"] if not tools:
bind_kwargs = self._filter_disabled_params( tool_name = convert_to_openai_tool(schema)["function"]["name"]
tool_choice=tool_name, parallel_tool_calls=False, strict=strict bind_kwargs = self._filter_disabled_params(
) tool_choice=tool_name, parallel_tool_calls=False, strict=strict
)
llm = self.bind_tools([schema], **bind_kwargs)
else:
bind_kwargs = self._filter_disabled_params(
strict=strict, tool_choice=tool_choice
)
llm = self.bind_tools([schema, *tools], **bind_kwargs)
llm = self.bind_tools([schema], **bind_kwargs)
if is_pydantic_schema: if is_pydantic_schema:
output_parser: Runnable = PydanticToolsParser( output_parser: Runnable = PydanticToolsParser(
tools=[schema], # type: ignore[list-item] tools=[schema], # type: ignore[list-item]
@ -1448,7 +1464,15 @@ class BaseChatOpenAI(BaseChatModel):
key_name=tool_name, first_tool_only=True key_name=tool_name, first_tool_only=True
) )
elif method == "json_mode": elif method == "json_mode":
llm = self.bind(response_format={"type": "json_object"}) if not tools:
llm = self.bind(response_format={"type": "json_object"})
else:
bind_kwargs = self._filter_disabled_params(
strict=strict,
tool_choice=tool_choice,
response_format={"type": "json_object"},
)
llm = self.bind_tools(tools, **bind_kwargs)
output_parser = ( output_parser = (
PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type] PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type]
if is_pydantic_schema if is_pydantic_schema
@ -1461,7 +1485,15 @@ class BaseChatOpenAI(BaseChatModel):
"Received None." "Received None."
) )
response_format = _convert_to_openai_response_format(schema, strict=strict) response_format = _convert_to_openai_response_format(schema, strict=strict)
llm = self.bind(response_format=response_format) if not tools:
llm = self.bind(response_format=response_format)
else:
bind_kwargs = self._filter_disabled_params(
strict=strict,
tool_choice=tool_choice,
response_format=response_format,
)
llm = self.bind_tools(tools, **bind_kwargs)
if is_pydantic_schema: if is_pydantic_schema:
output_parser = _oai_structured_outputs_parser.with_types( output_parser = _oai_structured_outputs_parser.with_types(
output_type=cast(type, schema) output_type=cast(type, schema)
@ -1474,7 +1506,7 @@ class BaseChatOpenAI(BaseChatModel):
f"'json_mode'. Received: '{method}'" f"'json_mode'. Received: '{method}'"
) )
if include_raw: if include_raw is True or include_raw == "raw_and_parsed":
parser_assign = RunnablePassthrough.assign( parser_assign = RunnablePassthrough.assign(
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
) )
@ -1483,6 +1515,8 @@ class BaseChatOpenAI(BaseChatModel):
[parser_none], exception_key="parsing_error" [parser_none], exception_key="parsing_error"
) )
return RunnableMap(raw=llm) | parser_with_fallback return RunnableMap(raw=llm) | parser_with_fallback
elif include_raw == "raw_only":
return llm
else: else:
return llm | output_parser return llm | output_parser
@ -2174,7 +2208,9 @@ def _convert_to_openai_response_format(
@chain @chain
def _oai_structured_outputs_parser(ai_msg: AIMessage) -> PydanticBaseModel: def _oai_structured_outputs_parser(ai_msg: AIMessage) -> PydanticBaseModel:
if ai_msg.additional_kwargs.get("parsed"): if ai_msg.parsed:
return cast(PydanticBaseModel, ai_msg.parsed)
elif ai_msg.additional_kwargs.get("parsed"):
return ai_msg.additional_kwargs["parsed"] return ai_msg.additional_kwargs["parsed"]
elif ai_msg.additional_kwargs.get("refusal"): elif ai_msg.additional_kwargs.get("refusal"):
raise OpenAIRefusalError(ai_msg.additional_kwargs["refusal"]) raise OpenAIRefusalError(ai_msg.additional_kwargs["refusal"])