From b440f92d8139a10d8c12a77cf3b301782e2009fc Mon Sep 17 00:00:00 2001 From: chyroc Date: Tue, 2 Jan 2024 07:24:00 +0800 Subject: [PATCH] Refactor: use SecretStr for embaas embeddings (#15091) --- .../langchain_community/embeddings/embaas.py | 12 +++++------ .../unit_tests/embeddings/test_embaas.py | 21 +++++++++++++++++++ 2 files changed, 27 insertions(+), 6 deletions(-) create mode 100644 libs/community/tests/unit_tests/embeddings/test_embaas.py diff --git a/libs/community/langchain_community/embeddings/embaas.py b/libs/community/langchain_community/embeddings/embaas.py index 5800126b366..2f0b31f4439 100644 --- a/libs/community/langchain_community/embeddings/embaas.py +++ b/libs/community/langchain_community/embeddings/embaas.py @@ -2,8 +2,8 @@ from typing import Any, Dict, List, Mapping, Optional import requests from langchain_core.embeddings import Embeddings -from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator -from langchain_core.utils import get_from_dict_or_env +from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr, root_validator +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env from requests.adapters import HTTPAdapter, Retry from typing_extensions import NotRequired, TypedDict @@ -50,7 +50,7 @@ class EmbaasEmbeddings(BaseModel, Embeddings): """Instruction used for domain-specific embeddings.""" api_url: str = EMBAAS_API_URL """The URL for the embaas embeddings API.""" - embaas_api_key: Optional[str] = None + embaas_api_key: Optional[SecretStr] = None """max number of retries for requests""" max_retries: Optional[int] = 3 """request timeout in seconds""" @@ -64,8 +64,8 @@ class EmbaasEmbeddings(BaseModel, Embeddings): @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - embaas_api_key = get_from_dict_or_env( - values, "embaas_api_key", "EMBAAS_API_KEY" + embaas_api_key = convert_to_secret_str( + get_from_dict_or_env(values, "embaas_api_key", "EMBAAS_API_KEY") ) values["embaas_api_key"] = embaas_api_key return values @@ -85,7 +85,7 @@ class EmbaasEmbeddings(BaseModel, Embeddings): def _handle_request(self, payload: EmbaasEmbeddingsPayload) -> List[List[float]]: """Sends a request to the Embaas API and handles the response.""" headers = { - "Authorization": f"Bearer {self.embaas_api_key}", + "Authorization": f"Bearer {self.embaas_api_key.get_secret_value()}", "Content-Type": "application/json", } diff --git a/libs/community/tests/unit_tests/embeddings/test_embaas.py b/libs/community/tests/unit_tests/embeddings/test_embaas.py new file mode 100644 index 00000000000..05420ff42fa --- /dev/null +++ b/libs/community/tests/unit_tests/embeddings/test_embaas.py @@ -0,0 +1,21 @@ +"""Test EmbaasEmbeddings embeddings""" + +from langchain_core.pydantic_v1 import SecretStr +from pytest import CaptureFixture + +from langchain_community.embeddings import EmbaasEmbeddings + + +def test_api_key_is_string() -> None: + llm = EmbaasEmbeddings(embaas_api_key="secret-api-key") + assert isinstance(llm.embaas_api_key, SecretStr) + + +def test_api_key_masked_when_passed_via_constructor( + capsys: CaptureFixture, +) -> None: + llm = EmbaasEmbeddings(embaas_api_key="secret-api-key") + print(llm.embaas_api_key, end="") + captured = capsys.readouterr() + + assert captured.out == "**********"