mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 23:00:00 +00:00
rfc: AIMessage.parsed and with_structured_output(..., tools=[])
This commit is contained in:
parent
5111063af2
commit
d85ece9fe3
@ -2,7 +2,7 @@ import json
|
||||
import operator
|
||||
from typing import Any, Literal, Optional, Union, cast
|
||||
|
||||
from pydantic import model_validator
|
||||
from pydantic import BaseModel, model_validator
|
||||
from typing_extensions import NotRequired, Self, TypedDict
|
||||
|
||||
from langchain_core.messages.base import (
|
||||
@ -166,6 +166,7 @@ class AIMessage(BaseMessage):
|
||||
|
||||
type: Literal["ai"] = "ai"
|
||||
"""The type of the message (used for deserialization). Defaults to "ai"."""
|
||||
parsed: Optional[Union[dict, BaseModel]] = None
|
||||
|
||||
def __init__(
|
||||
self, content: Union[str, list[Union[str, dict]]], **kwargs: Any
|
||||
@ -440,6 +441,17 @@ def add_ai_message_chunks(
|
||||
else:
|
||||
usage_metadata = None
|
||||
|
||||
has_parsed = [m for m in ([left, *others]) if m.parsed]
|
||||
if len(has_parsed) >= 2:
|
||||
msg = (
|
||||
"Cannot concatenate two AIMessageChunks with non-null 'parsed' attributes."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
elif len(has_parsed) == 1:
|
||||
parsed = has_parsed[0].parsed
|
||||
else:
|
||||
parsed = None
|
||||
|
||||
return left.__class__(
|
||||
example=left.example,
|
||||
content=content,
|
||||
@ -448,6 +460,7 @@ def add_ai_message_chunks(
|
||||
response_metadata=response_metadata,
|
||||
usage_metadata=usage_metadata,
|
||||
id=left.id,
|
||||
parsed=parsed,
|
||||
)
|
||||
|
||||
|
||||
|
@ -61,10 +61,12 @@ class BaseLLMOutputParser(Generic[T], ABC):
|
||||
|
||||
|
||||
class BaseGenerationOutputParser(
|
||||
BaseLLMOutputParser, RunnableSerializable[LanguageModelOutput, T]
|
||||
BaseLLMOutputParser, RunnableSerializable[LanguageModelOutput, Union[AnyMessage, T]]
|
||||
):
|
||||
"""Base class to parse the output of an LLM call."""
|
||||
|
||||
return_message: bool = False
|
||||
|
||||
@property
|
||||
@override
|
||||
def InputType(self) -> Any:
|
||||
@ -73,11 +75,14 @@ class BaseGenerationOutputParser(
|
||||
|
||||
@property
|
||||
@override
|
||||
def OutputType(self) -> type[T]:
|
||||
def OutputType(self) -> Union[type[AnyMessage], type[T]]:
|
||||
"""Return the output type for the parser."""
|
||||
# even though mypy complains this isn't valid,
|
||||
# it is good enough for pydantic to build the schema from
|
||||
return T # type: ignore[misc]
|
||||
if self.return_message:
|
||||
return AnyMessage
|
||||
else:
|
||||
# even though mypy complains this isn't valid,
|
||||
# it is good enough for pydantic to build the schema from
|
||||
return T # type: ignore[misc]
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
@ -86,7 +91,7 @@ class BaseGenerationOutputParser(
|
||||
**kwargs: Any,
|
||||
) -> T:
|
||||
if isinstance(input, BaseMessage):
|
||||
return self._call_with_config(
|
||||
parsed = self._call_with_config(
|
||||
lambda inner_input: self.parse_result(
|
||||
[ChatGeneration(message=inner_input)]
|
||||
),
|
||||
@ -94,6 +99,8 @@ class BaseGenerationOutputParser(
|
||||
config,
|
||||
run_type="parser",
|
||||
)
|
||||
if self.return_message:
|
||||
return input.model_copy(update={"parsed": parsed})
|
||||
else:
|
||||
return self._call_with_config(
|
||||
lambda inner_input: self.parse_result([Generation(text=inner_input)]),
|
||||
@ -109,7 +116,7 @@ class BaseGenerationOutputParser(
|
||||
**kwargs: Optional[Any],
|
||||
) -> T:
|
||||
if isinstance(input, BaseMessage):
|
||||
return await self._acall_with_config(
|
||||
parsed = await self._acall_with_config(
|
||||
lambda inner_input: self.aparse_result(
|
||||
[ChatGeneration(message=inner_input)]
|
||||
),
|
||||
@ -117,6 +124,8 @@ class BaseGenerationOutputParser(
|
||||
config,
|
||||
run_type="parser",
|
||||
)
|
||||
if self.return_message:
|
||||
return input.model_copy(update={"parsed": parsed})
|
||||
else:
|
||||
return await self._acall_with_config(
|
||||
lambda inner_input: self.aparse_result([Generation(text=inner_input)]),
|
||||
@ -127,7 +136,7 @@ class BaseGenerationOutputParser(
|
||||
|
||||
|
||||
class BaseOutputParser(
|
||||
BaseLLMOutputParser, RunnableSerializable[LanguageModelOutput, T]
|
||||
BaseLLMOutputParser, RunnableSerializable[LanguageModelOutput, Union[AnyMessage, T]]
|
||||
):
|
||||
"""Base class to parse the output of an LLM call.
|
||||
|
||||
@ -155,6 +164,8 @@ class BaseOutputParser(
|
||||
return "boolean_output_parser"
|
||||
""" # noqa: E501
|
||||
|
||||
return_message: bool = False
|
||||
|
||||
@property
|
||||
@override
|
||||
def InputType(self) -> Any:
|
||||
@ -163,7 +174,7 @@ class BaseOutputParser(
|
||||
|
||||
@property
|
||||
@override
|
||||
def OutputType(self) -> type[T]:
|
||||
def OutputType(self) -> Union[type[AnyMessage], type[T]]:
|
||||
"""Return the output type for the parser.
|
||||
|
||||
This property is inferred from the first type argument of the class.
|
||||
@ -171,6 +182,9 @@ class BaseOutputParser(
|
||||
Raises:
|
||||
TypeError: If the class doesn't have an inferable OutputType.
|
||||
"""
|
||||
if self.return_message:
|
||||
return AnyMessage
|
||||
|
||||
for base in self.__class__.mro():
|
||||
if hasattr(base, "__pydantic_generic_metadata__"):
|
||||
metadata = base.__pydantic_generic_metadata__
|
||||
@ -190,7 +204,7 @@ class BaseOutputParser(
|
||||
**kwargs: Any,
|
||||
) -> T:
|
||||
if isinstance(input, BaseMessage):
|
||||
return self._call_with_config(
|
||||
parsed = self._call_with_config(
|
||||
lambda inner_input: self.parse_result(
|
||||
[ChatGeneration(message=inner_input)]
|
||||
),
|
||||
@ -198,6 +212,10 @@ class BaseOutputParser(
|
||||
config,
|
||||
run_type="parser",
|
||||
)
|
||||
if self.return_message:
|
||||
return input.model_copy(update={"parsed": parsed})
|
||||
else:
|
||||
return parsed
|
||||
else:
|
||||
return self._call_with_config(
|
||||
lambda inner_input: self.parse_result([Generation(text=inner_input)]),
|
||||
@ -213,7 +231,7 @@ class BaseOutputParser(
|
||||
**kwargs: Optional[Any],
|
||||
) -> T:
|
||||
if isinstance(input, BaseMessage):
|
||||
return await self._acall_with_config(
|
||||
parsed = await self._acall_with_config(
|
||||
lambda inner_input: self.aparse_result(
|
||||
[ChatGeneration(message=inner_input)]
|
||||
),
|
||||
@ -221,6 +239,10 @@ class BaseOutputParser(
|
||||
config,
|
||||
run_type="parser",
|
||||
)
|
||||
if self.return_message:
|
||||
return input.model_copy(update={"parsed": parsed})
|
||||
else:
|
||||
return parsed
|
||||
else:
|
||||
return await self._acall_with_config(
|
||||
lambda inner_input: self.aparse_result([Generation(text=inner_input)]),
|
||||
|
@ -767,6 +767,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
message = response.choices[0].message # type: ignore[attr-defined]
|
||||
if hasattr(message, "parsed"):
|
||||
generations[0].message.additional_kwargs["parsed"] = message.parsed
|
||||
cast(AIMessage, generations[0].message).parsed = message.parsed
|
||||
if hasattr(message, "refusal"):
|
||||
generations[0].message.additional_kwargs["refusal"] = message.refusal
|
||||
|
||||
@ -1144,10 +1145,18 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
method: Literal[
|
||||
"function_calling", "json_mode", "json_schema"
|
||||
] = "function_calling",
|
||||
include_raw: bool = False,
|
||||
include_raw: Union[
|
||||
bool, Literal["raw_only", "parsed_only", "raw_and_parsed"]
|
||||
] = False,
|
||||
strict: Optional[bool] = None,
|
||||
tools: Optional[
|
||||
Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]]
|
||||
] = None,
|
||||
tool_choice: Optional[
|
||||
Union[dict, str, Literal["auto", "none", "required", "any"], bool]
|
||||
] = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, _DictOrPydantic]:
|
||||
) -> Runnable[LanguageModelInput, Union[_DictOrPydantic, BaseMessage]]:
|
||||
"""Model wrapper that returns outputs formatted to match the given schema.
|
||||
|
||||
Args:
|
||||
@ -1432,12 +1441,19 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
"schema must be specified when method is not 'json_mode'. "
|
||||
"Received None."
|
||||
)
|
||||
tool_name = convert_to_openai_tool(schema)["function"]["name"]
|
||||
bind_kwargs = self._filter_disabled_params(
|
||||
tool_choice=tool_name, parallel_tool_calls=False, strict=strict
|
||||
)
|
||||
if not tools:
|
||||
tool_name = convert_to_openai_tool(schema)["function"]["name"]
|
||||
bind_kwargs = self._filter_disabled_params(
|
||||
tool_choice=tool_name, parallel_tool_calls=False, strict=strict
|
||||
)
|
||||
|
||||
llm = self.bind_tools([schema], **bind_kwargs)
|
||||
else:
|
||||
bind_kwargs = self._filter_disabled_params(
|
||||
strict=strict, tool_choice=tool_choice
|
||||
)
|
||||
llm = self.bind_tools([schema, *tools], **bind_kwargs)
|
||||
|
||||
llm = self.bind_tools([schema], **bind_kwargs)
|
||||
if is_pydantic_schema:
|
||||
output_parser: Runnable = PydanticToolsParser(
|
||||
tools=[schema], # type: ignore[list-item]
|
||||
@ -1448,7 +1464,15 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
key_name=tool_name, first_tool_only=True
|
||||
)
|
||||
elif method == "json_mode":
|
||||
llm = self.bind(response_format={"type": "json_object"})
|
||||
if not tools:
|
||||
llm = self.bind(response_format={"type": "json_object"})
|
||||
else:
|
||||
bind_kwargs = self._filter_disabled_params(
|
||||
strict=strict,
|
||||
tool_choice=tool_choice,
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
llm = self.bind_tools(tools, **bind_kwargs)
|
||||
output_parser = (
|
||||
PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type]
|
||||
if is_pydantic_schema
|
||||
@ -1461,7 +1485,15 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
"Received None."
|
||||
)
|
||||
response_format = _convert_to_openai_response_format(schema, strict=strict)
|
||||
llm = self.bind(response_format=response_format)
|
||||
if not tools:
|
||||
llm = self.bind(response_format=response_format)
|
||||
else:
|
||||
bind_kwargs = self._filter_disabled_params(
|
||||
strict=strict,
|
||||
tool_choice=tool_choice,
|
||||
response_format=response_format,
|
||||
)
|
||||
llm = self.bind_tools(tools, **bind_kwargs)
|
||||
if is_pydantic_schema:
|
||||
output_parser = _oai_structured_outputs_parser.with_types(
|
||||
output_type=cast(type, schema)
|
||||
@ -1474,7 +1506,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
f"'json_mode'. Received: '{method}'"
|
||||
)
|
||||
|
||||
if include_raw:
|
||||
if include_raw is True or include_raw == "raw_and_parsed":
|
||||
parser_assign = RunnablePassthrough.assign(
|
||||
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
|
||||
)
|
||||
@ -1483,6 +1515,8 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
[parser_none], exception_key="parsing_error"
|
||||
)
|
||||
return RunnableMap(raw=llm) | parser_with_fallback
|
||||
elif include_raw == "raw_only":
|
||||
return llm
|
||||
else:
|
||||
return llm | output_parser
|
||||
|
||||
@ -2174,7 +2208,9 @@ def _convert_to_openai_response_format(
|
||||
|
||||
@chain
|
||||
def _oai_structured_outputs_parser(ai_msg: AIMessage) -> PydanticBaseModel:
|
||||
if ai_msg.additional_kwargs.get("parsed"):
|
||||
if ai_msg.parsed:
|
||||
return cast(PydanticBaseModel, ai_msg.parsed)
|
||||
elif ai_msg.additional_kwargs.get("parsed"):
|
||||
return ai_msg.additional_kwargs["parsed"]
|
||||
elif ai_msg.additional_kwargs.get("refusal"):
|
||||
raise OpenAIRefusalError(ai_msg.additional_kwargs["refusal"])
|
||||
|
Loading…
Reference in New Issue
Block a user