mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-04 01:50:08 +00:00
feat: APIServer supports embeddings (#1256)
This commit is contained in:
parent
5f3ee35804
commit
74ec8e52cd
@ -23,6 +23,8 @@ from fastchat.protocol.openai_api_protocol import (
|
|||||||
ChatCompletionStreamResponse,
|
ChatCompletionStreamResponse,
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
DeltaMessage,
|
DeltaMessage,
|
||||||
|
EmbeddingsRequest,
|
||||||
|
EmbeddingsResponse,
|
||||||
ModelCard,
|
ModelCard,
|
||||||
ModelList,
|
ModelList,
|
||||||
ModelPermission,
|
ModelPermission,
|
||||||
@ -51,6 +53,7 @@ class APIServerException(Exception):
|
|||||||
|
|
||||||
class APISettings(BaseModel):
|
class APISettings(BaseModel):
|
||||||
api_keys: Optional[List[str]] = None
|
api_keys: Optional[List[str]] = None
|
||||||
|
embedding_bach_size: int = 4
|
||||||
|
|
||||||
|
|
||||||
api_settings = APISettings()
|
api_settings = APISettings()
|
||||||
@ -181,27 +184,29 @@ class APIServer(BaseComponent):
|
|||||||
return controller
|
return controller
|
||||||
|
|
||||||
async def get_model_instances_or_raise(
|
async def get_model_instances_or_raise(
|
||||||
self, model_name: str
|
self, model_name: str, worker_type: str = "llm"
|
||||||
) -> List[ModelInstance]:
|
) -> List[ModelInstance]:
|
||||||
"""Get healthy model instances with request model name
|
"""Get healthy model instances with request model name
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_name (str): Model name
|
model_name (str): Model name
|
||||||
|
worker_type (str, optional): Worker type. Defaults to "llm".
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
APIServerException: If can't get healthy model instances with request model name
|
APIServerException: If can't get healthy model instances with request model name
|
||||||
"""
|
"""
|
||||||
registry = self.get_model_registry()
|
registry = self.get_model_registry()
|
||||||
registry_model_name = f"{model_name}@llm"
|
suffix = f"@{worker_type}"
|
||||||
|
registry_model_name = f"{model_name}{suffix}"
|
||||||
model_instances = await registry.get_all_instances(
|
model_instances = await registry.get_all_instances(
|
||||||
registry_model_name, healthy_only=True
|
registry_model_name, healthy_only=True
|
||||||
)
|
)
|
||||||
if not model_instances:
|
if not model_instances:
|
||||||
all_instances = await registry.get_all_model_instances(healthy_only=True)
|
all_instances = await registry.get_all_model_instances(healthy_only=True)
|
||||||
models = [
|
models = [
|
||||||
ins.model_name.split("@llm")[0]
|
ins.model_name.split(suffix)[0]
|
||||||
for ins in all_instances
|
for ins in all_instances
|
||||||
if ins.model_name.endswith("@llm")
|
if ins.model_name.endswith(suffix)
|
||||||
]
|
]
|
||||||
if models:
|
if models:
|
||||||
models = "&&".join(models)
|
models = "&&".join(models)
|
||||||
@ -336,6 +341,25 @@ class APIServer(BaseComponent):
|
|||||||
|
|
||||||
return ChatCompletionResponse(model=model_name, choices=choices, usage=usage)
|
return ChatCompletionResponse(model=model_name, choices=choices, usage=usage)
|
||||||
|
|
||||||
|
async def embeddings_generate(
|
||||||
|
self, model: str, texts: List[str]
|
||||||
|
) -> List[List[float]]:
|
||||||
|
"""Generate embeddings
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (str): Model name
|
||||||
|
texts (List[str]): Texts to embed
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[List[float]]: The embeddings of texts
|
||||||
|
"""
|
||||||
|
worker_manager: WorkerManager = self.get_worker_manager()
|
||||||
|
params = {
|
||||||
|
"input": texts,
|
||||||
|
"model": model,
|
||||||
|
}
|
||||||
|
return await worker_manager.embeddings(params)
|
||||||
|
|
||||||
|
|
||||||
def get_api_server() -> APIServer:
|
def get_api_server() -> APIServer:
|
||||||
api_server = global_system_app.get_component(
|
api_server = global_system_app.get_component(
|
||||||
@ -389,6 +413,40 @@ async def create_chat_completion(
|
|||||||
return await api_server.chat_completion_generate(request.model, params, request.n)
|
return await api_server.chat_completion_generate(request.model, params, request.n)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/v1/embeddings", dependencies=[Depends(check_api_key)])
|
||||||
|
async def create_embeddings(
|
||||||
|
request: EmbeddingsRequest, api_server: APIServer = Depends(get_api_server)
|
||||||
|
):
|
||||||
|
await api_server.get_model_instances_or_raise(request.model, worker_type="text2vec")
|
||||||
|
texts = request.input
|
||||||
|
if isinstance(texts, str):
|
||||||
|
texts = [texts]
|
||||||
|
batch_size = api_settings.embedding_bach_size
|
||||||
|
batches = [
|
||||||
|
texts[i : min(i + batch_size, len(texts))]
|
||||||
|
for i in range(0, len(texts), batch_size)
|
||||||
|
]
|
||||||
|
data = []
|
||||||
|
async_tasks = []
|
||||||
|
for num_batch, batch in enumerate(batches):
|
||||||
|
async_tasks.append(api_server.embeddings_generate(request.model, batch))
|
||||||
|
|
||||||
|
# Request all embeddings in parallel
|
||||||
|
batch_embeddings: List[List[List[float]]] = await asyncio.gather(*async_tasks)
|
||||||
|
for num_batch, embeddings in enumerate(batch_embeddings):
|
||||||
|
data += [
|
||||||
|
{
|
||||||
|
"object": "embedding",
|
||||||
|
"embedding": emb,
|
||||||
|
"index": num_batch * batch_size + i,
|
||||||
|
}
|
||||||
|
for i, emb in enumerate(embeddings)
|
||||||
|
]
|
||||||
|
return EmbeddingsResponse(data=data, model=request.model, usage=UsageInfo()).dict(
|
||||||
|
exclude_none=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _initialize_all(controller_addr: str, system_app: SystemApp):
|
def _initialize_all(controller_addr: str, system_app: SystemApp):
|
||||||
from dbgpt.model.cluster.controller.controller import ModelRegistryClient
|
from dbgpt.model.cluster.controller.controller import ModelRegistryClient
|
||||||
from dbgpt.model.cluster.worker.manager import _DefaultWorkerManagerFactory
|
from dbgpt.model.cluster.worker.manager import _DefaultWorkerManagerFactory
|
||||||
@ -427,6 +485,7 @@ def initialize_apiserver(
|
|||||||
host: str = None,
|
host: str = None,
|
||||||
port: int = None,
|
port: int = None,
|
||||||
api_keys: List[str] = None,
|
api_keys: List[str] = None,
|
||||||
|
embedding_batch_size: Optional[int] = None,
|
||||||
):
|
):
|
||||||
global global_system_app
|
global global_system_app
|
||||||
global api_settings
|
global api_settings
|
||||||
@ -434,13 +493,6 @@ def initialize_apiserver(
|
|||||||
if not app:
|
if not app:
|
||||||
embedded_mod = False
|
embedded_mod = False
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
app.add_middleware(
|
|
||||||
CORSMiddleware,
|
|
||||||
allow_origins=["*"],
|
|
||||||
allow_credentials=True,
|
|
||||||
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
|
|
||||||
allow_headers=["*"],
|
|
||||||
)
|
|
||||||
|
|
||||||
if not system_app:
|
if not system_app:
|
||||||
system_app = SystemApp(app)
|
system_app = SystemApp(app)
|
||||||
@ -449,6 +501,9 @@ def initialize_apiserver(
|
|||||||
if api_keys:
|
if api_keys:
|
||||||
api_settings.api_keys = api_keys
|
api_settings.api_keys = api_keys
|
||||||
|
|
||||||
|
if embedding_batch_size:
|
||||||
|
api_settings.embedding_bach_size = embedding_batch_size
|
||||||
|
|
||||||
app.include_router(router, prefix="/api", tags=["APIServer"])
|
app.include_router(router, prefix="/api", tags=["APIServer"])
|
||||||
|
|
||||||
@app.exception_handler(APIServerException)
|
@app.exception_handler(APIServerException)
|
||||||
@ -464,7 +519,15 @@ def initialize_apiserver(
|
|||||||
if not embedded_mod:
|
if not embedded_mod:
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
uvicorn.run(app, host=host, port=port, log_level="info")
|
# https://github.com/encode/starlette/issues/617
|
||||||
|
cors_app = CORSMiddleware(
|
||||||
|
app=app,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
uvicorn.run(cors_app, host=host, port=port, log_level="info")
|
||||||
|
|
||||||
|
|
||||||
def run_apiserver():
|
def run_apiserver():
|
||||||
@ -488,6 +551,7 @@ def run_apiserver():
|
|||||||
host=apiserver_params.host,
|
host=apiserver_params.host,
|
||||||
port=apiserver_params.port,
|
port=apiserver_params.port,
|
||||||
api_keys=api_keys,
|
api_keys=api_keys,
|
||||||
|
embedding_batch_size=apiserver_params.embedding_batch_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -113,6 +113,9 @@ class ModelAPIServerParameters(BaseParameters):
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Optional list of comma separated API keys"},
|
metadata={"help": "Optional list of comma separated API keys"},
|
||||||
)
|
)
|
||||||
|
embedding_batch_size: Optional[int] = field(
|
||||||
|
default=None, metadata={"help": "Embedding batch size"}
|
||||||
|
)
|
||||||
|
|
||||||
log_level: Optional[str] = field(
|
log_level: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
|
@ -0,0 +1,16 @@
|
|||||||
|
from .embedding_factory import DefaultEmbeddingFactory, EmbeddingFactory
|
||||||
|
from .embeddings import (
|
||||||
|
Embeddings,
|
||||||
|
HuggingFaceEmbeddings,
|
||||||
|
JinaEmbeddings,
|
||||||
|
OpenAPIEmbeddings,
|
||||||
|
)
|
||||||
|
|
||||||
|
__ALL__ = [
|
||||||
|
"OpenAPIEmbeddings",
|
||||||
|
"Embeddings",
|
||||||
|
"HuggingFaceEmbeddings",
|
||||||
|
"JinaEmbeddings",
|
||||||
|
"EmbeddingFactory",
|
||||||
|
"DefaultEmbeddingFactory",
|
||||||
|
]
|
@ -2,8 +2,10 @@ import asyncio
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
import requests
|
import requests
|
||||||
from pydantic import BaseModel, Extra, Field
|
|
||||||
|
from dbgpt._private.pydantic import BaseModel, Extra, Field
|
||||||
|
|
||||||
DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
|
DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
|
||||||
DEFAULT_INSTRUCT_MODEL = "hkunlp/instructor-large"
|
DEFAULT_INSTRUCT_MODEL = "hkunlp/instructor-large"
|
||||||
@ -363,6 +365,29 @@ class HuggingFaceInferenceAPIEmbeddings(BaseModel, Embeddings):
|
|||||||
return self.embed_documents([text])[0]
|
return self.embed_documents([text])[0]
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_request_result(res: requests.Response) -> List[List[float]]:
|
||||||
|
"""Parse the result from a request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
res: The response from the request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[List[float]]: The embeddings.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If the response is not successful.
|
||||||
|
"""
|
||||||
|
res.raise_for_status()
|
||||||
|
resp = res.json()
|
||||||
|
if "data" not in resp:
|
||||||
|
raise RuntimeError(resp["detail"])
|
||||||
|
embeddings = resp["data"]
|
||||||
|
# Sort resulting embeddings by index
|
||||||
|
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore
|
||||||
|
# Return just the embeddings
|
||||||
|
return [result["embedding"] for result in sorted_embeddings]
|
||||||
|
|
||||||
|
|
||||||
class JinaEmbeddings(BaseModel, Embeddings):
|
class JinaEmbeddings(BaseModel, Embeddings):
|
||||||
"""
|
"""
|
||||||
This class is used to get embeddings for a list of texts using the Jina AI API.
|
This class is used to get embeddings for a list of texts using the Jina AI API.
|
||||||
@ -406,17 +431,8 @@ class JinaEmbeddings(BaseModel, Embeddings):
|
|||||||
# Call Jina AI Embedding API
|
# Call Jina AI Embedding API
|
||||||
resp = self.session.post( # type: ignore
|
resp = self.session.post( # type: ignore
|
||||||
self.api_url, json={"input": texts, "model": self.model_name}
|
self.api_url, json={"input": texts, "model": self.model_name}
|
||||||
).json()
|
)
|
||||||
if "data" not in resp:
|
return _handle_request_result(res)
|
||||||
raise RuntimeError(resp["detail"])
|
|
||||||
|
|
||||||
embeddings = resp["data"]
|
|
||||||
|
|
||||||
# Sort resulting embeddings by index
|
|
||||||
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore
|
|
||||||
|
|
||||||
# Return just the embeddings
|
|
||||||
return [result["embedding"] for result in sorted_embeddings]
|
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
def embed_query(self, text: str) -> List[float]:
|
||||||
"""Compute query embeddings using a HuggingFace transformer model.
|
"""Compute query embeddings using a HuggingFace transformer model.
|
||||||
@ -428,3 +444,158 @@ class JinaEmbeddings(BaseModel, Embeddings):
|
|||||||
Embeddings for the text.
|
Embeddings for the text.
|
||||||
"""
|
"""
|
||||||
return self.embed_documents([text])[0]
|
return self.embed_documents([text])[0]
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAPIEmbeddings(BaseModel, Embeddings):
|
||||||
|
"""This class is used to get embeddings for a list of texts using the API.
|
||||||
|
|
||||||
|
This API is compatible with the OpenAI Embedding API.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
Using OpenAI's API:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from dbgpt.rag.embedding import OpenAPIEmbeddings
|
||||||
|
|
||||||
|
openai_embeddings = OpenAPIEmbeddings(
|
||||||
|
api_url="https://api.openai.com/v1/embeddings",
|
||||||
|
api_key="your_api_key",
|
||||||
|
model_name="text-embedding-3-small",
|
||||||
|
)
|
||||||
|
texts = ["Hello, world!", "How are you?"]
|
||||||
|
openai_embeddings.embed_documents(texts)
|
||||||
|
|
||||||
|
Using DB-GPT APIServer's embedding API:
|
||||||
|
To use the DB-GPT APIServer's embedding API, you should deploy DB-GPT according
|
||||||
|
to the `Cluster Deploy
|
||||||
|
<https://docs.dbgpt.site/docs/installation/model_service/cluster>`_.
|
||||||
|
|
||||||
|
A simple example:
|
||||||
|
1. Deploy Model Cluster with following command:
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
dbgpt start controller --port 8000
|
||||||
|
|
||||||
|
2. Deploy Embedding Model Worker with following command:
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
dbgpt start worker --model_name text2vec \
|
||||||
|
--model_path /app/models/text2vec-large-chinese \
|
||||||
|
--worker_type text2vec \
|
||||||
|
--port 8003 \
|
||||||
|
--controller_addr http://127.0.0.1:8000
|
||||||
|
|
||||||
|
3. Deploy API Server with following command:
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
dbgpt start apiserver --controller_addr http://127.0.0.1:8000 \
|
||||||
|
--api_keys my_api_token --port 8100
|
||||||
|
|
||||||
|
Now, you can use the API server to get embeddings:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from dbgpt.rag.embedding import OpenAPIEmbeddings
|
||||||
|
|
||||||
|
openai_embeddings = OpenAPIEmbeddings(
|
||||||
|
api_url="http://localhost:8100/api/v1/embeddings",
|
||||||
|
api_key="my_api_token",
|
||||||
|
model_name="text2vec",
|
||||||
|
)
|
||||||
|
texts = ["Hello, world!", "How are you?"]
|
||||||
|
openai_embeddings.embed_documents(texts)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
api_url: str = Field(
|
||||||
|
default="http://localhost:8100/api/v1/embeddings",
|
||||||
|
description="The URL of the embeddings API.",
|
||||||
|
)
|
||||||
|
api_key: Optional[str] = Field(
|
||||||
|
default=None, description="The API key for the embeddings API."
|
||||||
|
)
|
||||||
|
model_name: str = Field(
|
||||||
|
default="text2vec", description="The name of the model to use."
|
||||||
|
)
|
||||||
|
timeout: int = Field(
|
||||||
|
default=60, description="The timeout for the request in seconds."
|
||||||
|
)
|
||||||
|
|
||||||
|
session: requests.Session = None
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
"""Initialize the OpenAPIEmbeddings."""
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
try:
|
||||||
|
import requests
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"The requests python package is not installed. "
|
||||||
|
"Please install it with `pip install requests`"
|
||||||
|
)
|
||||||
|
self.session = requests.Session()
|
||||||
|
self.session.headers.update({"Authorization": f"Bearer {self.api_key}"})
|
||||||
|
|
||||||
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
"""Get the embeddings for a list of texts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts (Documents): A list of texts to get embeddings for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Embedded texts as List[List[float]], where each inner List[float]
|
||||||
|
corresponds to a single input text.
|
||||||
|
"""
|
||||||
|
# Call OpenAI Embedding API
|
||||||
|
res = self.session.post( # type: ignore
|
||||||
|
self.api_url,
|
||||||
|
json={"input": texts, "model": self.model_name},
|
||||||
|
timeout=self.timeout,
|
||||||
|
)
|
||||||
|
return _handle_request_result(res)
|
||||||
|
|
||||||
|
def embed_query(self, text: str) -> List[float]:
|
||||||
|
"""Compute query embeddings using a OpenAPI embedding model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to embed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Embeddings for the text.
|
||||||
|
"""
|
||||||
|
return self.embed_documents([text])[0]
|
||||||
|
|
||||||
|
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
"""Asynchronous Embed search docs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: A list of texts to get embeddings for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[List[float]]: Embedded texts as List[List[float]], where each inner
|
||||||
|
List[float] corresponds to a single input text.
|
||||||
|
"""
|
||||||
|
headers = {"Authorization": f"Bearer {self.api_key}"}
|
||||||
|
async with aiohttp.ClientSession(
|
||||||
|
headers=headers, timeout=aiohttp.ClientTimeout(total=self.timeout)
|
||||||
|
) as session:
|
||||||
|
async with session.post(
|
||||||
|
self.api_url, json={"input": texts, "model": self.model_name}
|
||||||
|
) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = await resp.json()
|
||||||
|
if "data" not in data:
|
||||||
|
raise RuntimeError(data["detail"])
|
||||||
|
embeddings = data["data"]
|
||||||
|
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"])
|
||||||
|
return [result["embedding"] for result in sorted_embeddings]
|
||||||
|
|
||||||
|
async def aembed_query(self, text: str) -> List[float]:
|
||||||
|
"""Asynchronous Embed query text."""
|
||||||
|
embeddings = await self.aembed_documents([text])
|
||||||
|
return embeddings[0]
|
||||||
|
@ -2,6 +2,7 @@ import logging
|
|||||||
import math
|
import math
|
||||||
import time
|
import time
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from typing import Any, Callable, List, Optional
|
from typing import Any, Callable, List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
@ -24,11 +25,13 @@ class VectorStoreConfig(BaseModel):
|
|||||||
)
|
)
|
||||||
password: Optional[str] = Field(
|
password: Optional[str] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="The password of vector store, if not set, will use the default password.",
|
description="The password of vector store, if not set, will use the default "
|
||||||
|
"password.",
|
||||||
)
|
)
|
||||||
embedding_fn: Optional[Any] = Field(
|
embedding_fn: Optional[Any] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="The embedding function of vector store, if not set, will use the default embedding function.",
|
description="The embedding function of vector store, if not set, will use the "
|
||||||
|
"default embedding function.",
|
||||||
)
|
)
|
||||||
max_chunks_once_load: int = Field(
|
max_chunks_once_load: int = Field(
|
||||||
default=10,
|
default=10,
|
||||||
@ -36,6 +39,11 @@ class VectorStoreConfig(BaseModel):
|
|||||||
"large, you can set this value to a larger number to speed up the loading "
|
"large, you can set this value to a larger number to speed up the loading "
|
||||||
"process. Default is 10.",
|
"process. Default is 10.",
|
||||||
)
|
)
|
||||||
|
max_threads: int = Field(
|
||||||
|
default=1,
|
||||||
|
description="The max number of threads to use. Default is 1. If you set this "
|
||||||
|
"bigger than 1, please make sure your vector store is thread-safe.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class VectorStoreBase(ABC):
|
class VectorStoreBase(ABC):
|
||||||
@ -52,12 +60,13 @@ class VectorStoreBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def load_document_with_limit(
|
def load_document_with_limit(
|
||||||
self, chunks: List[Chunk], max_chunks_once_load: int = 10
|
self, chunks: List[Chunk], max_chunks_once_load: int = 10, max_threads: int = 1
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""load document in vector database with limit.
|
"""load document in vector database with limit.
|
||||||
Args:
|
Args:
|
||||||
chunks: document chunks.
|
chunks: document chunks.
|
||||||
max_chunks_once_load: Max number of chunks to load at once.
|
max_chunks_once_load: Max number of chunks to load at once.
|
||||||
|
max_threads: Max number of threads to use.
|
||||||
Return:
|
Return:
|
||||||
"""
|
"""
|
||||||
# Group the chunks into chunks of size max_chunks
|
# Group the chunks into chunks of size max_chunks
|
||||||
@ -65,13 +74,21 @@ class VectorStoreBase(ABC):
|
|||||||
chunks[i : i + max_chunks_once_load]
|
chunks[i : i + max_chunks_once_load]
|
||||||
for i in range(0, len(chunks), max_chunks_once_load)
|
for i in range(0, len(chunks), max_chunks_once_load)
|
||||||
]
|
]
|
||||||
logger.info(f"Loading {len(chunks)} chunks in {len(chunk_groups)} groups")
|
logger.info(
|
||||||
|
f"Loading {len(chunks)} chunks in {len(chunk_groups)} groups with "
|
||||||
|
f"{max_threads} threads."
|
||||||
|
)
|
||||||
ids = []
|
ids = []
|
||||||
loaded_cnt = 0
|
loaded_cnt = 0
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
with ThreadPoolExecutor(max_workers=max_threads) as executor:
|
||||||
|
tasks = []
|
||||||
for chunk_group in chunk_groups:
|
for chunk_group in chunk_groups:
|
||||||
ids.extend(self.load_document(chunk_group))
|
tasks.append(executor.submit(self.load_document, chunk_group))
|
||||||
loaded_cnt += len(chunk_group)
|
for future in tasks:
|
||||||
|
success_ids = future.result()
|
||||||
|
ids.extend(success_ids)
|
||||||
|
loaded_cnt += len(success_ids)
|
||||||
logger.info(f"Loaded {loaded_cnt} chunks, total {len(chunks)} chunks.")
|
logger.info(f"Loaded {loaded_cnt} chunks, total {len(chunks)} chunks.")
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Loaded {len(chunks)} chunks in {time.time() - start_time} seconds"
|
f"Loaded {len(chunks)} chunks in {time.time() - start_time} seconds"
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Any, Callable, List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
from dbgpt.rag.chunk import Chunk
|
from dbgpt.rag.chunk import Chunk
|
||||||
from dbgpt.storage import vector_store
|
from dbgpt.storage import vector_store
|
||||||
@ -65,7 +65,9 @@ class VectorStoreConnector:
|
|||||||
Return chunk ids.
|
Return chunk ids.
|
||||||
"""
|
"""
|
||||||
return self.client.load_document_with_limit(
|
return self.client.load_document_with_limit(
|
||||||
chunks, self._vector_store_config.max_chunks_once_load
|
chunks,
|
||||||
|
self._vector_store_config.max_chunks_once_load,
|
||||||
|
self._vector_store_config.max_threads,
|
||||||
)
|
)
|
||||||
|
|
||||||
def similar_search(self, doc: str, topk: int) -> List[Chunk]:
|
def similar_search(self, doc: str, topk: int) -> List[Chunk]:
|
||||||
|
@ -10,7 +10,7 @@ The call of multi-model services is compatible with the OpenAI interface, and th
|
|||||||
## Start apiserver
|
## Start apiserver
|
||||||
|
|
||||||
After deploying the model service, you need to start the API Server. By default, the model API Server uses port `8100` to start.
|
After deploying the model service, you need to start the API Server. By default, the model API Server uses port `8100` to start.
|
||||||
```python
|
```bash
|
||||||
dbgpt start apiserver --controller_addr http://127.0.0.1:8000 --api_keys EMPTY
|
dbgpt start apiserver --controller_addr http://127.0.0.1:8000 --api_keys EMPTY
|
||||||
|
|
||||||
```
|
```
|
||||||
@ -25,7 +25,7 @@ After the apiserver is started, the service call can be verified. First, let's l
|
|||||||
:::tip
|
:::tip
|
||||||
List models
|
List models
|
||||||
:::
|
:::
|
||||||
```python
|
```bash
|
||||||
curl http://127.0.0.1:8100/api/v1/models \
|
curl http://127.0.0.1:8100/api/v1/models \
|
||||||
-H "Authorization: Bearer EMPTY" \
|
-H "Authorization: Bearer EMPTY" \
|
||||||
-H "Content-Type: application/json"
|
-H "Content-Type: application/json"
|
||||||
@ -34,17 +34,31 @@ curl http://127.0.0.1:8100/api/v1/models \
|
|||||||
:::tip
|
:::tip
|
||||||
Chat
|
Chat
|
||||||
:::
|
:::
|
||||||
```python
|
```bash
|
||||||
curl http://127.0.0.1:8100/api/v1/chat/completions \
|
curl http://127.0.0.1:8100/api/v1/chat/completions \
|
||||||
-H "Authorization: Bearer EMPTY" \
|
-H "Authorization: Bearer EMPTY" \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
-d '{"model": "vicuna-13b-v1.5", "messages": [{"role": "user", "content": "hello"}]}'
|
-d '{"model": "vicuna-13b-v1.5", "messages": [{"role": "user", "content": "hello"}]}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
:::tip
|
||||||
|
Embedding
|
||||||
|
:::
|
||||||
|
```bash
|
||||||
|
curl http://127.0.0.1:8100/api/v1/embeddings \
|
||||||
|
-H "Authorization: Bearer EMPTY" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"model": "text2vec",
|
||||||
|
"input": "Hello world!"
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Verify via OpenAI SDK
|
## Verify via OpenAI SDK
|
||||||
|
|
||||||
```python
|
```bash
|
||||||
import openai
|
import openai
|
||||||
openai.api_key = "EMPTY"
|
openai.api_key = "EMPTY"
|
||||||
openai.api_base = "http://127.0.0.1:8100/api/v1"
|
openai.api_base = "http://127.0.0.1:8100/api/v1"
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
|
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH, ROOT_PATH
|
||||||
from dbgpt.rag.chunk_manager import ChunkParameters
|
from dbgpt.rag.chunk_manager import ChunkParameters
|
||||||
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
|
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
|
||||||
from dbgpt.rag.knowledge.factory import KnowledgeFactory
|
from dbgpt.rag.knowledge.factory import KnowledgeFactory
|
||||||
@ -37,7 +37,7 @@ def _create_vector_connector():
|
|||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
file_path = "docs/docs/awel.md"
|
file_path = os.path.join(ROOT_PATH, "docs/docs/awel/awel.md")
|
||||||
knowledge = KnowledgeFactory.from_file_path(file_path)
|
knowledge = KnowledgeFactory.from_file_path(file_path)
|
||||||
vector_connector = _create_vector_connector()
|
vector_connector = _create_vector_connector()
|
||||||
chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE")
|
chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE")
|
||||||
|
87
examples/rag/rag_embedding_api_example.py
Normal file
87
examples/rag/rag_embedding_api_example.py
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
"""A RAG example using the OpenAPIEmbeddings.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
Test with `OpenAI embeddings
|
||||||
|
<https://platform.openai.com/docs/api-reference/embeddings/create>`_.
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
export API_SERVER_BASE_URL=${OPENAI_API_BASE:-"https://api.openai.com/v1"}
|
||||||
|
export API_SERVER_API_KEY="${OPENAI_API_KEY}"
|
||||||
|
export API_SERVER_EMBEDDINGS_MODEL="text-embedding-ada-002"
|
||||||
|
python examples/rag/rag_embedding_api_example.py
|
||||||
|
|
||||||
|
Test with DB-GPT `API Server
|
||||||
|
<https://docs.dbgpt.site/docs/installation/advanced_usage/OpenAI_SDK_call#start-apiserver>`_.
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
export API_SERVER_BASE_URL="http://localhost:8100/api/v1"
|
||||||
|
export API_SERVER_API_KEY="your_api_key"
|
||||||
|
export API_SERVER_EMBEDDINGS_MODEL="text2vec"
|
||||||
|
python examples/rag/rag_embedding_api_example.py
|
||||||
|
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from dbgpt.configs.model_config import PILOT_PATH, ROOT_PATH
|
||||||
|
from dbgpt.rag.chunk_manager import ChunkParameters
|
||||||
|
from dbgpt.rag.embedding import OpenAPIEmbeddings
|
||||||
|
from dbgpt.rag.knowledge.factory import KnowledgeFactory
|
||||||
|
from dbgpt.serve.rag.assembler.embedding import EmbeddingAssembler
|
||||||
|
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||||
|
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||||
|
|
||||||
|
|
||||||
|
def _create_embeddings(
|
||||||
|
api_url: str = None, api_key: Optional[str] = None, model_name: Optional[str] = None
|
||||||
|
) -> OpenAPIEmbeddings:
|
||||||
|
if not api_url:
|
||||||
|
api_server_base_url = os.getenv(
|
||||||
|
"API_SERVER_BASE_URL", "http://localhost:8100/api/v1/"
|
||||||
|
)
|
||||||
|
api_url = f"{api_server_base_url}/embeddings"
|
||||||
|
if not api_key:
|
||||||
|
api_key = os.getenv("API_SERVER_API_KEY")
|
||||||
|
|
||||||
|
if not model_name:
|
||||||
|
model_name = os.getenv("API_SERVER_EMBEDDINGS_MODEL", "text2vec")
|
||||||
|
|
||||||
|
return OpenAPIEmbeddings(api_url=api_url, api_key=api_key, model_name=model_name)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_vector_connector():
|
||||||
|
"""Create vector connector."""
|
||||||
|
|
||||||
|
return VectorStoreConnector.from_default(
|
||||||
|
"Chroma",
|
||||||
|
vector_store_config=ChromaVectorConfig(
|
||||||
|
name="example_embedding_api_vector_store_name",
|
||||||
|
persist_path=os.path.join(PILOT_PATH, "data"),
|
||||||
|
),
|
||||||
|
embedding_fn=_create_embeddings(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
file_path = os.path.join(ROOT_PATH, "docs/docs/awel/awel.md")
|
||||||
|
knowledge = KnowledgeFactory.from_file_path(file_path)
|
||||||
|
vector_connector = _create_vector_connector()
|
||||||
|
chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE")
|
||||||
|
# get embedding assembler
|
||||||
|
assembler = EmbeddingAssembler.load_from_knowledge(
|
||||||
|
knowledge=knowledge,
|
||||||
|
chunk_parameters=chunk_parameters,
|
||||||
|
vector_store_connector=vector_connector,
|
||||||
|
)
|
||||||
|
assembler.persist()
|
||||||
|
# get embeddings retriever
|
||||||
|
retriever = assembler.as_retriever(3)
|
||||||
|
chunks = await retriever.aretrieve_with_scores("what is awel talk about", 0.3)
|
||||||
|
print(f"embedding rag example results:{chunks}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
Loading…
Reference in New Issue
Block a user