mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-16 07:24:05 +00:00
feat: add Jina Embeddings (#1105)
Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
This commit is contained in:
parent
8f18478fa5
commit
9a2b0e355e
@ -361,3 +361,70 @@ class HuggingFaceInferenceAPIEmbeddings(BaseModel, Embeddings):
|
||||
Embeddings for the text.
|
||||
"""
|
||||
return self.embed_documents([text])[0]
|
||||
|
||||
|
||||
class JinaEmbeddings(BaseModel, Embeddings):
|
||||
"""
|
||||
This class is used to get embeddings for a list of texts using the Jina AI API.
|
||||
It requires an API key and a model name. The default model name is "jina-embeddings-v2-base-en".
|
||||
"""
|
||||
|
||||
api_url: Any #: :meta private:
|
||||
session: Any #: :meta private:
|
||||
api_key: str
|
||||
"""our API key for the Jina AI API.."""
|
||||
model_name: str = "jina-embeddings-v2-base-en"
|
||||
"""he name of the model to use for text embeddings. Defaults to "jina-embeddings-v2-base-en"."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""
|
||||
Initialize the JinaEmbeddings.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
try:
|
||||
import requests
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"The requests python package is not installed. Please install it with `pip install requests`"
|
||||
)
|
||||
self.api_url = "https://api.jina.ai/v1/embeddings"
|
||||
self.session = requests.Session()
|
||||
self.session.headers.update(
|
||||
{"Authorization": f"Bearer {self.api_key}", "Accept-Encoding": "identity"}
|
||||
)
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Get the embeddings for a list of texts.
|
||||
|
||||
Args:
|
||||
texts (Documents): A list of texts to get embeddings for.
|
||||
|
||||
Returns:
|
||||
Embedded texts as List[List[float]], where each inner List[float]
|
||||
corresponds to a single input text.
|
||||
"""
|
||||
# Call Jina AI Embedding API
|
||||
resp = self.session.post( # type: ignore
|
||||
self.api_url, json={"input": texts, "model": self.model_name}
|
||||
).json()
|
||||
if "data" not in resp:
|
||||
raise RuntimeError(resp["detail"])
|
||||
|
||||
embeddings = resp["data"]
|
||||
|
||||
# Sort resulting embeddings by index
|
||||
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore
|
||||
|
||||
# Return just the embeddings
|
||||
return [result["embedding"] for result in sorted_embeddings]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Compute query embeddings using a HuggingFace transformer model.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
return self.embed_documents([text])[0]
|
||||
|
Loading…
Reference in New Issue
Block a user