diff --git a/pilot/server/knowledge/service.py b/pilot/server/knowledge/service.py index cde2b7bb7..81e1dbdcc 100644 --- a/pilot/server/knowledge/service.py +++ b/pilot/server/knowledge/service.py @@ -429,14 +429,15 @@ class KnowledgeService: from llama_index import PromptHelper from llama_index.prompts.default_prompt_selectors import DEFAULT_TREE_SUMMARIZE_PROMPT_SEL texts = [doc.page_content for doc in chunk_docs] - prompt_helper = PromptHelper() + prompt_helper = PromptHelper(context_window=5000) texts = prompt_helper.repack(prompt=DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, text_chunks=texts) + logger.info( + f"async_document_summary, doc:{doc.doc_name}, chunk_size:{len(texts)}, begin generate summary" + ) summary = self._llm_extract_summary(texts[0]) # summaries = self._mapreduce_extract_summary(texts) outputs, summary = self._refine_extract_summary(texts[1:], summary) - logger.info( - f"async_document_summary, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to graph store" - ) + doc.summary = summary return knowledge_document_dao.update_knowledge_document(doc) @@ -525,14 +526,18 @@ class KnowledgeService: ChatScene.ExtractSummary.value(), **{"chat_param": chat_param} ) ) + logger.info( + f"initialize summary is :{summary}" + ) return summary - def _refine_extract_summary(self, docs, summary: str): + def _refine_extract_summary(self, docs, summary: str, max_iteration:int = 5): """Extract refine summary by llm""" from pilot.scene.base import ChatScene from pilot.common.chat_util import llm_chat_response_nostream import uuid outputs = [] - for doc in docs: + max_iteration = max_iteration if len(docs) > max_iteration else len(docs) + for doc in docs[0:max_iteration]: chat_param = { "chat_session_id": uuid.uuid1(), "current_user_input": doc, @@ -547,6 +552,9 @@ class KnowledgeService: ) ) outputs.append(summary) + logger.info( + f"iterator is {len(outputs)} current summary is :{summary}" + ) return outputs, summary def _mapreduce_extract_summary(self, docs): @@ -567,7 +575,8 @@ class KnowledgeService: ChatScene.ExtractSummary.value(), **{"chat_param": chat_param} )) from pilot.common.chat_util import run_async_tasks - summaries = run_async_tasks(tasks) + summary_iters = run_async_tasks(tasks) + summary = self._llm_extract_summary(" ".join(summary_iters)) # from pilot.utils import utils # loop = utils.get_or_create_event_loop() # summary = loop.run_until_complete( @@ -576,6 +585,6 @@ class KnowledgeService: # ) # ) # outputs.append(summary) - return summaries + return summary