perf(rag): Support load large document (#1233)

This commit is contained in:
Fangyin Cheng
2024-03-01 22:33:52 +08:00
committed by GitHub
parent ed4df23e4e
commit 505bc32775
13 changed files with 231 additions and 37 deletions

View File

@@ -75,6 +75,10 @@ EMBEDDING_MODEL=text2vec
#EMBEDDING_MODEL=bge-large-zh
KNOWLEDGE_CHUNK_SIZE=500
KNOWLEDGE_SEARCH_TOP_SIZE=5
## Maximum number of chunks to load at once, if your single document is too large,
## you can set this value to a higher value for better performance.
## if out of memory when load large document, you can set this value to a lower value.
# KNOWLEDGE_MAX_CHUNKS_ONCE_LOAD=10
#KNOWLEDGE_CHUNK_OVERLAP=50
# Control whether to display the source document of knowledge on the front end.
KNOWLEDGE_CHAT_SHOW_RELATIONS=False

View File

@@ -233,6 +233,9 @@ class Config(metaclass=Singleton):
self.KNOWLEDGE_CHUNK_SIZE = int(os.getenv("KNOWLEDGE_CHUNK_SIZE", 100))
self.KNOWLEDGE_CHUNK_OVERLAP = int(os.getenv("KNOWLEDGE_CHUNK_OVERLAP", 50))
self.KNOWLEDGE_SEARCH_TOP_SIZE = int(os.getenv("KNOWLEDGE_SEARCH_TOP_SIZE", 5))
self.KNOWLEDGE_MAX_CHUNKS_ONCE_LOAD = int(
os.getenv("KNOWLEDGE_MAX_CHUNKS_ONCE_LOAD", 10)
)
# default recall similarity score, between 0 and 1
self.KNOWLEDGE_SEARCH_RECALL_SCORE = float(
os.getenv("KNOWLEDGE_SEARCH_RECALL_SCORE", 0.3)

View File

@@ -43,6 +43,7 @@ from dbgpt.serve.rag.assembler.summary import SummaryAssembler
from dbgpt.storage.vector_store.base import VectorStoreConfig
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
from dbgpt.util.tracer import root_tracer, trace
knowledge_space_dao = KnowledgeSpaceDao()
knowledge_document_dao = KnowledgeDocumentDao()
@@ -335,7 +336,11 @@ class KnowledgeService:
)
from dbgpt.storage.vector_store.base import VectorStoreConfig
config = VectorStoreConfig(name=space_name, embedding_fn=embedding_fn)
config = VectorStoreConfig(
name=space_name,
embedding_fn=embedding_fn,
max_chunks_once_load=CFG.KNOWLEDGE_MAX_CHUNKS_ONCE_LOAD,
)
vector_store_connector = VectorStoreConnector(
vector_store_type=CFG.VECTOR_STORE_TYPE,
vector_store_config=config,
@@ -499,6 +504,7 @@ class KnowledgeService:
res.page = request.page
return res
@trace("async_doc_embedding")
def async_doc_embedding(self, assembler, chunk_docs, doc):
"""async document embedding into vector db
Args:
@@ -511,7 +517,11 @@ class KnowledgeService:
f"async doc embedding sync, doc:{doc.doc_name}, chunks length is {len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}"
)
try:
vector_ids = assembler.persist()
with root_tracer.start_span(
"app.knowledge.assembler.persist",
metadata={"doc": doc.doc_name, "chunks": len(chunk_docs)},
):
vector_ids = assembler.persist()
doc.status = SyncStatus.FINISHED.name
doc.result = "document embedding success"
if vector_ids is not None:

View File

@@ -11,7 +11,6 @@ from dbgpt.app.knowledge.document_db import (
)
from dbgpt.app.knowledge.service import KnowledgeService
from dbgpt.app.scene import BaseChat, ChatScene
from dbgpt.component import ComponentType
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG
from dbgpt.core import (
ChatPromptTemplate,
@@ -19,10 +18,8 @@ from dbgpt.core import (
MessagesPlaceholder,
SystemPromptTemplate,
)
from dbgpt.model import DefaultLLMClient
from dbgpt.model.cluster import WorkerManagerFactory
from dbgpt.rag.retriever.rewrite import QueryRewrite
from dbgpt.util.tracer import trace
from dbgpt.util.tracer import root_tracer, trace
CFG = Config()
@@ -226,6 +223,9 @@ class ChatKnowledge(BaseChat):
async def execute_similar_search(self, query):
"""execute similarity search"""
return await self.embedding_retriever.aretrieve_with_scores(
query, self.recall_score
)
with root_tracer.start_span(
"execute_similar_search", metadata={"query": query}
):
return await self.embedding_retriever.aretrieve_with_scores(
query, self.recall_score
)

View File

@@ -167,6 +167,8 @@ EMBEDDING_MODEL_CONFIG = {
# https://huggingface.co/BAAI/bge-large-zh
"bge-large-zh": os.path.join(MODEL_PATH, "bge-large-zh"),
"bge-base-zh": os.path.join(MODEL_PATH, "bge-base-zh"),
"gte-large-zh": os.path.join(MODEL_PATH, "gte-large-zh"),
"gte-base-zh": os.path.join(MODEL_PATH, "gte-base-zh"),
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"),
"proxy_openai": "proxy_openai",
"proxy_azure": "proxy_azure",

View File

@@ -1,12 +1,14 @@
try:
from dbgpt.model.cluster.client import DefaultLLMClient
from dbgpt.model.cluster.client import DefaultLLMClient, RemoteLLMClient
except ImportError as exc:
# logging.warning("Can't import dbgpt.model.DefaultLLMClient")
DefaultLLMClient = None
RemoteLLMClient = None
_exports = []
if DefaultLLMClient:
_exports.append("DefaultLLMClient")
if RemoteLLMClient:
_exports.append("RemoteLLMClient")
__ALL__ = _exports

View File

@@ -104,3 +104,60 @@ class DefaultLLMClient(LLMClient):
async def count_token(self, model: str, prompt: str) -> int:
return await self.worker_manager.count_token({"model": model, "prompt": prompt})
@register_resource(
label="Remote LLM Client",
name="remote_llm_client",
category=ResourceCategory.LLM_CLIENT,
description="Remote LLM client(Connect to the remote DB-GPT model serving)",
parameters=[
Parameter.build_from(
"Controller Address",
name="controller_address",
type=str,
optional=True,
default="http://127.0.0.1:8000",
description="Model controller address",
),
Parameter.build_from(
"Auto Convert Message",
name="auto_convert_message",
type=bool,
optional=True,
default=False,
description="Whether to auto convert the messages that are not supported "
"by the LLM to a compatible format",
),
],
)
class RemoteLLMClient(DefaultLLMClient):
"""Remote LLM client implementation.
Connect to the remote worker manager and send the request to the remote worker manager.
Args:
controller_address (str): model controller address
auto_convert_message (bool, optional): auto convert the message to
ModelRequest. Defaults to False.
If you start DB-GPT model cluster, the controller address is the address of the
Model Controller(`dbgpt start controller`, the default port of model controller
is 8000).
Otherwise, if you already have a running DB-GPT server(start it by
`dbgpt start webserver --port ${remote_port}`), you can use the address of the
`http://${remote_ip}:${remote_port}`.
"""
def __init__(
self,
controller_address: str = "http://127.0.0.1:8000",
auto_convert_message: bool = False,
):
"""Initialize the RemoteLLMClient."""
from dbgpt.model.cluster import ModelRegistryClient, RemoteWorkerManager
model_registry_client = ModelRegistryClient(controller_address)
worker_manager = RemoteWorkerManager(model_registry_client)
super().__init__(worker_manager, auto_convert_message)

View File

@@ -7,6 +7,7 @@ from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker
from dbgpt.rag.retriever.rewrite import QueryRewrite
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from dbgpt.util.chat_util import run_async_tasks
from dbgpt.util.tracer import root_tracer
class EmbeddingRetriever(BaseRetriever):
@@ -129,23 +130,45 @@ class EmbeddingRetriever(BaseRetriever):
"""
queries = [query]
if self._query_rewrite:
candidates_tasks = [self._similarity_search(query) for query in queries]
chunks = await self._run_async_tasks(candidates_tasks)
context = "\n".join([chunk.content for chunk in chunks])
new_queries = await self._query_rewrite.rewrite(
origin_query=query, context=context, nums=1
with root_tracer.start_span(
"EmbeddingRetriever.query_rewrite.similarity_search",
metadata={"query": query, "score_threshold": score_threshold},
):
candidates_tasks = [self._similarity_search(query) for query in queries]
chunks = await self._run_async_tasks(candidates_tasks)
context = "\n".join([chunk.content for chunk in chunks])
with root_tracer.start_span(
"EmbeddingRetriever.query_rewrite.rewrite",
metadata={"query": query, "context": context, "nums": 1},
):
new_queries = await self._query_rewrite.rewrite(
origin_query=query, context=context, nums=1
)
queries.extend(new_queries)
with root_tracer.start_span(
"EmbeddingRetriever.similarity_search_with_score",
metadata={"query": query, "score_threshold": score_threshold},
):
candidates_with_score = [
self._similarity_search_with_score(query, score_threshold)
for query in queries
]
candidates_with_score = await run_async_tasks(
tasks=candidates_with_score, concurrency_limit=1
)
queries.extend(new_queries)
candidates_with_score = [
self._similarity_search_with_score(query, score_threshold)
for query in queries
]
candidates_with_score = await run_async_tasks(
tasks=candidates_with_score, concurrency_limit=1
)
candidates_with_score = reduce(lambda x, y: x + y, candidates_with_score)
candidates_with_score = self._rerank.rank(candidates_with_score)
return candidates_with_score
candidates_with_score = reduce(lambda x, y: x + y, candidates_with_score)
with root_tracer.start_span(
"EmbeddingRetriever.rerank",
metadata={
"query": query,
"score_threshold": score_threshold,
"rerank_cls": self._rerank.__class__.__name__,
},
):
candidates_with_score = self._rerank.rank(candidates_with_score)
return candidates_with_score
async def _similarity_search(self, query) -> List[Chunk]:
"""Similar search."""

View File

@@ -1,6 +1,5 @@
import copy
import logging
import re
from abc import ABC, abstractmethod
from typing import (
Any,
@@ -66,10 +65,14 @@ class TextSplitter(ABC):
chunks.append(new_doc)
return chunks
def split_documents(self, documents: List[Document], **kwargs) -> List[Chunk]:
def split_documents(self, documents: Iterable[Document], **kwargs) -> List[Chunk]:
"""Split documents."""
texts = [doc.content for doc in documents]
metadatas = [doc.metadata for doc in documents]
texts = []
metadatas = []
for doc in documents:
# Iterable just supports one iteration
texts.append(doc.content)
metadatas.append(doc.metadata)
return self.create_documents(texts, metadatas, **kwargs)
def _join_docs(self, docs: List[str], separator: str, **kwargs) -> Optional[str]:

View File

@@ -6,6 +6,7 @@ from dbgpt.rag.chunk_manager import ChunkManager, ChunkParameters
from dbgpt.rag.extractor.base import Extractor
from dbgpt.rag.knowledge.base import Knowledge
from dbgpt.rag.retriever.base import BaseRetriever
from dbgpt.util.tracer import root_tracer, trace
class BaseAssembler(ABC):
@@ -30,12 +31,25 @@ class BaseAssembler(ABC):
knowledge=self._knowledge, chunk_parameter=self._chunk_parameters
)
self._chunks = None
self.load_knowledge(self._knowledge)
metadata = {
"knowledge_cls": self._knowledge.__class__.__name__
if self._knowledge
else None,
"knowledge_type": self._knowledge.type().value if self._knowledge else None,
"path": self._knowledge._path
if self._knowledge and hasattr(self._knowledge, "_path")
else None,
"chunk_parameters": self._chunk_parameters.dict(),
}
with root_tracer.start_span("BaseAssembler.load_knowledge", metadata=metadata):
self.load_knowledge(self._knowledge)
def load_knowledge(self, knowledge) -> None:
"""Load knowledge Pipeline."""
documents = knowledge.load()
self._chunks = self._chunk_manager.split(documents)
with root_tracer.start_span("BaseAssembler.knowledge.load"):
documents = knowledge.load()
with root_tracer.start_span("BaseAssembler.chunk_manager.split"):
self._chunks = self._chunk_manager.split(documents)
@abstractmethod
def as_retriever(self, **kwargs: Any) -> BaseRetriever:

View File

@@ -1,4 +1,6 @@
import logging
import math
import time
from abc import ABC, abstractmethod
from typing import Any, Callable, List, Optional
@@ -6,6 +8,8 @@ from pydantic import BaseModel, Field
from dbgpt.rag.chunk import Chunk
logger = logging.getLogger(__name__)
class VectorStoreConfig(BaseModel):
"""Vector store config."""
@@ -26,6 +30,12 @@ class VectorStoreConfig(BaseModel):
default=None,
description="The embedding function of vector store, if not set, will use the default embedding function.",
)
max_chunks_once_load: int = Field(
default=10,
description="The max number of chunks to load at once. If your document is "
"large, you can set this value to a larger number to speed up the loading "
"process. Default is 10.",
)
class VectorStoreBase(ABC):
@@ -41,6 +51,33 @@ class VectorStoreBase(ABC):
"""
pass
def load_document_with_limit(
self, chunks: List[Chunk], max_chunks_once_load: int = 10
) -> List[str]:
"""load document in vector database with limit.
Args:
chunks: document chunks.
max_chunks_once_load: Max number of chunks to load at once.
Return:
"""
# Group the chunks into chunks of size max_chunks
chunk_groups = [
chunks[i : i + max_chunks_once_load]
for i in range(0, len(chunks), max_chunks_once_load)
]
logger.info(f"Loading {len(chunks)} chunks in {len(chunk_groups)} groups")
ids = []
loaded_cnt = 0
start_time = time.time()
for chunk_group in chunk_groups:
ids.extend(self.load_document(chunk_group))
loaded_cnt += len(chunk_group)
logger.info(f"Loaded {loaded_cnt} chunks, total {len(chunks)} chunks.")
logger.info(
f"Loaded {len(chunks)} chunks in {time.time() - start_time} seconds"
)
return ids
@abstractmethod
def similar_search(self, text, topk) -> List[Chunk]:
"""similar search in vector database.

View File

@@ -64,7 +64,9 @@ class VectorStoreConnector:
- chunks: document chunks.
Return chunk ids.
"""
return self.client.load_document(chunks)
return self.client.load_document_with_limit(
chunks, self._vector_store_config.max_chunks_once_load
)
def similar_search(self, doc: str, topk: int) -> List[Chunk]:
"""similar search in vector database.

View File

@@ -1,11 +1,12 @@
from __future__ import annotations
import json
import uuid
from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from typing import Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional
from dbgpt.component import BaseComponent, ComponentType, SystemApp
@@ -95,7 +96,7 @@ class Span:
"end_time": None
if not self.end_time
else self.end_time.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3],
"metadata": self.metadata,
"metadata": _clean_for_json(self.metadata),
}
@@ -187,3 +188,39 @@ class Tracer(BaseComponent, ABC):
@dataclass
class TracerContext:
span_id: Optional[str] = None
def _clean_for_json(data: Optional[str, Any] = None):
if not data:
return None
if isinstance(data, dict):
cleaned_dict = {}
for key, value in data.items():
# Try to clean the sub-items
cleaned_value = _clean_for_json(value)
if cleaned_value is not None:
# Only add to the cleaned dict if it's not None
try:
json.dumps({key: cleaned_value})
cleaned_dict[key] = cleaned_value
except TypeError:
# Skip this key-value pair if it can't be serialized
pass
return cleaned_dict
elif isinstance(data, list):
cleaned_list = []
for item in data:
cleaned_item = _clean_for_json(item)
if cleaned_item is not None:
try:
json.dumps(cleaned_item)
cleaned_list.append(cleaned_item)
except TypeError:
pass
return cleaned_list
else:
try:
json.dumps(data)
return data
except TypeError:
return None