mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-31 15:47:05 +00:00
feat:add document summary
This commit is contained in:
parent
7a8c34346a
commit
d7f677a59a
@ -13,7 +13,6 @@ from string import Formatter
|
||||
from typing import Callable, List, Optional, Sequence
|
||||
|
||||
from pydantic import Field, PrivateAttr, BaseModel
|
||||
from llama_index.prompts import BasePromptTemplate
|
||||
|
||||
from pilot.common.global_helper import globals_helper
|
||||
from pilot.common.llm_metadata import LLMMetadata
|
||||
|
@ -79,4 +79,3 @@ def split_by_phrase_regex() -> Callable[[str], List[str]]:
|
||||
"""
|
||||
regex = "[^,.;。]+[,.;。]?"
|
||||
return split_by_regex(regex)
|
||||
|
||||
|
@ -174,9 +174,6 @@ class BaseChat(ABC):
|
||||
def stream_plugin_call(self, text):
|
||||
return text
|
||||
|
||||
# def knowledge_reference_call(self, text):
|
||||
# return text
|
||||
|
||||
async def check_iterator_end(iterator):
|
||||
try:
|
||||
await asyncio.anext(iterator)
|
||||
@ -218,7 +215,6 @@ class BaseChat(ABC):
|
||||
view_msg = view_msg.replace("\n", "\\n")
|
||||
yield view_msg
|
||||
self.current_message.add_ai_message(msg)
|
||||
# view_msg = self.knowledge_reference_call(msg)
|
||||
self.current_message.add_view_message(view_msg)
|
||||
span.end()
|
||||
except Exception as e:
|
||||
|
@ -153,12 +153,17 @@ async def document_upload(
|
||||
request.content = os.path.join(
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH, space_name, doc_file.filename
|
||||
)
|
||||
space_res = knowledge_space_service.get_knowledge_space(KnowledgeSpaceRequest(name=space_name))
|
||||
if len(space_res) == 0:
|
||||
# create default space
|
||||
if "default" != space_name:
|
||||
raise Exception(f"you have not create your knowledge space.")
|
||||
knowledge_space_service.create_knowledge_space(KnowledgeSpaceRequest(name=space_name, desc="first db-gpt rag application", owner="dbgpt"))
|
||||
return Result.succ(
|
||||
knowledge_space_service.create_knowledge_document(
|
||||
space=space_name, request=request
|
||||
)
|
||||
)
|
||||
# return Result.succ([])
|
||||
return Result.failed(code="E000X", msg=f"doc_file is None")
|
||||
except Exception as e:
|
||||
return Result.failed(code="E000X", msg=f"document add error {e}")
|
||||
@ -240,7 +245,7 @@ async def document_summary(request: DocumentSummaryRequest):
|
||||
# )
|
||||
# return Result.succ([])
|
||||
except Exception as e:
|
||||
return Result.faild(code="E000X", msg=f"document add error {e}")
|
||||
return Result.faild(code="E000X", msg=f"document summary error {e}")
|
||||
|
||||
|
||||
@router.post("/knowledge/entity/extract")
|
||||
|
@ -83,7 +83,7 @@ class DocumentChunkDao(BaseDao):
|
||||
DocumentChunkEntity.meta_info == query.meta_info
|
||||
)
|
||||
|
||||
document_chunks = document_chunks.order_by(DocumentChunkEntity.id.desc())
|
||||
document_chunks = document_chunks.order_by(DocumentChunkEntity.id.asc())
|
||||
document_chunks = document_chunks.offset((page - 1) * page_size).limit(
|
||||
page_size
|
||||
)
|
||||
|
@ -114,6 +114,7 @@ class DocumentSummaryRequest(BaseModel):
|
||||
"""doc_ids: doc ids"""
|
||||
doc_id: int
|
||||
model_name: str
|
||||
conv_uid: str
|
||||
|
||||
|
||||
class EntityExtractRequest(BaseModel):
|
||||
@ -121,3 +122,4 @@ class EntityExtractRequest(BaseModel):
|
||||
|
||||
text: str
|
||||
model_name: str
|
||||
|
||||
|
@ -66,11 +66,7 @@ class KnowledgeService:
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
from pilot.graph_engine.graph_engine import RAGGraphEngine
|
||||
|
||||
# source = "/Users/chenketing/Desktop/project/llama_index/examples/paul_graham_essay/data/test/test_kg_text.txt"
|
||||
|
||||
# pass
|
||||
pass
|
||||
|
||||
def create_knowledge_space(self, request: KnowledgeSpaceRequest):
|
||||
"""create knowledge space
|
||||
@ -286,7 +282,6 @@ class KnowledgeService:
|
||||
executor = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
|
||||
).create()
|
||||
# executor.submit(self.async_document_summary, chunk_docs, doc)
|
||||
executor.submit(self.async_doc_embedding, client, chunk_docs, doc)
|
||||
logger.info(f"begin save document chunks, doc:{doc.doc_name}")
|
||||
# save chunk details
|
||||
@ -326,7 +321,7 @@ class KnowledgeService:
|
||||
|
||||
chunk_docs = [Document(page_content=chunk.content) for chunk in chunks]
|
||||
return await self.async_document_summary(
|
||||
model_name=request.model_name, chunk_docs=chunk_docs, doc=document
|
||||
model_name=request.model_name, chunk_docs=chunk_docs, doc=document, conn_uid=request.conv_uid
|
||||
)
|
||||
|
||||
def update_knowledge_space(
|
||||
@ -441,7 +436,7 @@ class KnowledgeService:
|
||||
logger.error(f"document build graph failed:{doc.doc_name}, {str(e)}")
|
||||
return knowledge_document_dao.update_knowledge_document(doc)
|
||||
|
||||
async def async_document_summary(self, model_name, chunk_docs, doc):
|
||||
async def async_document_summary(self, model_name, chunk_docs, doc, conn_uid):
|
||||
"""async document extract summary
|
||||
Args:
|
||||
- model_name: str
|
||||
@ -458,8 +453,17 @@ class KnowledgeService:
|
||||
logger.info(
|
||||
f"async_document_summary, doc:{doc.doc_name}, chunk_size:{len(texts)}, begin generate summary"
|
||||
)
|
||||
summary = await self._mapreduce_extract_summary(texts, model_name, 10, 3)
|
||||
return await self._llm_extract_summary(summary, model_name)
|
||||
space_context = self.get_space_context(doc.space)
|
||||
if space_context and space_context.get("summary"):
|
||||
summary = await self._mapreduce_extract_summary(
|
||||
docs=texts,
|
||||
model_name=model_name,
|
||||
max_iteration=space_context["summary"]["max_iteration"],
|
||||
concurrency_limit=space_context["summary"]["concurrency_limit"],
|
||||
)
|
||||
else:
|
||||
summary = await self._mapreduce_extract_summary(docs=texts, model_name=model_name)
|
||||
return await self._llm_extract_summary(summary, conn_uid, model_name)
|
||||
|
||||
def async_doc_embedding(self, client, chunk_docs, doc):
|
||||
"""async document embedding into vector db
|
||||
@ -504,6 +508,10 @@ class KnowledgeService:
|
||||
"scene": PROMPT_SCENE_DEFINE,
|
||||
"template": _DEFAULT_TEMPLATE,
|
||||
},
|
||||
"summary": {
|
||||
"max_iteration": 5,
|
||||
"concurrency_limit": 3,
|
||||
},
|
||||
}
|
||||
context_template_string = json.dumps(context_template, indent=4)
|
||||
return context_template_string
|
||||
@ -525,13 +533,13 @@ class KnowledgeService:
|
||||
return json.loads(spaces[0].context)
|
||||
return None
|
||||
|
||||
async def _llm_extract_summary(self, doc: str, model_name: str = None):
|
||||
async def _llm_extract_summary(self, doc: str, conn_uid:str, model_name: str = None):
|
||||
"""Extract triplets from text by llm"""
|
||||
from pilot.scene.base import ChatScene
|
||||
import uuid
|
||||
|
||||
chat_param = {
|
||||
"chat_session_id": uuid.uuid1(),
|
||||
"chat_session_id": conn_uid,
|
||||
"current_user_input": "",
|
||||
"select_param": doc,
|
||||
"model_name": model_name,
|
||||
@ -554,10 +562,10 @@ class KnowledgeService:
|
||||
docs,
|
||||
model_name: str = None,
|
||||
max_iteration: int = 5,
|
||||
concurrency_limit: int = None,
|
||||
concurrency_limit: int = 3,
|
||||
):
|
||||
"""Extract summary by mapreduce mode
|
||||
map -> multi async thread generate summary
|
||||
map -> multi async call llm to generate summary
|
||||
reduce -> merge the summaries by map process
|
||||
Args:
|
||||
docs:List[str]
|
||||
|
Loading…
Reference in New Issue
Block a user