diff --git a/pilot/common/chat_util.py b/pilot/common/chat_util.py index 0de0b9bda..ae0ce73ed 100644 --- a/pilot/common/chat_util.py +++ b/pilot/common/chat_util.py @@ -1,4 +1,5 @@ import asyncio +from typing import Coroutine, List, Any from starlette.responses import StreamingResponse @@ -18,3 +19,37 @@ async def llm_chat_response_nostream(chat_scene: str, **chat_param): async def llm_chat_response(chat_scene: str, **chat_param): chat: BaseChat = chat_factory.get_implementation(chat_scene, **chat_param) return chat.stream_call() + + +def run_async_tasks( + tasks: List[Coroutine], + show_progress: bool = False, + progress_bar_desc: str = "Running async tasks", +) -> List[Any]: + """Run a list of async tasks.""" + + tasks_to_execute: List[Any] = tasks + if show_progress: + try: + import nest_asyncio + from tqdm.asyncio import tqdm + + nest_asyncio.apply() + loop = asyncio.get_event_loop() + + async def _tqdm_gather() -> List[Any]: + return await tqdm.gather(*tasks_to_execute, desc=progress_bar_desc) + + tqdm_outputs: List[Any] = loop.run_until_complete(_tqdm_gather()) + return tqdm_outputs + # run the operation w/o tqdm on hitting a fatal + # may occur in some environments where tqdm.asyncio + # is not supported + except Exception: + pass + + async def _gather() -> List[Any]: + return await asyncio.gather(*tasks_to_execute) + + outputs: List[Any] = asyncio.run(_gather()) + return outputs diff --git a/pilot/scene/chat_knowledge/refine_summary/__init__.py b/pilot/scene/chat_knowledge/refine_summary/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/scene/chat_knowledge/refine_summary/chat.py b/pilot/scene/chat_knowledge/refine_summary/chat.py new file mode 100644 index 000000000..b3a934dd5 --- /dev/null +++ b/pilot/scene/chat_knowledge/refine_summary/chat.py @@ -0,0 +1,37 @@ +from typing import Dict + +from pilot.scene.base_chat import BaseChat +from pilot.scene.base import ChatScene +from pilot.configs.config import Config + +from pilot.scene.chat_knowledge.refine_summary.prompt import prompt + +CFG = Config() + + +class ExtractRefineSummary(BaseChat): + chat_scene: str = ChatScene.ExtractRefineSummary.value() + + """get summary by llm""" + + def __init__(self, chat_param: Dict): + """ """ + chat_param["chat_mode"] = ChatScene.ExtractRefineSummary + super().__init__( + chat_param=chat_param, + ) + + self.user_input = chat_param["current_user_input"] + self.existing_answer = chat_param["select_param"] + # self.extract_mode = chat_param["select_param"] + + def generate_input_values(self): + input_values = { + "context": self.user_input, + "existing_answer": self.existing_answer, + } + return input_values + + @property + def chat_type(self) -> str: + return ChatScene.ExtractRefineSummary.value diff --git a/pilot/scene/chat_knowledge/refine_summary/out_parser.py b/pilot/scene/chat_knowledge/refine_summary/out_parser.py new file mode 100644 index 000000000..104419e88 --- /dev/null +++ b/pilot/scene/chat_knowledge/refine_summary/out_parser.py @@ -0,0 +1,57 @@ +import json +import logging +import re +from typing import List, Tuple + +from pilot.out_parser.base import BaseOutputParser, T +from pilot.configs.config import Config + +CFG = Config() + + +logger = logging.getLogger(__name__) + + +class ExtractRefineSummaryParser(BaseOutputParser): + def __init__(self, sep: str, is_stream_out: bool): + super().__init__(sep=sep, is_stream_out=is_stream_out) + + def parse_prompt_response( + self, response, max_length: int = 128 + ) -> List[Tuple[str, str, str]]: + # clean_str = super().parse_prompt_response(response) + print("clean prompt response:", response) + + # if response.startswith("Triplets:"): + # response = response[len("Triplets:") :] + # pattern = r"\([^()]+\)" + # response = re.findall(pattern, response) + # # response = response.strip().split("\n") + # print("parse prompt response:", response) + # results = [] + # for text in response: + # if not text or text[0] != "(" or text[-1] != ")": + # # skip empty lines and non-triplets + # continue + # tokens = text[1:-1].split(",") + # if len(tokens) != 3: + # continue + # + # if any(len(s.encode("utf-8")) > max_length for s in tokens): + # # We count byte-length instead of len() for UTF-8 chars, + # # will skip if any of the tokens are too long. + # # This is normally due to a poorly formatted triplet + # # extraction, in more serious KG building cases + # # we'll need NLP models to better extract triplets. + # continue + # + # subject, predicate, obj = map(str.strip, tokens) + # if not subject or not predicate or not obj: + # # skip partial triplets + # continue + # results.append((subject.lower(), predicate.lower(), obj.lower())) + return response + + def parse_view_response(self, speak, data) -> str: + ### tool out data to table view + return data diff --git a/pilot/scene/chat_knowledge/refine_summary/prompt.py b/pilot/scene/chat_knowledge/refine_summary/prompt.py new file mode 100644 index 000000000..0161cee35 --- /dev/null +++ b/pilot/scene/chat_knowledge/refine_summary/prompt.py @@ -0,0 +1,40 @@ +from pilot.prompts.prompt_new import PromptTemplate +from pilot.configs.config import Config +from pilot.scene.base import ChatScene +from pilot.common.schema import SeparatorStyle + +from pilot.scene.chat_knowledge.refine_summary.out_parser import ExtractRefineSummaryParser + +CFG = Config() + + +PROMPT_SCENE_DEFINE = """Your job is to produce a final summary.""" + +_DEFAULT_TEMPLATE = """ +We have provided an existing summary up to a certain point: {existing_answer}\nWe have the opportunity to refine the existing summary (only if needed) with some more context below.\n------------\n{context}\n------------\nGiven the new context, refine the original summary.\nIf the context isn't useful, return the original summary. + +please use original language. +""" +PROMPT_RESPONSE = """""" + + +RESPONSE_FORMAT = """""" + + +PROMPT_SEP = SeparatorStyle.SINGLE.value + +PROMPT_NEED_NEED_STREAM_OUT = False + +prompt = PromptTemplate( + template_scene=ChatScene.ExtractRefineSummary.value(), + input_variables=["existing_answer","context"], + response_format="", + template_define=PROMPT_SCENE_DEFINE, + template=_DEFAULT_TEMPLATE + PROMPT_RESPONSE, + stream_out=PROMPT_NEED_NEED_STREAM_OUT, + output_parser=ExtractRefineSummaryParser( + sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT + ), +) + +CFG.prompt_template_registry.register(prompt, is_default=True) diff --git a/pilot/scene/chat_knowledge/summary/__init__.py b/pilot/scene/chat_knowledge/summary/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/scene/chat_knowledge/summary/chat.py b/pilot/scene/chat_knowledge/summary/chat.py new file mode 100644 index 000000000..f887bde82 --- /dev/null +++ b/pilot/scene/chat_knowledge/summary/chat.py @@ -0,0 +1,35 @@ +from typing import Dict + +from pilot.scene.base_chat import BaseChat +from pilot.scene.base import ChatScene +from pilot.configs.config import Config + +from pilot.scene.chat_knowledge.summary.prompt import prompt + +CFG = Config() + + +class ExtractSummary(BaseChat): + chat_scene: str = ChatScene.ExtractSummary.value() + + """get summary by llm""" + + def __init__(self, chat_param: Dict): + """ """ + chat_param["chat_mode"] = ChatScene.ExtractSummary + super().__init__( + chat_param=chat_param, + ) + + self.user_input = chat_param["current_user_input"] + # self.extract_mode = chat_param["select_param"] + + def generate_input_values(self): + input_values = { + "context": self.user_input, + } + return input_values + + @property + def chat_type(self) -> str: + return ChatScene.ExtractSummary.value diff --git a/pilot/scene/chat_knowledge/summary/out_parser.py b/pilot/scene/chat_knowledge/summary/out_parser.py new file mode 100644 index 000000000..5626d0d4a --- /dev/null +++ b/pilot/scene/chat_knowledge/summary/out_parser.py @@ -0,0 +1,28 @@ +import json +import logging +import re +from typing import List, Tuple + +from pilot.out_parser.base import BaseOutputParser, T +from pilot.configs.config import Config + +CFG = Config() + + +logger = logging.getLogger(__name__) + + +class ExtractSummaryParser(BaseOutputParser): + def __init__(self, sep: str, is_stream_out: bool): + super().__init__(sep=sep, is_stream_out=is_stream_out) + + def parse_prompt_response( + self, response, max_length: int = 128 + ) -> List[Tuple[str, str, str]]: + # clean_str = super().parse_prompt_response(response) + print("clean prompt response:", response) + return response + + def parse_view_response(self, speak, data) -> str: + ### tool out data to table view + return data diff --git a/pilot/scene/chat_knowledge/summary/prompt.py b/pilot/scene/chat_knowledge/summary/prompt.py new file mode 100644 index 000000000..cbf452c99 --- /dev/null +++ b/pilot/scene/chat_knowledge/summary/prompt.py @@ -0,0 +1,47 @@ +from pilot.prompts.prompt_new import PromptTemplate +from pilot.configs.config import Config +from pilot.scene.base import ChatScene +from pilot.common.schema import SeparatorStyle + +from pilot.scene.chat_knowledge.summary.out_parser import ExtractSummaryParser + +CFG = Config() + +# PROMPT_SCENE_DEFINE = """You are an expert Q&A system that is trusted around the world.\nAlways answer the query using the provided context information, and not prior knowledge.\nSome rules to follow:\n1. Never directly reference the given context in your answer.\n2. Avoid statements like 'Based on the context, ...' or 'The context information ...' or anything along those lines.""" + +PROMPT_SCENE_DEFINE = """Your job is to produce a final summary.""" + +# _DEFAULT_TEMPLATE = """ +# Context information from multiple sources is below.\n---------------------\n +# {context} +# Given the information from multiple sources and not prior knowledge, answer the query.\nQuery: Describe what the provided text is about. Also describe some of the questions that this text can answer. \nAnswer: " +# """ + +_DEFAULT_TEMPLATE = """ +Write a concise summary of the following context: +{context} +please use original language. +""" +PROMPT_RESPONSE = """""" + + +RESPONSE_FORMAT = """""" + + +PROMPT_SEP = SeparatorStyle.SINGLE.value + +PROMPT_NEED_NEED_STREAM_OUT = False + +prompt = PromptTemplate( + template_scene=ChatScene.ExtractSummary.value(), + input_variables=["context"], + response_format="", + template_define=PROMPT_SCENE_DEFINE, + template=_DEFAULT_TEMPLATE + PROMPT_RESPONSE, + stream_out=PROMPT_NEED_NEED_STREAM_OUT, + output_parser=ExtractSummaryParser( + sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT + ), +) + +CFG.prompt_template_registry.register(prompt, is_default=True) diff --git a/pilot/server/knowledge/service.py b/pilot/server/knowledge/service.py index 017fef3ec..cde2b7bb7 100644 --- a/pilot/server/knowledge/service.py +++ b/pilot/server/knowledge/service.py @@ -288,7 +288,7 @@ class KnowledgeService: executor = CFG.SYSTEM_APP.get_component( ComponentType.EXECUTOR_DEFAULT, ExecutorFactory ).create() - executor.submit(self.async_document_summary, chunk_docs, doc) + # executor.submit(self.async_document_summary, chunk_docs, doc) executor.submit(self.async_doc_embedding, client, chunk_docs, doc) logger.info(f"begin save document chunks, doc:{doc.doc_name}") # save chunk details @@ -431,7 +431,8 @@ class KnowledgeService: texts = [doc.page_content for doc in chunk_docs] prompt_helper = PromptHelper() texts = prompt_helper.repack(prompt=DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, text_chunks=texts) - summary = self._llm_extract_summary(chunk_docs[0]) + summary = self._llm_extract_summary(texts[0]) + # summaries = self._mapreduce_extract_summary(texts) outputs, summary = self._refine_extract_summary(texts[1:], summary) logger.info( f"async_document_summary, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to graph store" @@ -452,6 +453,7 @@ class KnowledgeService: ) try: vector_ids = client.knowledge_embedding_batch(chunk_docs) + self.async_document_summary(chunk_docs, doc) doc.status = SyncStatus.FINISHED.name doc.result = "document embedding success" if vector_ids is not None: @@ -512,9 +514,9 @@ class KnowledgeService: chat_param = { "chat_session_id": uuid.uuid1(), - "current_user_input": doc.page_content, + "current_user_input": doc, "select_param": "summary", - "model_name": "proxyllm", + "model_name": CFG.LLM_MODEL, } from pilot.utils import utils loop = utils.get_or_create_event_loop() @@ -535,7 +537,7 @@ class KnowledgeService: "chat_session_id": uuid.uuid1(), "current_user_input": doc, "select_param": summary, - "model_name": "proxyllm", + "model_name": CFG.LLM_MODEL, } from pilot.utils import utils loop = utils.get_or_create_event_loop() @@ -547,4 +549,33 @@ class KnowledgeService: outputs.append(summary) return outputs, summary + def _mapreduce_extract_summary(self, docs): + """Extract mapreduce summary by llm""" + from pilot.scene.base import ChatScene + from pilot.common.chat_util import llm_chat_response_nostream + import uuid + outputs = [] + tasks = [] + for doc in docs: + chat_param = { + "chat_session_id": uuid.uuid1(), + "current_user_input": doc, + "select_param": "summary", + "model_name": CFG.LLM_MODEL, + } + tasks.append(llm_chat_response_nostream( + ChatScene.ExtractSummary.value(), **{"chat_param": chat_param} + )) + from pilot.common.chat_util import run_async_tasks + summaries = run_async_tasks(tasks) + # from pilot.utils import utils + # loop = utils.get_or_create_event_loop() + # summary = loop.run_until_complete( + # llm_chat_response_nostream( + # ChatScene.ExtractRefineSummary.value(), **{"chat_param": chat_param} + # ) + # ) + # outputs.append(summary) + return summaries +