feat(embeddings): text-embeddings-inference (#14288)

- **Description:** Added a notebook to illustrate how to use
`text-embeddings-inference` from huggingface. As
`HuggingFaceHubEmbeddings` was using a deprecated client, I made the
most of this PR updating that too.

- **Issue:** #13286 

- **Dependencies**: None

- **Tag maintainer:** @baskaryan
This commit is contained in:
Massimiliano Pronesti
2023-12-06 02:22:05 +01:00
committed by GitHub
parent 85b88c33f3
commit c215a4c9ec
2 changed files with 195 additions and 18 deletions

View File

@@ -1,3 +1,4 @@
import json
from typing import Any, Dict, List, Optional
from langchain_core.embeddings import Embeddings
@@ -5,7 +6,7 @@ from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
from langchain.utils import get_from_dict_or_env
DEFAULT_REPO_ID = "sentence-transformers/all-mpnet-base-v2"
DEFAULT_MODEL = "sentence-transformers/all-mpnet-base-v2"
VALID_TASKS = ("feature-extraction",)
@@ -20,17 +21,19 @@ class HuggingFaceHubEmbeddings(BaseModel, Embeddings):
.. code-block:: python
from langchain.embeddings import HuggingFaceHubEmbeddings
repo_id = "sentence-transformers/all-mpnet-base-v2"
model = "sentence-transformers/all-mpnet-base-v2"
hf = HuggingFaceHubEmbeddings(
repo_id=repo_id,
model=model,
task="feature-extraction",
huggingfacehub_api_token="my-api-key",
)
"""
client: Any #: :meta private:
repo_id: str = DEFAULT_REPO_ID
model: Optional[str] = None
"""Model name to use."""
repo_id: Optional[str] = None
"""Huggingfacehub repository id, for backward compatibility."""
task: Optional[str] = "feature-extraction"
"""Task to call the model with."""
model_kwargs: Optional[dict] = None
@@ -50,22 +53,23 @@ class HuggingFaceHubEmbeddings(BaseModel, Embeddings):
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
)
try:
from huggingface_hub.inference_api import InferenceApi
from huggingface_hub import InferenceClient
repo_id = values["repo_id"]
if not repo_id.startswith("sentence-transformers"):
raise ValueError(
"Currently only 'sentence-transformers' embedding models "
f"are supported. Got invalid 'repo_id' {repo_id}."
)
client = InferenceApi(
repo_id=repo_id,
if values["model"]:
values["repo_id"] = values["model"]
elif values["repo_id"]:
values["model"] = values["repo_id"]
else:
values["model"] = DEFAULT_MODEL
values["repo_id"] = DEFAULT_MODEL
client = InferenceClient(
model=values["model"],
token=huggingfacehub_api_token,
task=values.get("task"),
)
if client.task not in VALID_TASKS:
if values["task"] not in VALID_TASKS:
raise ValueError(
f"Got invalid task {client.task}, "
f"Got invalid task {values['task']}, "
f"currently only {VALID_TASKS} are supported"
)
values["client"] = client
@@ -88,8 +92,10 @@ class HuggingFaceHubEmbeddings(BaseModel, Embeddings):
# replace newlines, which can negatively affect performance.
texts = [text.replace("\n", " ") for text in texts]
_model_kwargs = self.model_kwargs or {}
responses = self.client(inputs=texts, params=_model_kwargs)
return responses
responses = self.client.post(
json={"inputs": texts, "parameters": _model_kwargs, "task": self.task}
)
return json.loads(responses.decode())
def embed_query(self, text: str) -> List[float]:
"""Call out to HuggingFaceHub's embedding endpoint for embedding query text.