diff --git a/libs/community/langchain_community/embeddings/__init__.py b/libs/community/langchain_community/embeddings/__init__.py index 3ae7e8ac4f4..9b9deba027c 100644 --- a/libs/community/langchain_community/embeddings/__init__.py +++ b/libs/community/langchain_community/embeddings/__init__.py @@ -54,6 +54,7 @@ from langchain_community.embeddings.javelin_ai_gateway import JavelinAIGatewayEm from langchain_community.embeddings.jina import JinaEmbeddings from langchain_community.embeddings.johnsnowlabs import JohnSnowLabsEmbeddings from langchain_community.embeddings.llamacpp import LlamaCppEmbeddings +from langchain_community.embeddings.llm_rails import LLMRailsEmbeddings from langchain_community.embeddings.localai import LocalAIEmbeddings from langchain_community.embeddings.minimax import MiniMaxEmbeddings from langchain_community.embeddings.mlflow import MlflowEmbeddings @@ -98,6 +99,7 @@ __all__ = [ "GradientEmbeddings", "JinaEmbeddings", "LlamaCppEmbeddings", + "LLMRailsEmbeddings", "HuggingFaceHubEmbeddings", "MlflowEmbeddings", "MlflowAIGatewayEmbeddings", diff --git a/libs/community/langchain_community/embeddings/llm_rails.py b/libs/community/langchain_community/embeddings/llm_rails.py index 6f233d59d33..44bc0171803 100644 --- a/libs/community/langchain_community/embeddings/llm_rails.py +++ b/libs/community/langchain_community/embeddings/llm_rails.py @@ -1,11 +1,10 @@ """ This file is for LLMRails Embedding """ -import logging -import os -from typing import List, Optional +from typing import Dict, List, Optional import requests from langchain_core.embeddings import Embeddings -from langchain_core.pydantic_v1 import BaseModel, Extra +from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr, root_validator +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env class LLMRailsEmbeddings(BaseModel, Embeddings): @@ -29,7 +28,7 @@ class LLMRailsEmbeddings(BaseModel, Embeddings): model: str = "embedding-english-v1" """Model name to use.""" - api_key: Optional[str] = None + api_key: Optional[SecretStr] = None """LLMRails API key.""" class Config: @@ -37,6 +36,15 @@ class LLMRailsEmbeddings(BaseModel, Embeddings): extra = Extra.forbid + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key exists in environment.""" + api_key = convert_to_secret_str( + get_from_dict_or_env(values, "api_key", "LLM_RAILS_API_KEY") + ) + values["api_key"] = api_key + return values + def embed_documents(self, texts: List[str]) -> List[List[float]]: """Call out to Cohere's embedding endpoint. @@ -46,14 +54,9 @@ class LLMRailsEmbeddings(BaseModel, Embeddings): Returns: List of embeddings, one for each text. """ - api_key = self.api_key or os.environ.get("LLM_RAILS_API_KEY") - if api_key is None: - logging.warning("Can't find LLMRails credentials in environment.") - raise ValueError("LLM_RAILS_API_KEY is not set") - response = requests.post( "https://api.llmrails.com/v1/embeddings", - headers={"X-API-KEY": api_key}, + headers={"X-API-KEY": self.api_key.get_secret_value()}, json={"input": texts, "model": self.model}, timeout=60, ) diff --git a/libs/community/tests/unit_tests/embeddings/test_imports.py b/libs/community/tests/unit_tests/embeddings/test_imports.py index cd91e675da5..6aac6609a99 100644 --- a/libs/community/tests/unit_tests/embeddings/test_imports.py +++ b/libs/community/tests/unit_tests/embeddings/test_imports.py @@ -14,6 +14,7 @@ EXPECTED_ALL = [ "GradientEmbeddings", "JinaEmbeddings", "LlamaCppEmbeddings", + "LLMRailsEmbeddings", "HuggingFaceHubEmbeddings", "MlflowAIGatewayEmbeddings", "MlflowEmbeddings", diff --git a/libs/community/tests/unit_tests/embeddings/test_llm_rails.py b/libs/community/tests/unit_tests/embeddings/test_llm_rails.py new file mode 100644 index 00000000000..05a40c726e5 --- /dev/null +++ b/libs/community/tests/unit_tests/embeddings/test_llm_rails.py @@ -0,0 +1,21 @@ +"""Test LLMRailsEmbeddings embeddings""" + +from langchain_core.pydantic_v1 import SecretStr +from pytest import CaptureFixture + +from langchain_community.embeddings import LLMRailsEmbeddings + + +def test_api_key_is_string() -> None: + llm = LLMRailsEmbeddings(api_key="secret-api-key") + assert isinstance(llm.api_key, SecretStr) + + +def test_api_key_masked_when_passed_via_constructor( + capsys: CaptureFixture, +) -> None: + llm = LLMRailsEmbeddings(api_key="secret-api-key") + print(llm.api_key, end="") + captured = capsys.readouterr() + + assert captured.out == "**********"