From 9a2b0e355ecf3ec81830eb2be4384a7fc69e1145 Mon Sep 17 00:00:00 2001 From: Joan Fontanals Date: Wed, 24 Jan 2024 14:41:19 +0100 Subject: [PATCH] feat: add Jina Embeddings (#1105) Co-authored-by: Fangyin Cheng --- dbgpt/rag/embedding/embeddings.py | 67 +++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/dbgpt/rag/embedding/embeddings.py b/dbgpt/rag/embedding/embeddings.py index 76524a2a7..e9fb38784 100644 --- a/dbgpt/rag/embedding/embeddings.py +++ b/dbgpt/rag/embedding/embeddings.py @@ -361,3 +361,70 @@ class HuggingFaceInferenceAPIEmbeddings(BaseModel, Embeddings): Embeddings for the text. """ return self.embed_documents([text])[0] + + +class JinaEmbeddings(BaseModel, Embeddings): + """ + This class is used to get embeddings for a list of texts using the Jina AI API. + It requires an API key and a model name. The default model name is "jina-embeddings-v2-base-en". + """ + + api_url: Any #: :meta private: + session: Any #: :meta private: + api_key: str + """our API key for the Jina AI API..""" + model_name: str = "jina-embeddings-v2-base-en" + """he name of the model to use for text embeddings. Defaults to "jina-embeddings-v2-base-en".""" + + def __init__(self, **kwargs): + """ + Initialize the JinaEmbeddings. + """ + super().__init__(**kwargs) + try: + import requests + except ImportError: + raise ValueError( + "The requests python package is not installed. Please install it with `pip install requests`" + ) + self.api_url = "https://api.jina.ai/v1/embeddings" + self.session = requests.Session() + self.session.headers.update( + {"Authorization": f"Bearer {self.api_key}", "Accept-Encoding": "identity"} + ) + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Get the embeddings for a list of texts. + + Args: + texts (Documents): A list of texts to get embeddings for. + + Returns: + Embedded texts as List[List[float]], where each inner List[float] + corresponds to a single input text. + """ + # Call Jina AI Embedding API + resp = self.session.post( # type: ignore + self.api_url, json={"input": texts, "model": self.model_name} + ).json() + if "data" not in resp: + raise RuntimeError(resp["detail"]) + + embeddings = resp["data"] + + # Sort resulting embeddings by index + sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore + + # Return just the embeddings + return [result["embedding"] for result in sorted_embeddings] + + def embed_query(self, text: str) -> List[float]: + """Compute query embeddings using a HuggingFace transformer model. + + Args: + text: The text to embed. + + Returns: + Embeddings for the text. + """ + return self.embed_documents([text])[0]