mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-23 04:12:13 +00:00
feat: APIServer supports embeddings (#1256)
This commit is contained in:
parent
5f3ee35804
commit
74ec8e52cd
@ -23,6 +23,8 @@ from fastchat.protocol.openai_api_protocol import (
|
||||
ChatCompletionStreamResponse,
|
||||
ChatMessage,
|
||||
DeltaMessage,
|
||||
EmbeddingsRequest,
|
||||
EmbeddingsResponse,
|
||||
ModelCard,
|
||||
ModelList,
|
||||
ModelPermission,
|
||||
@ -51,6 +53,7 @@ class APIServerException(Exception):
|
||||
|
||||
class APISettings(BaseModel):
|
||||
api_keys: Optional[List[str]] = None
|
||||
embedding_bach_size: int = 4
|
||||
|
||||
|
||||
api_settings = APISettings()
|
||||
@ -181,27 +184,29 @@ class APIServer(BaseComponent):
|
||||
return controller
|
||||
|
||||
async def get_model_instances_or_raise(
|
||||
self, model_name: str
|
||||
self, model_name: str, worker_type: str = "llm"
|
||||
) -> List[ModelInstance]:
|
||||
"""Get healthy model instances with request model name
|
||||
|
||||
Args:
|
||||
model_name (str): Model name
|
||||
worker_type (str, optional): Worker type. Defaults to "llm".
|
||||
|
||||
Raises:
|
||||
APIServerException: If can't get healthy model instances with request model name
|
||||
"""
|
||||
registry = self.get_model_registry()
|
||||
registry_model_name = f"{model_name}@llm"
|
||||
suffix = f"@{worker_type}"
|
||||
registry_model_name = f"{model_name}{suffix}"
|
||||
model_instances = await registry.get_all_instances(
|
||||
registry_model_name, healthy_only=True
|
||||
)
|
||||
if not model_instances:
|
||||
all_instances = await registry.get_all_model_instances(healthy_only=True)
|
||||
models = [
|
||||
ins.model_name.split("@llm")[0]
|
||||
ins.model_name.split(suffix)[0]
|
||||
for ins in all_instances
|
||||
if ins.model_name.endswith("@llm")
|
||||
if ins.model_name.endswith(suffix)
|
||||
]
|
||||
if models:
|
||||
models = "&&".join(models)
|
||||
@ -336,6 +341,25 @@ class APIServer(BaseComponent):
|
||||
|
||||
return ChatCompletionResponse(model=model_name, choices=choices, usage=usage)
|
||||
|
||||
async def embeddings_generate(
|
||||
self, model: str, texts: List[str]
|
||||
) -> List[List[float]]:
|
||||
"""Generate embeddings
|
||||
|
||||
Args:
|
||||
model (str): Model name
|
||||
texts (List[str]): Texts to embed
|
||||
|
||||
Returns:
|
||||
List[List[float]]: The embeddings of texts
|
||||
"""
|
||||
worker_manager: WorkerManager = self.get_worker_manager()
|
||||
params = {
|
||||
"input": texts,
|
||||
"model": model,
|
||||
}
|
||||
return await worker_manager.embeddings(params)
|
||||
|
||||
|
||||
def get_api_server() -> APIServer:
|
||||
api_server = global_system_app.get_component(
|
||||
@ -389,6 +413,40 @@ async def create_chat_completion(
|
||||
return await api_server.chat_completion_generate(request.model, params, request.n)
|
||||
|
||||
|
||||
@router.post("/v1/embeddings", dependencies=[Depends(check_api_key)])
|
||||
async def create_embeddings(
|
||||
request: EmbeddingsRequest, api_server: APIServer = Depends(get_api_server)
|
||||
):
|
||||
await api_server.get_model_instances_or_raise(request.model, worker_type="text2vec")
|
||||
texts = request.input
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
batch_size = api_settings.embedding_bach_size
|
||||
batches = [
|
||||
texts[i : min(i + batch_size, len(texts))]
|
||||
for i in range(0, len(texts), batch_size)
|
||||
]
|
||||
data = []
|
||||
async_tasks = []
|
||||
for num_batch, batch in enumerate(batches):
|
||||
async_tasks.append(api_server.embeddings_generate(request.model, batch))
|
||||
|
||||
# Request all embeddings in parallel
|
||||
batch_embeddings: List[List[List[float]]] = await asyncio.gather(*async_tasks)
|
||||
for num_batch, embeddings in enumerate(batch_embeddings):
|
||||
data += [
|
||||
{
|
||||
"object": "embedding",
|
||||
"embedding": emb,
|
||||
"index": num_batch * batch_size + i,
|
||||
}
|
||||
for i, emb in enumerate(embeddings)
|
||||
]
|
||||
return EmbeddingsResponse(data=data, model=request.model, usage=UsageInfo()).dict(
|
||||
exclude_none=True
|
||||
)
|
||||
|
||||
|
||||
def _initialize_all(controller_addr: str, system_app: SystemApp):
|
||||
from dbgpt.model.cluster.controller.controller import ModelRegistryClient
|
||||
from dbgpt.model.cluster.worker.manager import _DefaultWorkerManagerFactory
|
||||
@ -427,6 +485,7 @@ def initialize_apiserver(
|
||||
host: str = None,
|
||||
port: int = None,
|
||||
api_keys: List[str] = None,
|
||||
embedding_batch_size: Optional[int] = None,
|
||||
):
|
||||
global global_system_app
|
||||
global api_settings
|
||||
@ -434,13 +493,6 @@ def initialize_apiserver(
|
||||
if not app:
|
||||
embedded_mod = False
|
||||
app = FastAPI()
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
if not system_app:
|
||||
system_app = SystemApp(app)
|
||||
@ -449,6 +501,9 @@ def initialize_apiserver(
|
||||
if api_keys:
|
||||
api_settings.api_keys = api_keys
|
||||
|
||||
if embedding_batch_size:
|
||||
api_settings.embedding_bach_size = embedding_batch_size
|
||||
|
||||
app.include_router(router, prefix="/api", tags=["APIServer"])
|
||||
|
||||
@app.exception_handler(APIServerException)
|
||||
@ -464,7 +519,15 @@ def initialize_apiserver(
|
||||
if not embedded_mod:
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host=host, port=port, log_level="info")
|
||||
# https://github.com/encode/starlette/issues/617
|
||||
cors_app = CORSMiddleware(
|
||||
app=app,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
uvicorn.run(cors_app, host=host, port=port, log_level="info")
|
||||
|
||||
|
||||
def run_apiserver():
|
||||
@ -488,6 +551,7 @@ def run_apiserver():
|
||||
host=apiserver_params.host,
|
||||
port=apiserver_params.port,
|
||||
api_keys=api_keys,
|
||||
embedding_batch_size=apiserver_params.embedding_batch_size,
|
||||
)
|
||||
|
||||
|
||||
|
@ -113,6 +113,9 @@ class ModelAPIServerParameters(BaseParameters):
|
||||
default=None,
|
||||
metadata={"help": "Optional list of comma separated API keys"},
|
||||
)
|
||||
embedding_batch_size: Optional[int] = field(
|
||||
default=None, metadata={"help": "Embedding batch size"}
|
||||
)
|
||||
|
||||
log_level: Optional[str] = field(
|
||||
default=None,
|
||||
|
@ -0,0 +1,16 @@
|
||||
from .embedding_factory import DefaultEmbeddingFactory, EmbeddingFactory
|
||||
from .embeddings import (
|
||||
Embeddings,
|
||||
HuggingFaceEmbeddings,
|
||||
JinaEmbeddings,
|
||||
OpenAPIEmbeddings,
|
||||
)
|
||||
|
||||
__ALL__ = [
|
||||
"OpenAPIEmbeddings",
|
||||
"Embeddings",
|
||||
"HuggingFaceEmbeddings",
|
||||
"JinaEmbeddings",
|
||||
"EmbeddingFactory",
|
||||
"DefaultEmbeddingFactory",
|
||||
]
|
@ -2,8 +2,10 @@ import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
from pydantic import BaseModel, Extra, Field
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Extra, Field
|
||||
|
||||
DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
|
||||
DEFAULT_INSTRUCT_MODEL = "hkunlp/instructor-large"
|
||||
@ -363,6 +365,29 @@ class HuggingFaceInferenceAPIEmbeddings(BaseModel, Embeddings):
|
||||
return self.embed_documents([text])[0]
|
||||
|
||||
|
||||
def _handle_request_result(res: requests.Response) -> List[List[float]]:
|
||||
"""Parse the result from a request.
|
||||
|
||||
Args:
|
||||
res: The response from the request.
|
||||
|
||||
Returns:
|
||||
List[List[float]]: The embeddings.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the response is not successful.
|
||||
"""
|
||||
res.raise_for_status()
|
||||
resp = res.json()
|
||||
if "data" not in resp:
|
||||
raise RuntimeError(resp["detail"])
|
||||
embeddings = resp["data"]
|
||||
# Sort resulting embeddings by index
|
||||
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore
|
||||
# Return just the embeddings
|
||||
return [result["embedding"] for result in sorted_embeddings]
|
||||
|
||||
|
||||
class JinaEmbeddings(BaseModel, Embeddings):
|
||||
"""
|
||||
This class is used to get embeddings for a list of texts using the Jina AI API.
|
||||
@ -406,17 +431,8 @@ class JinaEmbeddings(BaseModel, Embeddings):
|
||||
# Call Jina AI Embedding API
|
||||
resp = self.session.post( # type: ignore
|
||||
self.api_url, json={"input": texts, "model": self.model_name}
|
||||
).json()
|
||||
if "data" not in resp:
|
||||
raise RuntimeError(resp["detail"])
|
||||
|
||||
embeddings = resp["data"]
|
||||
|
||||
# Sort resulting embeddings by index
|
||||
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore
|
||||
|
||||
# Return just the embeddings
|
||||
return [result["embedding"] for result in sorted_embeddings]
|
||||
)
|
||||
return _handle_request_result(res)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Compute query embeddings using a HuggingFace transformer model.
|
||||
@ -428,3 +444,158 @@ class JinaEmbeddings(BaseModel, Embeddings):
|
||||
Embeddings for the text.
|
||||
"""
|
||||
return self.embed_documents([text])[0]
|
||||
|
||||
|
||||
class OpenAPIEmbeddings(BaseModel, Embeddings):
|
||||
"""This class is used to get embeddings for a list of texts using the API.
|
||||
|
||||
This API is compatible with the OpenAI Embedding API.
|
||||
|
||||
Examples:
|
||||
|
||||
Using OpenAI's API:
|
||||
.. code-block:: python
|
||||
|
||||
from dbgpt.rag.embedding import OpenAPIEmbeddings
|
||||
|
||||
openai_embeddings = OpenAPIEmbeddings(
|
||||
api_url="https://api.openai.com/v1/embeddings",
|
||||
api_key="your_api_key",
|
||||
model_name="text-embedding-3-small",
|
||||
)
|
||||
texts = ["Hello, world!", "How are you?"]
|
||||
openai_embeddings.embed_documents(texts)
|
||||
|
||||
Using DB-GPT APIServer's embedding API:
|
||||
To use the DB-GPT APIServer's embedding API, you should deploy DB-GPT according
|
||||
to the `Cluster Deploy
|
||||
<https://docs.dbgpt.site/docs/installation/model_service/cluster>`_.
|
||||
|
||||
A simple example:
|
||||
1. Deploy Model Cluster with following command:
|
||||
.. code-block:: bash
|
||||
|
||||
dbgpt start controller --port 8000
|
||||
|
||||
2. Deploy Embedding Model Worker with following command:
|
||||
.. code-block:: bash
|
||||
|
||||
dbgpt start worker --model_name text2vec \
|
||||
--model_path /app/models/text2vec-large-chinese \
|
||||
--worker_type text2vec \
|
||||
--port 8003 \
|
||||
--controller_addr http://127.0.0.1:8000
|
||||
|
||||
3. Deploy API Server with following command:
|
||||
.. code-block:: bash
|
||||
|
||||
dbgpt start apiserver --controller_addr http://127.0.0.1:8000 \
|
||||
--api_keys my_api_token --port 8100
|
||||
|
||||
Now, you can use the API server to get embeddings:
|
||||
.. code-block:: python
|
||||
|
||||
from dbgpt.rag.embedding import OpenAPIEmbeddings
|
||||
|
||||
openai_embeddings = OpenAPIEmbeddings(
|
||||
api_url="http://localhost:8100/api/v1/embeddings",
|
||||
api_key="my_api_token",
|
||||
model_name="text2vec",
|
||||
)
|
||||
texts = ["Hello, world!", "How are you?"]
|
||||
openai_embeddings.embed_documents(texts)
|
||||
|
||||
"""
|
||||
|
||||
api_url: str = Field(
|
||||
default="http://localhost:8100/api/v1/embeddings",
|
||||
description="The URL of the embeddings API.",
|
||||
)
|
||||
api_key: Optional[str] = Field(
|
||||
default=None, description="The API key for the embeddings API."
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="text2vec", description="The name of the model to use."
|
||||
)
|
||||
timeout: int = Field(
|
||||
default=60, description="The timeout for the request in seconds."
|
||||
)
|
||||
|
||||
session: requests.Session = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialize the OpenAPIEmbeddings."""
|
||||
super().__init__(**kwargs)
|
||||
try:
|
||||
import requests
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"The requests python package is not installed. "
|
||||
"Please install it with `pip install requests`"
|
||||
)
|
||||
self.session = requests.Session()
|
||||
self.session.headers.update({"Authorization": f"Bearer {self.api_key}"})
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Get the embeddings for a list of texts.
|
||||
|
||||
Args:
|
||||
texts (Documents): A list of texts to get embeddings for.
|
||||
|
||||
Returns:
|
||||
Embedded texts as List[List[float]], where each inner List[float]
|
||||
corresponds to a single input text.
|
||||
"""
|
||||
# Call OpenAI Embedding API
|
||||
res = self.session.post( # type: ignore
|
||||
self.api_url,
|
||||
json={"input": texts, "model": self.model_name},
|
||||
timeout=self.timeout,
|
||||
)
|
||||
return _handle_request_result(res)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Compute query embeddings using a OpenAPI embedding model.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
return self.embed_documents([text])[0]
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Asynchronous Embed search docs.
|
||||
|
||||
Args:
|
||||
texts: A list of texts to get embeddings for.
|
||||
|
||||
Returns:
|
||||
List[List[float]]: Embedded texts as List[List[float]], where each inner
|
||||
List[float] corresponds to a single input text.
|
||||
"""
|
||||
headers = {"Authorization": f"Bearer {self.api_key}"}
|
||||
async with aiohttp.ClientSession(
|
||||
headers=headers, timeout=aiohttp.ClientTimeout(total=self.timeout)
|
||||
) as session:
|
||||
async with session.post(
|
||||
self.api_url, json={"input": texts, "model": self.model_name}
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
data = await resp.json()
|
||||
if "data" not in data:
|
||||
raise RuntimeError(data["detail"])
|
||||
embeddings = data["data"]
|
||||
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"])
|
||||
return [result["embedding"] for result in sorted_embeddings]
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
"""Asynchronous Embed query text."""
|
||||
embeddings = await self.aembed_documents([text])
|
||||
return embeddings[0]
|
||||
|
@ -2,6 +2,7 @@ import logging
|
||||
import math
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, Callable, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
@ -24,11 +25,13 @@ class VectorStoreConfig(BaseModel):
|
||||
)
|
||||
password: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The password of vector store, if not set, will use the default password.",
|
||||
description="The password of vector store, if not set, will use the default "
|
||||
"password.",
|
||||
)
|
||||
embedding_fn: Optional[Any] = Field(
|
||||
default=None,
|
||||
description="The embedding function of vector store, if not set, will use the default embedding function.",
|
||||
description="The embedding function of vector store, if not set, will use the "
|
||||
"default embedding function.",
|
||||
)
|
||||
max_chunks_once_load: int = Field(
|
||||
default=10,
|
||||
@ -36,6 +39,11 @@ class VectorStoreConfig(BaseModel):
|
||||
"large, you can set this value to a larger number to speed up the loading "
|
||||
"process. Default is 10.",
|
||||
)
|
||||
max_threads: int = Field(
|
||||
default=1,
|
||||
description="The max number of threads to use. Default is 1. If you set this "
|
||||
"bigger than 1, please make sure your vector store is thread-safe.",
|
||||
)
|
||||
|
||||
|
||||
class VectorStoreBase(ABC):
|
||||
@ -52,12 +60,13 @@ class VectorStoreBase(ABC):
|
||||
pass
|
||||
|
||||
def load_document_with_limit(
|
||||
self, chunks: List[Chunk], max_chunks_once_load: int = 10
|
||||
self, chunks: List[Chunk], max_chunks_once_load: int = 10, max_threads: int = 1
|
||||
) -> List[str]:
|
||||
"""load document in vector database with limit.
|
||||
Args:
|
||||
chunks: document chunks.
|
||||
max_chunks_once_load: Max number of chunks to load at once.
|
||||
max_threads: Max number of threads to use.
|
||||
Return:
|
||||
"""
|
||||
# Group the chunks into chunks of size max_chunks
|
||||
@ -65,14 +74,22 @@ class VectorStoreBase(ABC):
|
||||
chunks[i : i + max_chunks_once_load]
|
||||
for i in range(0, len(chunks), max_chunks_once_load)
|
||||
]
|
||||
logger.info(f"Loading {len(chunks)} chunks in {len(chunk_groups)} groups")
|
||||
logger.info(
|
||||
f"Loading {len(chunks)} chunks in {len(chunk_groups)} groups with "
|
||||
f"{max_threads} threads."
|
||||
)
|
||||
ids = []
|
||||
loaded_cnt = 0
|
||||
start_time = time.time()
|
||||
for chunk_group in chunk_groups:
|
||||
ids.extend(self.load_document(chunk_group))
|
||||
loaded_cnt += len(chunk_group)
|
||||
logger.info(f"Loaded {loaded_cnt} chunks, total {len(chunks)} chunks.")
|
||||
with ThreadPoolExecutor(max_workers=max_threads) as executor:
|
||||
tasks = []
|
||||
for chunk_group in chunk_groups:
|
||||
tasks.append(executor.submit(self.load_document, chunk_group))
|
||||
for future in tasks:
|
||||
success_ids = future.result()
|
||||
ids.extend(success_ids)
|
||||
loaded_cnt += len(success_ids)
|
||||
logger.info(f"Loaded {loaded_cnt} chunks, total {len(chunks)} chunks.")
|
||||
logger.info(
|
||||
f"Loaded {len(chunks)} chunks in {time.time() - start_time} seconds"
|
||||
)
|
||||
|
@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Any, Callable, List, Optional
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from dbgpt.rag.chunk import Chunk
|
||||
from dbgpt.storage import vector_store
|
||||
@ -65,7 +65,9 @@ class VectorStoreConnector:
|
||||
Return chunk ids.
|
||||
"""
|
||||
return self.client.load_document_with_limit(
|
||||
chunks, self._vector_store_config.max_chunks_once_load
|
||||
chunks,
|
||||
self._vector_store_config.max_chunks_once_load,
|
||||
self._vector_store_config.max_threads,
|
||||
)
|
||||
|
||||
def similar_search(self, doc: str, topk: int) -> List[Chunk]:
|
||||
|
@ -10,7 +10,7 @@ The call of multi-model services is compatible with the OpenAI interface, and th
|
||||
## Start apiserver
|
||||
|
||||
After deploying the model service, you need to start the API Server. By default, the model API Server uses port `8100` to start.
|
||||
```python
|
||||
```bash
|
||||
dbgpt start apiserver --controller_addr http://127.0.0.1:8000 --api_keys EMPTY
|
||||
|
||||
```
|
||||
@ -25,7 +25,7 @@ After the apiserver is started, the service call can be verified. First, let's l
|
||||
:::tip
|
||||
List models
|
||||
:::
|
||||
```python
|
||||
```bash
|
||||
curl http://127.0.0.1:8100/api/v1/models \
|
||||
-H "Authorization: Bearer EMPTY" \
|
||||
-H "Content-Type: application/json"
|
||||
@ -34,17 +34,31 @@ curl http://127.0.0.1:8100/api/v1/models \
|
||||
:::tip
|
||||
Chat
|
||||
:::
|
||||
```python
|
||||
```bash
|
||||
curl http://127.0.0.1:8100/api/v1/chat/completions \
|
||||
-H "Authorization: Bearer EMPTY" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"model": "vicuna-13b-v1.5", "messages": [{"role": "user", "content": "hello"}]}'
|
||||
```
|
||||
|
||||
:::tip
|
||||
Embedding
|
||||
:::
|
||||
```bash
|
||||
curl http://127.0.0.1:8100/api/v1/embeddings \
|
||||
-H "Authorization: Bearer EMPTY" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "text2vec",
|
||||
"input": "Hello world!"
|
||||
}'
|
||||
```
|
||||
|
||||
|
||||
|
||||
## Verify via OpenAI SDK
|
||||
|
||||
```python
|
||||
```bash
|
||||
import openai
|
||||
openai.api_key = "EMPTY"
|
||||
openai.api_base = "http://127.0.0.1:8100/api/v1"
|
||||
|
@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
|
||||
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH, ROOT_PATH
|
||||
from dbgpt.rag.chunk_manager import ChunkParameters
|
||||
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
|
||||
from dbgpt.rag.knowledge.factory import KnowledgeFactory
|
||||
@ -37,7 +37,7 @@ def _create_vector_connector():
|
||||
|
||||
|
||||
async def main():
|
||||
file_path = "docs/docs/awel.md"
|
||||
file_path = os.path.join(ROOT_PATH, "docs/docs/awel/awel.md")
|
||||
knowledge = KnowledgeFactory.from_file_path(file_path)
|
||||
vector_connector = _create_vector_connector()
|
||||
chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE")
|
||||
|
87
examples/rag/rag_embedding_api_example.py
Normal file
87
examples/rag/rag_embedding_api_example.py
Normal file
@ -0,0 +1,87 @@
|
||||
"""A RAG example using the OpenAPIEmbeddings.
|
||||
|
||||
Example:
|
||||
|
||||
Test with `OpenAI embeddings
|
||||
<https://platform.openai.com/docs/api-reference/embeddings/create>`_.
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
export API_SERVER_BASE_URL=${OPENAI_API_BASE:-"https://api.openai.com/v1"}
|
||||
export API_SERVER_API_KEY="${OPENAI_API_KEY}"
|
||||
export API_SERVER_EMBEDDINGS_MODEL="text-embedding-ada-002"
|
||||
python examples/rag/rag_embedding_api_example.py
|
||||
|
||||
Test with DB-GPT `API Server
|
||||
<https://docs.dbgpt.site/docs/installation/advanced_usage/OpenAI_SDK_call#start-apiserver>`_.
|
||||
|
||||
.. code-block:: shell
|
||||
export API_SERVER_BASE_URL="http://localhost:8100/api/v1"
|
||||
export API_SERVER_API_KEY="your_api_key"
|
||||
export API_SERVER_EMBEDDINGS_MODEL="text2vec"
|
||||
python examples/rag/rag_embedding_api_example.py
|
||||
|
||||
"""
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from dbgpt.configs.model_config import PILOT_PATH, ROOT_PATH
|
||||
from dbgpt.rag.chunk_manager import ChunkParameters
|
||||
from dbgpt.rag.embedding import OpenAPIEmbeddings
|
||||
from dbgpt.rag.knowledge.factory import KnowledgeFactory
|
||||
from dbgpt.serve.rag.assembler.embedding import EmbeddingAssembler
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
|
||||
def _create_embeddings(
|
||||
api_url: str = None, api_key: Optional[str] = None, model_name: Optional[str] = None
|
||||
) -> OpenAPIEmbeddings:
|
||||
if not api_url:
|
||||
api_server_base_url = os.getenv(
|
||||
"API_SERVER_BASE_URL", "http://localhost:8100/api/v1/"
|
||||
)
|
||||
api_url = f"{api_server_base_url}/embeddings"
|
||||
if not api_key:
|
||||
api_key = os.getenv("API_SERVER_API_KEY")
|
||||
|
||||
if not model_name:
|
||||
model_name = os.getenv("API_SERVER_EMBEDDINGS_MODEL", "text2vec")
|
||||
|
||||
return OpenAPIEmbeddings(api_url=api_url, api_key=api_key, model_name=model_name)
|
||||
|
||||
|
||||
def _create_vector_connector():
|
||||
"""Create vector connector."""
|
||||
|
||||
return VectorStoreConnector.from_default(
|
||||
"Chroma",
|
||||
vector_store_config=ChromaVectorConfig(
|
||||
name="example_embedding_api_vector_store_name",
|
||||
persist_path=os.path.join(PILOT_PATH, "data"),
|
||||
),
|
||||
embedding_fn=_create_embeddings(),
|
||||
)
|
||||
|
||||
|
||||
async def main():
|
||||
file_path = os.path.join(ROOT_PATH, "docs/docs/awel/awel.md")
|
||||
knowledge = KnowledgeFactory.from_file_path(file_path)
|
||||
vector_connector = _create_vector_connector()
|
||||
chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE")
|
||||
# get embedding assembler
|
||||
assembler = EmbeddingAssembler.load_from_knowledge(
|
||||
knowledge=knowledge,
|
||||
chunk_parameters=chunk_parameters,
|
||||
vector_store_connector=vector_connector,
|
||||
)
|
||||
assembler.persist()
|
||||
# get embeddings retriever
|
||||
retriever = assembler.as_retriever(3)
|
||||
chunks = await retriever.aretrieve_with_scores("what is awel talk about", 0.3)
|
||||
print(f"embedding rag example results:{chunks}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
Loading…
Reference in New Issue
Block a user