mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-09 12:59:43 +00:00
perf(rag): Support load large document (#1233)
This commit is contained in:
@@ -75,6 +75,10 @@ EMBEDDING_MODEL=text2vec
|
|||||||
#EMBEDDING_MODEL=bge-large-zh
|
#EMBEDDING_MODEL=bge-large-zh
|
||||||
KNOWLEDGE_CHUNK_SIZE=500
|
KNOWLEDGE_CHUNK_SIZE=500
|
||||||
KNOWLEDGE_SEARCH_TOP_SIZE=5
|
KNOWLEDGE_SEARCH_TOP_SIZE=5
|
||||||
|
## Maximum number of chunks to load at once, if your single document is too large,
|
||||||
|
## you can set this value to a higher value for better performance.
|
||||||
|
## if out of memory when load large document, you can set this value to a lower value.
|
||||||
|
# KNOWLEDGE_MAX_CHUNKS_ONCE_LOAD=10
|
||||||
#KNOWLEDGE_CHUNK_OVERLAP=50
|
#KNOWLEDGE_CHUNK_OVERLAP=50
|
||||||
# Control whether to display the source document of knowledge on the front end.
|
# Control whether to display the source document of knowledge on the front end.
|
||||||
KNOWLEDGE_CHAT_SHOW_RELATIONS=False
|
KNOWLEDGE_CHAT_SHOW_RELATIONS=False
|
||||||
|
@@ -233,6 +233,9 @@ class Config(metaclass=Singleton):
|
|||||||
self.KNOWLEDGE_CHUNK_SIZE = int(os.getenv("KNOWLEDGE_CHUNK_SIZE", 100))
|
self.KNOWLEDGE_CHUNK_SIZE = int(os.getenv("KNOWLEDGE_CHUNK_SIZE", 100))
|
||||||
self.KNOWLEDGE_CHUNK_OVERLAP = int(os.getenv("KNOWLEDGE_CHUNK_OVERLAP", 50))
|
self.KNOWLEDGE_CHUNK_OVERLAP = int(os.getenv("KNOWLEDGE_CHUNK_OVERLAP", 50))
|
||||||
self.KNOWLEDGE_SEARCH_TOP_SIZE = int(os.getenv("KNOWLEDGE_SEARCH_TOP_SIZE", 5))
|
self.KNOWLEDGE_SEARCH_TOP_SIZE = int(os.getenv("KNOWLEDGE_SEARCH_TOP_SIZE", 5))
|
||||||
|
self.KNOWLEDGE_MAX_CHUNKS_ONCE_LOAD = int(
|
||||||
|
os.getenv("KNOWLEDGE_MAX_CHUNKS_ONCE_LOAD", 10)
|
||||||
|
)
|
||||||
# default recall similarity score, between 0 and 1
|
# default recall similarity score, between 0 and 1
|
||||||
self.KNOWLEDGE_SEARCH_RECALL_SCORE = float(
|
self.KNOWLEDGE_SEARCH_RECALL_SCORE = float(
|
||||||
os.getenv("KNOWLEDGE_SEARCH_RECALL_SCORE", 0.3)
|
os.getenv("KNOWLEDGE_SEARCH_RECALL_SCORE", 0.3)
|
||||||
|
@@ -43,6 +43,7 @@ from dbgpt.serve.rag.assembler.summary import SummaryAssembler
|
|||||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||||
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
|
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
|
||||||
|
from dbgpt.util.tracer import root_tracer, trace
|
||||||
|
|
||||||
knowledge_space_dao = KnowledgeSpaceDao()
|
knowledge_space_dao = KnowledgeSpaceDao()
|
||||||
knowledge_document_dao = KnowledgeDocumentDao()
|
knowledge_document_dao = KnowledgeDocumentDao()
|
||||||
@@ -335,7 +336,11 @@ class KnowledgeService:
|
|||||||
)
|
)
|
||||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||||
|
|
||||||
config = VectorStoreConfig(name=space_name, embedding_fn=embedding_fn)
|
config = VectorStoreConfig(
|
||||||
|
name=space_name,
|
||||||
|
embedding_fn=embedding_fn,
|
||||||
|
max_chunks_once_load=CFG.KNOWLEDGE_MAX_CHUNKS_ONCE_LOAD,
|
||||||
|
)
|
||||||
vector_store_connector = VectorStoreConnector(
|
vector_store_connector = VectorStoreConnector(
|
||||||
vector_store_type=CFG.VECTOR_STORE_TYPE,
|
vector_store_type=CFG.VECTOR_STORE_TYPE,
|
||||||
vector_store_config=config,
|
vector_store_config=config,
|
||||||
@@ -499,6 +504,7 @@ class KnowledgeService:
|
|||||||
res.page = request.page
|
res.page = request.page
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
@trace("async_doc_embedding")
|
||||||
def async_doc_embedding(self, assembler, chunk_docs, doc):
|
def async_doc_embedding(self, assembler, chunk_docs, doc):
|
||||||
"""async document embedding into vector db
|
"""async document embedding into vector db
|
||||||
Args:
|
Args:
|
||||||
@@ -511,6 +517,10 @@ class KnowledgeService:
|
|||||||
f"async doc embedding sync, doc:{doc.doc_name}, chunks length is {len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}"
|
f"async doc embedding sync, doc:{doc.doc_name}, chunks length is {len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}"
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
|
with root_tracer.start_span(
|
||||||
|
"app.knowledge.assembler.persist",
|
||||||
|
metadata={"doc": doc.doc_name, "chunks": len(chunk_docs)},
|
||||||
|
):
|
||||||
vector_ids = assembler.persist()
|
vector_ids = assembler.persist()
|
||||||
doc.status = SyncStatus.FINISHED.name
|
doc.status = SyncStatus.FINISHED.name
|
||||||
doc.result = "document embedding success"
|
doc.result = "document embedding success"
|
||||||
|
@@ -11,7 +11,6 @@ from dbgpt.app.knowledge.document_db import (
|
|||||||
)
|
)
|
||||||
from dbgpt.app.knowledge.service import KnowledgeService
|
from dbgpt.app.knowledge.service import KnowledgeService
|
||||||
from dbgpt.app.scene import BaseChat, ChatScene
|
from dbgpt.app.scene import BaseChat, ChatScene
|
||||||
from dbgpt.component import ComponentType
|
|
||||||
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG
|
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG
|
||||||
from dbgpt.core import (
|
from dbgpt.core import (
|
||||||
ChatPromptTemplate,
|
ChatPromptTemplate,
|
||||||
@@ -19,10 +18,8 @@ from dbgpt.core import (
|
|||||||
MessagesPlaceholder,
|
MessagesPlaceholder,
|
||||||
SystemPromptTemplate,
|
SystemPromptTemplate,
|
||||||
)
|
)
|
||||||
from dbgpt.model import DefaultLLMClient
|
|
||||||
from dbgpt.model.cluster import WorkerManagerFactory
|
|
||||||
from dbgpt.rag.retriever.rewrite import QueryRewrite
|
from dbgpt.rag.retriever.rewrite import QueryRewrite
|
||||||
from dbgpt.util.tracer import trace
|
from dbgpt.util.tracer import root_tracer, trace
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
@@ -226,6 +223,9 @@ class ChatKnowledge(BaseChat):
|
|||||||
|
|
||||||
async def execute_similar_search(self, query):
|
async def execute_similar_search(self, query):
|
||||||
"""execute similarity search"""
|
"""execute similarity search"""
|
||||||
|
with root_tracer.start_span(
|
||||||
|
"execute_similar_search", metadata={"query": query}
|
||||||
|
):
|
||||||
return await self.embedding_retriever.aretrieve_with_scores(
|
return await self.embedding_retriever.aretrieve_with_scores(
|
||||||
query, self.recall_score
|
query, self.recall_score
|
||||||
)
|
)
|
||||||
|
@@ -167,6 +167,8 @@ EMBEDDING_MODEL_CONFIG = {
|
|||||||
# https://huggingface.co/BAAI/bge-large-zh
|
# https://huggingface.co/BAAI/bge-large-zh
|
||||||
"bge-large-zh": os.path.join(MODEL_PATH, "bge-large-zh"),
|
"bge-large-zh": os.path.join(MODEL_PATH, "bge-large-zh"),
|
||||||
"bge-base-zh": os.path.join(MODEL_PATH, "bge-base-zh"),
|
"bge-base-zh": os.path.join(MODEL_PATH, "bge-base-zh"),
|
||||||
|
"gte-large-zh": os.path.join(MODEL_PATH, "gte-large-zh"),
|
||||||
|
"gte-base-zh": os.path.join(MODEL_PATH, "gte-base-zh"),
|
||||||
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"),
|
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"),
|
||||||
"proxy_openai": "proxy_openai",
|
"proxy_openai": "proxy_openai",
|
||||||
"proxy_azure": "proxy_azure",
|
"proxy_azure": "proxy_azure",
|
||||||
|
@@ -1,12 +1,14 @@
|
|||||||
try:
|
try:
|
||||||
from dbgpt.model.cluster.client import DefaultLLMClient
|
from dbgpt.model.cluster.client import DefaultLLMClient, RemoteLLMClient
|
||||||
except ImportError as exc:
|
except ImportError as exc:
|
||||||
# logging.warning("Can't import dbgpt.model.DefaultLLMClient")
|
|
||||||
DefaultLLMClient = None
|
DefaultLLMClient = None
|
||||||
|
RemoteLLMClient = None
|
||||||
|
|
||||||
|
|
||||||
_exports = []
|
_exports = []
|
||||||
if DefaultLLMClient:
|
if DefaultLLMClient:
|
||||||
_exports.append("DefaultLLMClient")
|
_exports.append("DefaultLLMClient")
|
||||||
|
if RemoteLLMClient:
|
||||||
|
_exports.append("RemoteLLMClient")
|
||||||
|
|
||||||
__ALL__ = _exports
|
__ALL__ = _exports
|
||||||
|
@@ -104,3 +104,60 @@ class DefaultLLMClient(LLMClient):
|
|||||||
|
|
||||||
async def count_token(self, model: str, prompt: str) -> int:
|
async def count_token(self, model: str, prompt: str) -> int:
|
||||||
return await self.worker_manager.count_token({"model": model, "prompt": prompt})
|
return await self.worker_manager.count_token({"model": model, "prompt": prompt})
|
||||||
|
|
||||||
|
|
||||||
|
@register_resource(
|
||||||
|
label="Remote LLM Client",
|
||||||
|
name="remote_llm_client",
|
||||||
|
category=ResourceCategory.LLM_CLIENT,
|
||||||
|
description="Remote LLM client(Connect to the remote DB-GPT model serving)",
|
||||||
|
parameters=[
|
||||||
|
Parameter.build_from(
|
||||||
|
"Controller Address",
|
||||||
|
name="controller_address",
|
||||||
|
type=str,
|
||||||
|
optional=True,
|
||||||
|
default="http://127.0.0.1:8000",
|
||||||
|
description="Model controller address",
|
||||||
|
),
|
||||||
|
Parameter.build_from(
|
||||||
|
"Auto Convert Message",
|
||||||
|
name="auto_convert_message",
|
||||||
|
type=bool,
|
||||||
|
optional=True,
|
||||||
|
default=False,
|
||||||
|
description="Whether to auto convert the messages that are not supported "
|
||||||
|
"by the LLM to a compatible format",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
class RemoteLLMClient(DefaultLLMClient):
|
||||||
|
"""Remote LLM client implementation.
|
||||||
|
|
||||||
|
Connect to the remote worker manager and send the request to the remote worker manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
controller_address (str): model controller address
|
||||||
|
auto_convert_message (bool, optional): auto convert the message to
|
||||||
|
ModelRequest. Defaults to False.
|
||||||
|
|
||||||
|
If you start DB-GPT model cluster, the controller address is the address of the
|
||||||
|
Model Controller(`dbgpt start controller`, the default port of model controller
|
||||||
|
is 8000).
|
||||||
|
Otherwise, if you already have a running DB-GPT server(start it by
|
||||||
|
`dbgpt start webserver --port ${remote_port}`), you can use the address of the
|
||||||
|
`http://${remote_ip}:${remote_port}`.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
controller_address: str = "http://127.0.0.1:8000",
|
||||||
|
auto_convert_message: bool = False,
|
||||||
|
):
|
||||||
|
"""Initialize the RemoteLLMClient."""
|
||||||
|
from dbgpt.model.cluster import ModelRegistryClient, RemoteWorkerManager
|
||||||
|
|
||||||
|
model_registry_client = ModelRegistryClient(controller_address)
|
||||||
|
worker_manager = RemoteWorkerManager(model_registry_client)
|
||||||
|
super().__init__(worker_manager, auto_convert_message)
|
||||||
|
@@ -7,6 +7,7 @@ from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker
|
|||||||
from dbgpt.rag.retriever.rewrite import QueryRewrite
|
from dbgpt.rag.retriever.rewrite import QueryRewrite
|
||||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||||
from dbgpt.util.chat_util import run_async_tasks
|
from dbgpt.util.chat_util import run_async_tasks
|
||||||
|
from dbgpt.util.tracer import root_tracer
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingRetriever(BaseRetriever):
|
class EmbeddingRetriever(BaseRetriever):
|
||||||
@@ -129,13 +130,26 @@ class EmbeddingRetriever(BaseRetriever):
|
|||||||
"""
|
"""
|
||||||
queries = [query]
|
queries = [query]
|
||||||
if self._query_rewrite:
|
if self._query_rewrite:
|
||||||
|
with root_tracer.start_span(
|
||||||
|
"EmbeddingRetriever.query_rewrite.similarity_search",
|
||||||
|
metadata={"query": query, "score_threshold": score_threshold},
|
||||||
|
):
|
||||||
candidates_tasks = [self._similarity_search(query) for query in queries]
|
candidates_tasks = [self._similarity_search(query) for query in queries]
|
||||||
chunks = await self._run_async_tasks(candidates_tasks)
|
chunks = await self._run_async_tasks(candidates_tasks)
|
||||||
context = "\n".join([chunk.content for chunk in chunks])
|
context = "\n".join([chunk.content for chunk in chunks])
|
||||||
|
with root_tracer.start_span(
|
||||||
|
"EmbeddingRetriever.query_rewrite.rewrite",
|
||||||
|
metadata={"query": query, "context": context, "nums": 1},
|
||||||
|
):
|
||||||
new_queries = await self._query_rewrite.rewrite(
|
new_queries = await self._query_rewrite.rewrite(
|
||||||
origin_query=query, context=context, nums=1
|
origin_query=query, context=context, nums=1
|
||||||
)
|
)
|
||||||
queries.extend(new_queries)
|
queries.extend(new_queries)
|
||||||
|
|
||||||
|
with root_tracer.start_span(
|
||||||
|
"EmbeddingRetriever.similarity_search_with_score",
|
||||||
|
metadata={"query": query, "score_threshold": score_threshold},
|
||||||
|
):
|
||||||
candidates_with_score = [
|
candidates_with_score = [
|
||||||
self._similarity_search_with_score(query, score_threshold)
|
self._similarity_search_with_score(query, score_threshold)
|
||||||
for query in queries
|
for query in queries
|
||||||
@@ -144,6 +158,15 @@ class EmbeddingRetriever(BaseRetriever):
|
|||||||
tasks=candidates_with_score, concurrency_limit=1
|
tasks=candidates_with_score, concurrency_limit=1
|
||||||
)
|
)
|
||||||
candidates_with_score = reduce(lambda x, y: x + y, candidates_with_score)
|
candidates_with_score = reduce(lambda x, y: x + y, candidates_with_score)
|
||||||
|
|
||||||
|
with root_tracer.start_span(
|
||||||
|
"EmbeddingRetriever.rerank",
|
||||||
|
metadata={
|
||||||
|
"query": query,
|
||||||
|
"score_threshold": score_threshold,
|
||||||
|
"rerank_cls": self._rerank.__class__.__name__,
|
||||||
|
},
|
||||||
|
):
|
||||||
candidates_with_score = self._rerank.rank(candidates_with_score)
|
candidates_with_score = self._rerank.rank(candidates_with_score)
|
||||||
return candidates_with_score
|
return candidates_with_score
|
||||||
|
|
||||||
|
@@ -1,6 +1,5 @@
|
|||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import re
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
@@ -66,10 +65,14 @@ class TextSplitter(ABC):
|
|||||||
chunks.append(new_doc)
|
chunks.append(new_doc)
|
||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
def split_documents(self, documents: List[Document], **kwargs) -> List[Chunk]:
|
def split_documents(self, documents: Iterable[Document], **kwargs) -> List[Chunk]:
|
||||||
"""Split documents."""
|
"""Split documents."""
|
||||||
texts = [doc.content for doc in documents]
|
texts = []
|
||||||
metadatas = [doc.metadata for doc in documents]
|
metadatas = []
|
||||||
|
for doc in documents:
|
||||||
|
# Iterable just supports one iteration
|
||||||
|
texts.append(doc.content)
|
||||||
|
metadatas.append(doc.metadata)
|
||||||
return self.create_documents(texts, metadatas, **kwargs)
|
return self.create_documents(texts, metadatas, **kwargs)
|
||||||
|
|
||||||
def _join_docs(self, docs: List[str], separator: str, **kwargs) -> Optional[str]:
|
def _join_docs(self, docs: List[str], separator: str, **kwargs) -> Optional[str]:
|
||||||
|
@@ -6,6 +6,7 @@ from dbgpt.rag.chunk_manager import ChunkManager, ChunkParameters
|
|||||||
from dbgpt.rag.extractor.base import Extractor
|
from dbgpt.rag.extractor.base import Extractor
|
||||||
from dbgpt.rag.knowledge.base import Knowledge
|
from dbgpt.rag.knowledge.base import Knowledge
|
||||||
from dbgpt.rag.retriever.base import BaseRetriever
|
from dbgpt.rag.retriever.base import BaseRetriever
|
||||||
|
from dbgpt.util.tracer import root_tracer, trace
|
||||||
|
|
||||||
|
|
||||||
class BaseAssembler(ABC):
|
class BaseAssembler(ABC):
|
||||||
@@ -30,11 +31,24 @@ class BaseAssembler(ABC):
|
|||||||
knowledge=self._knowledge, chunk_parameter=self._chunk_parameters
|
knowledge=self._knowledge, chunk_parameter=self._chunk_parameters
|
||||||
)
|
)
|
||||||
self._chunks = None
|
self._chunks = None
|
||||||
|
metadata = {
|
||||||
|
"knowledge_cls": self._knowledge.__class__.__name__
|
||||||
|
if self._knowledge
|
||||||
|
else None,
|
||||||
|
"knowledge_type": self._knowledge.type().value if self._knowledge else None,
|
||||||
|
"path": self._knowledge._path
|
||||||
|
if self._knowledge and hasattr(self._knowledge, "_path")
|
||||||
|
else None,
|
||||||
|
"chunk_parameters": self._chunk_parameters.dict(),
|
||||||
|
}
|
||||||
|
with root_tracer.start_span("BaseAssembler.load_knowledge", metadata=metadata):
|
||||||
self.load_knowledge(self._knowledge)
|
self.load_knowledge(self._knowledge)
|
||||||
|
|
||||||
def load_knowledge(self, knowledge) -> None:
|
def load_knowledge(self, knowledge) -> None:
|
||||||
"""Load knowledge Pipeline."""
|
"""Load knowledge Pipeline."""
|
||||||
|
with root_tracer.start_span("BaseAssembler.knowledge.load"):
|
||||||
documents = knowledge.load()
|
documents = knowledge.load()
|
||||||
|
with root_tracer.start_span("BaseAssembler.chunk_manager.split"):
|
||||||
self._chunks = self._chunk_manager.split(documents)
|
self._chunks = self._chunk_manager.split(documents)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@@ -1,4 +1,6 @@
|
|||||||
|
import logging
|
||||||
import math
|
import math
|
||||||
|
import time
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Callable, List, Optional
|
from typing import Any, Callable, List, Optional
|
||||||
|
|
||||||
@@ -6,6 +8,8 @@ from pydantic import BaseModel, Field
|
|||||||
|
|
||||||
from dbgpt.rag.chunk import Chunk
|
from dbgpt.rag.chunk import Chunk
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class VectorStoreConfig(BaseModel):
|
class VectorStoreConfig(BaseModel):
|
||||||
"""Vector store config."""
|
"""Vector store config."""
|
||||||
@@ -26,6 +30,12 @@ class VectorStoreConfig(BaseModel):
|
|||||||
default=None,
|
default=None,
|
||||||
description="The embedding function of vector store, if not set, will use the default embedding function.",
|
description="The embedding function of vector store, if not set, will use the default embedding function.",
|
||||||
)
|
)
|
||||||
|
max_chunks_once_load: int = Field(
|
||||||
|
default=10,
|
||||||
|
description="The max number of chunks to load at once. If your document is "
|
||||||
|
"large, you can set this value to a larger number to speed up the loading "
|
||||||
|
"process. Default is 10.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class VectorStoreBase(ABC):
|
class VectorStoreBase(ABC):
|
||||||
@@ -41,6 +51,33 @@ class VectorStoreBase(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def load_document_with_limit(
|
||||||
|
self, chunks: List[Chunk], max_chunks_once_load: int = 10
|
||||||
|
) -> List[str]:
|
||||||
|
"""load document in vector database with limit.
|
||||||
|
Args:
|
||||||
|
chunks: document chunks.
|
||||||
|
max_chunks_once_load: Max number of chunks to load at once.
|
||||||
|
Return:
|
||||||
|
"""
|
||||||
|
# Group the chunks into chunks of size max_chunks
|
||||||
|
chunk_groups = [
|
||||||
|
chunks[i : i + max_chunks_once_load]
|
||||||
|
for i in range(0, len(chunks), max_chunks_once_load)
|
||||||
|
]
|
||||||
|
logger.info(f"Loading {len(chunks)} chunks in {len(chunk_groups)} groups")
|
||||||
|
ids = []
|
||||||
|
loaded_cnt = 0
|
||||||
|
start_time = time.time()
|
||||||
|
for chunk_group in chunk_groups:
|
||||||
|
ids.extend(self.load_document(chunk_group))
|
||||||
|
loaded_cnt += len(chunk_group)
|
||||||
|
logger.info(f"Loaded {loaded_cnt} chunks, total {len(chunks)} chunks.")
|
||||||
|
logger.info(
|
||||||
|
f"Loaded {len(chunks)} chunks in {time.time() - start_time} seconds"
|
||||||
|
)
|
||||||
|
return ids
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def similar_search(self, text, topk) -> List[Chunk]:
|
def similar_search(self, text, topk) -> List[Chunk]:
|
||||||
"""similar search in vector database.
|
"""similar search in vector database.
|
||||||
|
@@ -64,7 +64,9 @@ class VectorStoreConnector:
|
|||||||
- chunks: document chunks.
|
- chunks: document chunks.
|
||||||
Return chunk ids.
|
Return chunk ids.
|
||||||
"""
|
"""
|
||||||
return self.client.load_document(chunks)
|
return self.client.load_document_with_limit(
|
||||||
|
chunks, self._vector_store_config.max_chunks_once_load
|
||||||
|
)
|
||||||
|
|
||||||
def similar_search(self, doc: str, topk: int) -> List[Chunk]:
|
def similar_search(self, doc: str, topk: int) -> List[Chunk]:
|
||||||
"""similar search in vector database.
|
"""similar search in vector database.
|
||||||
|
@@ -1,11 +1,12 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Callable, Dict, List, Optional
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
||||||
|
|
||||||
@@ -95,7 +96,7 @@ class Span:
|
|||||||
"end_time": None
|
"end_time": None
|
||||||
if not self.end_time
|
if not self.end_time
|
||||||
else self.end_time.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3],
|
else self.end_time.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3],
|
||||||
"metadata": self.metadata,
|
"metadata": _clean_for_json(self.metadata),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -187,3 +188,39 @@ class Tracer(BaseComponent, ABC):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class TracerContext:
|
class TracerContext:
|
||||||
span_id: Optional[str] = None
|
span_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
def _clean_for_json(data: Optional[str, Any] = None):
|
||||||
|
if not data:
|
||||||
|
return None
|
||||||
|
if isinstance(data, dict):
|
||||||
|
cleaned_dict = {}
|
||||||
|
for key, value in data.items():
|
||||||
|
# Try to clean the sub-items
|
||||||
|
cleaned_value = _clean_for_json(value)
|
||||||
|
if cleaned_value is not None:
|
||||||
|
# Only add to the cleaned dict if it's not None
|
||||||
|
try:
|
||||||
|
json.dumps({key: cleaned_value})
|
||||||
|
cleaned_dict[key] = cleaned_value
|
||||||
|
except TypeError:
|
||||||
|
# Skip this key-value pair if it can't be serialized
|
||||||
|
pass
|
||||||
|
return cleaned_dict
|
||||||
|
elif isinstance(data, list):
|
||||||
|
cleaned_list = []
|
||||||
|
for item in data:
|
||||||
|
cleaned_item = _clean_for_json(item)
|
||||||
|
if cleaned_item is not None:
|
||||||
|
try:
|
||||||
|
json.dumps(cleaned_item)
|
||||||
|
cleaned_list.append(cleaned_item)
|
||||||
|
except TypeError:
|
||||||
|
pass
|
||||||
|
return cleaned_list
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
json.dumps(data)
|
||||||
|
return data
|
||||||
|
except TypeError:
|
||||||
|
return None
|
||||||
|
Reference in New Issue
Block a user