From 163d6fe8efa16f6aee810291a2ce96454360cc40 Mon Sep 17 00:00:00 2001 From: ccurme Date: Wed, 4 Sep 2024 13:35:51 -0400 Subject: [PATCH] anthropic: update pydantic (#26000) Migrated with gritql: https://github.com/eyurtsev/migrate-pydantic --- .../langchain_anthropic/chat_models.py | 77 +++++++++---------- .../langchain_anthropic/experimental.py | 4 +- .../anthropic/langchain_anthropic/llms.py | 54 ++++++------- .../langchain_anthropic/output_parsers.py | 7 +- libs/partners/anthropic/poetry.lock | 9 +-- libs/partners/anthropic/pyproject.toml | 3 +- .../anthropic/scripts/check_pydantic.sh | 27 ------- .../integration_tests/test_chat_models.py | 2 +- .../integration_tests/test_experimental.py | 2 +- .../anthropic/tests/unit_tests/_utils.py | 5 +- .../tests/unit_tests/test_chat_models.py | 2 +- .../tests/unit_tests/test_output_parsers.py | 2 +- 12 files changed, 84 insertions(+), 110 deletions(-) delete mode 100755 libs/partners/anthropic/scripts/check_pydantic.sh diff --git a/libs/partners/anthropic/langchain_anthropic/chat_models.py b/libs/partners/anthropic/langchain_anthropic/chat_models.py index ff931efb7c4..af61eba1002 100644 --- a/libs/partners/anthropic/langchain_anthropic/chat_models.py +++ b/libs/partners/anthropic/langchain_anthropic/chat_models.py @@ -49,12 +49,6 @@ from langchain_core.output_parsers import ( ) from langchain_core.output_parsers.base import OutputParserLike from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult -from langchain_core.pydantic_v1 import ( - BaseModel, - Field, - SecretStr, - root_validator, -) from langchain_core.runnables import ( Runnable, RunnableMap, @@ -69,7 +63,15 @@ from langchain_core.utils import ( ) from langchain_core.utils.function_calling import convert_to_openai_tool from langchain_core.utils.pydantic import is_basemodel_subclass -from typing_extensions import NotRequired +from pydantic import ( + BaseModel, + ConfigDict, + Field, + PrivateAttr, + SecretStr, + model_validator, +) +from typing_extensions import NotRequired, Self from langchain_anthropic.output_parsers import extract_tool_calls @@ -114,7 +116,7 @@ def _merge_messages( """Merge runs of human/tool messages into single human messages with content blocks.""" # noqa: E501 merged: list = [] for curr in messages: - curr = curr.copy(deep=True) + curr = curr.model_copy(deep=True) if isinstance(curr, ToolMessage): if isinstance(curr.content, list) and all( isinstance(block, dict) and block.get("type") == "tool_result" @@ -383,7 +385,7 @@ class ChatAnthropic(BaseChatModel): Tool calling: .. code-block:: python - from langchain_core.pydantic_v1 import BaseModel, Field + from pydantic import BaseModel, Field class GetWeather(BaseModel): '''Get the current weather in a given location''' @@ -421,7 +423,7 @@ class ChatAnthropic(BaseChatModel): from typing import Optional - from langchain_core.pydantic_v1 import BaseModel, Field + from pydantic import BaseModel, Field class Joke(BaseModel): '''Joke to tell user.''' @@ -508,13 +510,12 @@ class ChatAnthropic(BaseChatModel): """ # noqa: E501 - class Config: - """Configuration for this pydantic object.""" + model_config = ConfigDict( + populate_by_name=True, + ) - allow_population_by_field_name = True - - _client: anthropic.Client = Field(default=None) - _async_client: anthropic.AsyncClient = Field(default=None) + _client: anthropic.Client = PrivateAttr(default=None) + _async_client: anthropic.AsyncClient = PrivateAttr(default=None) model: str = Field(alias="model_name") """Model name to use.""" @@ -626,8 +627,9 @@ class ChatAnthropic(BaseChatModel): ls_params["ls_stop"] = ls_stop return ls_params - @root_validator(pre=True) - def build_extra(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def build_extra(cls, values: Dict) -> Any: extra = values.get("model_kwargs", {}) all_required_field_names = get_pydantic_field_names(cls) values["model_kwargs"] = build_extra_kwargs( @@ -635,28 +637,25 @@ class ChatAnthropic(BaseChatModel): ) return values - @root_validator(pre=False, skip_on_failure=True) - def post_init(cls, values: Dict) -> Dict: - api_key = values["anthropic_api_key"].get_secret_value() - api_url = values["anthropic_api_url"] - client_params = { + @model_validator(mode="after") + def post_init(self) -> Self: + api_key = self.anthropic_api_key.get_secret_value() + api_url = self.anthropic_api_url + client_params: Dict[str, Any] = { "api_key": api_key, "base_url": api_url, - "max_retries": values["max_retries"], - "default_headers": values.get("default_headers"), + "max_retries": self.max_retries, + "default_headers": (self.default_headers or None), } # value <= 0 indicates the param should be ignored. None is a meaningful value # for Anthropic client and treated differently than not specifying the param at # all. - if ( - values["default_request_timeout"] is None - or values["default_request_timeout"] > 0 - ): - client_params["timeout"] = values["default_request_timeout"] + if self.default_request_timeout is None or self.default_request_timeout > 0: + client_params["timeout"] = self.default_request_timeout - values["_client"] = anthropic.Client(**client_params) - values["_async_client"] = anthropic.AsyncClient(**client_params) - return values + self._client = anthropic.Client(**client_params) + self._async_client = anthropic.AsyncClient(**client_params) + return self def _get_request_payload( self, @@ -825,7 +824,7 @@ class ChatAnthropic(BaseChatModel): .. code-block:: python from langchain_anthropic import ChatAnthropic - from langchain_core.pydantic_v1 import BaseModel, Field + from pydantic import BaseModel, Field class GetWeather(BaseModel): '''Get the current weather in a given location''' @@ -854,7 +853,7 @@ class ChatAnthropic(BaseChatModel): .. code-block:: python from langchain_anthropic import ChatAnthropic - from langchain_core.pydantic_v1 import BaseModel, Field + from pydantic import BaseModel, Field class GetWeather(BaseModel): '''Get the current weather in a given location''' @@ -876,7 +875,7 @@ class ChatAnthropic(BaseChatModel): .. code-block:: python from langchain_anthropic import ChatAnthropic - from langchain_core.pydantic_v1 import BaseModel, Field + from pydantic import BaseModel, Field class GetWeather(BaseModel): '''Get the current weather in a given location''' @@ -897,7 +896,7 @@ class ChatAnthropic(BaseChatModel): .. code-block:: python from langchain_anthropic import ChatAnthropic, convert_to_anthropic_tool - from langchain_core.pydantic_v1 import BaseModel, Field + from pydantic import BaseModel, Field class GetWeather(BaseModel): '''Get the current weather in a given location''' @@ -1007,7 +1006,7 @@ class ChatAnthropic(BaseChatModel): .. code-block:: python from langchain_anthropic import ChatAnthropic - from langchain_core.pydantic_v1 import BaseModel + from pydantic import BaseModel class AnswerWithJustification(BaseModel): '''An answer to the user question along with justification for the answer.''' @@ -1028,7 +1027,7 @@ class ChatAnthropic(BaseChatModel): .. code-block:: python from langchain_anthropic import ChatAnthropic - from langchain_core.pydantic_v1 import BaseModel + from pydantic import BaseModel class AnswerWithJustification(BaseModel): '''An answer to the user question along with justification for the answer.''' diff --git a/libs/partners/anthropic/langchain_anthropic/experimental.py b/libs/partners/anthropic/langchain_anthropic/experimental.py index 529ee2a7454..eae408862da 100644 --- a/libs/partners/anthropic/langchain_anthropic/experimental.py +++ b/libs/partners/anthropic/langchain_anthropic/experimental.py @@ -7,7 +7,7 @@ from typing import ( ) from langchain_core._api import deprecated -from langchain_core.pydantic_v1 import Field +from pydantic import PrivateAttr from langchain_anthropic.chat_models import ChatAnthropic @@ -156,4 +156,4 @@ def _xml_to_tool_calls(elem: Any, tools: List[Dict]) -> List[Dict[str, Any]]: class ChatAnthropicTools(ChatAnthropic): """Chat model for interacting with Anthropic functions.""" - _xmllib: Any = Field(default=None) + _xmllib: Any = PrivateAttr(default=None) diff --git a/libs/partners/anthropic/langchain_anthropic/llms.py b/libs/partners/anthropic/langchain_anthropic/llms.py index e5a821ffde0..99e7df965fb 100644 --- a/libs/partners/anthropic/langchain_anthropic/llms.py +++ b/libs/partners/anthropic/langchain_anthropic/llms.py @@ -21,7 +21,6 @@ from langchain_core.language_models import BaseLanguageModel, LangSmithParams from langchain_core.language_models.llms import LLM from langchain_core.outputs import GenerationChunk from langchain_core.prompt_values import PromptValue -from langchain_core.pydantic_v1 import Field, SecretStr, root_validator from langchain_core.utils import ( get_pydantic_field_names, ) @@ -30,6 +29,8 @@ from langchain_core.utils.utils import ( from_env, secret_from_env, ) +from pydantic import ConfigDict, Field, SecretStr, model_validator +from typing_extensions import Self class _AnthropicCommon(BaseLanguageModel): @@ -84,8 +85,9 @@ class _AnthropicCommon(BaseLanguageModel): count_tokens: Optional[Callable[[str], int]] = None model_kwargs: Dict[str, Any] = Field(default_factory=dict) - @root_validator(pre=True) - def build_extra(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def build_extra(cls, values: Dict) -> Any: extra = values.get("model_kwargs", {}) all_required_field_names = get_pydantic_field_names(cls) values["model_kwargs"] = build_extra_kwargs( @@ -93,25 +95,25 @@ class _AnthropicCommon(BaseLanguageModel): ) return values - @root_validator(pre=False, skip_on_failure=True) - def validate_environment(cls, values: Dict) -> Dict: + @model_validator(mode="after") + def validate_environment(self) -> Self: """Validate that api key and python package exists in environment.""" - values["client"] = anthropic.Anthropic( - base_url=values["anthropic_api_url"], - api_key=values["anthropic_api_key"].get_secret_value(), - timeout=values["default_request_timeout"], - max_retries=values["max_retries"], + self.client = anthropic.Anthropic( + base_url=self.anthropic_api_url, + api_key=self.anthropic_api_key.get_secret_value(), + timeout=self.default_request_timeout, + max_retries=self.max_retries, ) - values["async_client"] = anthropic.AsyncAnthropic( - base_url=values["anthropic_api_url"], - api_key=values["anthropic_api_key"].get_secret_value(), - timeout=values["default_request_timeout"], - max_retries=values["max_retries"], + self.async_client = anthropic.AsyncAnthropic( + base_url=self.anthropic_api_url, + api_key=self.anthropic_api_key.get_secret_value(), + timeout=self.default_request_timeout, + max_retries=self.max_retries, ) - values["HUMAN_PROMPT"] = anthropic.HUMAN_PROMPT - values["AI_PROMPT"] = anthropic.AI_PROMPT - values["count_tokens"] = values["client"].count_tokens - return values + self.HUMAN_PROMPT = anthropic.HUMAN_PROMPT + self.AI_PROMPT = anthropic.AI_PROMPT + self.count_tokens = self.client.count_tokens + return self @property def _default_params(self) -> Mapping[str, Any]: @@ -160,14 +162,14 @@ class AnthropicLLM(LLM, _AnthropicCommon): model = AnthropicLLM() """ - class Config: - """Configuration for this pydantic object.""" + model_config = ConfigDict( + populate_by_name=True, + arbitrary_types_allowed=True, + ) - allow_population_by_field_name = True - arbitrary_types_allowed = True - - @root_validator(pre=True) - def raise_warning(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def raise_warning(cls, values: Dict) -> Any: """Raise warning that this class is deprecated.""" warnings.warn( "This Anthropic LLM is deprecated. " diff --git a/libs/partners/anthropic/langchain_anthropic/output_parsers.py b/libs/partners/anthropic/langchain_anthropic/output_parsers.py index cd9f5308ddc..c30f7ebc665 100644 --- a/libs/partners/anthropic/langchain_anthropic/output_parsers.py +++ b/libs/partners/anthropic/langchain_anthropic/output_parsers.py @@ -4,7 +4,7 @@ from langchain_core.messages import AIMessage, ToolCall from langchain_core.messages.tool import tool_call from langchain_core.output_parsers import BaseGenerationOutputParser from langchain_core.outputs import ChatGeneration, Generation -from langchain_core.pydantic_v1 import BaseModel +from pydantic import BaseModel, ConfigDict class ToolsOutputParser(BaseGenerationOutputParser): @@ -17,8 +17,9 @@ class ToolsOutputParser(BaseGenerationOutputParser): pydantic_schemas: Optional[List[Type[BaseModel]]] = None """Pydantic schemas to parse tool calls into.""" - class Config: - extra = "forbid" + model_config = ConfigDict( + extra="forbid", + ) def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: """Parse a list of candidate model Generations into a specific format. diff --git a/libs/partners/anthropic/poetry.lock b/libs/partners/anthropic/poetry.lock index 3c8af4e4d6a..a11a4808872 100644 --- a/libs/partners/anthropic/poetry.lock +++ b/libs/partners/anthropic/poetry.lock @@ -11,9 +11,6 @@ files = [ {file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"}, ] -[package.dependencies] -typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.9\""} - [[package]] name = "anthropic" version = "0.33.1" @@ -513,7 +510,7 @@ files = [ [[package]] name = "langchain-core" -version = "0.2.30" +version = "0.2.38" description = "Building applications with LLMs through composability" optional = false python-versions = ">=3.8.1,<4.0" @@ -1318,5 +1315,5 @@ watchmedo = ["PyYAML (>=3.10)"] [metadata] lock-version = "2.0" -python-versions = ">=3.8.1,<4.0" -content-hash = "8e57ed7d6701a7ae5881fb1e5be89fc53ba81b98f04a9b54ffb0032d33589960" +python-versions = ">=3.9,<4.0" +content-hash = "5263e046aef60dc42d7b0af41ec35ce48282f88d3ae8ce71003a54382fe4a1e6" diff --git a/libs/partners/anthropic/pyproject.toml b/libs/partners/anthropic/pyproject.toml index 5bf457adfcc..cb115aab92d 100644 --- a/libs/partners/anthropic/pyproject.toml +++ b/libs/partners/anthropic/pyproject.toml @@ -19,9 +19,10 @@ disallow_untyped_defs = "True" "Release Notes" = "https://github.com/langchain-ai/langchain/releases?q=tag%3A%22langchain-anthropic%3D%3D0%22&expanded=true" [tool.poetry.dependencies] -python = ">=3.8.1,<4.0" +python = ">=3.9,<4.0" anthropic = ">=0.30.0,<1" langchain-core = "^0.2.26" +pydantic = ">=2,<3" [tool.ruff.lint] select = [ "E", "F", "I", "T201",] diff --git a/libs/partners/anthropic/scripts/check_pydantic.sh b/libs/partners/anthropic/scripts/check_pydantic.sh deleted file mode 100755 index 06b5bb81ae2..00000000000 --- a/libs/partners/anthropic/scripts/check_pydantic.sh +++ /dev/null @@ -1,27 +0,0 @@ -#!/bin/bash -# -# This script searches for lines starting with "import pydantic" or "from pydantic" -# in tracked files within a Git repository. -# -# Usage: ./scripts/check_pydantic.sh /path/to/repository - -# Check if a path argument is provided -if [ $# -ne 1 ]; then - echo "Usage: $0 /path/to/repository" - exit 1 -fi - -repository_path="$1" - -# Search for lines matching the pattern within the specified repository -result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic') - -# Check if any matching lines were found -if [ -n "$result" ]; then - echo "ERROR: The following lines need to be updated:" - echo "$result" - echo "Please replace the code with an import from langchain_core.pydantic_v1." - echo "For example, replace 'from pydantic import BaseModel'" - echo "with 'from langchain_core.pydantic_v1 import BaseModel'" - exit 1 -fi diff --git a/libs/partners/anthropic/tests/integration_tests/test_chat_models.py b/libs/partners/anthropic/tests/integration_tests/test_chat_models.py index de6cdc0d139..037caf718da 100644 --- a/libs/partners/anthropic/tests/integration_tests/test_chat_models.py +++ b/libs/partners/anthropic/tests/integration_tests/test_chat_models.py @@ -16,8 +16,8 @@ from langchain_core.messages import ( ) from langchain_core.outputs import ChatGeneration, LLMResult from langchain_core.prompts import ChatPromptTemplate -from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.tools import tool +from pydantic import BaseModel, Field from langchain_anthropic import ChatAnthropic, ChatAnthropicMessages from tests.unit_tests._utils import FakeCallbackHandler diff --git a/libs/partners/anthropic/tests/integration_tests/test_experimental.py b/libs/partners/anthropic/tests/integration_tests/test_experimental.py index 54cb5378757..175fa60abe3 100644 --- a/libs/partners/anthropic/tests/integration_tests/test_experimental.py +++ b/libs/partners/anthropic/tests/integration_tests/test_experimental.py @@ -4,7 +4,7 @@ from enum import Enum from typing import List, Optional from langchain_core.prompts import ChatPromptTemplate -from langchain_core.pydantic_v1 import BaseModel, Field +from pydantic import BaseModel, Field from langchain_anthropic.experimental import ChatAnthropicTools diff --git a/libs/partners/anthropic/tests/unit_tests/_utils.py b/libs/partners/anthropic/tests/unit_tests/_utils.py index a39f31fc0f1..2d10ef80f51 100644 --- a/libs/partners/anthropic/tests/unit_tests/_utils.py +++ b/libs/partners/anthropic/tests/unit_tests/_utils.py @@ -3,7 +3,7 @@ from typing import Any, Union from langchain_core.callbacks import BaseCallbackHandler -from langchain_core.pydantic_v1 import BaseModel +from pydantic import BaseModel class BaseFakeCallbackHandler(BaseModel): @@ -251,5 +251,6 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin): ) -> Any: self.on_retriever_error_common() - def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler": + # Overriding since BaseModel has __deepcopy__ method as well + def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler": # type: ignore return self diff --git a/libs/partners/anthropic/tests/unit_tests/test_chat_models.py b/libs/partners/anthropic/tests/unit_tests/test_chat_models.py index ae0c69e3d5b..9c563a31dcc 100644 --- a/libs/partners/anthropic/tests/unit_tests/test_chat_models.py +++ b/libs/partners/anthropic/tests/unit_tests/test_chat_models.py @@ -7,9 +7,9 @@ import pytest from anthropic.types import Message, TextBlock, Usage from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage from langchain_core.outputs import ChatGeneration, ChatResult -from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr from langchain_core.runnables import RunnableBinding from langchain_core.tools import BaseTool +from pydantic import BaseModel, Field, SecretStr from pytest import CaptureFixture, MonkeyPatch from langchain_anthropic import ChatAnthropic diff --git a/libs/partners/anthropic/tests/unit_tests/test_output_parsers.py b/libs/partners/anthropic/tests/unit_tests/test_output_parsers.py index 84e2e7506f8..af560eac0fa 100644 --- a/libs/partners/anthropic/tests/unit_tests/test_output_parsers.py +++ b/libs/partners/anthropic/tests/unit_tests/test_output_parsers.py @@ -2,7 +2,7 @@ from typing import Any, List, Literal from langchain_core.messages import AIMessage from langchain_core.outputs import ChatGeneration -from langchain_core.pydantic_v1 import BaseModel +from pydantic import BaseModel from langchain_anthropic.output_parsers import ToolsOutputParser