mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 06:39:52 +00:00
Community[patch]use secret str in Tavily and HuggingFaceInferenceEmbeddings (#16109)
So the api keys don't show up in repr's Still need to do tests
This commit is contained in:
parent
f3601b0aaf
commit
e5cf1e2414
@ -2,7 +2,7 @@ from typing import Any, Dict, List, Optional
|
|||||||
|
|
||||||
import requests
|
import requests
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Extra, Field
|
from langchain_core.pydantic_v1 import BaseModel, Extra, Field, SecretStr
|
||||||
|
|
||||||
DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
|
DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
|
||||||
DEFAULT_INSTRUCT_MODEL = "hkunlp/instructor-large"
|
DEFAULT_INSTRUCT_MODEL = "hkunlp/instructor-large"
|
||||||
@ -275,7 +275,7 @@ class HuggingFaceInferenceAPIEmbeddings(BaseModel, Embeddings):
|
|||||||
Requires a HuggingFace Inference API key and a model name.
|
Requires a HuggingFace Inference API key and a model name.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
api_key: str
|
api_key: SecretStr
|
||||||
"""Your API key for the HuggingFace Inference API."""
|
"""Your API key for the HuggingFace Inference API."""
|
||||||
model_name: str = "sentence-transformers/all-MiniLM-L6-v2"
|
model_name: str = "sentence-transformers/all-MiniLM-L6-v2"
|
||||||
"""The name of the model to use for text embeddings."""
|
"""The name of the model to use for text embeddings."""
|
||||||
@ -297,7 +297,7 @@ class HuggingFaceInferenceAPIEmbeddings(BaseModel, Embeddings):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def _headers(self) -> dict:
|
def _headers(self) -> dict:
|
||||||
return {"Authorization": f"Bearer {self.api_key}"}
|
return {"Authorization": f"Bearer {self.api_key.get_secret_value()}"}
|
||||||
|
|
||||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
"""Get the embeddings for a list of texts.
|
"""Get the embeddings for a list of texts.
|
||||||
|
@ -7,7 +7,7 @@ from typing import Dict, List, Optional
|
|||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import requests
|
import requests
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
|
from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr, root_validator
|
||||||
from langchain_core.utils import get_from_dict_or_env
|
from langchain_core.utils import get_from_dict_or_env
|
||||||
|
|
||||||
TAVILY_API_URL = "https://api.tavily.com"
|
TAVILY_API_URL = "https://api.tavily.com"
|
||||||
@ -16,7 +16,7 @@ TAVILY_API_URL = "https://api.tavily.com"
|
|||||||
class TavilySearchAPIWrapper(BaseModel):
|
class TavilySearchAPIWrapper(BaseModel):
|
||||||
"""Wrapper for Tavily Search API."""
|
"""Wrapper for Tavily Search API."""
|
||||||
|
|
||||||
tavily_api_key: str
|
tavily_api_key: SecretStr
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
@ -45,7 +45,7 @@ class TavilySearchAPIWrapper(BaseModel):
|
|||||||
include_images: Optional[bool] = False,
|
include_images: Optional[bool] = False,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
params = {
|
params = {
|
||||||
"api_key": self.tavily_api_key,
|
"api_key": self.tavily_api_key.get_secret_value(),
|
||||||
"query": query,
|
"query": query,
|
||||||
"max_results": max_results,
|
"max_results": max_results,
|
||||||
"search_depth": search_depth,
|
"search_depth": search_depth,
|
||||||
@ -126,7 +126,7 @@ class TavilySearchAPIWrapper(BaseModel):
|
|||||||
# Function to perform the API call
|
# Function to perform the API call
|
||||||
async def fetch() -> str:
|
async def fetch() -> str:
|
||||||
params = {
|
params = {
|
||||||
"api_key": self.tavily_api_key,
|
"api_key": self.tavily_api_key.get_secret_value(),
|
||||||
"query": query,
|
"query": query,
|
||||||
"max_results": max_results,
|
"max_results": max_results,
|
||||||
"search_depth": search_depth,
|
"search_depth": search_depth,
|
||||||
|
@ -0,0 +1,7 @@
|
|||||||
|
from langchain_community.embeddings.huggingface import HuggingFaceInferenceAPIEmbeddings
|
||||||
|
|
||||||
|
|
||||||
|
def test_hugginggface_inferenceapi_embedding_documents_init() -> None:
|
||||||
|
"""Test huggingface embeddings."""
|
||||||
|
embedding = HuggingFaceInferenceAPIEmbeddings(api_key="abcd123")
|
||||||
|
assert "abcd123" not in repr(embedding)
|
7
libs/community/tests/unit_tests/utilities/test_tavily.py
Normal file
7
libs/community/tests/unit_tests/utilities/test_tavily.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
from langchain_community.utilities.tavily_search import TavilySearchAPIWrapper
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_wrapper_api_key_not_visible() -> None:
|
||||||
|
"""Test that an exception is raised if the API key is not present."""
|
||||||
|
wrapper = TavilySearchAPIWrapper(tavily_api_key="abcd123")
|
||||||
|
assert "abcd123" not in repr(wrapper)
|
@ -97,6 +97,7 @@ class TestResult(dict):
|
|||||||
for col in df.columns
|
for col in df.columns
|
||||||
if col.startswith("inputs.")
|
if col.startswith("inputs.")
|
||||||
or col.startswith("outputs.")
|
or col.startswith("outputs.")
|
||||||
|
or col in {"input", "output"}
|
||||||
or col.startswith("reference")
|
or col.startswith("reference")
|
||||||
]
|
]
|
||||||
return df.describe(include="all").drop(to_drop, axis=1)
|
return df.describe(include="all").drop(to_drop, axis=1)
|
||||||
|
Loading…
Reference in New Issue
Block a user