mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-12 20:53:48 +00:00
feat: APIServer supports embeddings (#1256)
This commit is contained in:
@@ -0,0 +1,16 @@
|
||||
from .embedding_factory import DefaultEmbeddingFactory, EmbeddingFactory
|
||||
from .embeddings import (
|
||||
Embeddings,
|
||||
HuggingFaceEmbeddings,
|
||||
JinaEmbeddings,
|
||||
OpenAPIEmbeddings,
|
||||
)
|
||||
|
||||
__ALL__ = [
|
||||
"OpenAPIEmbeddings",
|
||||
"Embeddings",
|
||||
"HuggingFaceEmbeddings",
|
||||
"JinaEmbeddings",
|
||||
"EmbeddingFactory",
|
||||
"DefaultEmbeddingFactory",
|
||||
]
|
||||
|
@@ -2,8 +2,10 @@ import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
from pydantic import BaseModel, Extra, Field
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Extra, Field
|
||||
|
||||
DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
|
||||
DEFAULT_INSTRUCT_MODEL = "hkunlp/instructor-large"
|
||||
@@ -363,6 +365,29 @@ class HuggingFaceInferenceAPIEmbeddings(BaseModel, Embeddings):
|
||||
return self.embed_documents([text])[0]
|
||||
|
||||
|
||||
def _handle_request_result(res: requests.Response) -> List[List[float]]:
|
||||
"""Parse the result from a request.
|
||||
|
||||
Args:
|
||||
res: The response from the request.
|
||||
|
||||
Returns:
|
||||
List[List[float]]: The embeddings.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the response is not successful.
|
||||
"""
|
||||
res.raise_for_status()
|
||||
resp = res.json()
|
||||
if "data" not in resp:
|
||||
raise RuntimeError(resp["detail"])
|
||||
embeddings = resp["data"]
|
||||
# Sort resulting embeddings by index
|
||||
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore
|
||||
# Return just the embeddings
|
||||
return [result["embedding"] for result in sorted_embeddings]
|
||||
|
||||
|
||||
class JinaEmbeddings(BaseModel, Embeddings):
|
||||
"""
|
||||
This class is used to get embeddings for a list of texts using the Jina AI API.
|
||||
@@ -406,17 +431,8 @@ class JinaEmbeddings(BaseModel, Embeddings):
|
||||
# Call Jina AI Embedding API
|
||||
resp = self.session.post( # type: ignore
|
||||
self.api_url, json={"input": texts, "model": self.model_name}
|
||||
).json()
|
||||
if "data" not in resp:
|
||||
raise RuntimeError(resp["detail"])
|
||||
|
||||
embeddings = resp["data"]
|
||||
|
||||
# Sort resulting embeddings by index
|
||||
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore
|
||||
|
||||
# Return just the embeddings
|
||||
return [result["embedding"] for result in sorted_embeddings]
|
||||
)
|
||||
return _handle_request_result(res)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Compute query embeddings using a HuggingFace transformer model.
|
||||
@@ -428,3 +444,158 @@ class JinaEmbeddings(BaseModel, Embeddings):
|
||||
Embeddings for the text.
|
||||
"""
|
||||
return self.embed_documents([text])[0]
|
||||
|
||||
|
||||
class OpenAPIEmbeddings(BaseModel, Embeddings):
|
||||
"""This class is used to get embeddings for a list of texts using the API.
|
||||
|
||||
This API is compatible with the OpenAI Embedding API.
|
||||
|
||||
Examples:
|
||||
|
||||
Using OpenAI's API:
|
||||
.. code-block:: python
|
||||
|
||||
from dbgpt.rag.embedding import OpenAPIEmbeddings
|
||||
|
||||
openai_embeddings = OpenAPIEmbeddings(
|
||||
api_url="https://api.openai.com/v1/embeddings",
|
||||
api_key="your_api_key",
|
||||
model_name="text-embedding-3-small",
|
||||
)
|
||||
texts = ["Hello, world!", "How are you?"]
|
||||
openai_embeddings.embed_documents(texts)
|
||||
|
||||
Using DB-GPT APIServer's embedding API:
|
||||
To use the DB-GPT APIServer's embedding API, you should deploy DB-GPT according
|
||||
to the `Cluster Deploy
|
||||
<https://docs.dbgpt.site/docs/installation/model_service/cluster>`_.
|
||||
|
||||
A simple example:
|
||||
1. Deploy Model Cluster with following command:
|
||||
.. code-block:: bash
|
||||
|
||||
dbgpt start controller --port 8000
|
||||
|
||||
2. Deploy Embedding Model Worker with following command:
|
||||
.. code-block:: bash
|
||||
|
||||
dbgpt start worker --model_name text2vec \
|
||||
--model_path /app/models/text2vec-large-chinese \
|
||||
--worker_type text2vec \
|
||||
--port 8003 \
|
||||
--controller_addr http://127.0.0.1:8000
|
||||
|
||||
3. Deploy API Server with following command:
|
||||
.. code-block:: bash
|
||||
|
||||
dbgpt start apiserver --controller_addr http://127.0.0.1:8000 \
|
||||
--api_keys my_api_token --port 8100
|
||||
|
||||
Now, you can use the API server to get embeddings:
|
||||
.. code-block:: python
|
||||
|
||||
from dbgpt.rag.embedding import OpenAPIEmbeddings
|
||||
|
||||
openai_embeddings = OpenAPIEmbeddings(
|
||||
api_url="http://localhost:8100/api/v1/embeddings",
|
||||
api_key="my_api_token",
|
||||
model_name="text2vec",
|
||||
)
|
||||
texts = ["Hello, world!", "How are you?"]
|
||||
openai_embeddings.embed_documents(texts)
|
||||
|
||||
"""
|
||||
|
||||
api_url: str = Field(
|
||||
default="http://localhost:8100/api/v1/embeddings",
|
||||
description="The URL of the embeddings API.",
|
||||
)
|
||||
api_key: Optional[str] = Field(
|
||||
default=None, description="The API key for the embeddings API."
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="text2vec", description="The name of the model to use."
|
||||
)
|
||||
timeout: int = Field(
|
||||
default=60, description="The timeout for the request in seconds."
|
||||
)
|
||||
|
||||
session: requests.Session = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialize the OpenAPIEmbeddings."""
|
||||
super().__init__(**kwargs)
|
||||
try:
|
||||
import requests
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"The requests python package is not installed. "
|
||||
"Please install it with `pip install requests`"
|
||||
)
|
||||
self.session = requests.Session()
|
||||
self.session.headers.update({"Authorization": f"Bearer {self.api_key}"})
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Get the embeddings for a list of texts.
|
||||
|
||||
Args:
|
||||
texts (Documents): A list of texts to get embeddings for.
|
||||
|
||||
Returns:
|
||||
Embedded texts as List[List[float]], where each inner List[float]
|
||||
corresponds to a single input text.
|
||||
"""
|
||||
# Call OpenAI Embedding API
|
||||
res = self.session.post( # type: ignore
|
||||
self.api_url,
|
||||
json={"input": texts, "model": self.model_name},
|
||||
timeout=self.timeout,
|
||||
)
|
||||
return _handle_request_result(res)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Compute query embeddings using a OpenAPI embedding model.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
return self.embed_documents([text])[0]
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Asynchronous Embed search docs.
|
||||
|
||||
Args:
|
||||
texts: A list of texts to get embeddings for.
|
||||
|
||||
Returns:
|
||||
List[List[float]]: Embedded texts as List[List[float]], where each inner
|
||||
List[float] corresponds to a single input text.
|
||||
"""
|
||||
headers = {"Authorization": f"Bearer {self.api_key}"}
|
||||
async with aiohttp.ClientSession(
|
||||
headers=headers, timeout=aiohttp.ClientTimeout(total=self.timeout)
|
||||
) as session:
|
||||
async with session.post(
|
||||
self.api_url, json={"input": texts, "model": self.model_name}
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
data = await resp.json()
|
||||
if "data" not in data:
|
||||
raise RuntimeError(data["detail"])
|
||||
embeddings = data["data"]
|
||||
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"])
|
||||
return [result["embedding"] for result in sorted_embeddings]
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
"""Asynchronous Embed query text."""
|
||||
embeddings = await self.aembed_documents([text])
|
||||
return embeddings[0]
|
||||
|
Reference in New Issue
Block a user