From 429a0ee7fd2de9e411d6fa32af5675b7f4e25cf0 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Thu, 8 Aug 2024 16:41:58 -0400 Subject: [PATCH] core[minor]: Add factory for looking up secrets from the env (#25198) Add factory method for looking secrets from the env. --- libs/core/langchain_core/utils/__init__.py | 2 + libs/core/langchain_core/utils/utils.py | 63 +++++++++- .../tests/unit_tests/utils/test_imports.py | 1 + .../core/tests/unit_tests/utils/test_utils.py | 108 +++++++++++++++++- 4 files changed, 171 insertions(+), 3 deletions(-) diff --git a/libs/core/langchain_core/utils/__init__.py b/libs/core/langchain_core/utils/__init__.py index ce407f11609..2e560b21e96 100644 --- a/libs/core/langchain_core/utils/__init__.py +++ b/libs/core/langchain_core/utils/__init__.py @@ -27,6 +27,7 @@ from langchain_core.utils.utils import ( guard_import, mock_now, raise_for_status_with_text, + secret_from_env, xor_args, ) @@ -56,4 +57,5 @@ __all__ = [ "batch_iterate", "abatch_iterate", "from_env", + "secret_from_env", ] diff --git a/libs/core/langchain_core/utils/utils.py b/libs/core/langchain_core/utils/utils.py index 291bdf2adfb..960646ecb70 100644 --- a/libs/core/langchain_core/utils/utils.py +++ b/libs/core/langchain_core/utils/utils.py @@ -313,11 +313,11 @@ def from_env( This will be raised as a ValueError. """ - def get_from_env_fn() -> str: # type: ignore + def get_from_env_fn() -> Optional[str]: """Get a value from an environment variable.""" if key in os.environ: return os.environ[key] - elif isinstance(default, str): + elif isinstance(default, (str, type(None))): return default else: if error_message: @@ -330,3 +330,62 @@ def from_env( ) return get_from_env_fn + + +@overload +def secret_from_env(key: str, /) -> Callable[[], SecretStr]: ... + + +@overload +def secret_from_env(key: str, /, *, default: str) -> Callable[[], SecretStr]: ... + + +@overload +def secret_from_env( + key: str, /, *, default: None +) -> Callable[[], Optional[SecretStr]]: ... + + +@overload +def secret_from_env(key: str, /, *, error_message: str) -> Callable[[], SecretStr]: ... + + +def secret_from_env( + key: str, + /, + *, + default: Union[str, _NoDefaultType, None] = _NoDefault, + error_message: Optional[str] = None, +) -> Union[Callable[[], Optional[SecretStr]], Callable[[], SecretStr]]: + """Secret from env. + + Args: + key: The environment variable to look up. + default: The default value to return if the environment variable is not set. + error_message: the error message which will be raised if the key is not found + and no default value is provided. + This will be raised as a ValueError. + + Returns: + factory method that will look up the secret from the environment. + """ + + def get_secret_from_env() -> Optional[SecretStr]: + """Get a value from an environment variable.""" + if key in os.environ: + return SecretStr(os.environ[key]) + elif isinstance(default, str): + return SecretStr(default) + elif isinstance(default, type(None)): + return None + else: + if error_message: + raise ValueError(error_message) + else: + raise ValueError( + f"Did not find {key}, please add an environment variable" + f" `{key}` which contains it, or pass" + f" `{key}` as a named parameter." + ) + + return get_secret_from_env diff --git a/libs/core/tests/unit_tests/utils/test_imports.py b/libs/core/tests/unit_tests/utils/test_imports.py index a22a8802b56..f33491ed295 100644 --- a/libs/core/tests/unit_tests/utils/test_imports.py +++ b/libs/core/tests/unit_tests/utils/test_imports.py @@ -26,6 +26,7 @@ EXPECTED_ALL = [ "stringify_value", "pre_init", "from_env", + "secret_from_env", ] diff --git a/libs/core/tests/unit_tests/utils/test_utils.py b/libs/core/tests/unit_tests/utils/test_utils.py index 69b52e94fbc..88d3b056ea2 100644 --- a/libs/core/tests/unit_tests/utils/test_utils.py +++ b/libs/core/tests/unit_tests/utils/test_utils.py @@ -1,12 +1,13 @@ import os import re from contextlib import AbstractContextManager, nullcontext -from typing import Any, Dict, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, Optional, Tuple, Type, Union from unittest.mock import patch import pytest from langchain_core import utils +from langchain_core.pydantic_v1 import SecretStr from langchain_core.utils import ( check_package_version, from_env, @@ -15,6 +16,7 @@ from langchain_core.utils import ( ) from langchain_core.utils._merge import merge_dicts from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION +from langchain_core.utils.utils import secret_from_env @pytest.mark.parametrize( @@ -254,3 +256,107 @@ def test_from_env_with_default_error_message() -> None: get_value = from_env(key) with pytest.raises(ValueError, match=f"Did not find {key}"): get_value() + + +def test_secret_from_env_with_env_variable(monkeypatch: pytest.MonkeyPatch) -> None: + # Set the environment variable + monkeypatch.setenv("TEST_KEY", "secret_value") + + # Get the function + get_secret: Callable[[], Optional[SecretStr]] = secret_from_env("TEST_KEY") + + # Assert that it returns the correct value + assert get_secret() == SecretStr("secret_value") + + +def test_secret_from_env_with_default_value(monkeypatch: pytest.MonkeyPatch) -> None: + # Unset the environment variable + monkeypatch.delenv("TEST_KEY", raising=False) + + # Get the function with a default value + get_secret: Callable[[], SecretStr] = secret_from_env( + "TEST_KEY", default="default_value" + ) + + # Assert that it returns the default value + assert get_secret() == SecretStr("default_value") + + +def test_secret_from_env_with_none_default(monkeypatch: pytest.MonkeyPatch) -> None: + # Unset the environment variable + monkeypatch.delenv("TEST_KEY", raising=False) + + # Get the function with a default value of None + get_secret: Callable[[], Optional[SecretStr]] = secret_from_env( + "TEST_KEY", default=None + ) + + # Assert that it returns None + assert get_secret() is None + + +def test_secret_from_env_without_default_raises_error( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Unset the environment variable + monkeypatch.delenv("TEST_KEY", raising=False) + + # Get the function without a default value + get_secret: Callable[[], SecretStr] = secret_from_env("TEST_KEY") + + # Assert that it raises a ValueError with the correct message + with pytest.raises(ValueError, match="Did not find TEST_KEY"): + get_secret() + + +def test_secret_from_env_with_custom_error_message( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Unset the environment variable + monkeypatch.delenv("TEST_KEY", raising=False) + + # Get the function without a default value but with a custom error message + get_secret: Callable[[], SecretStr] = secret_from_env( + "TEST_KEY", error_message="Custom error message" + ) + + # Assert that it raises a ValueError with the custom message + with pytest.raises(ValueError, match="Custom error message"): + get_secret() + + +def test_using_secret_from_env_as_default_factory( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Set the environment variable + monkeypatch.setenv("TEST_KEY", "secret_value") + # Get the function + from langchain_core.pydantic_v1 import BaseModel, Field + + class Foo(BaseModel): + secret: SecretStr = Field(default_factory=secret_from_env("TEST_KEY")) + + assert Foo().secret.get_secret_value() == "secret_value" + + class Bar(BaseModel): + secret: Optional[SecretStr] = Field( + default_factory=secret_from_env("TEST_KEY_2", default=None) + ) + + assert Bar().secret is None + + class Buzz(BaseModel): + secret: Optional[SecretStr] = Field( + default_factory=secret_from_env("TEST_KEY_2", default="hello") + ) + + # We know it will be SecretStr rather than Optional[SecretStr] + assert Buzz().secret.get_secret_value() == "hello" # type: ignore + + class OhMy(BaseModel): + secret: Optional[SecretStr] = Field( + default_factory=secret_from_env("FOOFOOFOOBAR") + ) + + with pytest.raises(ValueError, match="Did not find FOOFOOFOOBAR"): + OhMy()