feat: APIServer supports embeddings (#1256)

This commit is contained in:
Fangyin Cheng
2024-03-05 20:21:37 +08:00
committed by GitHub
parent 5f3ee35804
commit 74ec8e52cd
9 changed files with 414 additions and 40 deletions

View File

@@ -2,6 +2,7 @@ import logging
import math
import time
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Callable, List, Optional
from pydantic import BaseModel, Field
@@ -24,11 +25,13 @@ class VectorStoreConfig(BaseModel):
)
password: Optional[str] = Field(
default=None,
description="The password of vector store, if not set, will use the default password.",
description="The password of vector store, if not set, will use the default "
"password.",
)
embedding_fn: Optional[Any] = Field(
default=None,
description="The embedding function of vector store, if not set, will use the default embedding function.",
description="The embedding function of vector store, if not set, will use the "
"default embedding function.",
)
max_chunks_once_load: int = Field(
default=10,
@@ -36,6 +39,11 @@ class VectorStoreConfig(BaseModel):
"large, you can set this value to a larger number to speed up the loading "
"process. Default is 10.",
)
max_threads: int = Field(
default=1,
description="The max number of threads to use. Default is 1. If you set this "
"bigger than 1, please make sure your vector store is thread-safe.",
)
class VectorStoreBase(ABC):
@@ -52,12 +60,13 @@ class VectorStoreBase(ABC):
pass
def load_document_with_limit(
self, chunks: List[Chunk], max_chunks_once_load: int = 10
self, chunks: List[Chunk], max_chunks_once_load: int = 10, max_threads: int = 1
) -> List[str]:
"""load document in vector database with limit.
Args:
chunks: document chunks.
max_chunks_once_load: Max number of chunks to load at once.
max_threads: Max number of threads to use.
Return:
"""
# Group the chunks into chunks of size max_chunks
@@ -65,14 +74,22 @@ class VectorStoreBase(ABC):
chunks[i : i + max_chunks_once_load]
for i in range(0, len(chunks), max_chunks_once_load)
]
logger.info(f"Loading {len(chunks)} chunks in {len(chunk_groups)} groups")
logger.info(
f"Loading {len(chunks)} chunks in {len(chunk_groups)} groups with "
f"{max_threads} threads."
)
ids = []
loaded_cnt = 0
start_time = time.time()
for chunk_group in chunk_groups:
ids.extend(self.load_document(chunk_group))
loaded_cnt += len(chunk_group)
logger.info(f"Loaded {loaded_cnt} chunks, total {len(chunks)} chunks.")
with ThreadPoolExecutor(max_workers=max_threads) as executor:
tasks = []
for chunk_group in chunk_groups:
tasks.append(executor.submit(self.load_document, chunk_group))
for future in tasks:
success_ids = future.result()
ids.extend(success_ids)
loaded_cnt += len(success_ids)
logger.info(f"Loaded {loaded_cnt} chunks, total {len(chunks)} chunks.")
logger.info(
f"Loaded {len(chunks)} chunks in {time.time() - start_time} seconds"
)

View File

@@ -1,5 +1,5 @@
import os
from typing import Any, Callable, List, Optional
from typing import Any, List, Optional
from dbgpt.rag.chunk import Chunk
from dbgpt.storage import vector_store
@@ -65,7 +65,9 @@ class VectorStoreConnector:
Return chunk ids.
"""
return self.client.load_document_with_limit(
chunks, self._vector_store_config.max_chunks_once_load
chunks,
self._vector_store_config.max_chunks_once_load,
self._vector_store_config.max_threads,
)
def similar_search(self, doc: str, topk: int) -> List[Chunk]: