From 30fb345342f2ae10d1d8fe10f93c548761a0395e Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Thu, 8 Aug 2024 14:52:35 -0400 Subject: [PATCH] core[minor]: Add from_env utility (#25189) Add a utility that can be used as a default factory The goal will be to start migrating from of the pydantic models to use `from_env` as a default factory if possible. ```python from pydantic import Field, BaseModel from langchain_core.utils import from_env class Foo(BaseModel): name: str = Field(default_factory=from_env('HELLO')) ``` --- libs/core/langchain_core/utils/__init__.py | 2 + libs/core/langchain_core/utils/utils.py | 72 ++++++++++++++++++- .../tests/unit_tests/utils/test_imports.py | 1 + .../core/tests/unit_tests/utils/test_utils.py | 35 +++++++++ 4 files changed, 109 insertions(+), 1 deletion(-) diff --git a/libs/core/langchain_core/utils/__init__.py b/libs/core/langchain_core/utils/__init__.py index 654f81df56b..ce407f11609 100644 --- a/libs/core/langchain_core/utils/__init__.py +++ b/libs/core/langchain_core/utils/__init__.py @@ -22,6 +22,7 @@ from langchain_core.utils.utils import ( build_extra_kwargs, check_package_version, convert_to_secret_str, + from_env, get_pydantic_field_names, guard_import, mock_now, @@ -54,4 +55,5 @@ __all__ = [ "pre_init", "batch_iterate", "abatch_iterate", + "from_env", ] diff --git a/libs/core/langchain_core/utils/utils.py b/libs/core/langchain_core/utils/utils.py index 8711b399cd3..291bdf2adfb 100644 --- a/libs/core/langchain_core/utils/utils.py +++ b/libs/core/langchain_core/utils/utils.py @@ -4,9 +4,10 @@ import contextlib import datetime import functools import importlib +import os import warnings from importlib.metadata import version -from typing import Any, Callable, Dict, Optional, Set, Tuple, Union +from typing import Any, Callable, Dict, Optional, Set, Tuple, Union, overload from packaging.version import parse from requests import HTTPError, Response @@ -260,3 +261,72 @@ def convert_to_secret_str(value: Union[SecretStr, str]) -> SecretStr: if isinstance(value, SecretStr): return value return SecretStr(value) + + +class _NoDefaultType: + """Type to indicate no default value is provided.""" + + pass + + +_NoDefault = _NoDefaultType() + + +@overload +def from_env(key: str, /) -> Callable[[], str]: ... + + +@overload +def from_env(key: str, /, *, default: str) -> Callable[[], str]: ... + + +@overload +def from_env(key: str, /, *, error_message: str) -> Callable[[], str]: ... + + +@overload +def from_env( + key: str, /, *, default: str, error_message: Optional[str] +) -> Callable[[], str]: ... + + +@overload +def from_env( + key: str, /, *, default: None, error_message: Optional[str] +) -> Callable[[], Optional[str]]: ... + + +def from_env( + key: str, + /, + *, + default: Union[str, _NoDefaultType, None] = _NoDefault, + error_message: Optional[str] = None, +) -> Union[Callable[[], str], Callable[[], Optional[str]]]: + """Create a factory method that gets a value from an environment variable. + + Args: + key: The environment variable to look up. + default: The default value to return if the environment variable is not set. + error_message: the error message which will be raised if the key is not found + and no default value is provided. + This will be raised as a ValueError. + """ + + def get_from_env_fn() -> str: # type: ignore + """Get a value from an environment variable.""" + if key in os.environ: + return os.environ[key] + elif isinstance(default, str): + return default + else: + if error_message: + raise ValueError(error_message) + else: + raise ValueError( + f"Did not find {key}, please add an environment variable" + f" `{key}` which contains it, or pass" + f" `{key}` as a named parameter." + ) + + return get_from_env_fn diff --git a/libs/core/tests/unit_tests/utils/test_imports.py b/libs/core/tests/unit_tests/utils/test_imports.py index 8cb909d3f70..a22a8802b56 100644 --- a/libs/core/tests/unit_tests/utils/test_imports.py +++ b/libs/core/tests/unit_tests/utils/test_imports.py @@ -25,6 +25,7 @@ EXPECTED_ALL = [ "comma_list", "stringify_value", "pre_init", + "from_env", ] diff --git a/libs/core/tests/unit_tests/utils/test_utils.py b/libs/core/tests/unit_tests/utils/test_utils.py index 0aed89e1a64..69b52e94fbc 100644 --- a/libs/core/tests/unit_tests/utils/test_utils.py +++ b/libs/core/tests/unit_tests/utils/test_utils.py @@ -1,3 +1,4 @@ +import os import re from contextlib import AbstractContextManager, nullcontext from typing import Any, Dict, Optional, Tuple, Type, Union @@ -8,6 +9,7 @@ import pytest from langchain_core import utils from langchain_core.utils import ( check_package_version, + from_env, get_pydantic_field_names, guard_import, ) @@ -219,3 +221,36 @@ def test_get_pydantic_field_names_v1() -> None: result = get_pydantic_field_names(PydanticModel) expected = {"field1", "field2", "aliased_field", "alias_field"} assert result == expected + + +def test_from_env_with_env_variable() -> None: + key = "TEST_KEY" + value = "test_value" + with patch.dict(os.environ, {key: value}): + get_value = from_env(key) + assert get_value() == value + + +def test_from_env_with_default_value() -> None: + key = "TEST_KEY" + default_value = "default_value" + with patch.dict(os.environ, {}, clear=True): + get_value = from_env(key, default=default_value) + assert get_value() == default_value + + +def test_from_env_with_error_message() -> None: + key = "TEST_KEY" + error_message = "Custom error message" + with patch.dict(os.environ, {}, clear=True): + get_value = from_env(key, error_message=error_message) + with pytest.raises(ValueError, match=error_message): + get_value() + + +def test_from_env_with_default_error_message() -> None: + key = "TEST_KEY" + with patch.dict(os.environ, {}, clear=True): + get_value = from_env(key) + with pytest.raises(ValueError, match=f"Did not find {key}"): + get_value()