core,anthropic[patch]: fix with_structured_output typing (#28950)

This commit is contained in:
Bagatur 2024-12-28 15:46:51 -05:00 committed by GitHub
parent ccf69368b4
commit edbe7d5f5e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 10 additions and 11 deletions

View File

@ -233,7 +233,7 @@ class BaseLanguageModel(
""" """
def with_structured_output( def with_structured_output(
self, schema: Union[dict, type[BaseModel]], **kwargs: Any self, schema: Union[dict, type], **kwargs: Any
) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]: ) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]:
"""Not implemented on this class.""" """Not implemented on this class."""
# Implement this on child class if there is a way of steering the model to # Implement this on child class if there is a way of steering the model to

View File

@ -1128,7 +1128,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
The output schema. Can be passed in as: The output schema. Can be passed in as:
- an OpenAI function/tool schema, - an OpenAI function/tool schema,
- a JSON Schema, - a JSON Schema,
- a TypedDict class (support added in 0.2.26), - a TypedDict class,
- or a Pydantic class. - or a Pydantic class.
If ``schema`` is a Pydantic class then the model output will be a If ``schema`` is a Pydantic class then the model output will be a
Pydantic instance of that class, and the model-generated fields will be Pydantic instance of that class, and the model-generated fields will be
@ -1137,10 +1137,6 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
for more on how to properly specify types and descriptions of for more on how to properly specify types and descriptions of
schema fields when specifying a Pydantic or TypedDict class. schema fields when specifying a Pydantic or TypedDict class.
.. versionchanged:: 0.2.26
Added support for TypedDict class.
include_raw: include_raw:
If False then only the parsed structured output is returned. If If False then only the parsed structured output is returned. If
an error occurs during model output parsing it will be raised. If True an error occurs during model output parsing it will be raised. If True
@ -1222,6 +1218,10 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
# 'answer': 'They weigh the same', # 'answer': 'They weigh the same',
# 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume and density of the two substances differ.' # 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume and density of the two substances differ.'
# } # }
.. versionchanged:: 0.2.26
Added support for TypedDict class.
""" # noqa: E501 """ # noqa: E501
if kwargs: if kwargs:
msg = f"Received unsupported arguments {kwargs}" msg = f"Received unsupported arguments {kwargs}"

View File

@ -28,7 +28,7 @@ from langchain_core.utils import get_pydantic_field_names
class StructuredPrompt(ChatPromptTemplate): class StructuredPrompt(ChatPromptTemplate):
"""Structured prompt template for a language model.""" """Structured prompt template for a language model."""
schema_: Union[dict, type[BaseModel]] schema_: Union[dict, type]
"""Schema for the structured prompt.""" """Schema for the structured prompt."""
structured_output_kwargs: dict[str, Any] = Field(default_factory=dict) structured_output_kwargs: dict[str, Any] = Field(default_factory=dict)
@ -66,7 +66,7 @@ class StructuredPrompt(ChatPromptTemplate):
def from_messages_and_schema( def from_messages_and_schema(
cls, cls,
messages: Sequence[MessageLikeRepresentation], messages: Sequence[MessageLikeRepresentation],
schema: Union[dict, type[BaseModel]], schema: Union[dict, type],
**kwargs: Any, **kwargs: Any,
) -> ChatPromptTemplate: ) -> ChatPromptTemplate:
"""Create a chat prompt template from a variety of message formats. """Create a chat prompt template from a variety of message formats.

View File

@ -16,7 +16,6 @@ from typing import (
Sequence, Sequence,
Tuple, Tuple,
Type, Type,
TypedDict,
Union, Union,
cast, cast,
) )
@ -72,7 +71,7 @@ from pydantic import (
SecretStr, SecretStr,
model_validator, model_validator,
) )
from typing_extensions import NotRequired from typing_extensions import NotRequired, TypedDict
from langchain_anthropic.output_parsers import extract_tool_calls from langchain_anthropic.output_parsers import extract_tool_calls
@ -973,7 +972,7 @@ class ChatAnthropic(BaseChatModel):
def with_structured_output( def with_structured_output(
self, self,
schema: Union[Dict, Type[BaseModel]], schema: Union[Dict, type],
*, *,
include_raw: bool = False, include_raw: bool = False,
**kwargs: Any, **kwargs: Any,