Merge branch 'master' into eugene/openai_root_validators

This commit is contained in:
Eugene Yurtsev
2024-08-16 10:50:01 -04:00
6 changed files with 103 additions and 73 deletions

View File

@@ -23,14 +23,14 @@ class MimeTypeBasedParser(BaseBlobParser):
.. code-block:: python
from langchain_community.document_loaders.parsers.generic import MimeTypeBasedParser
from langchain_community.document_loaders.parsers.generic import MimeTypeBasedParser
parser = MimeTypeBasedParser(
handlers={
"application/pdf": ...,
},
fallback_parser=...,
)
parser = MimeTypeBasedParser(
handlers={
"application/pdf": ...,
},
fallback_parser=...,
)
""" # noqa: E501
def __init__(

View File

@@ -7,7 +7,7 @@ import importlib
import os
import warnings
from importlib.metadata import version
from typing import Any, Callable, Dict, Optional, Set, Tuple, Union, overload
from typing import Any, Callable, Dict, Optional, Sequence, Set, Tuple, Union, overload
from packaging.version import parse
from requests import HTTPError, Response
@@ -280,13 +280,17 @@ def from_env(key: str, /) -> Callable[[], str]: ...
def from_env(key: str, /, *, default: str) -> Callable[[], str]: ...
@overload
def from_env(key: Sequence[str], /, *, default: str) -> Callable[[], str]: ...
@overload
def from_env(key: str, /, *, error_message: str) -> Callable[[], str]: ...
@overload
def from_env(
key: str, /, *, default: str, error_message: Optional[str]
key: Union[str, Sequence[str]], /, *, default: str, error_message: Optional[str]
) -> Callable[[], str]: ...
@@ -301,7 +305,7 @@ def from_env(key: str, /, *, default: None) -> Callable[[], Optional[str]]: ...
def from_env(
key: str,
key: Union[str, Sequence[str]],
/,
*,
default: Union[str, _NoDefaultType, None] = _NoDefault,
@@ -310,7 +314,10 @@ def from_env(
"""Create a factory method that gets a value from an environment variable.
Args:
key: The environment variable to look up.
key: The environment variable to look up. If a list of keys is provided,
the first key found in the environment will be used.
If no key is found, the default value will be used if set,
otherwise an error will be raised.
default: The default value to return if the environment variable is not set.
error_message: the error message which will be raised if the key is not found
and no default value is provided.
@@ -319,9 +326,15 @@ def from_env(
def get_from_env_fn() -> Optional[str]:
"""Get a value from an environment variable."""
if key in os.environ:
return os.environ[key]
elif isinstance(default, (str, type(None))):
if isinstance(key, (list, tuple)):
for k in key:
if k in os.environ:
return os.environ[k]
if isinstance(key, str):
if key in os.environ:
return os.environ[key]
if isinstance(default, (str, type(None))):
return default
else:
if error_message:

View File

@@ -1,4 +1,3 @@
import os
import re
import warnings
from operator import itemgetter
@@ -64,8 +63,9 @@ from langchain_core.runnables import (
from langchain_core.tools import BaseTool
from langchain_core.utils import (
build_extra_kwargs,
convert_to_secret_str,
from_env,
get_pydantic_field_names,
secret_from_env,
)
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.pydantic import is_basemodel_subclass
@@ -541,14 +541,26 @@ class ChatAnthropic(BaseChatModel):
stop_sequences: Optional[List[str]] = Field(None, alias="stop")
"""Default stop sequences."""
anthropic_api_url: Optional[str] = Field(None, alias="base_url")
anthropic_api_url: Optional[str] = Field(
alias="base_url",
default_factory=from_env(
["ANTHROPIC_API_URL", "ANTHROPIC_BASE_URL"],
default="https://api.anthropic.com",
),
)
"""Base URL for API requests. Only specify if using a proxy or service emulator.
If a value isn't passed in and environment variable ANTHROPIC_BASE_URL is set, value
will be read from there.
If a value isn't passed in, will attempt to read the value first from
ANTHROPIC_API_URL and if that is not set, ANTHROPIC_BASE_URL.
If neither are set, the default value of 'https://api.anthropic.com' will
be used.
"""
anthropic_api_key: Optional[SecretStr] = Field(None, alias="api_key")
anthropic_api_key: SecretStr = Field(
alias="api_key",
default_factory=secret_from_env("ANTHROPIC_API_KEY", default=""),
)
"""Automatically read from env var `ANTHROPIC_API_KEY` if not provided."""
default_headers: Optional[Mapping[str, str]] = None
@@ -623,20 +635,10 @@ class ChatAnthropic(BaseChatModel):
)
return values
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
anthropic_api_key = convert_to_secret_str(
values.get("anthropic_api_key") or os.environ.get("ANTHROPIC_API_KEY") or ""
)
values["anthropic_api_key"] = anthropic_api_key
api_key = anthropic_api_key.get_secret_value()
api_url = (
values.get("anthropic_api_url")
or os.environ.get("ANTHROPIC_API_URL")
or os.environ.get("ANTHROPIC_BASE_URL")
or "https://api.anthropic.com"
)
values["anthropic_api_url"] = api_url
@root_validator(pre=False, skip_on_failure=True)
def post_init(cls, values: Dict) -> Dict:
api_key = values["anthropic_api_key"].get_secret_value()
api_url = values["anthropic_api_url"]
client_params = {
"api_key": api_key,
"base_url": api_url,

View File

@@ -23,10 +23,13 @@ from langchain_core.outputs import GenerationChunk
from langchain_core.prompt_values import PromptValue
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import (
get_from_dict_or_env,
get_pydantic_field_names,
)
from langchain_core.utils.utils import build_extra_kwargs, convert_to_secret_str
from langchain_core.utils.utils import (
build_extra_kwargs,
from_env,
secret_from_env,
)
class _AnthropicCommon(BaseLanguageModel):
@@ -56,9 +59,25 @@ class _AnthropicCommon(BaseLanguageModel):
max_retries: int = 2
"""Number of retries allowed for requests sent to the Anthropic Completion API."""
anthropic_api_url: Optional[str] = None
anthropic_api_url: Optional[str] = Field(
alias="base_url",
default_factory=from_env(
"ANTHROPIC_API_URL",
default="https://api.anthropic.com",
),
)
"""Base URL for API requests. Only specify if using a proxy or service emulator.
anthropic_api_key: Optional[SecretStr] = None
If a value isn't passed in, will attempt to read the value from
ANTHROPIC_API_URL. If not set, the default value of 'https://api.anthropic.com' will
be used.
"""
anthropic_api_key: SecretStr = Field(
alias="api_key",
default_factory=secret_from_env("ANTHROPIC_API_KEY", default=""),
)
"""Automatically read from env var `ANTHROPIC_API_KEY` if not provided."""
HUMAN_PROMPT: Optional[str] = None
AI_PROMPT: Optional[str] = None
@@ -74,20 +93,9 @@ class _AnthropicCommon(BaseLanguageModel):
)
return values
@root_validator()
@root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["anthropic_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "anthropic_api_key", "ANTHROPIC_API_KEY")
)
# Get custom api url from environment.
values["anthropic_api_url"] = get_from_dict_or_env(
values,
"anthropic_api_url",
"ANTHROPIC_API_URL",
default="https://api.anthropic.com",
)
values["client"] = anthropic.Anthropic(
base_url=values["anthropic_api_url"],
api_key=values["anthropic_api_key"].get_secret_value(),
@@ -158,7 +166,7 @@ class AnthropicLLM(LLM, _AnthropicCommon):
allow_population_by_field_name = True
arbitrary_types_allowed = True
@root_validator()
@root_validator(pre=True)
def raise_warning(cls, values: Dict) -> Dict:
"""Raise warning that this class is deprecated."""
warnings.warn(

View File

@@ -1,5 +1,4 @@
import logging
import os
from typing import Dict, Iterable, List, Optional
import aiohttp
@@ -10,7 +9,7 @@ from langchain_core.pydantic_v1 import (
SecretStr,
root_validator,
)
from langchain_core.utils import convert_to_secret_str
from langchain_core.utils import secret_from_env
from pinecone import Pinecone as PineconeClient # type: ignore
logger = logging.getLogger(__name__)
@@ -45,10 +44,21 @@ class PineconeEmbeddings(BaseModel, Embeddings):
dimension: Optional[int] = None
#
show_progress_bar: bool = False
pinecone_api_key: Optional[SecretStr] = None
pinecone_api_key: Optional[SecretStr] = Field(
default_factory=secret_from_env(
"PINECONE_API_KEY",
error_message="Pinecone API key not found. Please set the PINECONE_API_KEY "
"environment variable or pass it via `pinecone_api_key`.",
),
alias="api_key",
)
"""Pinecone API key.
If not provided, will look for the PINECONE_API_KEY environment variable."""
class Config:
extra = "forbid"
allow_population_by_field_name = True
@root_validator(pre=True)
def set_default_config(cls, values: dict) -> dict:
@@ -69,25 +79,10 @@ class PineconeEmbeddings(BaseModel, Embeddings):
values[key] = value
return values
@root_validator()
@root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: dict) -> dict:
"""Validate that Pinecone version and credentials exist in environment."""
pinecone_api_key = values.get("pinecone_api_key") or os.getenv(
"PINECONE_API_KEY", None
)
if pinecone_api_key:
api_key_secretstr = convert_to_secret_str(pinecone_api_key)
values["pinecone_api_key"] = api_key_secretstr
api_key_str = api_key_secretstr.get_secret_value()
else:
api_key_str = None
if api_key_str is None:
raise ValueError(
"Pinecone API key not found. Please set the PINECONE_API_KEY "
"environment variable or pass it via `pinecone_api_key`."
)
api_key_str = values["pinecone_api_key"].get_secret_value()
client = PineconeClient(api_key=api_key_str, source_tag="langchain")
values["_client"] = client

View File

@@ -7,10 +7,22 @@ MODEL_NAME = "multilingual-e5-large"
def test_default_config() -> None:
e = PineconeEmbeddings(pinecone_api_key=API_KEY, model=MODEL_NAME)
e = PineconeEmbeddings(
pinecone_api_key=API_KEY, # type: ignore[call-arg]
model=MODEL_NAME,
)
assert e.batch_size == 96
def test_default_config_with_api_key() -> None:
e = PineconeEmbeddings(api_key=API_KEY, model=MODEL_NAME)
assert e.batch_size == 96
def test_custom_config() -> None:
e = PineconeEmbeddings(pinecone_api_key=API_KEY, model=MODEL_NAME, batch_size=128)
e = PineconeEmbeddings(
pinecone_api_key=API_KEY, # type: ignore[call-arg]
model=MODEL_NAME,
batch_size=128,
)
assert e.batch_size == 128