mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-05 20:58:25 +00:00
fmt
This commit is contained in:
parent
47b386d28f
commit
87d8012ef6
@ -9,6 +9,7 @@ from typing import (
|
|||||||
Optional,
|
Optional,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
@ -65,6 +66,8 @@ class BaseGenerationOutputParser(
|
|||||||
):
|
):
|
||||||
"""Base class to parse the output of an LLM call."""
|
"""Base class to parse the output of an LLM call."""
|
||||||
|
|
||||||
|
return_message: bool = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@override
|
@override
|
||||||
def InputType(self) -> Any:
|
def InputType(self) -> Any:
|
||||||
@ -75,9 +78,12 @@ class BaseGenerationOutputParser(
|
|||||||
@override
|
@override
|
||||||
def OutputType(self) -> type[T]:
|
def OutputType(self) -> type[T]:
|
||||||
"""Return the output type for the parser."""
|
"""Return the output type for the parser."""
|
||||||
# even though mypy complains this isn't valid,
|
if self.return_message:
|
||||||
# it is good enough for pydantic to build the schema from
|
return cast(type[T], AnyMessage)
|
||||||
return T # type: ignore[misc]
|
else:
|
||||||
|
# even though mypy complains this isn't valid,
|
||||||
|
# it is good enough for pydantic to build the schema from
|
||||||
|
return T # type: ignore[misc]
|
||||||
|
|
||||||
def invoke(
|
def invoke(
|
||||||
self,
|
self,
|
||||||
@ -86,7 +92,7 @@ class BaseGenerationOutputParser(
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> T:
|
) -> T:
|
||||||
if isinstance(input, BaseMessage):
|
if isinstance(input, BaseMessage):
|
||||||
return self._call_with_config(
|
parsed = self._call_with_config(
|
||||||
lambda inner_input: self.parse_result(
|
lambda inner_input: self.parse_result(
|
||||||
[ChatGeneration(message=inner_input)]
|
[ChatGeneration(message=inner_input)]
|
||||||
),
|
),
|
||||||
@ -94,6 +100,10 @@ class BaseGenerationOutputParser(
|
|||||||
config,
|
config,
|
||||||
run_type="parser",
|
run_type="parser",
|
||||||
)
|
)
|
||||||
|
if self.return_message:
|
||||||
|
return cast(T, input.model_copy(update={"parsed": parsed}))
|
||||||
|
else:
|
||||||
|
return parsed
|
||||||
else:
|
else:
|
||||||
return self._call_with_config(
|
return self._call_with_config(
|
||||||
lambda inner_input: self.parse_result([Generation(text=inner_input)]),
|
lambda inner_input: self.parse_result([Generation(text=inner_input)]),
|
||||||
@ -109,7 +119,7 @@ class BaseGenerationOutputParser(
|
|||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> T:
|
) -> T:
|
||||||
if isinstance(input, BaseMessage):
|
if isinstance(input, BaseMessage):
|
||||||
return await self._acall_with_config(
|
parsed = await self._acall_with_config(
|
||||||
lambda inner_input: self.aparse_result(
|
lambda inner_input: self.aparse_result(
|
||||||
[ChatGeneration(message=inner_input)]
|
[ChatGeneration(message=inner_input)]
|
||||||
),
|
),
|
||||||
@ -117,6 +127,10 @@ class BaseGenerationOutputParser(
|
|||||||
config,
|
config,
|
||||||
run_type="parser",
|
run_type="parser",
|
||||||
)
|
)
|
||||||
|
if self.return_message:
|
||||||
|
return cast(T, input.model_copy(update={"parsed": parsed}))
|
||||||
|
else:
|
||||||
|
return parsed
|
||||||
else:
|
else:
|
||||||
return await self._acall_with_config(
|
return await self._acall_with_config(
|
||||||
lambda inner_input: self.aparse_result([Generation(text=inner_input)]),
|
lambda inner_input: self.aparse_result([Generation(text=inner_input)]),
|
||||||
@ -155,6 +169,8 @@ class BaseOutputParser(
|
|||||||
return "boolean_output_parser"
|
return "boolean_output_parser"
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
|
|
||||||
|
return_message: bool = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@override
|
@override
|
||||||
def InputType(self) -> Any:
|
def InputType(self) -> Any:
|
||||||
@ -171,6 +187,9 @@ class BaseOutputParser(
|
|||||||
Raises:
|
Raises:
|
||||||
TypeError: If the class doesn't have an inferable OutputType.
|
TypeError: If the class doesn't have an inferable OutputType.
|
||||||
"""
|
"""
|
||||||
|
if self.return_message:
|
||||||
|
return cast(type[T], AnyMessage)
|
||||||
|
|
||||||
for base in self.__class__.mro():
|
for base in self.__class__.mro():
|
||||||
if hasattr(base, "__pydantic_generic_metadata__"):
|
if hasattr(base, "__pydantic_generic_metadata__"):
|
||||||
metadata = base.__pydantic_generic_metadata__
|
metadata = base.__pydantic_generic_metadata__
|
||||||
@ -190,7 +209,7 @@ class BaseOutputParser(
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> T:
|
) -> T:
|
||||||
if isinstance(input, BaseMessage):
|
if isinstance(input, BaseMessage):
|
||||||
return self._call_with_config(
|
parsed = self._call_with_config(
|
||||||
lambda inner_input: self.parse_result(
|
lambda inner_input: self.parse_result(
|
||||||
[ChatGeneration(message=inner_input)]
|
[ChatGeneration(message=inner_input)]
|
||||||
),
|
),
|
||||||
@ -198,6 +217,10 @@ class BaseOutputParser(
|
|||||||
config,
|
config,
|
||||||
run_type="parser",
|
run_type="parser",
|
||||||
)
|
)
|
||||||
|
if self.return_message:
|
||||||
|
return cast(T, input.model_copy(update={"parsed": parsed}))
|
||||||
|
else:
|
||||||
|
return parsed
|
||||||
else:
|
else:
|
||||||
return self._call_with_config(
|
return self._call_with_config(
|
||||||
lambda inner_input: self.parse_result([Generation(text=inner_input)]),
|
lambda inner_input: self.parse_result([Generation(text=inner_input)]),
|
||||||
@ -213,7 +236,7 @@ class BaseOutputParser(
|
|||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> T:
|
) -> T:
|
||||||
if isinstance(input, BaseMessage):
|
if isinstance(input, BaseMessage):
|
||||||
return await self._acall_with_config(
|
parsed = await self._acall_with_config(
|
||||||
lambda inner_input: self.aparse_result(
|
lambda inner_input: self.aparse_result(
|
||||||
[ChatGeneration(message=inner_input)]
|
[ChatGeneration(message=inner_input)]
|
||||||
),
|
),
|
||||||
@ -221,6 +244,10 @@ class BaseOutputParser(
|
|||||||
config,
|
config,
|
||||||
run_type="parser",
|
run_type="parser",
|
||||||
)
|
)
|
||||||
|
if self.return_message:
|
||||||
|
return cast(T, input.model_copy(update={"parsed": parsed}))
|
||||||
|
else:
|
||||||
|
return parsed
|
||||||
else:
|
else:
|
||||||
return await self._acall_with_config(
|
return await self._acall_with_config(
|
||||||
lambda inner_input: self.aparse_result([Generation(text=inner_input)]),
|
lambda inner_input: self.aparse_result([Generation(text=inner_input)]),
|
||||||
|
@ -10,7 +10,6 @@ import sys
|
|||||||
import warnings
|
import warnings
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from math import ceil
|
from math import ceil
|
||||||
from operator import itemgetter
|
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
AsyncIterator,
|
AsyncIterator,
|
||||||
@ -85,11 +84,7 @@ from langchain_core.utils.function_calling import (
|
|||||||
convert_to_openai_function,
|
convert_to_openai_function,
|
||||||
convert_to_openai_tool,
|
convert_to_openai_tool,
|
||||||
)
|
)
|
||||||
from langchain_core.utils.pydantic import (
|
from langchain_core.utils.pydantic import TypeBaseModel, is_basemodel_subclass
|
||||||
PydanticBaseModel,
|
|
||||||
TypeBaseModel,
|
|
||||||
is_basemodel_subclass,
|
|
||||||
)
|
|
||||||
from langchain_core.utils.utils import _build_model_kwargs, from_env, secret_from_env
|
from langchain_core.utils.utils import _build_model_kwargs, from_env, secret_from_env
|
||||||
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
|
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
@ -777,7 +772,7 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
):
|
):
|
||||||
message = response.choices[0].message # type: ignore[attr-defined]
|
message = response.choices[0].message # type: ignore[attr-defined]
|
||||||
if hasattr(message, "parsed"):
|
if hasattr(message, "parsed"):
|
||||||
generations[0].message.parsed = message.parsed
|
cast(AIMessage, generations[0].message).parsed = message.parsed
|
||||||
# For backwards compatibility.
|
# For backwards compatibility.
|
||||||
generations[0].message.additional_kwargs["parsed"] = message.parsed
|
generations[0].message.additional_kwargs["parsed"] = message.parsed
|
||||||
if hasattr(message, "refusal"):
|
if hasattr(message, "refusal"):
|
||||||
@ -1474,17 +1469,18 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
output_parser: Runnable = PydanticToolsParser(
|
output_parser: Runnable = PydanticToolsParser(
|
||||||
tools=[schema], # type: ignore[list-item]
|
tools=[schema], # type: ignore[list-item]
|
||||||
first_tool_only=True, # type: ignore[list-item]
|
first_tool_only=True, # type: ignore[list-item]
|
||||||
|
return_message=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
output_parser = JsonOutputKeyToolsParser(
|
output_parser = JsonOutputKeyToolsParser(
|
||||||
key_name=tool_name, first_tool_only=True
|
key_name=tool_name, first_tool_only=True, return_message=True
|
||||||
)
|
)
|
||||||
elif method == "json_mode":
|
elif method == "json_mode":
|
||||||
llm = self.bind(response_format={"type": "json_object"})
|
llm = self.bind(response_format={"type": "json_object"})
|
||||||
output_parser = (
|
output_parser = (
|
||||||
PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type]
|
PydanticOutputParser(pydantic_object=schema, return_message=True) # type: ignore[arg-type]
|
||||||
if is_pydantic_schema
|
if is_pydantic_schema
|
||||||
else JsonOutputParser()
|
else JsonOutputParser(return_message=True)
|
||||||
)
|
)
|
||||||
elif method == "json_schema":
|
elif method == "json_schema":
|
||||||
if schema is None:
|
if schema is None:
|
||||||
@ -1496,10 +1492,10 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
llm = self.bind(response_format=response_format)
|
llm = self.bind(response_format=response_format)
|
||||||
if is_pydantic_schema:
|
if is_pydantic_schema:
|
||||||
output_parser = _oai_structured_outputs_parser.with_types(
|
output_parser = _oai_structured_outputs_parser.with_types(
|
||||||
output_type=cast(type, schema)
|
output_type=AIMessage
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
output_parser = JsonOutputParser()
|
output_parser = JsonOutputParser(return_message=True)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unrecognized method argument. Expected one of 'function_calling' or "
|
f"Unrecognized method argument. Expected one of 'function_calling' or "
|
||||||
@ -1507,8 +1503,8 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if include_raw:
|
if include_raw:
|
||||||
parser_assign = RunnablePassthrough.assign(
|
parser_assign = RunnablePassthrough.assign(raw=output_parser).assign(
|
||||||
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
|
parsed=lambda x: x["raw"].parsed, parsing_error=lambda _: None
|
||||||
)
|
)
|
||||||
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
|
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
|
||||||
parser_with_fallback = parser_assign.with_fallbacks(
|
parser_with_fallback = parser_assign.with_fallbacks(
|
||||||
@ -2231,15 +2227,15 @@ def _convert_to_openai_response_format(
|
|||||||
|
|
||||||
|
|
||||||
@chain
|
@chain
|
||||||
def _oai_structured_outputs_parser(ai_msg: AIMessage) -> PydanticBaseModel:
|
def _oai_structured_outputs_parser(ai_msg: AIMessage) -> AIMessage:
|
||||||
if ai_msg.additional_kwargs.get("parsed"):
|
if ai_msg.parsed is not None:
|
||||||
return ai_msg.additional_kwargs["parsed"]
|
return ai_msg
|
||||||
elif ai_msg.additional_kwargs.get("refusal"):
|
elif ai_msg.additional_kwargs.get("refusal"):
|
||||||
raise OpenAIRefusalError(ai_msg.additional_kwargs["refusal"])
|
raise OpenAIRefusalError(ai_msg.additional_kwargs["refusal"])
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Structured Output response does not have a 'parsed' field nor a 'refusal' "
|
"Structured Output response does not have a 'parsed' field nor a 'refusal' "
|
||||||
"field. Received message:\n\n{ai_msg}"
|
f"field. Received message:\n\n{ai_msg}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user