Refactor: use SecretStr for embaas embeddings (#15091)

This commit is contained in:
chyroc
2024-01-02 07:24:00 +08:00
committed by GitHub
parent ea6cf0f1b1
commit b440f92d81
2 changed files with 27 additions and 6 deletions

View File

@@ -2,8 +2,8 @@ from typing import Any, Dict, List, Mapping, Optional
import requests
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
from langchain_core.utils import get_from_dict_or_env
from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from requests.adapters import HTTPAdapter, Retry
from typing_extensions import NotRequired, TypedDict
@@ -50,7 +50,7 @@ class EmbaasEmbeddings(BaseModel, Embeddings):
"""Instruction used for domain-specific embeddings."""
api_url: str = EMBAAS_API_URL
"""The URL for the embaas embeddings API."""
embaas_api_key: Optional[str] = None
embaas_api_key: Optional[SecretStr] = None
"""max number of retries for requests"""
max_retries: Optional[int] = 3
"""request timeout in seconds"""
@@ -64,8 +64,8 @@ class EmbaasEmbeddings(BaseModel, Embeddings):
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
embaas_api_key = get_from_dict_or_env(
values, "embaas_api_key", "EMBAAS_API_KEY"
embaas_api_key = convert_to_secret_str(
get_from_dict_or_env(values, "embaas_api_key", "EMBAAS_API_KEY")
)
values["embaas_api_key"] = embaas_api_key
return values
@@ -85,7 +85,7 @@ class EmbaasEmbeddings(BaseModel, Embeddings):
def _handle_request(self, payload: EmbaasEmbeddingsPayload) -> List[List[float]]:
"""Sends a request to the Embaas API and handles the response."""
headers = {
"Authorization": f"Bearer {self.embaas_api_key}",
"Authorization": f"Bearer {self.embaas_api_key.get_secret_value()}",
"Content-Type": "application/json",
}