anthropic: update pydantic (#26000)

Migrated with gritql: https://github.com/eyurtsev/migrate-pydantic
This commit is contained in:
ccurme
2024-09-04 13:35:51 -04:00
committed by GitHub
parent 7cee7fbfad
commit 163d6fe8ef
12 changed files with 84 additions and 110 deletions

View File

@@ -49,12 +49,6 @@ from langchain_core.output_parsers import (
) )
from langchain_core.output_parsers.base import OutputParserLike from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import (
BaseModel,
Field,
SecretStr,
root_validator,
)
from langchain_core.runnables import ( from langchain_core.runnables import (
Runnable, Runnable,
RunnableMap, RunnableMap,
@@ -69,7 +63,15 @@ from langchain_core.utils import (
) )
from langchain_core.utils.function_calling import convert_to_openai_tool from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.pydantic import is_basemodel_subclass from langchain_core.utils.pydantic import is_basemodel_subclass
from typing_extensions import NotRequired from pydantic import (
BaseModel,
ConfigDict,
Field,
PrivateAttr,
SecretStr,
model_validator,
)
from typing_extensions import NotRequired, Self
from langchain_anthropic.output_parsers import extract_tool_calls from langchain_anthropic.output_parsers import extract_tool_calls
@@ -114,7 +116,7 @@ def _merge_messages(
"""Merge runs of human/tool messages into single human messages with content blocks.""" # noqa: E501 """Merge runs of human/tool messages into single human messages with content blocks.""" # noqa: E501
merged: list = [] merged: list = []
for curr in messages: for curr in messages:
curr = curr.copy(deep=True) curr = curr.model_copy(deep=True)
if isinstance(curr, ToolMessage): if isinstance(curr, ToolMessage):
if isinstance(curr.content, list) and all( if isinstance(curr.content, list) and all(
isinstance(block, dict) and block.get("type") == "tool_result" isinstance(block, dict) and block.get("type") == "tool_result"
@@ -383,7 +385,7 @@ class ChatAnthropic(BaseChatModel):
Tool calling: Tool calling:
.. code-block:: python .. code-block:: python
from langchain_core.pydantic_v1 import BaseModel, Field from pydantic import BaseModel, Field
class GetWeather(BaseModel): class GetWeather(BaseModel):
'''Get the current weather in a given location''' '''Get the current weather in a given location'''
@@ -421,7 +423,7 @@ class ChatAnthropic(BaseChatModel):
from typing import Optional from typing import Optional
from langchain_core.pydantic_v1 import BaseModel, Field from pydantic import BaseModel, Field
class Joke(BaseModel): class Joke(BaseModel):
'''Joke to tell user.''' '''Joke to tell user.'''
@@ -508,13 +510,12 @@ class ChatAnthropic(BaseChatModel):
""" # noqa: E501 """ # noqa: E501
class Config: model_config = ConfigDict(
"""Configuration for this pydantic object.""" populate_by_name=True,
)
allow_population_by_field_name = True _client: anthropic.Client = PrivateAttr(default=None)
_async_client: anthropic.AsyncClient = PrivateAttr(default=None)
_client: anthropic.Client = Field(default=None)
_async_client: anthropic.AsyncClient = Field(default=None)
model: str = Field(alias="model_name") model: str = Field(alias="model_name")
"""Model name to use.""" """Model name to use."""
@@ -626,8 +627,9 @@ class ChatAnthropic(BaseChatModel):
ls_params["ls_stop"] = ls_stop ls_params["ls_stop"] = ls_stop
return ls_params return ls_params
@root_validator(pre=True) @model_validator(mode="before")
def build_extra(cls, values: Dict) -> Dict: @classmethod
def build_extra(cls, values: Dict) -> Any:
extra = values.get("model_kwargs", {}) extra = values.get("model_kwargs", {})
all_required_field_names = get_pydantic_field_names(cls) all_required_field_names = get_pydantic_field_names(cls)
values["model_kwargs"] = build_extra_kwargs( values["model_kwargs"] = build_extra_kwargs(
@@ -635,28 +637,25 @@ class ChatAnthropic(BaseChatModel):
) )
return values return values
@root_validator(pre=False, skip_on_failure=True) @model_validator(mode="after")
def post_init(cls, values: Dict) -> Dict: def post_init(self) -> Self:
api_key = values["anthropic_api_key"].get_secret_value() api_key = self.anthropic_api_key.get_secret_value()
api_url = values["anthropic_api_url"] api_url = self.anthropic_api_url
client_params = { client_params: Dict[str, Any] = {
"api_key": api_key, "api_key": api_key,
"base_url": api_url, "base_url": api_url,
"max_retries": values["max_retries"], "max_retries": self.max_retries,
"default_headers": values.get("default_headers"), "default_headers": (self.default_headers or None),
} }
# value <= 0 indicates the param should be ignored. None is a meaningful value # value <= 0 indicates the param should be ignored. None is a meaningful value
# for Anthropic client and treated differently than not specifying the param at # for Anthropic client and treated differently than not specifying the param at
# all. # all.
if ( if self.default_request_timeout is None or self.default_request_timeout > 0:
values["default_request_timeout"] is None client_params["timeout"] = self.default_request_timeout
or values["default_request_timeout"] > 0
):
client_params["timeout"] = values["default_request_timeout"]
values["_client"] = anthropic.Client(**client_params) self._client = anthropic.Client(**client_params)
values["_async_client"] = anthropic.AsyncClient(**client_params) self._async_client = anthropic.AsyncClient(**client_params)
return values return self
def _get_request_payload( def _get_request_payload(
self, self,
@@ -825,7 +824,7 @@ class ChatAnthropic(BaseChatModel):
.. code-block:: python .. code-block:: python
from langchain_anthropic import ChatAnthropic from langchain_anthropic import ChatAnthropic
from langchain_core.pydantic_v1 import BaseModel, Field from pydantic import BaseModel, Field
class GetWeather(BaseModel): class GetWeather(BaseModel):
'''Get the current weather in a given location''' '''Get the current weather in a given location'''
@@ -854,7 +853,7 @@ class ChatAnthropic(BaseChatModel):
.. code-block:: python .. code-block:: python
from langchain_anthropic import ChatAnthropic from langchain_anthropic import ChatAnthropic
from langchain_core.pydantic_v1 import BaseModel, Field from pydantic import BaseModel, Field
class GetWeather(BaseModel): class GetWeather(BaseModel):
'''Get the current weather in a given location''' '''Get the current weather in a given location'''
@@ -876,7 +875,7 @@ class ChatAnthropic(BaseChatModel):
.. code-block:: python .. code-block:: python
from langchain_anthropic import ChatAnthropic from langchain_anthropic import ChatAnthropic
from langchain_core.pydantic_v1 import BaseModel, Field from pydantic import BaseModel, Field
class GetWeather(BaseModel): class GetWeather(BaseModel):
'''Get the current weather in a given location''' '''Get the current weather in a given location'''
@@ -897,7 +896,7 @@ class ChatAnthropic(BaseChatModel):
.. code-block:: python .. code-block:: python
from langchain_anthropic import ChatAnthropic, convert_to_anthropic_tool from langchain_anthropic import ChatAnthropic, convert_to_anthropic_tool
from langchain_core.pydantic_v1 import BaseModel, Field from pydantic import BaseModel, Field
class GetWeather(BaseModel): class GetWeather(BaseModel):
'''Get the current weather in a given location''' '''Get the current weather in a given location'''
@@ -1007,7 +1006,7 @@ class ChatAnthropic(BaseChatModel):
.. code-block:: python .. code-block:: python
from langchain_anthropic import ChatAnthropic from langchain_anthropic import ChatAnthropic
from langchain_core.pydantic_v1 import BaseModel from pydantic import BaseModel
class AnswerWithJustification(BaseModel): class AnswerWithJustification(BaseModel):
'''An answer to the user question along with justification for the answer.''' '''An answer to the user question along with justification for the answer.'''
@@ -1028,7 +1027,7 @@ class ChatAnthropic(BaseChatModel):
.. code-block:: python .. code-block:: python
from langchain_anthropic import ChatAnthropic from langchain_anthropic import ChatAnthropic
from langchain_core.pydantic_v1 import BaseModel from pydantic import BaseModel
class AnswerWithJustification(BaseModel): class AnswerWithJustification(BaseModel):
'''An answer to the user question along with justification for the answer.''' '''An answer to the user question along with justification for the answer.'''

View File

@@ -7,7 +7,7 @@ from typing import (
) )
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.pydantic_v1 import Field from pydantic import PrivateAttr
from langchain_anthropic.chat_models import ChatAnthropic from langchain_anthropic.chat_models import ChatAnthropic
@@ -156,4 +156,4 @@ def _xml_to_tool_calls(elem: Any, tools: List[Dict]) -> List[Dict[str, Any]]:
class ChatAnthropicTools(ChatAnthropic): class ChatAnthropicTools(ChatAnthropic):
"""Chat model for interacting with Anthropic functions.""" """Chat model for interacting with Anthropic functions."""
_xmllib: Any = Field(default=None) _xmllib: Any = PrivateAttr(default=None)

View File

@@ -21,7 +21,6 @@ from langchain_core.language_models import BaseLanguageModel, LangSmithParams
from langchain_core.language_models.llms import LLM from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk from langchain_core.outputs import GenerationChunk
from langchain_core.prompt_values import PromptValue from langchain_core.prompt_values import PromptValue
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import ( from langchain_core.utils import (
get_pydantic_field_names, get_pydantic_field_names,
) )
@@ -30,6 +29,8 @@ from langchain_core.utils.utils import (
from_env, from_env,
secret_from_env, secret_from_env,
) )
from pydantic import ConfigDict, Field, SecretStr, model_validator
from typing_extensions import Self
class _AnthropicCommon(BaseLanguageModel): class _AnthropicCommon(BaseLanguageModel):
@@ -84,8 +85,9 @@ class _AnthropicCommon(BaseLanguageModel):
count_tokens: Optional[Callable[[str], int]] = None count_tokens: Optional[Callable[[str], int]] = None
model_kwargs: Dict[str, Any] = Field(default_factory=dict) model_kwargs: Dict[str, Any] = Field(default_factory=dict)
@root_validator(pre=True) @model_validator(mode="before")
def build_extra(cls, values: Dict) -> Dict: @classmethod
def build_extra(cls, values: Dict) -> Any:
extra = values.get("model_kwargs", {}) extra = values.get("model_kwargs", {})
all_required_field_names = get_pydantic_field_names(cls) all_required_field_names = get_pydantic_field_names(cls)
values["model_kwargs"] = build_extra_kwargs( values["model_kwargs"] = build_extra_kwargs(
@@ -93,25 +95,25 @@ class _AnthropicCommon(BaseLanguageModel):
) )
return values return values
@root_validator(pre=False, skip_on_failure=True) @model_validator(mode="after")
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(self) -> Self:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
values["client"] = anthropic.Anthropic( self.client = anthropic.Anthropic(
base_url=values["anthropic_api_url"], base_url=self.anthropic_api_url,
api_key=values["anthropic_api_key"].get_secret_value(), api_key=self.anthropic_api_key.get_secret_value(),
timeout=values["default_request_timeout"], timeout=self.default_request_timeout,
max_retries=values["max_retries"], max_retries=self.max_retries,
) )
values["async_client"] = anthropic.AsyncAnthropic( self.async_client = anthropic.AsyncAnthropic(
base_url=values["anthropic_api_url"], base_url=self.anthropic_api_url,
api_key=values["anthropic_api_key"].get_secret_value(), api_key=self.anthropic_api_key.get_secret_value(),
timeout=values["default_request_timeout"], timeout=self.default_request_timeout,
max_retries=values["max_retries"], max_retries=self.max_retries,
) )
values["HUMAN_PROMPT"] = anthropic.HUMAN_PROMPT self.HUMAN_PROMPT = anthropic.HUMAN_PROMPT
values["AI_PROMPT"] = anthropic.AI_PROMPT self.AI_PROMPT = anthropic.AI_PROMPT
values["count_tokens"] = values["client"].count_tokens self.count_tokens = self.client.count_tokens
return values return self
@property @property
def _default_params(self) -> Mapping[str, Any]: def _default_params(self) -> Mapping[str, Any]:
@@ -160,14 +162,14 @@ class AnthropicLLM(LLM, _AnthropicCommon):
model = AnthropicLLM() model = AnthropicLLM()
""" """
class Config: model_config = ConfigDict(
"""Configuration for this pydantic object.""" populate_by_name=True,
arbitrary_types_allowed=True,
)
allow_population_by_field_name = True @model_validator(mode="before")
arbitrary_types_allowed = True @classmethod
def raise_warning(cls, values: Dict) -> Any:
@root_validator(pre=True)
def raise_warning(cls, values: Dict) -> Dict:
"""Raise warning that this class is deprecated.""" """Raise warning that this class is deprecated."""
warnings.warn( warnings.warn(
"This Anthropic LLM is deprecated. " "This Anthropic LLM is deprecated. "

View File

@@ -4,7 +4,7 @@ from langchain_core.messages import AIMessage, ToolCall
from langchain_core.messages.tool import tool_call from langchain_core.messages.tool import tool_call
from langchain_core.output_parsers import BaseGenerationOutputParser from langchain_core.output_parsers import BaseGenerationOutputParser
from langchain_core.outputs import ChatGeneration, Generation from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.pydantic_v1 import BaseModel from pydantic import BaseModel, ConfigDict
class ToolsOutputParser(BaseGenerationOutputParser): class ToolsOutputParser(BaseGenerationOutputParser):
@@ -17,8 +17,9 @@ class ToolsOutputParser(BaseGenerationOutputParser):
pydantic_schemas: Optional[List[Type[BaseModel]]] = None pydantic_schemas: Optional[List[Type[BaseModel]]] = None
"""Pydantic schemas to parse tool calls into.""" """Pydantic schemas to parse tool calls into."""
class Config: model_config = ConfigDict(
extra = "forbid" extra="forbid",
)
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
"""Parse a list of candidate model Generations into a specific format. """Parse a list of candidate model Generations into a specific format.

View File

@@ -11,9 +11,6 @@ files = [
{file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"}, {file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"},
] ]
[package.dependencies]
typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.9\""}
[[package]] [[package]]
name = "anthropic" name = "anthropic"
version = "0.33.1" version = "0.33.1"
@@ -513,7 +510,7 @@ files = [
[[package]] [[package]]
name = "langchain-core" name = "langchain-core"
version = "0.2.30" version = "0.2.38"
description = "Building applications with LLMs through composability" description = "Building applications with LLMs through composability"
optional = false optional = false
python-versions = ">=3.8.1,<4.0" python-versions = ">=3.8.1,<4.0"
@@ -1318,5 +1315,5 @@ watchmedo = ["PyYAML (>=3.10)"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.8.1,<4.0" python-versions = ">=3.9,<4.0"
content-hash = "8e57ed7d6701a7ae5881fb1e5be89fc53ba81b98f04a9b54ffb0032d33589960" content-hash = "5263e046aef60dc42d7b0af41ec35ce48282f88d3ae8ce71003a54382fe4a1e6"

View File

@@ -19,9 +19,10 @@ disallow_untyped_defs = "True"
"Release Notes" = "https://github.com/langchain-ai/langchain/releases?q=tag%3A%22langchain-anthropic%3D%3D0%22&expanded=true" "Release Notes" = "https://github.com/langchain-ai/langchain/releases?q=tag%3A%22langchain-anthropic%3D%3D0%22&expanded=true"
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = ">=3.8.1,<4.0" python = ">=3.9,<4.0"
anthropic = ">=0.30.0,<1" anthropic = ">=0.30.0,<1"
langchain-core = "^0.2.26" langchain-core = "^0.2.26"
pydantic = ">=2,<3"
[tool.ruff.lint] [tool.ruff.lint]
select = [ "E", "F", "I", "T201",] select = [ "E", "F", "I", "T201",]

View File

@@ -1,27 +0,0 @@
#!/bin/bash
#
# This script searches for lines starting with "import pydantic" or "from pydantic"
# in tracked files within a Git repository.
#
# Usage: ./scripts/check_pydantic.sh /path/to/repository
# Check if a path argument is provided
if [ $# -ne 1 ]; then
echo "Usage: $0 /path/to/repository"
exit 1
fi
repository_path="$1"
# Search for lines matching the pattern within the specified repository
result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic')
# Check if any matching lines were found
if [ -n "$result" ]; then
echo "ERROR: The following lines need to be updated:"
echo "$result"
echo "Please replace the code with an import from langchain_core.pydantic_v1."
echo "For example, replace 'from pydantic import BaseModel'"
echo "with 'from langchain_core.pydantic_v1 import BaseModel'"
exit 1
fi

View File

@@ -16,8 +16,8 @@ from langchain_core.messages import (
) )
from langchain_core.outputs import ChatGeneration, LLMResult from langchain_core.outputs import ChatGeneration, LLMResult
from langchain_core.prompts import ChatPromptTemplate from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.tools import tool from langchain_core.tools import tool
from pydantic import BaseModel, Field
from langchain_anthropic import ChatAnthropic, ChatAnthropicMessages from langchain_anthropic import ChatAnthropic, ChatAnthropicMessages
from tests.unit_tests._utils import FakeCallbackHandler from tests.unit_tests._utils import FakeCallbackHandler

View File

@@ -4,7 +4,7 @@ from enum import Enum
from typing import List, Optional from typing import List, Optional
from langchain_core.prompts import ChatPromptTemplate from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field from pydantic import BaseModel, Field
from langchain_anthropic.experimental import ChatAnthropicTools from langchain_anthropic.experimental import ChatAnthropicTools

View File

@@ -3,7 +3,7 @@
from typing import Any, Union from typing import Any, Union
from langchain_core.callbacks import BaseCallbackHandler from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.pydantic_v1 import BaseModel from pydantic import BaseModel
class BaseFakeCallbackHandler(BaseModel): class BaseFakeCallbackHandler(BaseModel):
@@ -251,5 +251,6 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any: ) -> Any:
self.on_retriever_error_common() self.on_retriever_error_common()
def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler": # Overriding since BaseModel has __deepcopy__ method as well
def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler": # type: ignore
return self return self

View File

@@ -7,9 +7,9 @@ import pytest
from anthropic.types import Message, TextBlock, Usage from anthropic.types import Message, TextBlock, Usage
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.outputs import ChatGeneration, ChatResult from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr
from langchain_core.runnables import RunnableBinding from langchain_core.runnables import RunnableBinding
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from pydantic import BaseModel, Field, SecretStr
from pytest import CaptureFixture, MonkeyPatch from pytest import CaptureFixture, MonkeyPatch
from langchain_anthropic import ChatAnthropic from langchain_anthropic import ChatAnthropic

View File

@@ -2,7 +2,7 @@ from typing import Any, List, Literal
from langchain_core.messages import AIMessage from langchain_core.messages import AIMessage
from langchain_core.outputs import ChatGeneration from langchain_core.outputs import ChatGeneration
from langchain_core.pydantic_v1 import BaseModel from pydantic import BaseModel
from langchain_anthropic.output_parsers import ToolsOutputParser from langchain_anthropic.output_parsers import ToolsOutputParser