mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-09 04:49:26 +00:00
perf(rag): Support load large document (#1233)
This commit is contained in:
@@ -75,6 +75,10 @@ EMBEDDING_MODEL=text2vec
|
||||
#EMBEDDING_MODEL=bge-large-zh
|
||||
KNOWLEDGE_CHUNK_SIZE=500
|
||||
KNOWLEDGE_SEARCH_TOP_SIZE=5
|
||||
## Maximum number of chunks to load at once, if your single document is too large,
|
||||
## you can set this value to a higher value for better performance.
|
||||
## if out of memory when load large document, you can set this value to a lower value.
|
||||
# KNOWLEDGE_MAX_CHUNKS_ONCE_LOAD=10
|
||||
#KNOWLEDGE_CHUNK_OVERLAP=50
|
||||
# Control whether to display the source document of knowledge on the front end.
|
||||
KNOWLEDGE_CHAT_SHOW_RELATIONS=False
|
||||
|
@@ -233,6 +233,9 @@ class Config(metaclass=Singleton):
|
||||
self.KNOWLEDGE_CHUNK_SIZE = int(os.getenv("KNOWLEDGE_CHUNK_SIZE", 100))
|
||||
self.KNOWLEDGE_CHUNK_OVERLAP = int(os.getenv("KNOWLEDGE_CHUNK_OVERLAP", 50))
|
||||
self.KNOWLEDGE_SEARCH_TOP_SIZE = int(os.getenv("KNOWLEDGE_SEARCH_TOP_SIZE", 5))
|
||||
self.KNOWLEDGE_MAX_CHUNKS_ONCE_LOAD = int(
|
||||
os.getenv("KNOWLEDGE_MAX_CHUNKS_ONCE_LOAD", 10)
|
||||
)
|
||||
# default recall similarity score, between 0 and 1
|
||||
self.KNOWLEDGE_SEARCH_RECALL_SCORE = float(
|
||||
os.getenv("KNOWLEDGE_SEARCH_RECALL_SCORE", 0.3)
|
||||
|
@@ -43,6 +43,7 @@ from dbgpt.serve.rag.assembler.summary import SummaryAssembler
|
||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
|
||||
from dbgpt.util.tracer import root_tracer, trace
|
||||
|
||||
knowledge_space_dao = KnowledgeSpaceDao()
|
||||
knowledge_document_dao = KnowledgeDocumentDao()
|
||||
@@ -335,7 +336,11 @@ class KnowledgeService:
|
||||
)
|
||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||
|
||||
config = VectorStoreConfig(name=space_name, embedding_fn=embedding_fn)
|
||||
config = VectorStoreConfig(
|
||||
name=space_name,
|
||||
embedding_fn=embedding_fn,
|
||||
max_chunks_once_load=CFG.KNOWLEDGE_MAX_CHUNKS_ONCE_LOAD,
|
||||
)
|
||||
vector_store_connector = VectorStoreConnector(
|
||||
vector_store_type=CFG.VECTOR_STORE_TYPE,
|
||||
vector_store_config=config,
|
||||
@@ -499,6 +504,7 @@ class KnowledgeService:
|
||||
res.page = request.page
|
||||
return res
|
||||
|
||||
@trace("async_doc_embedding")
|
||||
def async_doc_embedding(self, assembler, chunk_docs, doc):
|
||||
"""async document embedding into vector db
|
||||
Args:
|
||||
@@ -511,7 +517,11 @@ class KnowledgeService:
|
||||
f"async doc embedding sync, doc:{doc.doc_name}, chunks length is {len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}"
|
||||
)
|
||||
try:
|
||||
vector_ids = assembler.persist()
|
||||
with root_tracer.start_span(
|
||||
"app.knowledge.assembler.persist",
|
||||
metadata={"doc": doc.doc_name, "chunks": len(chunk_docs)},
|
||||
):
|
||||
vector_ids = assembler.persist()
|
||||
doc.status = SyncStatus.FINISHED.name
|
||||
doc.result = "document embedding success"
|
||||
if vector_ids is not None:
|
||||
|
@@ -11,7 +11,6 @@ from dbgpt.app.knowledge.document_db import (
|
||||
)
|
||||
from dbgpt.app.knowledge.service import KnowledgeService
|
||||
from dbgpt.app.scene import BaseChat, ChatScene
|
||||
from dbgpt.component import ComponentType
|
||||
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG
|
||||
from dbgpt.core import (
|
||||
ChatPromptTemplate,
|
||||
@@ -19,10 +18,8 @@ from dbgpt.core import (
|
||||
MessagesPlaceholder,
|
||||
SystemPromptTemplate,
|
||||
)
|
||||
from dbgpt.model import DefaultLLMClient
|
||||
from dbgpt.model.cluster import WorkerManagerFactory
|
||||
from dbgpt.rag.retriever.rewrite import QueryRewrite
|
||||
from dbgpt.util.tracer import trace
|
||||
from dbgpt.util.tracer import root_tracer, trace
|
||||
|
||||
CFG = Config()
|
||||
|
||||
@@ -226,6 +223,9 @@ class ChatKnowledge(BaseChat):
|
||||
|
||||
async def execute_similar_search(self, query):
|
||||
"""execute similarity search"""
|
||||
return await self.embedding_retriever.aretrieve_with_scores(
|
||||
query, self.recall_score
|
||||
)
|
||||
with root_tracer.start_span(
|
||||
"execute_similar_search", metadata={"query": query}
|
||||
):
|
||||
return await self.embedding_retriever.aretrieve_with_scores(
|
||||
query, self.recall_score
|
||||
)
|
||||
|
@@ -167,6 +167,8 @@ EMBEDDING_MODEL_CONFIG = {
|
||||
# https://huggingface.co/BAAI/bge-large-zh
|
||||
"bge-large-zh": os.path.join(MODEL_PATH, "bge-large-zh"),
|
||||
"bge-base-zh": os.path.join(MODEL_PATH, "bge-base-zh"),
|
||||
"gte-large-zh": os.path.join(MODEL_PATH, "gte-large-zh"),
|
||||
"gte-base-zh": os.path.join(MODEL_PATH, "gte-base-zh"),
|
||||
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"),
|
||||
"proxy_openai": "proxy_openai",
|
||||
"proxy_azure": "proxy_azure",
|
||||
|
@@ -1,12 +1,14 @@
|
||||
try:
|
||||
from dbgpt.model.cluster.client import DefaultLLMClient
|
||||
from dbgpt.model.cluster.client import DefaultLLMClient, RemoteLLMClient
|
||||
except ImportError as exc:
|
||||
# logging.warning("Can't import dbgpt.model.DefaultLLMClient")
|
||||
DefaultLLMClient = None
|
||||
RemoteLLMClient = None
|
||||
|
||||
|
||||
_exports = []
|
||||
if DefaultLLMClient:
|
||||
_exports.append("DefaultLLMClient")
|
||||
if RemoteLLMClient:
|
||||
_exports.append("RemoteLLMClient")
|
||||
|
||||
__ALL__ = _exports
|
||||
|
@@ -104,3 +104,60 @@ class DefaultLLMClient(LLMClient):
|
||||
|
||||
async def count_token(self, model: str, prompt: str) -> int:
|
||||
return await self.worker_manager.count_token({"model": model, "prompt": prompt})
|
||||
|
||||
|
||||
@register_resource(
|
||||
label="Remote LLM Client",
|
||||
name="remote_llm_client",
|
||||
category=ResourceCategory.LLM_CLIENT,
|
||||
description="Remote LLM client(Connect to the remote DB-GPT model serving)",
|
||||
parameters=[
|
||||
Parameter.build_from(
|
||||
"Controller Address",
|
||||
name="controller_address",
|
||||
type=str,
|
||||
optional=True,
|
||||
default="http://127.0.0.1:8000",
|
||||
description="Model controller address",
|
||||
),
|
||||
Parameter.build_from(
|
||||
"Auto Convert Message",
|
||||
name="auto_convert_message",
|
||||
type=bool,
|
||||
optional=True,
|
||||
default=False,
|
||||
description="Whether to auto convert the messages that are not supported "
|
||||
"by the LLM to a compatible format",
|
||||
),
|
||||
],
|
||||
)
|
||||
class RemoteLLMClient(DefaultLLMClient):
|
||||
"""Remote LLM client implementation.
|
||||
|
||||
Connect to the remote worker manager and send the request to the remote worker manager.
|
||||
|
||||
Args:
|
||||
controller_address (str): model controller address
|
||||
auto_convert_message (bool, optional): auto convert the message to
|
||||
ModelRequest. Defaults to False.
|
||||
|
||||
If you start DB-GPT model cluster, the controller address is the address of the
|
||||
Model Controller(`dbgpt start controller`, the default port of model controller
|
||||
is 8000).
|
||||
Otherwise, if you already have a running DB-GPT server(start it by
|
||||
`dbgpt start webserver --port ${remote_port}`), you can use the address of the
|
||||
`http://${remote_ip}:${remote_port}`.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
controller_address: str = "http://127.0.0.1:8000",
|
||||
auto_convert_message: bool = False,
|
||||
):
|
||||
"""Initialize the RemoteLLMClient."""
|
||||
from dbgpt.model.cluster import ModelRegistryClient, RemoteWorkerManager
|
||||
|
||||
model_registry_client = ModelRegistryClient(controller_address)
|
||||
worker_manager = RemoteWorkerManager(model_registry_client)
|
||||
super().__init__(worker_manager, auto_convert_message)
|
||||
|
@@ -7,6 +7,7 @@ from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker
|
||||
from dbgpt.rag.retriever.rewrite import QueryRewrite
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
from dbgpt.util.chat_util import run_async_tasks
|
||||
from dbgpt.util.tracer import root_tracer
|
||||
|
||||
|
||||
class EmbeddingRetriever(BaseRetriever):
|
||||
@@ -129,23 +130,45 @@ class EmbeddingRetriever(BaseRetriever):
|
||||
"""
|
||||
queries = [query]
|
||||
if self._query_rewrite:
|
||||
candidates_tasks = [self._similarity_search(query) for query in queries]
|
||||
chunks = await self._run_async_tasks(candidates_tasks)
|
||||
context = "\n".join([chunk.content for chunk in chunks])
|
||||
new_queries = await self._query_rewrite.rewrite(
|
||||
origin_query=query, context=context, nums=1
|
||||
with root_tracer.start_span(
|
||||
"EmbeddingRetriever.query_rewrite.similarity_search",
|
||||
metadata={"query": query, "score_threshold": score_threshold},
|
||||
):
|
||||
candidates_tasks = [self._similarity_search(query) for query in queries]
|
||||
chunks = await self._run_async_tasks(candidates_tasks)
|
||||
context = "\n".join([chunk.content for chunk in chunks])
|
||||
with root_tracer.start_span(
|
||||
"EmbeddingRetriever.query_rewrite.rewrite",
|
||||
metadata={"query": query, "context": context, "nums": 1},
|
||||
):
|
||||
new_queries = await self._query_rewrite.rewrite(
|
||||
origin_query=query, context=context, nums=1
|
||||
)
|
||||
queries.extend(new_queries)
|
||||
|
||||
with root_tracer.start_span(
|
||||
"EmbeddingRetriever.similarity_search_with_score",
|
||||
metadata={"query": query, "score_threshold": score_threshold},
|
||||
):
|
||||
candidates_with_score = [
|
||||
self._similarity_search_with_score(query, score_threshold)
|
||||
for query in queries
|
||||
]
|
||||
candidates_with_score = await run_async_tasks(
|
||||
tasks=candidates_with_score, concurrency_limit=1
|
||||
)
|
||||
queries.extend(new_queries)
|
||||
candidates_with_score = [
|
||||
self._similarity_search_with_score(query, score_threshold)
|
||||
for query in queries
|
||||
]
|
||||
candidates_with_score = await run_async_tasks(
|
||||
tasks=candidates_with_score, concurrency_limit=1
|
||||
)
|
||||
candidates_with_score = reduce(lambda x, y: x + y, candidates_with_score)
|
||||
candidates_with_score = self._rerank.rank(candidates_with_score)
|
||||
return candidates_with_score
|
||||
candidates_with_score = reduce(lambda x, y: x + y, candidates_with_score)
|
||||
|
||||
with root_tracer.start_span(
|
||||
"EmbeddingRetriever.rerank",
|
||||
metadata={
|
||||
"query": query,
|
||||
"score_threshold": score_threshold,
|
||||
"rerank_cls": self._rerank.__class__.__name__,
|
||||
},
|
||||
):
|
||||
candidates_with_score = self._rerank.rank(candidates_with_score)
|
||||
return candidates_with_score
|
||||
|
||||
async def _similarity_search(self, query) -> List[Chunk]:
|
||||
"""Similar search."""
|
||||
|
@@ -1,6 +1,5 @@
|
||||
import copy
|
||||
import logging
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import (
|
||||
Any,
|
||||
@@ -66,10 +65,14 @@ class TextSplitter(ABC):
|
||||
chunks.append(new_doc)
|
||||
return chunks
|
||||
|
||||
def split_documents(self, documents: List[Document], **kwargs) -> List[Chunk]:
|
||||
def split_documents(self, documents: Iterable[Document], **kwargs) -> List[Chunk]:
|
||||
"""Split documents."""
|
||||
texts = [doc.content for doc in documents]
|
||||
metadatas = [doc.metadata for doc in documents]
|
||||
texts = []
|
||||
metadatas = []
|
||||
for doc in documents:
|
||||
# Iterable just supports one iteration
|
||||
texts.append(doc.content)
|
||||
metadatas.append(doc.metadata)
|
||||
return self.create_documents(texts, metadatas, **kwargs)
|
||||
|
||||
def _join_docs(self, docs: List[str], separator: str, **kwargs) -> Optional[str]:
|
||||
|
@@ -6,6 +6,7 @@ from dbgpt.rag.chunk_manager import ChunkManager, ChunkParameters
|
||||
from dbgpt.rag.extractor.base import Extractor
|
||||
from dbgpt.rag.knowledge.base import Knowledge
|
||||
from dbgpt.rag.retriever.base import BaseRetriever
|
||||
from dbgpt.util.tracer import root_tracer, trace
|
||||
|
||||
|
||||
class BaseAssembler(ABC):
|
||||
@@ -30,12 +31,25 @@ class BaseAssembler(ABC):
|
||||
knowledge=self._knowledge, chunk_parameter=self._chunk_parameters
|
||||
)
|
||||
self._chunks = None
|
||||
self.load_knowledge(self._knowledge)
|
||||
metadata = {
|
||||
"knowledge_cls": self._knowledge.__class__.__name__
|
||||
if self._knowledge
|
||||
else None,
|
||||
"knowledge_type": self._knowledge.type().value if self._knowledge else None,
|
||||
"path": self._knowledge._path
|
||||
if self._knowledge and hasattr(self._knowledge, "_path")
|
||||
else None,
|
||||
"chunk_parameters": self._chunk_parameters.dict(),
|
||||
}
|
||||
with root_tracer.start_span("BaseAssembler.load_knowledge", metadata=metadata):
|
||||
self.load_knowledge(self._knowledge)
|
||||
|
||||
def load_knowledge(self, knowledge) -> None:
|
||||
"""Load knowledge Pipeline."""
|
||||
documents = knowledge.load()
|
||||
self._chunks = self._chunk_manager.split(documents)
|
||||
with root_tracer.start_span("BaseAssembler.knowledge.load"):
|
||||
documents = knowledge.load()
|
||||
with root_tracer.start_span("BaseAssembler.chunk_manager.split"):
|
||||
self._chunks = self._chunk_manager.split(documents)
|
||||
|
||||
@abstractmethod
|
||||
def as_retriever(self, **kwargs: Any) -> BaseRetriever:
|
||||
|
@@ -1,4 +1,6 @@
|
||||
import logging
|
||||
import math
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, List, Optional
|
||||
|
||||
@@ -6,6 +8,8 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from dbgpt.rag.chunk import Chunk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VectorStoreConfig(BaseModel):
|
||||
"""Vector store config."""
|
||||
@@ -26,6 +30,12 @@ class VectorStoreConfig(BaseModel):
|
||||
default=None,
|
||||
description="The embedding function of vector store, if not set, will use the default embedding function.",
|
||||
)
|
||||
max_chunks_once_load: int = Field(
|
||||
default=10,
|
||||
description="The max number of chunks to load at once. If your document is "
|
||||
"large, you can set this value to a larger number to speed up the loading "
|
||||
"process. Default is 10.",
|
||||
)
|
||||
|
||||
|
||||
class VectorStoreBase(ABC):
|
||||
@@ -41,6 +51,33 @@ class VectorStoreBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def load_document_with_limit(
|
||||
self, chunks: List[Chunk], max_chunks_once_load: int = 10
|
||||
) -> List[str]:
|
||||
"""load document in vector database with limit.
|
||||
Args:
|
||||
chunks: document chunks.
|
||||
max_chunks_once_load: Max number of chunks to load at once.
|
||||
Return:
|
||||
"""
|
||||
# Group the chunks into chunks of size max_chunks
|
||||
chunk_groups = [
|
||||
chunks[i : i + max_chunks_once_load]
|
||||
for i in range(0, len(chunks), max_chunks_once_load)
|
||||
]
|
||||
logger.info(f"Loading {len(chunks)} chunks in {len(chunk_groups)} groups")
|
||||
ids = []
|
||||
loaded_cnt = 0
|
||||
start_time = time.time()
|
||||
for chunk_group in chunk_groups:
|
||||
ids.extend(self.load_document(chunk_group))
|
||||
loaded_cnt += len(chunk_group)
|
||||
logger.info(f"Loaded {loaded_cnt} chunks, total {len(chunks)} chunks.")
|
||||
logger.info(
|
||||
f"Loaded {len(chunks)} chunks in {time.time() - start_time} seconds"
|
||||
)
|
||||
return ids
|
||||
|
||||
@abstractmethod
|
||||
def similar_search(self, text, topk) -> List[Chunk]:
|
||||
"""similar search in vector database.
|
||||
|
@@ -64,7 +64,9 @@ class VectorStoreConnector:
|
||||
- chunks: document chunks.
|
||||
Return chunk ids.
|
||||
"""
|
||||
return self.client.load_document(chunks)
|
||||
return self.client.load_document_with_limit(
|
||||
chunks, self._vector_store_config.max_chunks_once_load
|
||||
)
|
||||
|
||||
def similar_search(self, doc: str, topk: int) -> List[Chunk]:
|
||||
"""similar search in vector database.
|
||||
|
@@ -1,11 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Callable, Dict, List, Optional
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
||||
|
||||
@@ -95,7 +96,7 @@ class Span:
|
||||
"end_time": None
|
||||
if not self.end_time
|
||||
else self.end_time.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3],
|
||||
"metadata": self.metadata,
|
||||
"metadata": _clean_for_json(self.metadata),
|
||||
}
|
||||
|
||||
|
||||
@@ -187,3 +188,39 @@ class Tracer(BaseComponent, ABC):
|
||||
@dataclass
|
||||
class TracerContext:
|
||||
span_id: Optional[str] = None
|
||||
|
||||
|
||||
def _clean_for_json(data: Optional[str, Any] = None):
|
||||
if not data:
|
||||
return None
|
||||
if isinstance(data, dict):
|
||||
cleaned_dict = {}
|
||||
for key, value in data.items():
|
||||
# Try to clean the sub-items
|
||||
cleaned_value = _clean_for_json(value)
|
||||
if cleaned_value is not None:
|
||||
# Only add to the cleaned dict if it's not None
|
||||
try:
|
||||
json.dumps({key: cleaned_value})
|
||||
cleaned_dict[key] = cleaned_value
|
||||
except TypeError:
|
||||
# Skip this key-value pair if it can't be serialized
|
||||
pass
|
||||
return cleaned_dict
|
||||
elif isinstance(data, list):
|
||||
cleaned_list = []
|
||||
for item in data:
|
||||
cleaned_item = _clean_for_json(item)
|
||||
if cleaned_item is not None:
|
||||
try:
|
||||
json.dumps(cleaned_item)
|
||||
cleaned_list.append(cleaned_item)
|
||||
except TypeError:
|
||||
pass
|
||||
return cleaned_list
|
||||
else:
|
||||
try:
|
||||
json.dumps(data)
|
||||
return data
|
||||
except TypeError:
|
||||
return None
|
||||
|
Reference in New Issue
Block a user