diff --git a/libs/community/langchain_community/document_loaders/parsers/generic.py b/libs/community/langchain_community/document_loaders/parsers/generic.py index 615578846d2..75861ab346d 100644 --- a/libs/community/langchain_community/document_loaders/parsers/generic.py +++ b/libs/community/langchain_community/document_loaders/parsers/generic.py @@ -23,14 +23,14 @@ class MimeTypeBasedParser(BaseBlobParser): .. code-block:: python - from langchain_community.document_loaders.parsers.generic import MimeTypeBasedParser + from langchain_community.document_loaders.parsers.generic import MimeTypeBasedParser - parser = MimeTypeBasedParser( - handlers={ - "application/pdf": ..., - }, - fallback_parser=..., - ) + parser = MimeTypeBasedParser( + handlers={ + "application/pdf": ..., + }, + fallback_parser=..., + ) """ # noqa: E501 def __init__( diff --git a/libs/core/langchain_core/utils/utils.py b/libs/core/langchain_core/utils/utils.py index 89b7f1d7a7f..dd92d0a947b 100644 --- a/libs/core/langchain_core/utils/utils.py +++ b/libs/core/langchain_core/utils/utils.py @@ -7,7 +7,7 @@ import importlib import os import warnings from importlib.metadata import version -from typing import Any, Callable, Dict, Optional, Set, Tuple, Union, overload +from typing import Any, Callable, Dict, Optional, Sequence, Set, Tuple, Union, overload from packaging.version import parse from requests import HTTPError, Response @@ -280,13 +280,17 @@ def from_env(key: str, /) -> Callable[[], str]: ... def from_env(key: str, /, *, default: str) -> Callable[[], str]: ... +@overload +def from_env(key: Sequence[str], /, *, default: str) -> Callable[[], str]: ... + + @overload def from_env(key: str, /, *, error_message: str) -> Callable[[], str]: ... @overload def from_env( - key: str, /, *, default: str, error_message: Optional[str] + key: Union[str, Sequence[str]], /, *, default: str, error_message: Optional[str] ) -> Callable[[], str]: ... @@ -301,7 +305,7 @@ def from_env(key: str, /, *, default: None) -> Callable[[], Optional[str]]: ... def from_env( - key: str, + key: Union[str, Sequence[str]], /, *, default: Union[str, _NoDefaultType, None] = _NoDefault, @@ -310,7 +314,10 @@ def from_env( """Create a factory method that gets a value from an environment variable. Args: - key: The environment variable to look up. + key: The environment variable to look up. If a list of keys is provided, + the first key found in the environment will be used. + If no key is found, the default value will be used if set, + otherwise an error will be raised. default: The default value to return if the environment variable is not set. error_message: the error message which will be raised if the key is not found and no default value is provided. @@ -319,9 +326,15 @@ def from_env( def get_from_env_fn() -> Optional[str]: """Get a value from an environment variable.""" - if key in os.environ: - return os.environ[key] - elif isinstance(default, (str, type(None))): + if isinstance(key, (list, tuple)): + for k in key: + if k in os.environ: + return os.environ[k] + if isinstance(key, str): + if key in os.environ: + return os.environ[key] + + if isinstance(default, (str, type(None))): return default else: if error_message: diff --git a/libs/partners/anthropic/langchain_anthropic/chat_models.py b/libs/partners/anthropic/langchain_anthropic/chat_models.py index a6958e1ff81..ff931efb7c4 100644 --- a/libs/partners/anthropic/langchain_anthropic/chat_models.py +++ b/libs/partners/anthropic/langchain_anthropic/chat_models.py @@ -1,4 +1,3 @@ -import os import re import warnings from operator import itemgetter @@ -64,8 +63,9 @@ from langchain_core.runnables import ( from langchain_core.tools import BaseTool from langchain_core.utils import ( build_extra_kwargs, - convert_to_secret_str, + from_env, get_pydantic_field_names, + secret_from_env, ) from langchain_core.utils.function_calling import convert_to_openai_tool from langchain_core.utils.pydantic import is_basemodel_subclass @@ -541,14 +541,26 @@ class ChatAnthropic(BaseChatModel): stop_sequences: Optional[List[str]] = Field(None, alias="stop") """Default stop sequences.""" - anthropic_api_url: Optional[str] = Field(None, alias="base_url") + anthropic_api_url: Optional[str] = Field( + alias="base_url", + default_factory=from_env( + ["ANTHROPIC_API_URL", "ANTHROPIC_BASE_URL"], + default="https://api.anthropic.com", + ), + ) """Base URL for API requests. Only specify if using a proxy or service emulator. - If a value isn't passed in and environment variable ANTHROPIC_BASE_URL is set, value - will be read from there. + If a value isn't passed in, will attempt to read the value first from + ANTHROPIC_API_URL and if that is not set, ANTHROPIC_BASE_URL. + If neither are set, the default value of 'https://api.anthropic.com' will + be used. """ - anthropic_api_key: Optional[SecretStr] = Field(None, alias="api_key") + anthropic_api_key: SecretStr = Field( + alias="api_key", + default_factory=secret_from_env("ANTHROPIC_API_KEY", default=""), + ) + """Automatically read from env var `ANTHROPIC_API_KEY` if not provided.""" default_headers: Optional[Mapping[str, str]] = None @@ -623,20 +635,10 @@ class ChatAnthropic(BaseChatModel): ) return values - @root_validator() - def validate_environment(cls, values: Dict) -> Dict: - anthropic_api_key = convert_to_secret_str( - values.get("anthropic_api_key") or os.environ.get("ANTHROPIC_API_KEY") or "" - ) - values["anthropic_api_key"] = anthropic_api_key - api_key = anthropic_api_key.get_secret_value() - api_url = ( - values.get("anthropic_api_url") - or os.environ.get("ANTHROPIC_API_URL") - or os.environ.get("ANTHROPIC_BASE_URL") - or "https://api.anthropic.com" - ) - values["anthropic_api_url"] = api_url + @root_validator(pre=False, skip_on_failure=True) + def post_init(cls, values: Dict) -> Dict: + api_key = values["anthropic_api_key"].get_secret_value() + api_url = values["anthropic_api_url"] client_params = { "api_key": api_key, "base_url": api_url, diff --git a/libs/partners/anthropic/langchain_anthropic/llms.py b/libs/partners/anthropic/langchain_anthropic/llms.py index 47ac9f4eb53..97bb85af5c2 100644 --- a/libs/partners/anthropic/langchain_anthropic/llms.py +++ b/libs/partners/anthropic/langchain_anthropic/llms.py @@ -23,10 +23,13 @@ from langchain_core.outputs import GenerationChunk from langchain_core.prompt_values import PromptValue from langchain_core.pydantic_v1 import Field, SecretStr, root_validator from langchain_core.utils import ( - get_from_dict_or_env, get_pydantic_field_names, ) -from langchain_core.utils.utils import build_extra_kwargs, convert_to_secret_str +from langchain_core.utils.utils import ( + build_extra_kwargs, + from_env, + secret_from_env, +) class _AnthropicCommon(BaseLanguageModel): @@ -56,9 +59,25 @@ class _AnthropicCommon(BaseLanguageModel): max_retries: int = 2 """Number of retries allowed for requests sent to the Anthropic Completion API.""" - anthropic_api_url: Optional[str] = None + anthropic_api_url: Optional[str] = Field( + alias="base_url", + default_factory=from_env( + "ANTHROPIC_API_URL", + default="https://api.anthropic.com", + ), + ) + """Base URL for API requests. Only specify if using a proxy or service emulator. - anthropic_api_key: Optional[SecretStr] = None + If a value isn't passed in, will attempt to read the value from + ANTHROPIC_API_URL. If not set, the default value of 'https://api.anthropic.com' will + be used. + """ + + anthropic_api_key: SecretStr = Field( + alias="api_key", + default_factory=secret_from_env("ANTHROPIC_API_KEY", default=""), + ) + """Automatically read from env var `ANTHROPIC_API_KEY` if not provided.""" HUMAN_PROMPT: Optional[str] = None AI_PROMPT: Optional[str] = None @@ -74,20 +93,9 @@ class _AnthropicCommon(BaseLanguageModel): ) return values - @root_validator() + @root_validator(pre=False, skip_on_failure=True) def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - values["anthropic_api_key"] = convert_to_secret_str( - get_from_dict_or_env(values, "anthropic_api_key", "ANTHROPIC_API_KEY") - ) - # Get custom api url from environment. - values["anthropic_api_url"] = get_from_dict_or_env( - values, - "anthropic_api_url", - "ANTHROPIC_API_URL", - default="https://api.anthropic.com", - ) - values["client"] = anthropic.Anthropic( base_url=values["anthropic_api_url"], api_key=values["anthropic_api_key"].get_secret_value(), @@ -158,7 +166,7 @@ class AnthropicLLM(LLM, _AnthropicCommon): allow_population_by_field_name = True arbitrary_types_allowed = True - @root_validator() + @root_validator(pre=True) def raise_warning(cls, values: Dict) -> Dict: """Raise warning that this class is deprecated.""" warnings.warn( diff --git a/libs/partners/pinecone/langchain_pinecone/embeddings.py b/libs/partners/pinecone/langchain_pinecone/embeddings.py index cad6b18d495..c0adecd1e86 100644 --- a/libs/partners/pinecone/langchain_pinecone/embeddings.py +++ b/libs/partners/pinecone/langchain_pinecone/embeddings.py @@ -1,5 +1,4 @@ import logging -import os from typing import Dict, Iterable, List, Optional import aiohttp @@ -10,7 +9,7 @@ from langchain_core.pydantic_v1 import ( SecretStr, root_validator, ) -from langchain_core.utils import convert_to_secret_str +from langchain_core.utils import secret_from_env from pinecone import Pinecone as PineconeClient # type: ignore logger = logging.getLogger(__name__) @@ -45,10 +44,21 @@ class PineconeEmbeddings(BaseModel, Embeddings): dimension: Optional[int] = None # show_progress_bar: bool = False - pinecone_api_key: Optional[SecretStr] = None + pinecone_api_key: Optional[SecretStr] = Field( + default_factory=secret_from_env( + "PINECONE_API_KEY", + error_message="Pinecone API key not found. Please set the PINECONE_API_KEY " + "environment variable or pass it via `pinecone_api_key`.", + ), + alias="api_key", + ) + """Pinecone API key. + + If not provided, will look for the PINECONE_API_KEY environment variable.""" class Config: extra = "forbid" + allow_population_by_field_name = True @root_validator(pre=True) def set_default_config(cls, values: dict) -> dict: @@ -69,25 +79,10 @@ class PineconeEmbeddings(BaseModel, Embeddings): values[key] = value return values - @root_validator() + @root_validator(pre=False, skip_on_failure=True) def validate_environment(cls, values: dict) -> dict: """Validate that Pinecone version and credentials exist in environment.""" - - pinecone_api_key = values.get("pinecone_api_key") or os.getenv( - "PINECONE_API_KEY", None - ) - if pinecone_api_key: - api_key_secretstr = convert_to_secret_str(pinecone_api_key) - values["pinecone_api_key"] = api_key_secretstr - - api_key_str = api_key_secretstr.get_secret_value() - else: - api_key_str = None - if api_key_str is None: - raise ValueError( - "Pinecone API key not found. Please set the PINECONE_API_KEY " - "environment variable or pass it via `pinecone_api_key`." - ) + api_key_str = values["pinecone_api_key"].get_secret_value() client = PineconeClient(api_key=api_key_str, source_tag="langchain") values["_client"] = client diff --git a/libs/partners/pinecone/tests/unit_tests/test_embeddings.py b/libs/partners/pinecone/tests/unit_tests/test_embeddings.py index 23d7b1df4c7..924b4e79662 100644 --- a/libs/partners/pinecone/tests/unit_tests/test_embeddings.py +++ b/libs/partners/pinecone/tests/unit_tests/test_embeddings.py @@ -7,10 +7,22 @@ MODEL_NAME = "multilingual-e5-large" def test_default_config() -> None: - e = PineconeEmbeddings(pinecone_api_key=API_KEY, model=MODEL_NAME) + e = PineconeEmbeddings( + pinecone_api_key=API_KEY, # type: ignore[call-arg] + model=MODEL_NAME, + ) + assert e.batch_size == 96 + + +def test_default_config_with_api_key() -> None: + e = PineconeEmbeddings(api_key=API_KEY, model=MODEL_NAME) assert e.batch_size == 96 def test_custom_config() -> None: - e = PineconeEmbeddings(pinecone_api_key=API_KEY, model=MODEL_NAME, batch_size=128) + e = PineconeEmbeddings( + pinecone_api_key=API_KEY, # type: ignore[call-arg] + model=MODEL_NAME, + batch_size=128, + ) assert e.batch_size == 128