feat: APIServer supports embeddings (#1256)

This commit is contained in:
Fangyin Cheng
2024-03-05 20:21:37 +08:00
committed by GitHub
parent 5f3ee35804
commit 74ec8e52cd
9 changed files with 414 additions and 40 deletions

View File

@@ -0,0 +1,16 @@
from .embedding_factory import DefaultEmbeddingFactory, EmbeddingFactory
from .embeddings import (
Embeddings,
HuggingFaceEmbeddings,
JinaEmbeddings,
OpenAPIEmbeddings,
)
__ALL__ = [
"OpenAPIEmbeddings",
"Embeddings",
"HuggingFaceEmbeddings",
"JinaEmbeddings",
"EmbeddingFactory",
"DefaultEmbeddingFactory",
]

View File

@@ -2,8 +2,10 @@ import asyncio
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
import aiohttp
import requests
from pydantic import BaseModel, Extra, Field
from dbgpt._private.pydantic import BaseModel, Extra, Field
DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
DEFAULT_INSTRUCT_MODEL = "hkunlp/instructor-large"
@@ -363,6 +365,29 @@ class HuggingFaceInferenceAPIEmbeddings(BaseModel, Embeddings):
return self.embed_documents([text])[0]
def _handle_request_result(res: requests.Response) -> List[List[float]]:
"""Parse the result from a request.
Args:
res: The response from the request.
Returns:
List[List[float]]: The embeddings.
Raises:
RuntimeError: If the response is not successful.
"""
res.raise_for_status()
resp = res.json()
if "data" not in resp:
raise RuntimeError(resp["detail"])
embeddings = resp["data"]
# Sort resulting embeddings by index
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore
# Return just the embeddings
return [result["embedding"] for result in sorted_embeddings]
class JinaEmbeddings(BaseModel, Embeddings):
"""
This class is used to get embeddings for a list of texts using the Jina AI API.
@@ -406,17 +431,8 @@ class JinaEmbeddings(BaseModel, Embeddings):
# Call Jina AI Embedding API
resp = self.session.post( # type: ignore
self.api_url, json={"input": texts, "model": self.model_name}
).json()
if "data" not in resp:
raise RuntimeError(resp["detail"])
embeddings = resp["data"]
# Sort resulting embeddings by index
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore
# Return just the embeddings
return [result["embedding"] for result in sorted_embeddings]
)
return _handle_request_result(res)
def embed_query(self, text: str) -> List[float]:
"""Compute query embeddings using a HuggingFace transformer model.
@@ -428,3 +444,158 @@ class JinaEmbeddings(BaseModel, Embeddings):
Embeddings for the text.
"""
return self.embed_documents([text])[0]
class OpenAPIEmbeddings(BaseModel, Embeddings):
"""This class is used to get embeddings for a list of texts using the API.
This API is compatible with the OpenAI Embedding API.
Examples:
Using OpenAI's API:
.. code-block:: python
from dbgpt.rag.embedding import OpenAPIEmbeddings
openai_embeddings = OpenAPIEmbeddings(
api_url="https://api.openai.com/v1/embeddings",
api_key="your_api_key",
model_name="text-embedding-3-small",
)
texts = ["Hello, world!", "How are you?"]
openai_embeddings.embed_documents(texts)
Using DB-GPT APIServer's embedding API:
To use the DB-GPT APIServer's embedding API, you should deploy DB-GPT according
to the `Cluster Deploy
<https://docs.dbgpt.site/docs/installation/model_service/cluster>`_.
A simple example:
1. Deploy Model Cluster with following command:
.. code-block:: bash
dbgpt start controller --port 8000
2. Deploy Embedding Model Worker with following command:
.. code-block:: bash
dbgpt start worker --model_name text2vec \
--model_path /app/models/text2vec-large-chinese \
--worker_type text2vec \
--port 8003 \
--controller_addr http://127.0.0.1:8000
3. Deploy API Server with following command:
.. code-block:: bash
dbgpt start apiserver --controller_addr http://127.0.0.1:8000 \
--api_keys my_api_token --port 8100
Now, you can use the API server to get embeddings:
.. code-block:: python
from dbgpt.rag.embedding import OpenAPIEmbeddings
openai_embeddings = OpenAPIEmbeddings(
api_url="http://localhost:8100/api/v1/embeddings",
api_key="my_api_token",
model_name="text2vec",
)
texts = ["Hello, world!", "How are you?"]
openai_embeddings.embed_documents(texts)
"""
api_url: str = Field(
default="http://localhost:8100/api/v1/embeddings",
description="The URL of the embeddings API.",
)
api_key: Optional[str] = Field(
default=None, description="The API key for the embeddings API."
)
model_name: str = Field(
default="text2vec", description="The name of the model to use."
)
timeout: int = Field(
default=60, description="The timeout for the request in seconds."
)
session: requests.Session = None
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def __init__(self, **kwargs):
"""Initialize the OpenAPIEmbeddings."""
super().__init__(**kwargs)
try:
import requests
except ImportError:
raise ValueError(
"The requests python package is not installed. "
"Please install it with `pip install requests`"
)
self.session = requests.Session()
self.session.headers.update({"Authorization": f"Bearer {self.api_key}"})
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Get the embeddings for a list of texts.
Args:
texts (Documents): A list of texts to get embeddings for.
Returns:
Embedded texts as List[List[float]], where each inner List[float]
corresponds to a single input text.
"""
# Call OpenAI Embedding API
res = self.session.post( # type: ignore
self.api_url,
json={"input": texts, "model": self.model_name},
timeout=self.timeout,
)
return _handle_request_result(res)
def embed_query(self, text: str) -> List[float]:
"""Compute query embeddings using a OpenAPI embedding model.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
return self.embed_documents([text])[0]
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Asynchronous Embed search docs.
Args:
texts: A list of texts to get embeddings for.
Returns:
List[List[float]]: Embedded texts as List[List[float]], where each inner
List[float] corresponds to a single input text.
"""
headers = {"Authorization": f"Bearer {self.api_key}"}
async with aiohttp.ClientSession(
headers=headers, timeout=aiohttp.ClientTimeout(total=self.timeout)
) as session:
async with session.post(
self.api_url, json={"input": texts, "model": self.model_name}
) as resp:
resp.raise_for_status()
data = await resp.json()
if "data" not in data:
raise RuntimeError(data["detail"])
embeddings = data["data"]
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"])
return [result["embedding"] for result in sorted_embeddings]
async def aembed_query(self, text: str) -> List[float]:
"""Asynchronous Embed query text."""
embeddings = await self.aembed_documents([text])
return embeddings[0]