mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 14:43:07 +00:00
Merge branch 'master' into eugene/openai_root_validators
This commit is contained in:
@@ -23,14 +23,14 @@ class MimeTypeBasedParser(BaseBlobParser):
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.document_loaders.parsers.generic import MimeTypeBasedParser
|
||||
from langchain_community.document_loaders.parsers.generic import MimeTypeBasedParser
|
||||
|
||||
parser = MimeTypeBasedParser(
|
||||
handlers={
|
||||
"application/pdf": ...,
|
||||
},
|
||||
fallback_parser=...,
|
||||
)
|
||||
parser = MimeTypeBasedParser(
|
||||
handlers={
|
||||
"application/pdf": ...,
|
||||
},
|
||||
fallback_parser=...,
|
||||
)
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -7,7 +7,7 @@ import importlib
|
||||
import os
|
||||
import warnings
|
||||
from importlib.metadata import version
|
||||
from typing import Any, Callable, Dict, Optional, Set, Tuple, Union, overload
|
||||
from typing import Any, Callable, Dict, Optional, Sequence, Set, Tuple, Union, overload
|
||||
|
||||
from packaging.version import parse
|
||||
from requests import HTTPError, Response
|
||||
@@ -280,13 +280,17 @@ def from_env(key: str, /) -> Callable[[], str]: ...
|
||||
def from_env(key: str, /, *, default: str) -> Callable[[], str]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def from_env(key: Sequence[str], /, *, default: str) -> Callable[[], str]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def from_env(key: str, /, *, error_message: str) -> Callable[[], str]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def from_env(
|
||||
key: str, /, *, default: str, error_message: Optional[str]
|
||||
key: Union[str, Sequence[str]], /, *, default: str, error_message: Optional[str]
|
||||
) -> Callable[[], str]: ...
|
||||
|
||||
|
||||
@@ -301,7 +305,7 @@ def from_env(key: str, /, *, default: None) -> Callable[[], Optional[str]]: ...
|
||||
|
||||
|
||||
def from_env(
|
||||
key: str,
|
||||
key: Union[str, Sequence[str]],
|
||||
/,
|
||||
*,
|
||||
default: Union[str, _NoDefaultType, None] = _NoDefault,
|
||||
@@ -310,7 +314,10 @@ def from_env(
|
||||
"""Create a factory method that gets a value from an environment variable.
|
||||
|
||||
Args:
|
||||
key: The environment variable to look up.
|
||||
key: The environment variable to look up. If a list of keys is provided,
|
||||
the first key found in the environment will be used.
|
||||
If no key is found, the default value will be used if set,
|
||||
otherwise an error will be raised.
|
||||
default: The default value to return if the environment variable is not set.
|
||||
error_message: the error message which will be raised if the key is not found
|
||||
and no default value is provided.
|
||||
@@ -319,9 +326,15 @@ def from_env(
|
||||
|
||||
def get_from_env_fn() -> Optional[str]:
|
||||
"""Get a value from an environment variable."""
|
||||
if key in os.environ:
|
||||
return os.environ[key]
|
||||
elif isinstance(default, (str, type(None))):
|
||||
if isinstance(key, (list, tuple)):
|
||||
for k in key:
|
||||
if k in os.environ:
|
||||
return os.environ[k]
|
||||
if isinstance(key, str):
|
||||
if key in os.environ:
|
||||
return os.environ[key]
|
||||
|
||||
if isinstance(default, (str, type(None))):
|
||||
return default
|
||||
else:
|
||||
if error_message:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
from operator import itemgetter
|
||||
@@ -64,8 +63,9 @@ from langchain_core.runnables import (
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import (
|
||||
build_extra_kwargs,
|
||||
convert_to_secret_str,
|
||||
from_env,
|
||||
get_pydantic_field_names,
|
||||
secret_from_env,
|
||||
)
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
@@ -541,14 +541,26 @@ class ChatAnthropic(BaseChatModel):
|
||||
stop_sequences: Optional[List[str]] = Field(None, alias="stop")
|
||||
"""Default stop sequences."""
|
||||
|
||||
anthropic_api_url: Optional[str] = Field(None, alias="base_url")
|
||||
anthropic_api_url: Optional[str] = Field(
|
||||
alias="base_url",
|
||||
default_factory=from_env(
|
||||
["ANTHROPIC_API_URL", "ANTHROPIC_BASE_URL"],
|
||||
default="https://api.anthropic.com",
|
||||
),
|
||||
)
|
||||
"""Base URL for API requests. Only specify if using a proxy or service emulator.
|
||||
|
||||
If a value isn't passed in and environment variable ANTHROPIC_BASE_URL is set, value
|
||||
will be read from there.
|
||||
If a value isn't passed in, will attempt to read the value first from
|
||||
ANTHROPIC_API_URL and if that is not set, ANTHROPIC_BASE_URL.
|
||||
If neither are set, the default value of 'https://api.anthropic.com' will
|
||||
be used.
|
||||
"""
|
||||
|
||||
anthropic_api_key: Optional[SecretStr] = Field(None, alias="api_key")
|
||||
anthropic_api_key: SecretStr = Field(
|
||||
alias="api_key",
|
||||
default_factory=secret_from_env("ANTHROPIC_API_KEY", default=""),
|
||||
)
|
||||
|
||||
"""Automatically read from env var `ANTHROPIC_API_KEY` if not provided."""
|
||||
|
||||
default_headers: Optional[Mapping[str, str]] = None
|
||||
@@ -623,20 +635,10 @@ class ChatAnthropic(BaseChatModel):
|
||||
)
|
||||
return values
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
anthropic_api_key = convert_to_secret_str(
|
||||
values.get("anthropic_api_key") or os.environ.get("ANTHROPIC_API_KEY") or ""
|
||||
)
|
||||
values["anthropic_api_key"] = anthropic_api_key
|
||||
api_key = anthropic_api_key.get_secret_value()
|
||||
api_url = (
|
||||
values.get("anthropic_api_url")
|
||||
or os.environ.get("ANTHROPIC_API_URL")
|
||||
or os.environ.get("ANTHROPIC_BASE_URL")
|
||||
or "https://api.anthropic.com"
|
||||
)
|
||||
values["anthropic_api_url"] = api_url
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def post_init(cls, values: Dict) -> Dict:
|
||||
api_key = values["anthropic_api_key"].get_secret_value()
|
||||
api_url = values["anthropic_api_url"]
|
||||
client_params = {
|
||||
"api_key": api_key,
|
||||
"base_url": api_url,
|
||||
|
||||
@@ -23,10 +23,13 @@ from langchain_core.outputs import GenerationChunk
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||
from langchain_core.utils import (
|
||||
get_from_dict_or_env,
|
||||
get_pydantic_field_names,
|
||||
)
|
||||
from langchain_core.utils.utils import build_extra_kwargs, convert_to_secret_str
|
||||
from langchain_core.utils.utils import (
|
||||
build_extra_kwargs,
|
||||
from_env,
|
||||
secret_from_env,
|
||||
)
|
||||
|
||||
|
||||
class _AnthropicCommon(BaseLanguageModel):
|
||||
@@ -56,9 +59,25 @@ class _AnthropicCommon(BaseLanguageModel):
|
||||
max_retries: int = 2
|
||||
"""Number of retries allowed for requests sent to the Anthropic Completion API."""
|
||||
|
||||
anthropic_api_url: Optional[str] = None
|
||||
anthropic_api_url: Optional[str] = Field(
|
||||
alias="base_url",
|
||||
default_factory=from_env(
|
||||
"ANTHROPIC_API_URL",
|
||||
default="https://api.anthropic.com",
|
||||
),
|
||||
)
|
||||
"""Base URL for API requests. Only specify if using a proxy or service emulator.
|
||||
|
||||
anthropic_api_key: Optional[SecretStr] = None
|
||||
If a value isn't passed in, will attempt to read the value from
|
||||
ANTHROPIC_API_URL. If not set, the default value of 'https://api.anthropic.com' will
|
||||
be used.
|
||||
"""
|
||||
|
||||
anthropic_api_key: SecretStr = Field(
|
||||
alias="api_key",
|
||||
default_factory=secret_from_env("ANTHROPIC_API_KEY", default=""),
|
||||
)
|
||||
"""Automatically read from env var `ANTHROPIC_API_KEY` if not provided."""
|
||||
|
||||
HUMAN_PROMPT: Optional[str] = None
|
||||
AI_PROMPT: Optional[str] = None
|
||||
@@ -74,20 +93,9 @@ class _AnthropicCommon(BaseLanguageModel):
|
||||
)
|
||||
return values
|
||||
|
||||
@root_validator()
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["anthropic_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "anthropic_api_key", "ANTHROPIC_API_KEY")
|
||||
)
|
||||
# Get custom api url from environment.
|
||||
values["anthropic_api_url"] = get_from_dict_or_env(
|
||||
values,
|
||||
"anthropic_api_url",
|
||||
"ANTHROPIC_API_URL",
|
||||
default="https://api.anthropic.com",
|
||||
)
|
||||
|
||||
values["client"] = anthropic.Anthropic(
|
||||
base_url=values["anthropic_api_url"],
|
||||
api_key=values["anthropic_api_key"].get_secret_value(),
|
||||
@@ -158,7 +166,7 @@ class AnthropicLLM(LLM, _AnthropicCommon):
|
||||
allow_population_by_field_name = True
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator()
|
||||
@root_validator(pre=True)
|
||||
def raise_warning(cls, values: Dict) -> Dict:
|
||||
"""Raise warning that this class is deprecated."""
|
||||
warnings.warn(
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict, Iterable, List, Optional
|
||||
|
||||
import aiohttp
|
||||
@@ -10,7 +9,7 @@ from langchain_core.pydantic_v1 import (
|
||||
SecretStr,
|
||||
root_validator,
|
||||
)
|
||||
from langchain_core.utils import convert_to_secret_str
|
||||
from langchain_core.utils import secret_from_env
|
||||
from pinecone import Pinecone as PineconeClient # type: ignore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -45,10 +44,21 @@ class PineconeEmbeddings(BaseModel, Embeddings):
|
||||
dimension: Optional[int] = None
|
||||
#
|
||||
show_progress_bar: bool = False
|
||||
pinecone_api_key: Optional[SecretStr] = None
|
||||
pinecone_api_key: Optional[SecretStr] = Field(
|
||||
default_factory=secret_from_env(
|
||||
"PINECONE_API_KEY",
|
||||
error_message="Pinecone API key not found. Please set the PINECONE_API_KEY "
|
||||
"environment variable or pass it via `pinecone_api_key`.",
|
||||
),
|
||||
alias="api_key",
|
||||
)
|
||||
"""Pinecone API key.
|
||||
|
||||
If not provided, will look for the PINECONE_API_KEY environment variable."""
|
||||
|
||||
class Config:
|
||||
extra = "forbid"
|
||||
allow_population_by_field_name = True
|
||||
|
||||
@root_validator(pre=True)
|
||||
def set_default_config(cls, values: dict) -> dict:
|
||||
@@ -69,25 +79,10 @@ class PineconeEmbeddings(BaseModel, Embeddings):
|
||||
values[key] = value
|
||||
return values
|
||||
|
||||
@root_validator()
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def validate_environment(cls, values: dict) -> dict:
|
||||
"""Validate that Pinecone version and credentials exist in environment."""
|
||||
|
||||
pinecone_api_key = values.get("pinecone_api_key") or os.getenv(
|
||||
"PINECONE_API_KEY", None
|
||||
)
|
||||
if pinecone_api_key:
|
||||
api_key_secretstr = convert_to_secret_str(pinecone_api_key)
|
||||
values["pinecone_api_key"] = api_key_secretstr
|
||||
|
||||
api_key_str = api_key_secretstr.get_secret_value()
|
||||
else:
|
||||
api_key_str = None
|
||||
if api_key_str is None:
|
||||
raise ValueError(
|
||||
"Pinecone API key not found. Please set the PINECONE_API_KEY "
|
||||
"environment variable or pass it via `pinecone_api_key`."
|
||||
)
|
||||
api_key_str = values["pinecone_api_key"].get_secret_value()
|
||||
client = PineconeClient(api_key=api_key_str, source_tag="langchain")
|
||||
values["_client"] = client
|
||||
|
||||
|
||||
@@ -7,10 +7,22 @@ MODEL_NAME = "multilingual-e5-large"
|
||||
|
||||
|
||||
def test_default_config() -> None:
|
||||
e = PineconeEmbeddings(pinecone_api_key=API_KEY, model=MODEL_NAME)
|
||||
e = PineconeEmbeddings(
|
||||
pinecone_api_key=API_KEY, # type: ignore[call-arg]
|
||||
model=MODEL_NAME,
|
||||
)
|
||||
assert e.batch_size == 96
|
||||
|
||||
|
||||
def test_default_config_with_api_key() -> None:
|
||||
e = PineconeEmbeddings(api_key=API_KEY, model=MODEL_NAME)
|
||||
assert e.batch_size == 96
|
||||
|
||||
|
||||
def test_custom_config() -> None:
|
||||
e = PineconeEmbeddings(pinecone_api_key=API_KEY, model=MODEL_NAME, batch_size=128)
|
||||
e = PineconeEmbeddings(
|
||||
pinecone_api_key=API_KEY, # type: ignore[call-arg]
|
||||
model=MODEL_NAME,
|
||||
batch_size=128,
|
||||
)
|
||||
assert e.batch_size == 128
|
||||
|
||||
Reference in New Issue
Block a user