mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-14 23:26:34 +00:00
core[minor]: Add factory for looking up secrets from the env (#25198)
Add factory method for looking secrets from the env.
This commit is contained in:
parent
da9281feb2
commit
429a0ee7fd
@ -27,6 +27,7 @@ from langchain_core.utils.utils import (
|
|||||||
guard_import,
|
guard_import,
|
||||||
mock_now,
|
mock_now,
|
||||||
raise_for_status_with_text,
|
raise_for_status_with_text,
|
||||||
|
secret_from_env,
|
||||||
xor_args,
|
xor_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -56,4 +57,5 @@ __all__ = [
|
|||||||
"batch_iterate",
|
"batch_iterate",
|
||||||
"abatch_iterate",
|
"abatch_iterate",
|
||||||
"from_env",
|
"from_env",
|
||||||
|
"secret_from_env",
|
||||||
]
|
]
|
||||||
|
@ -313,11 +313,11 @@ def from_env(
|
|||||||
This will be raised as a ValueError.
|
This will be raised as a ValueError.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_from_env_fn() -> str: # type: ignore
|
def get_from_env_fn() -> Optional[str]:
|
||||||
"""Get a value from an environment variable."""
|
"""Get a value from an environment variable."""
|
||||||
if key in os.environ:
|
if key in os.environ:
|
||||||
return os.environ[key]
|
return os.environ[key]
|
||||||
elif isinstance(default, str):
|
elif isinstance(default, (str, type(None))):
|
||||||
return default
|
return default
|
||||||
else:
|
else:
|
||||||
if error_message:
|
if error_message:
|
||||||
@ -330,3 +330,62 @@ def from_env(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return get_from_env_fn
|
return get_from_env_fn
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def secret_from_env(key: str, /) -> Callable[[], SecretStr]: ...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def secret_from_env(key: str, /, *, default: str) -> Callable[[], SecretStr]: ...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def secret_from_env(
|
||||||
|
key: str, /, *, default: None
|
||||||
|
) -> Callable[[], Optional[SecretStr]]: ...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def secret_from_env(key: str, /, *, error_message: str) -> Callable[[], SecretStr]: ...
|
||||||
|
|
||||||
|
|
||||||
|
def secret_from_env(
|
||||||
|
key: str,
|
||||||
|
/,
|
||||||
|
*,
|
||||||
|
default: Union[str, _NoDefaultType, None] = _NoDefault,
|
||||||
|
error_message: Optional[str] = None,
|
||||||
|
) -> Union[Callable[[], Optional[SecretStr]], Callable[[], SecretStr]]:
|
||||||
|
"""Secret from env.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: The environment variable to look up.
|
||||||
|
default: The default value to return if the environment variable is not set.
|
||||||
|
error_message: the error message which will be raised if the key is not found
|
||||||
|
and no default value is provided.
|
||||||
|
This will be raised as a ValueError.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
factory method that will look up the secret from the environment.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_secret_from_env() -> Optional[SecretStr]:
|
||||||
|
"""Get a value from an environment variable."""
|
||||||
|
if key in os.environ:
|
||||||
|
return SecretStr(os.environ[key])
|
||||||
|
elif isinstance(default, str):
|
||||||
|
return SecretStr(default)
|
||||||
|
elif isinstance(default, type(None)):
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
if error_message:
|
||||||
|
raise ValueError(error_message)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Did not find {key}, please add an environment variable"
|
||||||
|
f" `{key}` which contains it, or pass"
|
||||||
|
f" `{key}` as a named parameter."
|
||||||
|
)
|
||||||
|
|
||||||
|
return get_secret_from_env
|
||||||
|
@ -26,6 +26,7 @@ EXPECTED_ALL = [
|
|||||||
"stringify_value",
|
"stringify_value",
|
||||||
"pre_init",
|
"pre_init",
|
||||||
"from_env",
|
"from_env",
|
||||||
|
"secret_from_env",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,12 +1,13 @@
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from contextlib import AbstractContextManager, nullcontext
|
from contextlib import AbstractContextManager, nullcontext
|
||||||
from typing import Any, Dict, Optional, Tuple, Type, Union
|
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from langchain_core import utils
|
from langchain_core import utils
|
||||||
|
from langchain_core.pydantic_v1 import SecretStr
|
||||||
from langchain_core.utils import (
|
from langchain_core.utils import (
|
||||||
check_package_version,
|
check_package_version,
|
||||||
from_env,
|
from_env,
|
||||||
@ -15,6 +16,7 @@ from langchain_core.utils import (
|
|||||||
)
|
)
|
||||||
from langchain_core.utils._merge import merge_dicts
|
from langchain_core.utils._merge import merge_dicts
|
||||||
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION
|
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION
|
||||||
|
from langchain_core.utils.utils import secret_from_env
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -254,3 +256,107 @@ def test_from_env_with_default_error_message() -> None:
|
|||||||
get_value = from_env(key)
|
get_value = from_env(key)
|
||||||
with pytest.raises(ValueError, match=f"Did not find {key}"):
|
with pytest.raises(ValueError, match=f"Did not find {key}"):
|
||||||
get_value()
|
get_value()
|
||||||
|
|
||||||
|
|
||||||
|
def test_secret_from_env_with_env_variable(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
# Set the environment variable
|
||||||
|
monkeypatch.setenv("TEST_KEY", "secret_value")
|
||||||
|
|
||||||
|
# Get the function
|
||||||
|
get_secret: Callable[[], Optional[SecretStr]] = secret_from_env("TEST_KEY")
|
||||||
|
|
||||||
|
# Assert that it returns the correct value
|
||||||
|
assert get_secret() == SecretStr("secret_value")
|
||||||
|
|
||||||
|
|
||||||
|
def test_secret_from_env_with_default_value(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
# Unset the environment variable
|
||||||
|
monkeypatch.delenv("TEST_KEY", raising=False)
|
||||||
|
|
||||||
|
# Get the function with a default value
|
||||||
|
get_secret: Callable[[], SecretStr] = secret_from_env(
|
||||||
|
"TEST_KEY", default="default_value"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert that it returns the default value
|
||||||
|
assert get_secret() == SecretStr("default_value")
|
||||||
|
|
||||||
|
|
||||||
|
def test_secret_from_env_with_none_default(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
# Unset the environment variable
|
||||||
|
monkeypatch.delenv("TEST_KEY", raising=False)
|
||||||
|
|
||||||
|
# Get the function with a default value of None
|
||||||
|
get_secret: Callable[[], Optional[SecretStr]] = secret_from_env(
|
||||||
|
"TEST_KEY", default=None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert that it returns None
|
||||||
|
assert get_secret() is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_secret_from_env_without_default_raises_error(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
# Unset the environment variable
|
||||||
|
monkeypatch.delenv("TEST_KEY", raising=False)
|
||||||
|
|
||||||
|
# Get the function without a default value
|
||||||
|
get_secret: Callable[[], SecretStr] = secret_from_env("TEST_KEY")
|
||||||
|
|
||||||
|
# Assert that it raises a ValueError with the correct message
|
||||||
|
with pytest.raises(ValueError, match="Did not find TEST_KEY"):
|
||||||
|
get_secret()
|
||||||
|
|
||||||
|
|
||||||
|
def test_secret_from_env_with_custom_error_message(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
# Unset the environment variable
|
||||||
|
monkeypatch.delenv("TEST_KEY", raising=False)
|
||||||
|
|
||||||
|
# Get the function without a default value but with a custom error message
|
||||||
|
get_secret: Callable[[], SecretStr] = secret_from_env(
|
||||||
|
"TEST_KEY", error_message="Custom error message"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert that it raises a ValueError with the custom message
|
||||||
|
with pytest.raises(ValueError, match="Custom error message"):
|
||||||
|
get_secret()
|
||||||
|
|
||||||
|
|
||||||
|
def test_using_secret_from_env_as_default_factory(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
# Set the environment variable
|
||||||
|
monkeypatch.setenv("TEST_KEY", "secret_value")
|
||||||
|
# Get the function
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||||
|
|
||||||
|
class Foo(BaseModel):
|
||||||
|
secret: SecretStr = Field(default_factory=secret_from_env("TEST_KEY"))
|
||||||
|
|
||||||
|
assert Foo().secret.get_secret_value() == "secret_value"
|
||||||
|
|
||||||
|
class Bar(BaseModel):
|
||||||
|
secret: Optional[SecretStr] = Field(
|
||||||
|
default_factory=secret_from_env("TEST_KEY_2", default=None)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert Bar().secret is None
|
||||||
|
|
||||||
|
class Buzz(BaseModel):
|
||||||
|
secret: Optional[SecretStr] = Field(
|
||||||
|
default_factory=secret_from_env("TEST_KEY_2", default="hello")
|
||||||
|
)
|
||||||
|
|
||||||
|
# We know it will be SecretStr rather than Optional[SecretStr]
|
||||||
|
assert Buzz().secret.get_secret_value() == "hello" # type: ignore
|
||||||
|
|
||||||
|
class OhMy(BaseModel):
|
||||||
|
secret: Optional[SecretStr] = Field(
|
||||||
|
default_factory=secret_from_env("FOOFOOFOOBAR")
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Did not find FOOFOOFOOBAR"):
|
||||||
|
OhMy()
|
||||||
|
Loading…
Reference in New Issue
Block a user