Refactor: use SecretStr for embaas embeddings (#15091)

This commit is contained in:
chyroc 2024-01-02 07:24:00 +08:00 committed by GitHub
parent ea6cf0f1b1
commit b440f92d81
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 27 additions and 6 deletions

View File

@ -2,8 +2,8 @@ from typing import Any, Dict, List, Mapping, Optional
import requests
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
from langchain_core.utils import get_from_dict_or_env
from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from requests.adapters import HTTPAdapter, Retry
from typing_extensions import NotRequired, TypedDict
@ -50,7 +50,7 @@ class EmbaasEmbeddings(BaseModel, Embeddings):
"""Instruction used for domain-specific embeddings."""
api_url: str = EMBAAS_API_URL
"""The URL for the embaas embeddings API."""
embaas_api_key: Optional[str] = None
embaas_api_key: Optional[SecretStr] = None
"""max number of retries for requests"""
max_retries: Optional[int] = 3
"""request timeout in seconds"""
@ -64,8 +64,8 @@ class EmbaasEmbeddings(BaseModel, Embeddings):
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
embaas_api_key = get_from_dict_or_env(
values, "embaas_api_key", "EMBAAS_API_KEY"
embaas_api_key = convert_to_secret_str(
get_from_dict_or_env(values, "embaas_api_key", "EMBAAS_API_KEY")
)
values["embaas_api_key"] = embaas_api_key
return values
@ -85,7 +85,7 @@ class EmbaasEmbeddings(BaseModel, Embeddings):
def _handle_request(self, payload: EmbaasEmbeddingsPayload) -> List[List[float]]:
"""Sends a request to the Embaas API and handles the response."""
headers = {
"Authorization": f"Bearer {self.embaas_api_key}",
"Authorization": f"Bearer {self.embaas_api_key.get_secret_value()}",
"Content-Type": "application/json",
}

View File

@ -0,0 +1,21 @@
"""Test EmbaasEmbeddings embeddings"""
from langchain_core.pydantic_v1 import SecretStr
from pytest import CaptureFixture
from langchain_community.embeddings import EmbaasEmbeddings
def test_api_key_is_string() -> None:
llm = EmbaasEmbeddings(embaas_api_key="secret-api-key")
assert isinstance(llm.embaas_api_key, SecretStr)
def test_api_key_masked_when_passed_via_constructor(
capsys: CaptureFixture,
) -> None:
llm = EmbaasEmbeddings(embaas_api_key="secret-api-key")
print(llm.embaas_api_key, end="")
captured = capsys.readouterr()
assert captured.out == "**********"