feat: add Jina Embeddings (#1105)

Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
This commit is contained in:
Joan Fontanals 2024-01-24 14:41:19 +01:00 committed by GitHub
parent 8f18478fa5
commit 9a2b0e355e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -361,3 +361,70 @@ class HuggingFaceInferenceAPIEmbeddings(BaseModel, Embeddings):
Embeddings for the text. Embeddings for the text.
""" """
return self.embed_documents([text])[0] return self.embed_documents([text])[0]
class JinaEmbeddings(BaseModel, Embeddings):
"""
This class is used to get embeddings for a list of texts using the Jina AI API.
It requires an API key and a model name. The default model name is "jina-embeddings-v2-base-en".
"""
api_url: Any #: :meta private:
session: Any #: :meta private:
api_key: str
"""our API key for the Jina AI API.."""
model_name: str = "jina-embeddings-v2-base-en"
"""he name of the model to use for text embeddings. Defaults to "jina-embeddings-v2-base-en"."""
def __init__(self, **kwargs):
"""
Initialize the JinaEmbeddings.
"""
super().__init__(**kwargs)
try:
import requests
except ImportError:
raise ValueError(
"The requests python package is not installed. Please install it with `pip install requests`"
)
self.api_url = "https://api.jina.ai/v1/embeddings"
self.session = requests.Session()
self.session.headers.update(
{"Authorization": f"Bearer {self.api_key}", "Accept-Encoding": "identity"}
)
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Get the embeddings for a list of texts.
Args:
texts (Documents): A list of texts to get embeddings for.
Returns:
Embedded texts as List[List[float]], where each inner List[float]
corresponds to a single input text.
"""
# Call Jina AI Embedding API
resp = self.session.post( # type: ignore
self.api_url, json={"input": texts, "model": self.model_name}
).json()
if "data" not in resp:
raise RuntimeError(resp["detail"])
embeddings = resp["data"]
# Sort resulting embeddings by index
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore
# Return just the embeddings
return [result["embedding"] for result in sorted_embeddings]
def embed_query(self, text: str) -> List[float]:
"""Compute query embeddings using a HuggingFace transformer model.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
return self.embed_documents([text])[0]