feat: APIServer supports embeddings (#1256)

This commit is contained in:
Fangyin Cheng 2024-03-05 20:21:37 +08:00 committed by GitHub
parent 5f3ee35804
commit 74ec8e52cd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 414 additions and 40 deletions

View File

@ -23,6 +23,8 @@ from fastchat.protocol.openai_api_protocol import (
ChatCompletionStreamResponse, ChatCompletionStreamResponse,
ChatMessage, ChatMessage,
DeltaMessage, DeltaMessage,
EmbeddingsRequest,
EmbeddingsResponse,
ModelCard, ModelCard,
ModelList, ModelList,
ModelPermission, ModelPermission,
@ -51,6 +53,7 @@ class APIServerException(Exception):
class APISettings(BaseModel): class APISettings(BaseModel):
api_keys: Optional[List[str]] = None api_keys: Optional[List[str]] = None
embedding_bach_size: int = 4
api_settings = APISettings() api_settings = APISettings()
@ -181,27 +184,29 @@ class APIServer(BaseComponent):
return controller return controller
async def get_model_instances_or_raise( async def get_model_instances_or_raise(
self, model_name: str self, model_name: str, worker_type: str = "llm"
) -> List[ModelInstance]: ) -> List[ModelInstance]:
"""Get healthy model instances with request model name """Get healthy model instances with request model name
Args: Args:
model_name (str): Model name model_name (str): Model name
worker_type (str, optional): Worker type. Defaults to "llm".
Raises: Raises:
APIServerException: If can't get healthy model instances with request model name APIServerException: If can't get healthy model instances with request model name
""" """
registry = self.get_model_registry() registry = self.get_model_registry()
registry_model_name = f"{model_name}@llm" suffix = f"@{worker_type}"
registry_model_name = f"{model_name}{suffix}"
model_instances = await registry.get_all_instances( model_instances = await registry.get_all_instances(
registry_model_name, healthy_only=True registry_model_name, healthy_only=True
) )
if not model_instances: if not model_instances:
all_instances = await registry.get_all_model_instances(healthy_only=True) all_instances = await registry.get_all_model_instances(healthy_only=True)
models = [ models = [
ins.model_name.split("@llm")[0] ins.model_name.split(suffix)[0]
for ins in all_instances for ins in all_instances
if ins.model_name.endswith("@llm") if ins.model_name.endswith(suffix)
] ]
if models: if models:
models = "&&".join(models) models = "&&".join(models)
@ -336,6 +341,25 @@ class APIServer(BaseComponent):
return ChatCompletionResponse(model=model_name, choices=choices, usage=usage) return ChatCompletionResponse(model=model_name, choices=choices, usage=usage)
async def embeddings_generate(
self, model: str, texts: List[str]
) -> List[List[float]]:
"""Generate embeddings
Args:
model (str): Model name
texts (List[str]): Texts to embed
Returns:
List[List[float]]: The embeddings of texts
"""
worker_manager: WorkerManager = self.get_worker_manager()
params = {
"input": texts,
"model": model,
}
return await worker_manager.embeddings(params)
def get_api_server() -> APIServer: def get_api_server() -> APIServer:
api_server = global_system_app.get_component( api_server = global_system_app.get_component(
@ -389,6 +413,40 @@ async def create_chat_completion(
return await api_server.chat_completion_generate(request.model, params, request.n) return await api_server.chat_completion_generate(request.model, params, request.n)
@router.post("/v1/embeddings", dependencies=[Depends(check_api_key)])
async def create_embeddings(
request: EmbeddingsRequest, api_server: APIServer = Depends(get_api_server)
):
await api_server.get_model_instances_or_raise(request.model, worker_type="text2vec")
texts = request.input
if isinstance(texts, str):
texts = [texts]
batch_size = api_settings.embedding_bach_size
batches = [
texts[i : min(i + batch_size, len(texts))]
for i in range(0, len(texts), batch_size)
]
data = []
async_tasks = []
for num_batch, batch in enumerate(batches):
async_tasks.append(api_server.embeddings_generate(request.model, batch))
# Request all embeddings in parallel
batch_embeddings: List[List[List[float]]] = await asyncio.gather(*async_tasks)
for num_batch, embeddings in enumerate(batch_embeddings):
data += [
{
"object": "embedding",
"embedding": emb,
"index": num_batch * batch_size + i,
}
for i, emb in enumerate(embeddings)
]
return EmbeddingsResponse(data=data, model=request.model, usage=UsageInfo()).dict(
exclude_none=True
)
def _initialize_all(controller_addr: str, system_app: SystemApp): def _initialize_all(controller_addr: str, system_app: SystemApp):
from dbgpt.model.cluster.controller.controller import ModelRegistryClient from dbgpt.model.cluster.controller.controller import ModelRegistryClient
from dbgpt.model.cluster.worker.manager import _DefaultWorkerManagerFactory from dbgpt.model.cluster.worker.manager import _DefaultWorkerManagerFactory
@ -427,6 +485,7 @@ def initialize_apiserver(
host: str = None, host: str = None,
port: int = None, port: int = None,
api_keys: List[str] = None, api_keys: List[str] = None,
embedding_batch_size: Optional[int] = None,
): ):
global global_system_app global global_system_app
global api_settings global api_settings
@ -434,13 +493,6 @@ def initialize_apiserver(
if not app: if not app:
embedded_mod = False embedded_mod = False
app = FastAPI() app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
allow_headers=["*"],
)
if not system_app: if not system_app:
system_app = SystemApp(app) system_app = SystemApp(app)
@ -449,6 +501,9 @@ def initialize_apiserver(
if api_keys: if api_keys:
api_settings.api_keys = api_keys api_settings.api_keys = api_keys
if embedding_batch_size:
api_settings.embedding_bach_size = embedding_batch_size
app.include_router(router, prefix="/api", tags=["APIServer"]) app.include_router(router, prefix="/api", tags=["APIServer"])
@app.exception_handler(APIServerException) @app.exception_handler(APIServerException)
@ -464,7 +519,15 @@ def initialize_apiserver(
if not embedded_mod: if not embedded_mod:
import uvicorn import uvicorn
uvicorn.run(app, host=host, port=port, log_level="info") # https://github.com/encode/starlette/issues/617
cors_app = CORSMiddleware(
app=app,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
allow_headers=["*"],
)
uvicorn.run(cors_app, host=host, port=port, log_level="info")
def run_apiserver(): def run_apiserver():
@ -488,6 +551,7 @@ def run_apiserver():
host=apiserver_params.host, host=apiserver_params.host,
port=apiserver_params.port, port=apiserver_params.port,
api_keys=api_keys, api_keys=api_keys,
embedding_batch_size=apiserver_params.embedding_batch_size,
) )

View File

@ -113,6 +113,9 @@ class ModelAPIServerParameters(BaseParameters):
default=None, default=None,
metadata={"help": "Optional list of comma separated API keys"}, metadata={"help": "Optional list of comma separated API keys"},
) )
embedding_batch_size: Optional[int] = field(
default=None, metadata={"help": "Embedding batch size"}
)
log_level: Optional[str] = field( log_level: Optional[str] = field(
default=None, default=None,

View File

@ -0,0 +1,16 @@
from .embedding_factory import DefaultEmbeddingFactory, EmbeddingFactory
from .embeddings import (
Embeddings,
HuggingFaceEmbeddings,
JinaEmbeddings,
OpenAPIEmbeddings,
)
__ALL__ = [
"OpenAPIEmbeddings",
"Embeddings",
"HuggingFaceEmbeddings",
"JinaEmbeddings",
"EmbeddingFactory",
"DefaultEmbeddingFactory",
]

View File

@ -2,8 +2,10 @@ import asyncio
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import aiohttp
import requests import requests
from pydantic import BaseModel, Extra, Field
from dbgpt._private.pydantic import BaseModel, Extra, Field
DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2" DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
DEFAULT_INSTRUCT_MODEL = "hkunlp/instructor-large" DEFAULT_INSTRUCT_MODEL = "hkunlp/instructor-large"
@ -363,6 +365,29 @@ class HuggingFaceInferenceAPIEmbeddings(BaseModel, Embeddings):
return self.embed_documents([text])[0] return self.embed_documents([text])[0]
def _handle_request_result(res: requests.Response) -> List[List[float]]:
"""Parse the result from a request.
Args:
res: The response from the request.
Returns:
List[List[float]]: The embeddings.
Raises:
RuntimeError: If the response is not successful.
"""
res.raise_for_status()
resp = res.json()
if "data" not in resp:
raise RuntimeError(resp["detail"])
embeddings = resp["data"]
# Sort resulting embeddings by index
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore
# Return just the embeddings
return [result["embedding"] for result in sorted_embeddings]
class JinaEmbeddings(BaseModel, Embeddings): class JinaEmbeddings(BaseModel, Embeddings):
""" """
This class is used to get embeddings for a list of texts using the Jina AI API. This class is used to get embeddings for a list of texts using the Jina AI API.
@ -406,17 +431,8 @@ class JinaEmbeddings(BaseModel, Embeddings):
# Call Jina AI Embedding API # Call Jina AI Embedding API
resp = self.session.post( # type: ignore resp = self.session.post( # type: ignore
self.api_url, json={"input": texts, "model": self.model_name} self.api_url, json={"input": texts, "model": self.model_name}
).json() )
if "data" not in resp: return _handle_request_result(res)
raise RuntimeError(resp["detail"])
embeddings = resp["data"]
# Sort resulting embeddings by index
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore
# Return just the embeddings
return [result["embedding"] for result in sorted_embeddings]
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> List[float]:
"""Compute query embeddings using a HuggingFace transformer model. """Compute query embeddings using a HuggingFace transformer model.
@ -428,3 +444,158 @@ class JinaEmbeddings(BaseModel, Embeddings):
Embeddings for the text. Embeddings for the text.
""" """
return self.embed_documents([text])[0] return self.embed_documents([text])[0]
class OpenAPIEmbeddings(BaseModel, Embeddings):
"""This class is used to get embeddings for a list of texts using the API.
This API is compatible with the OpenAI Embedding API.
Examples:
Using OpenAI's API:
.. code-block:: python
from dbgpt.rag.embedding import OpenAPIEmbeddings
openai_embeddings = OpenAPIEmbeddings(
api_url="https://api.openai.com/v1/embeddings",
api_key="your_api_key",
model_name="text-embedding-3-small",
)
texts = ["Hello, world!", "How are you?"]
openai_embeddings.embed_documents(texts)
Using DB-GPT APIServer's embedding API:
To use the DB-GPT APIServer's embedding API, you should deploy DB-GPT according
to the `Cluster Deploy
<https://docs.dbgpt.site/docs/installation/model_service/cluster>`_.
A simple example:
1. Deploy Model Cluster with following command:
.. code-block:: bash
dbgpt start controller --port 8000
2. Deploy Embedding Model Worker with following command:
.. code-block:: bash
dbgpt start worker --model_name text2vec \
--model_path /app/models/text2vec-large-chinese \
--worker_type text2vec \
--port 8003 \
--controller_addr http://127.0.0.1:8000
3. Deploy API Server with following command:
.. code-block:: bash
dbgpt start apiserver --controller_addr http://127.0.0.1:8000 \
--api_keys my_api_token --port 8100
Now, you can use the API server to get embeddings:
.. code-block:: python
from dbgpt.rag.embedding import OpenAPIEmbeddings
openai_embeddings = OpenAPIEmbeddings(
api_url="http://localhost:8100/api/v1/embeddings",
api_key="my_api_token",
model_name="text2vec",
)
texts = ["Hello, world!", "How are you?"]
openai_embeddings.embed_documents(texts)
"""
api_url: str = Field(
default="http://localhost:8100/api/v1/embeddings",
description="The URL of the embeddings API.",
)
api_key: Optional[str] = Field(
default=None, description="The API key for the embeddings API."
)
model_name: str = Field(
default="text2vec", description="The name of the model to use."
)
timeout: int = Field(
default=60, description="The timeout for the request in seconds."
)
session: requests.Session = None
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def __init__(self, **kwargs):
"""Initialize the OpenAPIEmbeddings."""
super().__init__(**kwargs)
try:
import requests
except ImportError:
raise ValueError(
"The requests python package is not installed. "
"Please install it with `pip install requests`"
)
self.session = requests.Session()
self.session.headers.update({"Authorization": f"Bearer {self.api_key}"})
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Get the embeddings for a list of texts.
Args:
texts (Documents): A list of texts to get embeddings for.
Returns:
Embedded texts as List[List[float]], where each inner List[float]
corresponds to a single input text.
"""
# Call OpenAI Embedding API
res = self.session.post( # type: ignore
self.api_url,
json={"input": texts, "model": self.model_name},
timeout=self.timeout,
)
return _handle_request_result(res)
def embed_query(self, text: str) -> List[float]:
"""Compute query embeddings using a OpenAPI embedding model.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
return self.embed_documents([text])[0]
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Asynchronous Embed search docs.
Args:
texts: A list of texts to get embeddings for.
Returns:
List[List[float]]: Embedded texts as List[List[float]], where each inner
List[float] corresponds to a single input text.
"""
headers = {"Authorization": f"Bearer {self.api_key}"}
async with aiohttp.ClientSession(
headers=headers, timeout=aiohttp.ClientTimeout(total=self.timeout)
) as session:
async with session.post(
self.api_url, json={"input": texts, "model": self.model_name}
) as resp:
resp.raise_for_status()
data = await resp.json()
if "data" not in data:
raise RuntimeError(data["detail"])
embeddings = data["data"]
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"])
return [result["embedding"] for result in sorted_embeddings]
async def aembed_query(self, text: str) -> List[float]:
"""Asynchronous Embed query text."""
embeddings = await self.aembed_documents([text])
return embeddings[0]

View File

@ -2,6 +2,7 @@ import logging
import math import math
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Callable, List, Optional from typing import Any, Callable, List, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -24,11 +25,13 @@ class VectorStoreConfig(BaseModel):
) )
password: Optional[str] = Field( password: Optional[str] = Field(
default=None, default=None,
description="The password of vector store, if not set, will use the default password.", description="The password of vector store, if not set, will use the default "
"password.",
) )
embedding_fn: Optional[Any] = Field( embedding_fn: Optional[Any] = Field(
default=None, default=None,
description="The embedding function of vector store, if not set, will use the default embedding function.", description="The embedding function of vector store, if not set, will use the "
"default embedding function.",
) )
max_chunks_once_load: int = Field( max_chunks_once_load: int = Field(
default=10, default=10,
@ -36,6 +39,11 @@ class VectorStoreConfig(BaseModel):
"large, you can set this value to a larger number to speed up the loading " "large, you can set this value to a larger number to speed up the loading "
"process. Default is 10.", "process. Default is 10.",
) )
max_threads: int = Field(
default=1,
description="The max number of threads to use. Default is 1. If you set this "
"bigger than 1, please make sure your vector store is thread-safe.",
)
class VectorStoreBase(ABC): class VectorStoreBase(ABC):
@ -52,12 +60,13 @@ class VectorStoreBase(ABC):
pass pass
def load_document_with_limit( def load_document_with_limit(
self, chunks: List[Chunk], max_chunks_once_load: int = 10 self, chunks: List[Chunk], max_chunks_once_load: int = 10, max_threads: int = 1
) -> List[str]: ) -> List[str]:
"""load document in vector database with limit. """load document in vector database with limit.
Args: Args:
chunks: document chunks. chunks: document chunks.
max_chunks_once_load: Max number of chunks to load at once. max_chunks_once_load: Max number of chunks to load at once.
max_threads: Max number of threads to use.
Return: Return:
""" """
# Group the chunks into chunks of size max_chunks # Group the chunks into chunks of size max_chunks
@ -65,13 +74,21 @@ class VectorStoreBase(ABC):
chunks[i : i + max_chunks_once_load] chunks[i : i + max_chunks_once_load]
for i in range(0, len(chunks), max_chunks_once_load) for i in range(0, len(chunks), max_chunks_once_load)
] ]
logger.info(f"Loading {len(chunks)} chunks in {len(chunk_groups)} groups") logger.info(
f"Loading {len(chunks)} chunks in {len(chunk_groups)} groups with "
f"{max_threads} threads."
)
ids = [] ids = []
loaded_cnt = 0 loaded_cnt = 0
start_time = time.time() start_time = time.time()
with ThreadPoolExecutor(max_workers=max_threads) as executor:
tasks = []
for chunk_group in chunk_groups: for chunk_group in chunk_groups:
ids.extend(self.load_document(chunk_group)) tasks.append(executor.submit(self.load_document, chunk_group))
loaded_cnt += len(chunk_group) for future in tasks:
success_ids = future.result()
ids.extend(success_ids)
loaded_cnt += len(success_ids)
logger.info(f"Loaded {loaded_cnt} chunks, total {len(chunks)} chunks.") logger.info(f"Loaded {loaded_cnt} chunks, total {len(chunks)} chunks.")
logger.info( logger.info(
f"Loaded {len(chunks)} chunks in {time.time() - start_time} seconds" f"Loaded {len(chunks)} chunks in {time.time() - start_time} seconds"

View File

@ -1,5 +1,5 @@
import os import os
from typing import Any, Callable, List, Optional from typing import Any, List, Optional
from dbgpt.rag.chunk import Chunk from dbgpt.rag.chunk import Chunk
from dbgpt.storage import vector_store from dbgpt.storage import vector_store
@ -65,7 +65,9 @@ class VectorStoreConnector:
Return chunk ids. Return chunk ids.
""" """
return self.client.load_document_with_limit( return self.client.load_document_with_limit(
chunks, self._vector_store_config.max_chunks_once_load chunks,
self._vector_store_config.max_chunks_once_load,
self._vector_store_config.max_threads,
) )
def similar_search(self, doc: str, topk: int) -> List[Chunk]: def similar_search(self, doc: str, topk: int) -> List[Chunk]:

View File

@ -10,7 +10,7 @@ The call of multi-model services is compatible with the OpenAI interface, and th
## Start apiserver ## Start apiserver
After deploying the model service, you need to start the API Server. By default, the model API Server uses port `8100` to start. After deploying the model service, you need to start the API Server. By default, the model API Server uses port `8100` to start.
```python ```bash
dbgpt start apiserver --controller_addr http://127.0.0.1:8000 --api_keys EMPTY dbgpt start apiserver --controller_addr http://127.0.0.1:8000 --api_keys EMPTY
``` ```
@ -25,7 +25,7 @@ After the apiserver is started, the service call can be verified. First, let's l
:::tip :::tip
List models List models
::: :::
```python ```bash
curl http://127.0.0.1:8100/api/v1/models \ curl http://127.0.0.1:8100/api/v1/models \
-H "Authorization: Bearer EMPTY" \ -H "Authorization: Bearer EMPTY" \
-H "Content-Type: application/json" -H "Content-Type: application/json"
@ -34,17 +34,31 @@ curl http://127.0.0.1:8100/api/v1/models \
:::tip :::tip
Chat Chat
::: :::
```python ```bash
curl http://127.0.0.1:8100/api/v1/chat/completions \ curl http://127.0.0.1:8100/api/v1/chat/completions \
-H "Authorization: Bearer EMPTY" \ -H "Authorization: Bearer EMPTY" \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
-d '{"model": "vicuna-13b-v1.5", "messages": [{"role": "user", "content": "hello"}]}' -d '{"model": "vicuna-13b-v1.5", "messages": [{"role": "user", "content": "hello"}]}'
``` ```
:::tip
Embedding
:::
```bash
curl http://127.0.0.1:8100/api/v1/embeddings \
-H "Authorization: Bearer EMPTY" \
-H "Content-Type: application/json" \
-d '{
"model": "text2vec",
"input": "Hello world!"
}'
```
## Verify via OpenAI SDK ## Verify via OpenAI SDK
```python ```bash
import openai import openai
openai.api_key = "EMPTY" openai.api_key = "EMPTY"
openai.api_base = "http://127.0.0.1:8100/api/v1" openai.api_base = "http://127.0.0.1:8100/api/v1"

View File

@ -1,7 +1,7 @@
import asyncio import asyncio
import os import os
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH, ROOT_PATH
from dbgpt.rag.chunk_manager import ChunkParameters from dbgpt.rag.chunk_manager import ChunkParameters
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
from dbgpt.rag.knowledge.factory import KnowledgeFactory from dbgpt.rag.knowledge.factory import KnowledgeFactory
@ -37,7 +37,7 @@ def _create_vector_connector():
async def main(): async def main():
file_path = "docs/docs/awel.md" file_path = os.path.join(ROOT_PATH, "docs/docs/awel/awel.md")
knowledge = KnowledgeFactory.from_file_path(file_path) knowledge = KnowledgeFactory.from_file_path(file_path)
vector_connector = _create_vector_connector() vector_connector = _create_vector_connector()
chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE") chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE")

View File

@ -0,0 +1,87 @@
"""A RAG example using the OpenAPIEmbeddings.
Example:
Test with `OpenAI embeddings
<https://platform.openai.com/docs/api-reference/embeddings/create>`_.
.. code-block:: shell
export API_SERVER_BASE_URL=${OPENAI_API_BASE:-"https://api.openai.com/v1"}
export API_SERVER_API_KEY="${OPENAI_API_KEY}"
export API_SERVER_EMBEDDINGS_MODEL="text-embedding-ada-002"
python examples/rag/rag_embedding_api_example.py
Test with DB-GPT `API Server
<https://docs.dbgpt.site/docs/installation/advanced_usage/OpenAI_SDK_call#start-apiserver>`_.
.. code-block:: shell
export API_SERVER_BASE_URL="http://localhost:8100/api/v1"
export API_SERVER_API_KEY="your_api_key"
export API_SERVER_EMBEDDINGS_MODEL="text2vec"
python examples/rag/rag_embedding_api_example.py
"""
import asyncio
import os
from typing import Optional
from dbgpt.configs.model_config import PILOT_PATH, ROOT_PATH
from dbgpt.rag.chunk_manager import ChunkParameters
from dbgpt.rag.embedding import OpenAPIEmbeddings
from dbgpt.rag.knowledge.factory import KnowledgeFactory
from dbgpt.serve.rag.assembler.embedding import EmbeddingAssembler
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
from dbgpt.storage.vector_store.connector import VectorStoreConnector
def _create_embeddings(
api_url: str = None, api_key: Optional[str] = None, model_name: Optional[str] = None
) -> OpenAPIEmbeddings:
if not api_url:
api_server_base_url = os.getenv(
"API_SERVER_BASE_URL", "http://localhost:8100/api/v1/"
)
api_url = f"{api_server_base_url}/embeddings"
if not api_key:
api_key = os.getenv("API_SERVER_API_KEY")
if not model_name:
model_name = os.getenv("API_SERVER_EMBEDDINGS_MODEL", "text2vec")
return OpenAPIEmbeddings(api_url=api_url, api_key=api_key, model_name=model_name)
def _create_vector_connector():
"""Create vector connector."""
return VectorStoreConnector.from_default(
"Chroma",
vector_store_config=ChromaVectorConfig(
name="example_embedding_api_vector_store_name",
persist_path=os.path.join(PILOT_PATH, "data"),
),
embedding_fn=_create_embeddings(),
)
async def main():
file_path = os.path.join(ROOT_PATH, "docs/docs/awel/awel.md")
knowledge = KnowledgeFactory.from_file_path(file_path)
vector_connector = _create_vector_connector()
chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE")
# get embedding assembler
assembler = EmbeddingAssembler.load_from_knowledge(
knowledge=knowledge,
chunk_parameters=chunk_parameters,
vector_store_connector=vector_connector,
)
assembler.persist()
# get embeddings retriever
retriever = assembler.as_retriever(3)
chunks = await retriever.aretrieve_with_scores("what is awel talk about", 0.3)
print(f"embedding rag example results:{chunks}")
if __name__ == "__main__":
asyncio.run(main())