diff --git a/libs/community/langchain_community/tools/edenai/audio_speech_to_text.py b/libs/community/langchain_community/tools/edenai/audio_speech_to_text.py index dc772ece593..342f33d2b2a 100644 --- a/libs/community/langchain_community/tools/edenai/audio_speech_to_text.py +++ b/libs/community/langchain_community/tools/edenai/audio_speech_to_text.py @@ -29,8 +29,6 @@ class EdenAiSpeechToTextTool(EdenaiTool): You can find your token here: https://app.edenai.run/admin/account/settings """ - edenai_api_key: Optional[str] = None - name: str = "edenai_speech_to_text" description = ( "A wrapper around edenai Services speech to text " diff --git a/libs/community/langchain_community/tools/edenai/edenai_base_tool.py b/libs/community/langchain_community/tools/edenai/edenai_base_tool.py index 7e8a35ede9c..2f83ef37bc4 100644 --- a/libs/community/langchain_community/tools/edenai/edenai_base_tool.py +++ b/libs/community/langchain_community/tools/edenai/edenai_base_tool.py @@ -6,9 +6,9 @@ from typing import Any, Dict, List, Optional import requests from langchain_core.callbacks import CallbackManagerForToolRun -from langchain_core.pydantic_v1 import root_validator +from langchain_core.pydantic_v1 import Field, SecretStr from langchain_core.tools import BaseTool -from langchain_core.utils import get_from_dict_or_env +from langchain_core.utils import secret_from_env logger = logging.getLogger(__name__) @@ -23,20 +23,14 @@ class EdenaiTool(BaseTool): feature: str subfeature: str - edenai_api_key: Optional[str] = None + edenai_api_key: SecretStr = Field( + default_factory=secret_from_env("EDENAI_API_KEY", default=None) + ) is_async: bool = False providers: List[str] """provider to use for the API call.""" - @root_validator(allow_reuse=True) - def validate_environment(cls, values: Dict) -> Dict: - """Validate that api key exists in environment.""" - values["edenai_api_key"] = get_from_dict_or_env( - values, "edenai_api_key", "EDENAI_API_KEY" - ) - return values - @staticmethod def get_user_agent() -> str: from langchain_community import __version__ @@ -54,11 +48,8 @@ class EdenaiTool(BaseTool): requests.Response: The response from the EdenAI API call. """ - - # faire l'API call - headers = { - "Authorization": f"Bearer {self.edenai_api_key}", + "Authorization": f"Bearer {self.edenai_api_key.get_secret_value()}", "User-Agent": self.get_user_agent(), } diff --git a/libs/community/scripts/check_pydantic.sh b/libs/community/scripts/check_pydantic.sh index 4a5cfaf89ee..1b091d0f4eb 100755 --- a/libs/community/scripts/check_pydantic.sh +++ b/libs/community/scripts/check_pydantic.sh @@ -20,7 +20,7 @@ count=$(git grep -E '(@root_validator)|(@validator)|(@pre_init)' -- "*.py" | wc # PRs that increase the current count will not be accepted. # PRs that decrease update the code in the repository # and allow decreasing the count of are welcome! -current_count=337 +current_count=336 if [ "$count" -gt "$current_count" ]; then echo "The PR seems to be introducing new usage of @root_validator and/or @field_validator." diff --git a/libs/community/tests/unit_tests/tools/eden_ai/test_tools.py b/libs/community/tests/unit_tests/tools/eden_ai/test_tools.py index f3976ae8c0f..f32aec69552 100644 --- a/libs/community/tests/unit_tests/tools/eden_ai/test_tools.py +++ b/libs/community/tests/unit_tests/tools/eden_ai/test_tools.py @@ -6,7 +6,9 @@ import pytest from langchain_community.tools.edenai import EdenAiTextModerationTool tool = EdenAiTextModerationTool( # type: ignore[call-arg] - providers=["openai"], language="en", edenai_api_key="fake_key" + providers=["openai"], + language="en", + edenai_api_key="fake_key", # type: ignore[arg-type] ) diff --git a/libs/core/tests/unit_tests/utils/test_utils.py b/libs/core/tests/unit_tests/utils/test_utils.py index 88d3b056ea2..7905bfb62dc 100644 --- a/libs/core/tests/unit_tests/utils/test_utils.py +++ b/libs/core/tests/unit_tests/utils/test_utils.py @@ -328,14 +328,17 @@ def test_secret_from_env_with_custom_error_message( def test_using_secret_from_env_as_default_factory( monkeypatch: pytest.MonkeyPatch, ) -> None: - # Set the environment variable - monkeypatch.setenv("TEST_KEY", "secret_value") - # Get the function from langchain_core.pydantic_v1 import BaseModel, Field class Foo(BaseModel): secret: SecretStr = Field(default_factory=secret_from_env("TEST_KEY")) + # Pass the secret as a parameter + foo = Foo(secret="super_secret") # type: ignore[arg-type] + assert foo.secret.get_secret_value() == "super_secret" + + # Set the environment variable + monkeypatch.setenv("TEST_KEY", "secret_value") assert Foo().secret.get_secret_value() == "secret_value" class Bar(BaseModel):