mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-15 22:19:28 +00:00
feat: APIServer supports embeddings (#1256)
This commit is contained in:
@@ -2,6 +2,7 @@ import logging
|
||||
import math
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, Callable, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -24,11 +25,13 @@ class VectorStoreConfig(BaseModel):
|
||||
)
|
||||
password: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The password of vector store, if not set, will use the default password.",
|
||||
description="The password of vector store, if not set, will use the default "
|
||||
"password.",
|
||||
)
|
||||
embedding_fn: Optional[Any] = Field(
|
||||
default=None,
|
||||
description="The embedding function of vector store, if not set, will use the default embedding function.",
|
||||
description="The embedding function of vector store, if not set, will use the "
|
||||
"default embedding function.",
|
||||
)
|
||||
max_chunks_once_load: int = Field(
|
||||
default=10,
|
||||
@@ -36,6 +39,11 @@ class VectorStoreConfig(BaseModel):
|
||||
"large, you can set this value to a larger number to speed up the loading "
|
||||
"process. Default is 10.",
|
||||
)
|
||||
max_threads: int = Field(
|
||||
default=1,
|
||||
description="The max number of threads to use. Default is 1. If you set this "
|
||||
"bigger than 1, please make sure your vector store is thread-safe.",
|
||||
)
|
||||
|
||||
|
||||
class VectorStoreBase(ABC):
|
||||
@@ -52,12 +60,13 @@ class VectorStoreBase(ABC):
|
||||
pass
|
||||
|
||||
def load_document_with_limit(
|
||||
self, chunks: List[Chunk], max_chunks_once_load: int = 10
|
||||
self, chunks: List[Chunk], max_chunks_once_load: int = 10, max_threads: int = 1
|
||||
) -> List[str]:
|
||||
"""load document in vector database with limit.
|
||||
Args:
|
||||
chunks: document chunks.
|
||||
max_chunks_once_load: Max number of chunks to load at once.
|
||||
max_threads: Max number of threads to use.
|
||||
Return:
|
||||
"""
|
||||
# Group the chunks into chunks of size max_chunks
|
||||
@@ -65,14 +74,22 @@ class VectorStoreBase(ABC):
|
||||
chunks[i : i + max_chunks_once_load]
|
||||
for i in range(0, len(chunks), max_chunks_once_load)
|
||||
]
|
||||
logger.info(f"Loading {len(chunks)} chunks in {len(chunk_groups)} groups")
|
||||
logger.info(
|
||||
f"Loading {len(chunks)} chunks in {len(chunk_groups)} groups with "
|
||||
f"{max_threads} threads."
|
||||
)
|
||||
ids = []
|
||||
loaded_cnt = 0
|
||||
start_time = time.time()
|
||||
for chunk_group in chunk_groups:
|
||||
ids.extend(self.load_document(chunk_group))
|
||||
loaded_cnt += len(chunk_group)
|
||||
logger.info(f"Loaded {loaded_cnt} chunks, total {len(chunks)} chunks.")
|
||||
with ThreadPoolExecutor(max_workers=max_threads) as executor:
|
||||
tasks = []
|
||||
for chunk_group in chunk_groups:
|
||||
tasks.append(executor.submit(self.load_document, chunk_group))
|
||||
for future in tasks:
|
||||
success_ids = future.result()
|
||||
ids.extend(success_ids)
|
||||
loaded_cnt += len(success_ids)
|
||||
logger.info(f"Loaded {loaded_cnt} chunks, total {len(chunks)} chunks.")
|
||||
logger.info(
|
||||
f"Loaded {len(chunks)} chunks in {time.time() - start_time} seconds"
|
||||
)
|
||||
|
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Any, Callable, List, Optional
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from dbgpt.rag.chunk import Chunk
|
||||
from dbgpt.storage import vector_store
|
||||
@@ -65,7 +65,9 @@ class VectorStoreConnector:
|
||||
Return chunk ids.
|
||||
"""
|
||||
return self.client.load_document_with_limit(
|
||||
chunks, self._vector_store_config.max_chunks_once_load
|
||||
chunks,
|
||||
self._vector_store_config.max_chunks_once_load,
|
||||
self._vector_store_config.max_threads,
|
||||
)
|
||||
|
||||
def similar_search(self, doc: str, topk: int) -> List[Chunk]:
|
||||
|
Reference in New Issue
Block a user