core[minor], anthropic[patch]: Upgrade @root_validator usage to be consistent with pydantic 2 (#25457)

anthropic: Upgrade `@root_validator` usage to be consistent with
pydantic 2
core: support looking up multiple keys from env in from_env factory
This commit is contained in:
Eugene Yurtsev 2024-08-15 16:09:34 -04:00 committed by GitHub
parent 34da8be60b
commit e18511bb22
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 67 additions and 44 deletions

View File

@ -7,7 +7,7 @@ import importlib
import os import os
import warnings import warnings
from importlib.metadata import version from importlib.metadata import version
from typing import Any, Callable, Dict, Optional, Set, Tuple, Union, overload from typing import Any, Callable, Dict, Optional, Sequence, Set, Tuple, Union, overload
from packaging.version import parse from packaging.version import parse
from requests import HTTPError, Response from requests import HTTPError, Response
@ -280,13 +280,17 @@ def from_env(key: str, /) -> Callable[[], str]: ...
def from_env(key: str, /, *, default: str) -> Callable[[], str]: ... def from_env(key: str, /, *, default: str) -> Callable[[], str]: ...
@overload
def from_env(key: Sequence[str], /, *, default: str) -> Callable[[], str]: ...
@overload @overload
def from_env(key: str, /, *, error_message: str) -> Callable[[], str]: ... def from_env(key: str, /, *, error_message: str) -> Callable[[], str]: ...
@overload @overload
def from_env( def from_env(
key: str, /, *, default: str, error_message: Optional[str] key: Union[str, Sequence[str]], /, *, default: str, error_message: Optional[str]
) -> Callable[[], str]: ... ) -> Callable[[], str]: ...
@ -301,7 +305,7 @@ def from_env(key: str, /, *, default: None) -> Callable[[], Optional[str]]: ...
def from_env( def from_env(
key: str, key: Union[str, Sequence[str]],
/, /,
*, *,
default: Union[str, _NoDefaultType, None] = _NoDefault, default: Union[str, _NoDefaultType, None] = _NoDefault,
@ -310,7 +314,10 @@ def from_env(
"""Create a factory method that gets a value from an environment variable. """Create a factory method that gets a value from an environment variable.
Args: Args:
key: The environment variable to look up. key: The environment variable to look up. If a list of keys is provided,
the first key found in the environment will be used.
If no key is found, the default value will be used if set,
otherwise an error will be raised.
default: The default value to return if the environment variable is not set. default: The default value to return if the environment variable is not set.
error_message: the error message which will be raised if the key is not found error_message: the error message which will be raised if the key is not found
and no default value is provided. and no default value is provided.
@ -319,9 +326,15 @@ def from_env(
def get_from_env_fn() -> Optional[str]: def get_from_env_fn() -> Optional[str]:
"""Get a value from an environment variable.""" """Get a value from an environment variable."""
if key in os.environ: if isinstance(key, (list, tuple)):
return os.environ[key] for k in key:
elif isinstance(default, (str, type(None))): if k in os.environ:
return os.environ[k]
if isinstance(key, str):
if key in os.environ:
return os.environ[key]
if isinstance(default, (str, type(None))):
return default return default
else: else:
if error_message: if error_message:

View File

@ -1,4 +1,3 @@
import os
import re import re
import warnings import warnings
from operator import itemgetter from operator import itemgetter
@ -64,8 +63,9 @@ from langchain_core.runnables import (
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from langchain_core.utils import ( from langchain_core.utils import (
build_extra_kwargs, build_extra_kwargs,
convert_to_secret_str, from_env,
get_pydantic_field_names, get_pydantic_field_names,
secret_from_env,
) )
from langchain_core.utils.function_calling import convert_to_openai_tool from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.pydantic import is_basemodel_subclass from langchain_core.utils.pydantic import is_basemodel_subclass
@ -541,14 +541,26 @@ class ChatAnthropic(BaseChatModel):
stop_sequences: Optional[List[str]] = Field(None, alias="stop") stop_sequences: Optional[List[str]] = Field(None, alias="stop")
"""Default stop sequences.""" """Default stop sequences."""
anthropic_api_url: Optional[str] = Field(None, alias="base_url") anthropic_api_url: Optional[str] = Field(
alias="base_url",
default_factory=from_env(
["ANTHROPIC_API_URL", "ANTHROPIC_BASE_URL"],
default="https://api.anthropic.com",
),
)
"""Base URL for API requests. Only specify if using a proxy or service emulator. """Base URL for API requests. Only specify if using a proxy or service emulator.
If a value isn't passed in and environment variable ANTHROPIC_BASE_URL is set, value If a value isn't passed in, will attempt to read the value first from
will be read from there. ANTHROPIC_API_URL and if that is not set, ANTHROPIC_BASE_URL.
If neither are set, the default value of 'https://api.anthropic.com' will
be used.
""" """
anthropic_api_key: Optional[SecretStr] = Field(None, alias="api_key") anthropic_api_key: SecretStr = Field(
alias="api_key",
default_factory=secret_from_env("ANTHROPIC_API_KEY", default=""),
)
"""Automatically read from env var `ANTHROPIC_API_KEY` if not provided.""" """Automatically read from env var `ANTHROPIC_API_KEY` if not provided."""
default_headers: Optional[Mapping[str, str]] = None default_headers: Optional[Mapping[str, str]] = None
@ -623,20 +635,10 @@ class ChatAnthropic(BaseChatModel):
) )
return values return values
@root_validator() @root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict: def post_init(cls, values: Dict) -> Dict:
anthropic_api_key = convert_to_secret_str( api_key = values["anthropic_api_key"].get_secret_value()
values.get("anthropic_api_key") or os.environ.get("ANTHROPIC_API_KEY") or "" api_url = values["anthropic_api_url"]
)
values["anthropic_api_key"] = anthropic_api_key
api_key = anthropic_api_key.get_secret_value()
api_url = (
values.get("anthropic_api_url")
or os.environ.get("ANTHROPIC_API_URL")
or os.environ.get("ANTHROPIC_BASE_URL")
or "https://api.anthropic.com"
)
values["anthropic_api_url"] = api_url
client_params = { client_params = {
"api_key": api_key, "api_key": api_key,
"base_url": api_url, "base_url": api_url,

View File

@ -23,10 +23,13 @@ from langchain_core.outputs import GenerationChunk
from langchain_core.prompt_values import PromptValue from langchain_core.prompt_values import PromptValue
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import ( from langchain_core.utils import (
get_from_dict_or_env,
get_pydantic_field_names, get_pydantic_field_names,
) )
from langchain_core.utils.utils import build_extra_kwargs, convert_to_secret_str from langchain_core.utils.utils import (
build_extra_kwargs,
from_env,
secret_from_env,
)
class _AnthropicCommon(BaseLanguageModel): class _AnthropicCommon(BaseLanguageModel):
@ -56,9 +59,25 @@ class _AnthropicCommon(BaseLanguageModel):
max_retries: int = 2 max_retries: int = 2
"""Number of retries allowed for requests sent to the Anthropic Completion API.""" """Number of retries allowed for requests sent to the Anthropic Completion API."""
anthropic_api_url: Optional[str] = None anthropic_api_url: Optional[str] = Field(
alias="base_url",
default_factory=from_env(
"ANTHROPIC_API_URL",
default="https://api.anthropic.com",
),
)
"""Base URL for API requests. Only specify if using a proxy or service emulator.
anthropic_api_key: Optional[SecretStr] = None If a value isn't passed in, will attempt to read the value from
ANTHROPIC_API_URL. If not set, the default value of 'https://api.anthropic.com' will
be used.
"""
anthropic_api_key: SecretStr = Field(
alias="api_key",
default_factory=secret_from_env("ANTHROPIC_API_KEY", default=""),
)
"""Automatically read from env var `ANTHROPIC_API_KEY` if not provided."""
HUMAN_PROMPT: Optional[str] = None HUMAN_PROMPT: Optional[str] = None
AI_PROMPT: Optional[str] = None AI_PROMPT: Optional[str] = None
@ -74,20 +93,9 @@ class _AnthropicCommon(BaseLanguageModel):
) )
return values return values
@root_validator() @root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
values["anthropic_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "anthropic_api_key", "ANTHROPIC_API_KEY")
)
# Get custom api url from environment.
values["anthropic_api_url"] = get_from_dict_or_env(
values,
"anthropic_api_url",
"ANTHROPIC_API_URL",
default="https://api.anthropic.com",
)
values["client"] = anthropic.Anthropic( values["client"] = anthropic.Anthropic(
base_url=values["anthropic_api_url"], base_url=values["anthropic_api_url"],
api_key=values["anthropic_api_key"].get_secret_value(), api_key=values["anthropic_api_key"].get_secret_value(),
@ -158,7 +166,7 @@ class AnthropicLLM(LLM, _AnthropicCommon):
allow_population_by_field_name = True allow_population_by_field_name = True
arbitrary_types_allowed = True arbitrary_types_allowed = True
@root_validator() @root_validator(pre=True)
def raise_warning(cls, values: Dict) -> Dict: def raise_warning(cls, values: Dict) -> Dict:
"""Raise warning that this class is deprecated.""" """Raise warning that this class is deprecated."""
warnings.warn( warnings.warn(