feat:document summary

This commit is contained in:
aries_ckt
2023-10-31 15:09:11 +08:00
parent 16dd8e3ef5
commit 523838fb79

View File

@@ -429,14 +429,15 @@ class KnowledgeService:
from llama_index import PromptHelper from llama_index import PromptHelper
from llama_index.prompts.default_prompt_selectors import DEFAULT_TREE_SUMMARIZE_PROMPT_SEL from llama_index.prompts.default_prompt_selectors import DEFAULT_TREE_SUMMARIZE_PROMPT_SEL
texts = [doc.page_content for doc in chunk_docs] texts = [doc.page_content for doc in chunk_docs]
prompt_helper = PromptHelper() prompt_helper = PromptHelper(context_window=5000)
texts = prompt_helper.repack(prompt=DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, text_chunks=texts) texts = prompt_helper.repack(prompt=DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, text_chunks=texts)
logger.info(
f"async_document_summary, doc:{doc.doc_name}, chunk_size:{len(texts)}, begin generate summary"
)
summary = self._llm_extract_summary(texts[0]) summary = self._llm_extract_summary(texts[0])
# summaries = self._mapreduce_extract_summary(texts) # summaries = self._mapreduce_extract_summary(texts)
outputs, summary = self._refine_extract_summary(texts[1:], summary) outputs, summary = self._refine_extract_summary(texts[1:], summary)
logger.info(
f"async_document_summary, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to graph store"
)
doc.summary = summary doc.summary = summary
return knowledge_document_dao.update_knowledge_document(doc) return knowledge_document_dao.update_knowledge_document(doc)
@@ -525,14 +526,18 @@ class KnowledgeService:
ChatScene.ExtractSummary.value(), **{"chat_param": chat_param} ChatScene.ExtractSummary.value(), **{"chat_param": chat_param}
) )
) )
logger.info(
f"initialize summary is :{summary}"
)
return summary return summary
def _refine_extract_summary(self, docs, summary: str): def _refine_extract_summary(self, docs, summary: str, max_iteration:int = 5):
"""Extract refine summary by llm""" """Extract refine summary by llm"""
from pilot.scene.base import ChatScene from pilot.scene.base import ChatScene
from pilot.common.chat_util import llm_chat_response_nostream from pilot.common.chat_util import llm_chat_response_nostream
import uuid import uuid
outputs = [] outputs = []
for doc in docs: max_iteration = max_iteration if len(docs) > max_iteration else len(docs)
for doc in docs[0:max_iteration]:
chat_param = { chat_param = {
"chat_session_id": uuid.uuid1(), "chat_session_id": uuid.uuid1(),
"current_user_input": doc, "current_user_input": doc,
@@ -547,6 +552,9 @@ class KnowledgeService:
) )
) )
outputs.append(summary) outputs.append(summary)
logger.info(
f"iterator is {len(outputs)} current summary is :{summary}"
)
return outputs, summary return outputs, summary
def _mapreduce_extract_summary(self, docs): def _mapreduce_extract_summary(self, docs):
@@ -567,7 +575,8 @@ class KnowledgeService:
ChatScene.ExtractSummary.value(), **{"chat_param": chat_param} ChatScene.ExtractSummary.value(), **{"chat_param": chat_param}
)) ))
from pilot.common.chat_util import run_async_tasks from pilot.common.chat_util import run_async_tasks
summaries = run_async_tasks(tasks) summary_iters = run_async_tasks(tasks)
summary = self._llm_extract_summary(" ".join(summary_iters))
# from pilot.utils import utils # from pilot.utils import utils
# loop = utils.get_or_create_event_loop() # loop = utils.get_or_create_event_loop()
# summary = loop.run_until_complete( # summary = loop.run_until_complete(
@@ -576,6 +585,6 @@ class KnowledgeService:
# ) # )
# ) # )
# outputs.append(summary) # outputs.append(summary)
return summaries return summary