refactor: RAG Refactor (#985)

Co-authored-by: Aralhi <xiaoping0501@gmail.com>
Co-authored-by: csunny <cfqsunny@163.com>
This commit is contained in:
Aries-ckt
2024-01-03 09:45:26 +08:00
committed by GitHub
parent 90775aad50
commit 9ad70a2961
206 changed files with 5766 additions and 2419 deletions

View File

@@ -14,10 +14,10 @@ from dbgpt.app.knowledge.request.request import (
DocumentQueryRequest,
)
from dbgpt.rag.embedding_engine.knowledge_type import KnowledgeType
from dbgpt.app.knowledge.request.request import DocumentSyncRequest
from dbgpt.app.knowledge.request.request import KnowledgeSpaceRequest
from dbgpt.rag.knowledge.base import KnowledgeType
HTTP_HEADERS = {"Content-Type": "application/json"}

View File

@@ -2,6 +2,7 @@ import os
import shutil
import tempfile
import logging
from typing import List
from fastapi import APIRouter, File, UploadFile, Form
@@ -13,10 +14,10 @@ from dbgpt.configs.model_config import (
from dbgpt.app.openapi.api_v1.api_v1 import no_stream_generator, stream_generator
from dbgpt.app.openapi.api_view_model import Result
from dbgpt.rag.embedding_engine.embedding_engine import EmbeddingEngine
from dbgpt.rag.embedding_engine.embedding_factory import EmbeddingFactory
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
from dbgpt.app.knowledge.service import KnowledgeService
from dbgpt.rag.knowledge.factory import KnowledgeFactory
from dbgpt.app.knowledge.request.request import (
KnowledgeQueryRequest,
KnowledgeQueryResponse,
@@ -27,9 +28,14 @@ from dbgpt.app.knowledge.request.request import (
SpaceArgumentRequest,
EntityExtractRequest,
DocumentSummaryRequest,
KnowledgeSyncRequest,
)
from dbgpt.app.knowledge.request.request import KnowledgeSpaceRequest
from dbgpt.rag.knowledge.base import ChunkStrategy
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
from dbgpt.storage.vector_store.base import VectorStoreConfig
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from dbgpt.util.tracer import root_tracer, SpanType
logger = logging.getLogger(__name__)
@@ -103,6 +109,39 @@ def document_add(space_name: str, request: KnowledgeDocumentRequest):
return Result.failed(code="E000X", msg=f"document add error {e}")
@router.get("/knowledge/document/chunkstrategies")
def chunk_strategies():
"""Get chunk strategies"""
print(f"/document/chunkstrategies:")
try:
return Result.succ(
[
{
"strategy": strategy.name,
"name": strategy.value[2],
"description": strategy.value[3],
"parameters": strategy.value[1],
"suffix": [
knowledge.document_type().value
for knowledge in KnowledgeFactory.subclasses()
if strategy in knowledge.support_chunk_strategy()
and knowledge.document_type() is not None
],
"type": set(
[
knowledge.type().value
for knowledge in KnowledgeFactory.subclasses()
if strategy in knowledge.support_chunk_strategy()
]
),
}
for strategy in ChunkStrategy
]
)
except Exception as e:
return Result.failed(code="E000X", msg=f"chunk strategies error {e}")
@router.post("/knowledge/{space_name}/document/list")
def document_list(space_name: str, query_request: DocumentQueryRequest):
print(f"/document/list params: {space_name}, {query_request}")
@@ -189,6 +228,18 @@ def document_sync(space_name: str, request: DocumentSyncRequest):
return Result.failed(code="E000X", msg=f"document sync error {e}")
@router.post("/knowledge/{space_name}/document/sync_batch")
def batch_document_sync(space_name: str, request: List[KnowledgeSyncRequest]):
logger.info(f"Received params: {space_name}, {request}")
try:
doc_ids = knowledge_space_service.batch_document_sync(
space_name=space_name, sync_requests=request
)
return Result.succ({"tasks": doc_ids})
except Exception as e:
return Result.failed(code="E000X", msg=f"document sync error {e}")
@router.post("/knowledge/{space_name}/chunk/list")
def document_list(space_name: str, query_request: ChunkQueryRequest):
print(f"/document/list params: {space_name}, {query_request}")
@@ -204,15 +255,23 @@ def similar_query(space_name: str, query_request: KnowledgeQueryRequest):
embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory
)
client = EmbeddingEngine(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config={"vector_store_name": space_name},
embedding_factory=embedding_factory,
config = VectorStoreConfig(
name=space_name,
embedding_fn=embedding_factory.create(
EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
),
)
docs = client.similar_search(query_request.query, query_request.top_k)
vector_store_connector = VectorStoreConnector(
vector_store_type=CFG.VECTOR_STORE_TYPE,
vector_store_config=config,
)
retriever = EmbeddingRetriever(
top_k=query_request.top_k, vector_store_connector=vector_store_connector
)
chunks = retriever.retrieve(query_request.query)
res = [
KnowledgeQueryResponse(text=d.page_content, source=d.metadata["source"])
for d in docs
KnowledgeQueryResponse(text=d.content, source=d.metadata["source"])
for d in chunks
]
return {"response": res}
@@ -254,7 +313,7 @@ async def entity_extract(request: EntityExtractRequest):
logger.info(f"Received params: {request}")
try:
from dbgpt.app.scene import ChatScene
from dbgpt._private.chat_util import llm_chat_response_nostream
from dbgpt.util.chat_util import llm_chat_response_nostream
import uuid
chat_param = {

View File

@@ -1,4 +1,5 @@
from datetime import datetime
from typing import List
from sqlalchemy import Column, String, DateTime, Integer, Text, func
@@ -51,6 +52,12 @@ class KnowledgeDocumentDao(BaseDao):
return doc_id
def get_knowledge_documents(self, query, page=1, page_size=20):
"""Get a list of documents that match the given query.
Args:
query: A KnowledgeDocumentEntity object containing the query parameters.
page: The page number to return.
page_size: The number of documents to return per page.
"""
session = self.get_raw_session()
print(f"current session:{session}")
knowledge_documents = session.query(KnowledgeDocumentEntity)
@@ -85,6 +92,23 @@ class KnowledgeDocumentDao(BaseDao):
session.close()
return result
def documents_by_ids(self, ids) -> List[KnowledgeDocumentEntity]:
"""Get a list of documents by their IDs.
Args:
ids: A list of document IDs.
Returns:
A list of KnowledgeDocumentEntity objects.
"""
session = self.get_raw_session()
print(f"current session:{session}")
knowledge_documents = session.query(KnowledgeDocumentEntity)
knowledge_documents = knowledge_documents.filter(
KnowledgeDocumentEntity.id.in_(ids)
)
result = knowledge_documents.all()
session.close()
return result
def get_documents(self, query):
session = self.get_raw_session()
print(f"current session:{session}")

View File

@@ -3,6 +3,8 @@ from typing import List, Optional
from dbgpt._private.pydantic import BaseModel
from fastapi import UploadFile
from dbgpt.rag.chunk_manager import ChunkParameters
class KnowledgeQueryRequest(BaseModel):
"""query: knowledge query"""
@@ -43,6 +45,8 @@ class DocumentQueryRequest(BaseModel):
"""doc_name: doc path"""
doc_name: str = None
"""doc_ids: doc ids"""
doc_ids: Optional[List] = None
"""doc_type: doc type"""
doc_type: str = None
"""status: status"""
@@ -76,6 +80,20 @@ class DocumentSyncRequest(BaseModel):
chunk_overlap: Optional[int] = None
class KnowledgeSyncRequest(BaseModel):
"""Sync request"""
"""doc_ids: doc ids"""
doc_id: int
"""model_name: model name"""
model_name: Optional[str] = None
"""chunk_parameters: chunk parameters
"""
chunk_parameters: ChunkParameters
class ChunkQueryRequest(BaseModel):
"""id: id"""

View File

@@ -1,13 +1,26 @@
import json
import logging
from datetime import datetime
from typing import List
from dbgpt.model import DefaultLLMClient
from dbgpt.rag.chunk import Chunk
from dbgpt.rag.chunk_manager import ChunkParameters
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
from dbgpt.rag.knowledge.base import KnowledgeType
from dbgpt.rag.knowledge.factory import KnowledgeFactory
from dbgpt.rag.text_splitter.text_splitter import (
RecursiveCharacterTextSplitter,
SpacyTextSplitter,
)
from dbgpt.serve.rag.assembler.embedding import EmbeddingAssembler
from dbgpt.serve.rag.assembler.summary import SummaryAssembler
from dbgpt.storage.vector_store.base import VectorStoreConfig
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from dbgpt._private.config import Config
from dbgpt.configs.model_config import (
EMBEDDING_MODEL_CONFIG,
KNOWLEDGE_UPLOAD_ROOT_PATH,
)
from dbgpt.component import ComponentType
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
@@ -32,6 +45,7 @@ from dbgpt.app.knowledge.request.request import (
SpaceArgumentRequest,
DocumentSyncRequest,
DocumentSummaryRequest,
KnowledgeSyncRequest,
)
from enum import Enum
@@ -106,7 +120,10 @@ class KnowledgeService:
content=request.content,
result="",
)
return knowledge_document_dao.create_knowledge_document(document)
doc_id = knowledge_document_dao.create_knowledge_document(document)
if doc_id is None:
raise Exception(f"create document failed, {request.doc_name}")
return doc_id
def get_knowledge_space(self, request: KnowledgeSpaceRequest):
"""get knowledge space
@@ -171,36 +188,75 @@ class KnowledgeService:
Args:
- space: Knowledge Space Name
- request: DocumentQueryRequest
Returns:
- res DocumentQueryResponse
"""
query = KnowledgeDocumentEntity(
doc_name=request.doc_name,
doc_type=request.doc_type,
space=space,
status=request.status,
)
res = DocumentQueryResponse()
res.data = knowledge_document_dao.get_knowledge_documents(
query, page=request.page, page_size=request.page_size
)
res.total = knowledge_document_dao.get_knowledge_documents_count(query)
res.page = request.page
if request.doc_ids and len(request.doc_ids) > 0:
res.data = knowledge_document_dao.documents_by_ids(request.doc_ids)
else:
query = KnowledgeDocumentEntity(
doc_name=request.doc_name,
doc_type=request.doc_type,
space=space,
status=request.status,
)
res.data = knowledge_document_dao.get_knowledge_documents(
query, page=request.page, page_size=request.page_size
)
res.total = knowledge_document_dao.get_knowledge_documents_count(query)
res.page = request.page
return res
def batch_document_sync(
self, space_name, sync_requests: List[KnowledgeSyncRequest]
) -> List[int]:
"""batch sync knowledge document chunk into vector store
Args:
- space: Knowledge Space Name
- sync_requests: List[KnowledgeSyncRequest]
Returns:
- List[int]: document ids
"""
doc_ids = []
for sync_request in sync_requests:
docs = knowledge_document_dao.documents_by_ids([sync_request.doc_id])
if len(docs) == 0:
raise Exception(
f"there are document called, doc_id: {sync_request.doc_id}"
)
doc = docs[0]
if (
doc.status == SyncStatus.RUNNING.name
or doc.status == SyncStatus.FINISHED.name
):
raise Exception(
f" doc:{doc.doc_name} status is {doc.status}, can not sync"
)
chunk_parameters = sync_request.chunk_parameters
if "Automatic" == chunk_parameters.chunk_strategy:
space_context = self.get_space_context(space_name)
chunk_parameters.chunk_size = (
CFG.KNOWLEDGE_CHUNK_SIZE
if space_context is None
else int(space_context["embedding"]["chunk_size"])
)
chunk_parameters.chunk_overlap = (
CFG.KNOWLEDGE_CHUNK_OVERLAP
if space_context is None
else int(space_context["embedding"]["chunk_overlap"])
)
self._sync_knowledge_document(space_name, doc, chunk_parameters)
doc_ids.append(doc.id)
return doc_ids
def sync_knowledge_document(self, space_name, sync_request: DocumentSyncRequest):
"""sync knowledge document chunk into vector store
Args:
- space: Knowledge Space Name
- sync_request: DocumentSyncRequest
"""
from dbgpt.rag.embedding_engine.embedding_engine import EmbeddingEngine
from dbgpt.rag.embedding_engine.embedding_factory import EmbeddingFactory
from dbgpt.rag.embedding_engine.pre_text_splitter import PreTextSplitter
from langchain.text_splitter import (
RecursiveCharacterTextSplitter,
SpacyTextSplitter,
)
# import langchain is very very slow!!!
from dbgpt.rag.text_splitter.pre_text_splitter import PreTextSplitter
doc_ids = sync_request.doc_ids
self.model_name = sync_request.model_name or CFG.LLM_MODEL
@@ -234,6 +290,11 @@ class KnowledgeService:
if sync_request.chunk_overlap:
chunk_overlap = sync_request.chunk_overlap
separators = sync_request.separators or None
from dbgpt.rag.chunk_manager import ChunkParameters
chunk_parameters = ChunkParameters(
chunk_size=chunk_size, chunk_overlap=chunk_overlap
)
if CFG.LANGUAGE == "en":
text_splitter = RecursiveCharacterTextSplitter(
separators=separators,
@@ -244,7 +305,7 @@ class KnowledgeService:
else:
if separators and len(separators) > 1:
raise ValueError(
"SpacyTextSplitter do not support multiple separators"
"SpacyTextSplitter do not support multipsle separators"
)
try:
separator = "\n\n" if not separators else separators[0]
@@ -266,48 +327,51 @@ class KnowledgeService:
pre_separator=sync_request.pre_separator,
text_splitter_impl=text_splitter,
)
embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory
)
client = EmbeddingEngine(
knowledge_source=doc.content,
knowledge_type=doc.doc_type.upper(),
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config={
"vector_store_name": space_name,
"vector_store_type": CFG.VECTOR_STORE_TYPE,
},
text_splitter=text_splitter,
embedding_factory=embedding_factory,
)
chunk_docs = client.read()
# update document status
doc.status = SyncStatus.RUNNING.name
doc.chunk_size = len(chunk_docs)
doc.gmt_modified = datetime.now()
knowledge_document_dao.update_knowledge_document(doc)
executor = CFG.SYSTEM_APP.get_component(
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create()
executor.submit(self.async_doc_embedding, client, chunk_docs, doc)
logger.info(f"begin save document chunks, doc:{doc.doc_name}")
# save chunk details
chunk_entities = [
DocumentChunkEntity(
doc_name=doc.doc_name,
doc_type=doc.doc_type,
document_id=doc.id,
content=chunk_doc.page_content,
meta_info=str(chunk_doc.metadata),
gmt_created=datetime.now(),
gmt_modified=datetime.now(),
)
for chunk_doc in chunk_docs
]
document_chunk_dao.create_documents_chunks(chunk_entities)
chunk_parameters.text_splitter = text_splitter
self._sync_knowledge_document(space_name, doc, chunk_parameters)
return doc.id
def _sync_knowledge_document(
self,
space_name,
doc: KnowledgeDocumentEntity,
chunk_parameters: ChunkParameters,
) -> List[Chunk]:
"""sync knowledge document chunk into vector store"""
embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory
)
embedding_fn = embedding_factory.create(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
)
from dbgpt.storage.vector_store.base import VectorStoreConfig
config = VectorStoreConfig(name=space_name, embedding_fn=embedding_fn)
vector_store_connector = VectorStoreConnector(
vector_store_type=CFG.VECTOR_STORE_TYPE,
vector_store_config=config,
)
knowledge = KnowledgeFactory.create(
datasource=doc.content,
knowledge_type=KnowledgeType.get_by_value(doc.doc_type),
)
assembler = EmbeddingAssembler.load_from_knowledge(
knowledge=knowledge,
chunk_parameters=chunk_parameters,
vector_store_connector=vector_store_connector,
)
chunk_docs = assembler.get_chunks()
doc.status = SyncStatus.RUNNING.name
doc.chunk_size = len(chunk_docs)
doc.gmt_modified = datetime.now()
knowledge_document_dao.update_knowledge_document(doc)
executor = CFG.SYSTEM_APP.get_component(
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create()
executor.submit(self.async_doc_embedding, assembler, chunk_docs, doc)
logger.info(f"begin save document chunks, doc:{doc.doc_name}")
return chunk_docs
async def document_summary(self, request: DocumentSummaryRequest):
"""get document summary
Args:
@@ -318,20 +382,46 @@ class KnowledgeService:
if len(documents) != 1:
raise Exception(f"can not found document for {request.doc_id}")
document = documents[0]
query = DocumentChunkEntity(
document_id=request.doc_id,
)
chunks = document_chunk_dao.get_document_chunks(query, page=1, page_size=100)
if len(chunks) == 0:
raise Exception(f"can not found chunks for {request.doc_id}")
from langchain.schema import Document
from dbgpt.model.cluster import WorkerManagerFactory
chunk_docs = [Document(page_content=chunk.content) for chunk in chunks]
return await self.async_document_summary(
worker_manager = CFG.SYSTEM_APP.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
chunk_parameters = ChunkParameters(
chunk_strategy="CHUNK_BY_SIZE",
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
chunk_overlap=CFG.KNOWLEDGE_CHUNK_OVERLAP,
)
chunk_entities = document_chunk_dao.get_document_chunks(
DocumentChunkEntity(document_id=document.id)
)
if (
document.status not in [SyncStatus.RUNNING.name]
and len(chunk_entities) == 0
):
self._sync_knowledge_document(
space_name=document.space,
doc=document,
chunk_parameters=chunk_parameters,
)
knowledge = KnowledgeFactory.create(
datasource=document.content,
knowledge_type=KnowledgeType.get_by_value(document.doc_type),
)
assembler = SummaryAssembler(
knowledge=knowledge,
model_name=request.model_name,
chunk_docs=chunk_docs,
doc=document,
conn_uid=request.conv_uid,
llm_client=DefaultLLMClient(worker_manager=worker_manager),
language=CFG.LANGUAGE,
chunk_parameters=chunk_parameters,
)
summary = await assembler.generate_summary()
if len(assembler.get_chunks()) == 0:
raise Exception(f"can not found chunks for {request.doc_id}")
return await self._llm_extract_summary(
summary, request.conv_uid, request.model_name
)
def update_knowledge_space(
@@ -354,15 +444,13 @@ class KnowledgeService:
if len(spaces) == 0:
raise Exception(f"delete error, no space name:{space_name} in database")
space = spaces[0]
vector_config = {}
vector_config["vector_store_name"] = space.name
vector_config["vector_store_type"] = CFG.VECTOR_STORE_TYPE
vector_config["chroma_persist_path"] = KNOWLEDGE_UPLOAD_ROOT_PATH
vector_client = VectorStoreConnector(
vector_store_type=CFG.VECTOR_STORE_TYPE, ctx=vector_config
config = VectorStoreConfig(name=space.name)
vector_store_connector = VectorStoreConnector(
vector_store_type=CFG.VECTOR_STORE_TYPE,
vector_store_config=config,
)
# delete vectors
vector_client.delete_vector_name(space.name)
vector_store_connector.delete_vector_name(space.name)
document_query = KnowledgeDocumentEntity(space=space.name)
# delete chunks
documents = knowledge_document_dao.get_documents(document_query)
@@ -385,15 +473,13 @@ class KnowledgeService:
raise Exception(f"there are no or more than one document called {doc_name}")
vector_ids = documents[0].vector_ids
if vector_ids is not None:
vector_config = {}
vector_config["vector_store_name"] = space_name
vector_config["vector_store_type"] = CFG.VECTOR_STORE_TYPE
vector_config["chroma_persist_path"] = KNOWLEDGE_UPLOAD_ROOT_PATH
vector_client = VectorStoreConnector(
vector_store_type=CFG.VECTOR_STORE_TYPE, ctx=vector_config
config = VectorStoreConfig(name=space_name)
vector_store_connector = VectorStoreConnector(
vector_store_type=CFG.VECTOR_STORE_TYPE,
vector_store_config=config,
)
# delete vector by ids
vector_client.delete_by_ids(vector_ids)
vector_store_connector.delete_by_ids(vector_ids)
# delete chunks
document_chunk_dao.raw_delete(documents[0].id)
# delete document
@@ -432,7 +518,7 @@ class KnowledgeService:
f"async_knowledge_graph, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to graph store"
)
try:
from dbgpt.rag.graph_engine.graph_factory import RAGGraphFactory
from dbgpt.rag.graph.graph_factory import RAGGraphFactory
rag_engine = CFG.SYSTEM_APP.get_component(
ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory
@@ -446,54 +532,38 @@ class KnowledgeService:
logger.error(f"document build graph failed:{doc.doc_name}, {str(e)}")
return knowledge_document_dao.update_knowledge_document(doc)
async def async_document_summary(self, model_name, chunk_docs, doc, conn_uid):
"""async document extract summary
Args:
- model_name: str
- chunk_docs: List[Document]
- doc: KnowledgeDocumentEntity
"""
texts = [doc.page_content for doc in chunk_docs]
from dbgpt.util.prompt_util import PromptHelper
prompt_helper = PromptHelper()
from dbgpt.app.scene.chat_knowledge.summary.prompt import prompt
texts = prompt_helper.repack(prompt_template=prompt.template, text_chunks=texts)
logger.info(
f"async_document_summary, doc:{doc.doc_name}, chunk_size:{len(texts)}, begin generate summary"
)
space_context = self.get_space_context(doc.space)
if space_context and space_context.get("summary"):
summary = await self._mapreduce_extract_summary(
docs=texts,
model_name=model_name,
max_iteration=int(space_context["summary"]["max_iteration"]),
concurrency_limit=int(space_context["summary"]["concurrency_limit"]),
)
else:
summary = await self._mapreduce_extract_summary(
docs=texts, model_name=model_name
)
return await self._llm_extract_summary(summary, conn_uid, model_name)
def async_doc_embedding(self, client, chunk_docs, doc):
def async_doc_embedding(self, assembler, chunk_docs, doc):
"""async document embedding into vector db
Args:
- client: EmbeddingEngine Client
- chunk_docs: List[Document]
- doc: KnowledgeDocumentEntity
"""
logger.info(
f"async doc sync, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}"
f"async doc embedding sync, doc:{doc.doc_name}, chunks length is {len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}"
)
try:
vector_ids = client.knowledge_embedding_batch(chunk_docs)
vector_ids = assembler.persist()
doc.status = SyncStatus.FINISHED.name
doc.result = "document embedding success"
if vector_ids is not None:
doc.vector_ids = ",".join(vector_ids)
logger.info(f"async document embedding, success:{doc.doc_name}")
# save chunk details
chunk_entities = [
DocumentChunkEntity(
doc_name=doc.doc_name,
doc_type=doc.doc_type,
document_id=doc.id,
content=chunk_doc.content,
meta_info=str(chunk_doc.metadata),
gmt_created=datetime.now(),
gmt_modified=datetime.now(),
)
for chunk_doc in chunk_docs
]
document_chunk_dao.create_documents_chunks(chunk_entities)
except Exception as e:
doc.status = SyncStatus.FAILED.name
doc.result = "document embedding failed" + str(e)
@@ -577,65 +647,3 @@ class KnowledgeService:
**{"chat_param": chat_param},
)
return chat
async def _mapreduce_extract_summary(
self,
docs,
model_name: str = None,
max_iteration: int = 5,
concurrency_limit: int = 3,
):
"""Extract summary by mapreduce mode
map -> multi async call llm to generate summary
reduce -> merge the summaries by map process
Args:
docs:List[str]
model_name:model name str
max_iteration:max iteration will call llm to summary
concurrency_limit:the max concurrency threads to call llm
Returns:
Document: refine summary context document.
"""
from dbgpt.app.scene import ChatScene
from dbgpt._private.chat_util import llm_chat_response_nostream
import uuid
tasks = []
if len(docs) == 1:
return docs[0]
else:
max_iteration = max_iteration if len(docs) > max_iteration else len(docs)
for doc in docs[0:max_iteration]:
chat_param = {
"chat_session_id": uuid.uuid1(),
"current_user_input": "",
"select_param": doc,
"model_name": model_name,
"model_cache_enable": True,
}
tasks.append(
llm_chat_response_nostream(
ChatScene.ExtractSummary.value(), **{"chat_param": chat_param}
)
)
from dbgpt._private.chat_util import run_async_tasks
summary_iters = await run_async_tasks(
tasks=tasks, concurrency_limit=concurrency_limit
)
summary_iters = list(
filter(
lambda content: "LLMServer Generate Error" not in content,
summary_iters,
)
)
from dbgpt.util.prompt_util import PromptHelper
from dbgpt.app.scene.chat_knowledge.summary.prompt import prompt
prompt_helper = PromptHelper()
summary_iters = prompt_helper.repack(
prompt_template=prompt.template, text_chunks=summary_iters
)
return await self._mapreduce_extract_summary(
summary_iters, model_name, max_iteration, concurrency_limit
)