fix(embeddings): huggingface hub embeddings and TEI (#14489)

**Description:** This PR fixes `HuggingFaceHubEmbeddings` by making the
API token optional (as in the client beneath). Most models don't require
one. I also updated the notebook for TEI (text-embeddings-inference)
accordingly as requested here #14288. In addition, I fixed a mistake in
the POST call parameters.

**Tag maintainers:** @baskaryan
This commit is contained in:
Massimiliano Pronesti
2023-12-13 02:21:52 +01:00
committed by GitHub
parent 5da79e150b
commit 6080c98108
2 changed files with 36 additions and 38 deletions

View File

@@ -1,9 +1,9 @@
import json
import os
from typing import Any, Dict, List, Optional
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
from langchain_core.utils import get_from_dict_or_env
DEFAULT_MODEL = "sentence-transformers/all-mpnet-base-v2"
VALID_TASKS = ("feature-extraction",)
@@ -48,9 +48,10 @@ class HuggingFaceHubEmbeddings(BaseModel, Embeddings):
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
huggingfacehub_api_token = get_from_dict_or_env(
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
huggingfacehub_api_token = values["huggingfacehub_api_token"] or os.getenv(
"HUGGINGFACEHUB_API_TOKEN"
)
try:
from huggingface_hub import InferenceClient
@@ -92,7 +93,7 @@ class HuggingFaceHubEmbeddings(BaseModel, Embeddings):
texts = [text.replace("\n", " ") for text in texts]
_model_kwargs = self.model_kwargs or {}
responses = self.client.post(
json={"inputs": texts, "parameters": _model_kwargs, "task": self.task}
json={"inputs": texts, "parameters": _model_kwargs}, task=self.task
)
return json.loads(responses.decode())