Refactor: use SecretStr for llm_rails embeddings (#15090)

This commit is contained in:
chyroc 2024-01-02 07:24:50 +08:00 committed by GitHub
parent b440f92d81
commit 32e96a471c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 38 additions and 11 deletions

View File

@ -54,6 +54,7 @@ from langchain_community.embeddings.javelin_ai_gateway import JavelinAIGatewayEm
from langchain_community.embeddings.jina import JinaEmbeddings from langchain_community.embeddings.jina import JinaEmbeddings
from langchain_community.embeddings.johnsnowlabs import JohnSnowLabsEmbeddings from langchain_community.embeddings.johnsnowlabs import JohnSnowLabsEmbeddings
from langchain_community.embeddings.llamacpp import LlamaCppEmbeddings from langchain_community.embeddings.llamacpp import LlamaCppEmbeddings
from langchain_community.embeddings.llm_rails import LLMRailsEmbeddings
from langchain_community.embeddings.localai import LocalAIEmbeddings from langchain_community.embeddings.localai import LocalAIEmbeddings
from langchain_community.embeddings.minimax import MiniMaxEmbeddings from langchain_community.embeddings.minimax import MiniMaxEmbeddings
from langchain_community.embeddings.mlflow import MlflowEmbeddings from langchain_community.embeddings.mlflow import MlflowEmbeddings
@ -98,6 +99,7 @@ __all__ = [
"GradientEmbeddings", "GradientEmbeddings",
"JinaEmbeddings", "JinaEmbeddings",
"LlamaCppEmbeddings", "LlamaCppEmbeddings",
"LLMRailsEmbeddings",
"HuggingFaceHubEmbeddings", "HuggingFaceHubEmbeddings",
"MlflowEmbeddings", "MlflowEmbeddings",
"MlflowAIGatewayEmbeddings", "MlflowAIGatewayEmbeddings",

View File

@ -1,11 +1,10 @@
""" This file is for LLMRails Embedding """ """ This file is for LLMRails Embedding """
import logging from typing import Dict, List, Optional
import os
from typing import List, Optional
import requests import requests
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
class LLMRailsEmbeddings(BaseModel, Embeddings): class LLMRailsEmbeddings(BaseModel, Embeddings):
@ -29,7 +28,7 @@ class LLMRailsEmbeddings(BaseModel, Embeddings):
model: str = "embedding-english-v1" model: str = "embedding-english-v1"
"""Model name to use.""" """Model name to use."""
api_key: Optional[str] = None api_key: Optional[SecretStr] = None
"""LLMRails API key.""" """LLMRails API key."""
class Config: class Config:
@ -37,6 +36,15 @@ class LLMRailsEmbeddings(BaseModel, Embeddings):
extra = Extra.forbid extra = Extra.forbid
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key exists in environment."""
api_key = convert_to_secret_str(
get_from_dict_or_env(values, "api_key", "LLM_RAILS_API_KEY")
)
values["api_key"] = api_key
return values
def embed_documents(self, texts: List[str]) -> List[List[float]]: def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Call out to Cohere's embedding endpoint. """Call out to Cohere's embedding endpoint.
@ -46,14 +54,9 @@ class LLMRailsEmbeddings(BaseModel, Embeddings):
Returns: Returns:
List of embeddings, one for each text. List of embeddings, one for each text.
""" """
api_key = self.api_key or os.environ.get("LLM_RAILS_API_KEY")
if api_key is None:
logging.warning("Can't find LLMRails credentials in environment.")
raise ValueError("LLM_RAILS_API_KEY is not set")
response = requests.post( response = requests.post(
"https://api.llmrails.com/v1/embeddings", "https://api.llmrails.com/v1/embeddings",
headers={"X-API-KEY": api_key}, headers={"X-API-KEY": self.api_key.get_secret_value()},
json={"input": texts, "model": self.model}, json={"input": texts, "model": self.model},
timeout=60, timeout=60,
) )

View File

@ -14,6 +14,7 @@ EXPECTED_ALL = [
"GradientEmbeddings", "GradientEmbeddings",
"JinaEmbeddings", "JinaEmbeddings",
"LlamaCppEmbeddings", "LlamaCppEmbeddings",
"LLMRailsEmbeddings",
"HuggingFaceHubEmbeddings", "HuggingFaceHubEmbeddings",
"MlflowAIGatewayEmbeddings", "MlflowAIGatewayEmbeddings",
"MlflowEmbeddings", "MlflowEmbeddings",

View File

@ -0,0 +1,21 @@
"""Test LLMRailsEmbeddings embeddings"""
from langchain_core.pydantic_v1 import SecretStr
from pytest import CaptureFixture
from langchain_community.embeddings import LLMRailsEmbeddings
def test_api_key_is_string() -> None:
llm = LLMRailsEmbeddings(api_key="secret-api-key")
assert isinstance(llm.api_key, SecretStr)
def test_api_key_masked_when_passed_via_constructor(
capsys: CaptureFixture,
) -> None:
llm = LLMRailsEmbeddings(api_key="secret-api-key")
print(llm.api_key, end="")
captured = capsys.readouterr()
assert captured.out == "**********"