feat: APIServer supports embeddings (#1256)

This commit is contained in:
Fangyin Cheng 2024-03-05 20:21:37 +08:00 committed by GitHub
parent 5f3ee35804
commit 74ec8e52cd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 414 additions and 40 deletions

View File

@ -23,6 +23,8 @@ from fastchat.protocol.openai_api_protocol import (
ChatCompletionStreamResponse,
ChatMessage,
DeltaMessage,
EmbeddingsRequest,
EmbeddingsResponse,
ModelCard,
ModelList,
ModelPermission,
@ -51,6 +53,7 @@ class APIServerException(Exception):
class APISettings(BaseModel):
api_keys: Optional[List[str]] = None
embedding_bach_size: int = 4
api_settings = APISettings()
@ -181,27 +184,29 @@ class APIServer(BaseComponent):
return controller
async def get_model_instances_or_raise(
self, model_name: str
self, model_name: str, worker_type: str = "llm"
) -> List[ModelInstance]:
"""Get healthy model instances with request model name
Args:
model_name (str): Model name
worker_type (str, optional): Worker type. Defaults to "llm".
Raises:
APIServerException: If can't get healthy model instances with request model name
"""
registry = self.get_model_registry()
registry_model_name = f"{model_name}@llm"
suffix = f"@{worker_type}"
registry_model_name = f"{model_name}{suffix}"
model_instances = await registry.get_all_instances(
registry_model_name, healthy_only=True
)
if not model_instances:
all_instances = await registry.get_all_model_instances(healthy_only=True)
models = [
ins.model_name.split("@llm")[0]
ins.model_name.split(suffix)[0]
for ins in all_instances
if ins.model_name.endswith("@llm")
if ins.model_name.endswith(suffix)
]
if models:
models = "&&".join(models)
@ -336,6 +341,25 @@ class APIServer(BaseComponent):
return ChatCompletionResponse(model=model_name, choices=choices, usage=usage)
async def embeddings_generate(
self, model: str, texts: List[str]
) -> List[List[float]]:
"""Generate embeddings
Args:
model (str): Model name
texts (List[str]): Texts to embed
Returns:
List[List[float]]: The embeddings of texts
"""
worker_manager: WorkerManager = self.get_worker_manager()
params = {
"input": texts,
"model": model,
}
return await worker_manager.embeddings(params)
def get_api_server() -> APIServer:
api_server = global_system_app.get_component(
@ -389,6 +413,40 @@ async def create_chat_completion(
return await api_server.chat_completion_generate(request.model, params, request.n)
@router.post("/v1/embeddings", dependencies=[Depends(check_api_key)])
async def create_embeddings(
request: EmbeddingsRequest, api_server: APIServer = Depends(get_api_server)
):
await api_server.get_model_instances_or_raise(request.model, worker_type="text2vec")
texts = request.input
if isinstance(texts, str):
texts = [texts]
batch_size = api_settings.embedding_bach_size
batches = [
texts[i : min(i + batch_size, len(texts))]
for i in range(0, len(texts), batch_size)
]
data = []
async_tasks = []
for num_batch, batch in enumerate(batches):
async_tasks.append(api_server.embeddings_generate(request.model, batch))
# Request all embeddings in parallel
batch_embeddings: List[List[List[float]]] = await asyncio.gather(*async_tasks)
for num_batch, embeddings in enumerate(batch_embeddings):
data += [
{
"object": "embedding",
"embedding": emb,
"index": num_batch * batch_size + i,
}
for i, emb in enumerate(embeddings)
]
return EmbeddingsResponse(data=data, model=request.model, usage=UsageInfo()).dict(
exclude_none=True
)
def _initialize_all(controller_addr: str, system_app: SystemApp):
from dbgpt.model.cluster.controller.controller import ModelRegistryClient
from dbgpt.model.cluster.worker.manager import _DefaultWorkerManagerFactory
@ -427,6 +485,7 @@ def initialize_apiserver(
host: str = None,
port: int = None,
api_keys: List[str] = None,
embedding_batch_size: Optional[int] = None,
):
global global_system_app
global api_settings
@ -434,13 +493,6 @@ def initialize_apiserver(
if not app:
embedded_mod = False
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
allow_headers=["*"],
)
if not system_app:
system_app = SystemApp(app)
@ -449,6 +501,9 @@ def initialize_apiserver(
if api_keys:
api_settings.api_keys = api_keys
if embedding_batch_size:
api_settings.embedding_bach_size = embedding_batch_size
app.include_router(router, prefix="/api", tags=["APIServer"])
@app.exception_handler(APIServerException)
@ -464,7 +519,15 @@ def initialize_apiserver(
if not embedded_mod:
import uvicorn
uvicorn.run(app, host=host, port=port, log_level="info")
# https://github.com/encode/starlette/issues/617
cors_app = CORSMiddleware(
app=app,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
allow_headers=["*"],
)
uvicorn.run(cors_app, host=host, port=port, log_level="info")
def run_apiserver():
@ -488,6 +551,7 @@ def run_apiserver():
host=apiserver_params.host,
port=apiserver_params.port,
api_keys=api_keys,
embedding_batch_size=apiserver_params.embedding_batch_size,
)

View File

@ -113,6 +113,9 @@ class ModelAPIServerParameters(BaseParameters):
default=None,
metadata={"help": "Optional list of comma separated API keys"},
)
embedding_batch_size: Optional[int] = field(
default=None, metadata={"help": "Embedding batch size"}
)
log_level: Optional[str] = field(
default=None,

View File

@ -0,0 +1,16 @@
from .embedding_factory import DefaultEmbeddingFactory, EmbeddingFactory
from .embeddings import (
Embeddings,
HuggingFaceEmbeddings,
JinaEmbeddings,
OpenAPIEmbeddings,
)
__ALL__ = [
"OpenAPIEmbeddings",
"Embeddings",
"HuggingFaceEmbeddings",
"JinaEmbeddings",
"EmbeddingFactory",
"DefaultEmbeddingFactory",
]

View File

@ -2,8 +2,10 @@ import asyncio
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
import aiohttp
import requests
from pydantic import BaseModel, Extra, Field
from dbgpt._private.pydantic import BaseModel, Extra, Field
DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
DEFAULT_INSTRUCT_MODEL = "hkunlp/instructor-large"
@ -363,6 +365,29 @@ class HuggingFaceInferenceAPIEmbeddings(BaseModel, Embeddings):
return self.embed_documents([text])[0]
def _handle_request_result(res: requests.Response) -> List[List[float]]:
"""Parse the result from a request.
Args:
res: The response from the request.
Returns:
List[List[float]]: The embeddings.
Raises:
RuntimeError: If the response is not successful.
"""
res.raise_for_status()
resp = res.json()
if "data" not in resp:
raise RuntimeError(resp["detail"])
embeddings = resp["data"]
# Sort resulting embeddings by index
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore
# Return just the embeddings
return [result["embedding"] for result in sorted_embeddings]
class JinaEmbeddings(BaseModel, Embeddings):
"""
This class is used to get embeddings for a list of texts using the Jina AI API.
@ -406,17 +431,8 @@ class JinaEmbeddings(BaseModel, Embeddings):
# Call Jina AI Embedding API
resp = self.session.post( # type: ignore
self.api_url, json={"input": texts, "model": self.model_name}
).json()
if "data" not in resp:
raise RuntimeError(resp["detail"])
embeddings = resp["data"]
# Sort resulting embeddings by index
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore
# Return just the embeddings
return [result["embedding"] for result in sorted_embeddings]
)
return _handle_request_result(res)
def embed_query(self, text: str) -> List[float]:
"""Compute query embeddings using a HuggingFace transformer model.
@ -428,3 +444,158 @@ class JinaEmbeddings(BaseModel, Embeddings):
Embeddings for the text.
"""
return self.embed_documents([text])[0]
class OpenAPIEmbeddings(BaseModel, Embeddings):
"""This class is used to get embeddings for a list of texts using the API.
This API is compatible with the OpenAI Embedding API.
Examples:
Using OpenAI's API:
.. code-block:: python
from dbgpt.rag.embedding import OpenAPIEmbeddings
openai_embeddings = OpenAPIEmbeddings(
api_url="https://api.openai.com/v1/embeddings",
api_key="your_api_key",
model_name="text-embedding-3-small",
)
texts = ["Hello, world!", "How are you?"]
openai_embeddings.embed_documents(texts)
Using DB-GPT APIServer's embedding API:
To use the DB-GPT APIServer's embedding API, you should deploy DB-GPT according
to the `Cluster Deploy
<https://docs.dbgpt.site/docs/installation/model_service/cluster>`_.
A simple example:
1. Deploy Model Cluster with following command:
.. code-block:: bash
dbgpt start controller --port 8000
2. Deploy Embedding Model Worker with following command:
.. code-block:: bash
dbgpt start worker --model_name text2vec \
--model_path /app/models/text2vec-large-chinese \
--worker_type text2vec \
--port 8003 \
--controller_addr http://127.0.0.1:8000
3. Deploy API Server with following command:
.. code-block:: bash
dbgpt start apiserver --controller_addr http://127.0.0.1:8000 \
--api_keys my_api_token --port 8100
Now, you can use the API server to get embeddings:
.. code-block:: python
from dbgpt.rag.embedding import OpenAPIEmbeddings
openai_embeddings = OpenAPIEmbeddings(
api_url="http://localhost:8100/api/v1/embeddings",
api_key="my_api_token",
model_name="text2vec",
)
texts = ["Hello, world!", "How are you?"]
openai_embeddings.embed_documents(texts)
"""
api_url: str = Field(
default="http://localhost:8100/api/v1/embeddings",
description="The URL of the embeddings API.",
)
api_key: Optional[str] = Field(
default=None, description="The API key for the embeddings API."
)
model_name: str = Field(
default="text2vec", description="The name of the model to use."
)
timeout: int = Field(
default=60, description="The timeout for the request in seconds."
)
session: requests.Session = None
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def __init__(self, **kwargs):
"""Initialize the OpenAPIEmbeddings."""
super().__init__(**kwargs)
try:
import requests
except ImportError:
raise ValueError(
"The requests python package is not installed. "
"Please install it with `pip install requests`"
)
self.session = requests.Session()
self.session.headers.update({"Authorization": f"Bearer {self.api_key}"})
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Get the embeddings for a list of texts.
Args:
texts (Documents): A list of texts to get embeddings for.
Returns:
Embedded texts as List[List[float]], where each inner List[float]
corresponds to a single input text.
"""
# Call OpenAI Embedding API
res = self.session.post( # type: ignore
self.api_url,
json={"input": texts, "model": self.model_name},
timeout=self.timeout,
)
return _handle_request_result(res)
def embed_query(self, text: str) -> List[float]:
"""Compute query embeddings using a OpenAPI embedding model.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
return self.embed_documents([text])[0]
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Asynchronous Embed search docs.
Args:
texts: A list of texts to get embeddings for.
Returns:
List[List[float]]: Embedded texts as List[List[float]], where each inner
List[float] corresponds to a single input text.
"""
headers = {"Authorization": f"Bearer {self.api_key}"}
async with aiohttp.ClientSession(
headers=headers, timeout=aiohttp.ClientTimeout(total=self.timeout)
) as session:
async with session.post(
self.api_url, json={"input": texts, "model": self.model_name}
) as resp:
resp.raise_for_status()
data = await resp.json()
if "data" not in data:
raise RuntimeError(data["detail"])
embeddings = data["data"]
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"])
return [result["embedding"] for result in sorted_embeddings]
async def aembed_query(self, text: str) -> List[float]:
"""Asynchronous Embed query text."""
embeddings = await self.aembed_documents([text])
return embeddings[0]

View File

@ -2,6 +2,7 @@ import logging
import math
import time
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Callable, List, Optional
from pydantic import BaseModel, Field
@ -24,11 +25,13 @@ class VectorStoreConfig(BaseModel):
)
password: Optional[str] = Field(
default=None,
description="The password of vector store, if not set, will use the default password.",
description="The password of vector store, if not set, will use the default "
"password.",
)
embedding_fn: Optional[Any] = Field(
default=None,
description="The embedding function of vector store, if not set, will use the default embedding function.",
description="The embedding function of vector store, if not set, will use the "
"default embedding function.",
)
max_chunks_once_load: int = Field(
default=10,
@ -36,6 +39,11 @@ class VectorStoreConfig(BaseModel):
"large, you can set this value to a larger number to speed up the loading "
"process. Default is 10.",
)
max_threads: int = Field(
default=1,
description="The max number of threads to use. Default is 1. If you set this "
"bigger than 1, please make sure your vector store is thread-safe.",
)
class VectorStoreBase(ABC):
@ -52,12 +60,13 @@ class VectorStoreBase(ABC):
pass
def load_document_with_limit(
self, chunks: List[Chunk], max_chunks_once_load: int = 10
self, chunks: List[Chunk], max_chunks_once_load: int = 10, max_threads: int = 1
) -> List[str]:
"""load document in vector database with limit.
Args:
chunks: document chunks.
max_chunks_once_load: Max number of chunks to load at once.
max_threads: Max number of threads to use.
Return:
"""
# Group the chunks into chunks of size max_chunks
@ -65,14 +74,22 @@ class VectorStoreBase(ABC):
chunks[i : i + max_chunks_once_load]
for i in range(0, len(chunks), max_chunks_once_load)
]
logger.info(f"Loading {len(chunks)} chunks in {len(chunk_groups)} groups")
logger.info(
f"Loading {len(chunks)} chunks in {len(chunk_groups)} groups with "
f"{max_threads} threads."
)
ids = []
loaded_cnt = 0
start_time = time.time()
for chunk_group in chunk_groups:
ids.extend(self.load_document(chunk_group))
loaded_cnt += len(chunk_group)
logger.info(f"Loaded {loaded_cnt} chunks, total {len(chunks)} chunks.")
with ThreadPoolExecutor(max_workers=max_threads) as executor:
tasks = []
for chunk_group in chunk_groups:
tasks.append(executor.submit(self.load_document, chunk_group))
for future in tasks:
success_ids = future.result()
ids.extend(success_ids)
loaded_cnt += len(success_ids)
logger.info(f"Loaded {loaded_cnt} chunks, total {len(chunks)} chunks.")
logger.info(
f"Loaded {len(chunks)} chunks in {time.time() - start_time} seconds"
)

View File

@ -1,5 +1,5 @@
import os
from typing import Any, Callable, List, Optional
from typing import Any, List, Optional
from dbgpt.rag.chunk import Chunk
from dbgpt.storage import vector_store
@ -65,7 +65,9 @@ class VectorStoreConnector:
Return chunk ids.
"""
return self.client.load_document_with_limit(
chunks, self._vector_store_config.max_chunks_once_load
chunks,
self._vector_store_config.max_chunks_once_load,
self._vector_store_config.max_threads,
)
def similar_search(self, doc: str, topk: int) -> List[Chunk]:

View File

@ -10,7 +10,7 @@ The call of multi-model services is compatible with the OpenAI interface, and th
## Start apiserver
After deploying the model service, you need to start the API Server. By default, the model API Server uses port `8100` to start.
```python
```bash
dbgpt start apiserver --controller_addr http://127.0.0.1:8000 --api_keys EMPTY
```
@ -25,7 +25,7 @@ After the apiserver is started, the service call can be verified. First, let's l
:::tip
List models
:::
```python
```bash
curl http://127.0.0.1:8100/api/v1/models \
-H "Authorization: Bearer EMPTY" \
-H "Content-Type: application/json"
@ -34,17 +34,31 @@ curl http://127.0.0.1:8100/api/v1/models \
:::tip
Chat
:::
```python
```bash
curl http://127.0.0.1:8100/api/v1/chat/completions \
-H "Authorization: Bearer EMPTY" \
-H "Content-Type: application/json" \
-d '{"model": "vicuna-13b-v1.5", "messages": [{"role": "user", "content": "hello"}]}'
```
:::tip
Embedding
:::
```bash
curl http://127.0.0.1:8100/api/v1/embeddings \
-H "Authorization: Bearer EMPTY" \
-H "Content-Type: application/json" \
-d '{
"model": "text2vec",
"input": "Hello world!"
}'
```
## Verify via OpenAI SDK
```python
```bash
import openai
openai.api_key = "EMPTY"
openai.api_base = "http://127.0.0.1:8100/api/v1"

View File

@ -1,7 +1,7 @@
import asyncio
import os
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH, ROOT_PATH
from dbgpt.rag.chunk_manager import ChunkParameters
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
from dbgpt.rag.knowledge.factory import KnowledgeFactory
@ -37,7 +37,7 @@ def _create_vector_connector():
async def main():
file_path = "docs/docs/awel.md"
file_path = os.path.join(ROOT_PATH, "docs/docs/awel/awel.md")
knowledge = KnowledgeFactory.from_file_path(file_path)
vector_connector = _create_vector_connector()
chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE")

View File

@ -0,0 +1,87 @@
"""A RAG example using the OpenAPIEmbeddings.
Example:
Test with `OpenAI embeddings
<https://platform.openai.com/docs/api-reference/embeddings/create>`_.
.. code-block:: shell
export API_SERVER_BASE_URL=${OPENAI_API_BASE:-"https://api.openai.com/v1"}
export API_SERVER_API_KEY="${OPENAI_API_KEY}"
export API_SERVER_EMBEDDINGS_MODEL="text-embedding-ada-002"
python examples/rag/rag_embedding_api_example.py
Test with DB-GPT `API Server
<https://docs.dbgpt.site/docs/installation/advanced_usage/OpenAI_SDK_call#start-apiserver>`_.
.. code-block:: shell
export API_SERVER_BASE_URL="http://localhost:8100/api/v1"
export API_SERVER_API_KEY="your_api_key"
export API_SERVER_EMBEDDINGS_MODEL="text2vec"
python examples/rag/rag_embedding_api_example.py
"""
import asyncio
import os
from typing import Optional
from dbgpt.configs.model_config import PILOT_PATH, ROOT_PATH
from dbgpt.rag.chunk_manager import ChunkParameters
from dbgpt.rag.embedding import OpenAPIEmbeddings
from dbgpt.rag.knowledge.factory import KnowledgeFactory
from dbgpt.serve.rag.assembler.embedding import EmbeddingAssembler
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
from dbgpt.storage.vector_store.connector import VectorStoreConnector
def _create_embeddings(
api_url: str = None, api_key: Optional[str] = None, model_name: Optional[str] = None
) -> OpenAPIEmbeddings:
if not api_url:
api_server_base_url = os.getenv(
"API_SERVER_BASE_URL", "http://localhost:8100/api/v1/"
)
api_url = f"{api_server_base_url}/embeddings"
if not api_key:
api_key = os.getenv("API_SERVER_API_KEY")
if not model_name:
model_name = os.getenv("API_SERVER_EMBEDDINGS_MODEL", "text2vec")
return OpenAPIEmbeddings(api_url=api_url, api_key=api_key, model_name=model_name)
def _create_vector_connector():
"""Create vector connector."""
return VectorStoreConnector.from_default(
"Chroma",
vector_store_config=ChromaVectorConfig(
name="example_embedding_api_vector_store_name",
persist_path=os.path.join(PILOT_PATH, "data"),
),
embedding_fn=_create_embeddings(),
)
async def main():
file_path = os.path.join(ROOT_PATH, "docs/docs/awel/awel.md")
knowledge = KnowledgeFactory.from_file_path(file_path)
vector_connector = _create_vector_connector()
chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE")
# get embedding assembler
assembler = EmbeddingAssembler.load_from_knowledge(
knowledge=knowledge,
chunk_parameters=chunk_parameters,
vector_store_connector=vector_connector,
)
assembler.persist()
# get embeddings retriever
retriever = assembler.as_retriever(3)
chunks = await retriever.aretrieve_with_scores("what is awel talk about", 0.3)
print(f"embedding rag example results:{chunks}")
if __name__ == "__main__":
asyncio.run(main())