feat:add summary

This commit is contained in:
aries_ckt 2023-10-31 13:47:19 +08:00
parent 07ad8fac67
commit dca3ddb931
10 changed files with 315 additions and 5 deletions

View File

@ -1,4 +1,5 @@
import asyncio
from typing import Coroutine, List, Any
from starlette.responses import StreamingResponse
@ -18,3 +19,37 @@ async def llm_chat_response_nostream(chat_scene: str, **chat_param):
async def llm_chat_response(chat_scene: str, **chat_param):
chat: BaseChat = chat_factory.get_implementation(chat_scene, **chat_param)
return chat.stream_call()
def run_async_tasks(
tasks: List[Coroutine],
show_progress: bool = False,
progress_bar_desc: str = "Running async tasks",
) -> List[Any]:
"""Run a list of async tasks."""
tasks_to_execute: List[Any] = tasks
if show_progress:
try:
import nest_asyncio
from tqdm.asyncio import tqdm
nest_asyncio.apply()
loop = asyncio.get_event_loop()
async def _tqdm_gather() -> List[Any]:
return await tqdm.gather(*tasks_to_execute, desc=progress_bar_desc)
tqdm_outputs: List[Any] = loop.run_until_complete(_tqdm_gather())
return tqdm_outputs
# run the operation w/o tqdm on hitting a fatal
# may occur in some environments where tqdm.asyncio
# is not supported
except Exception:
pass
async def _gather() -> List[Any]:
return await asyncio.gather(*tasks_to_execute)
outputs: List[Any] = asyncio.run(_gather())
return outputs

View File

@ -0,0 +1,37 @@
from typing import Dict
from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene
from pilot.configs.config import Config
from pilot.scene.chat_knowledge.refine_summary.prompt import prompt
CFG = Config()
class ExtractRefineSummary(BaseChat):
chat_scene: str = ChatScene.ExtractRefineSummary.value()
"""get summary by llm"""
def __init__(self, chat_param: Dict):
""" """
chat_param["chat_mode"] = ChatScene.ExtractRefineSummary
super().__init__(
chat_param=chat_param,
)
self.user_input = chat_param["current_user_input"]
self.existing_answer = chat_param["select_param"]
# self.extract_mode = chat_param["select_param"]
def generate_input_values(self):
input_values = {
"context": self.user_input,
"existing_answer": self.existing_answer,
}
return input_values
@property
def chat_type(self) -> str:
return ChatScene.ExtractRefineSummary.value

View File

@ -0,0 +1,57 @@
import json
import logging
import re
from typing import List, Tuple
from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.config import Config
CFG = Config()
logger = logging.getLogger(__name__)
class ExtractRefineSummaryParser(BaseOutputParser):
def __init__(self, sep: str, is_stream_out: bool):
super().__init__(sep=sep, is_stream_out=is_stream_out)
def parse_prompt_response(
self, response, max_length: int = 128
) -> List[Tuple[str, str, str]]:
# clean_str = super().parse_prompt_response(response)
print("clean prompt response:", response)
# if response.startswith("Triplets:"):
# response = response[len("Triplets:") :]
# pattern = r"\([^()]+\)"
# response = re.findall(pattern, response)
# # response = response.strip().split("\n")
# print("parse prompt response:", response)
# results = []
# for text in response:
# if not text or text[0] != "(" or text[-1] != ")":
# # skip empty lines and non-triplets
# continue
# tokens = text[1:-1].split(",")
# if len(tokens) != 3:
# continue
#
# if any(len(s.encode("utf-8")) > max_length for s in tokens):
# # We count byte-length instead of len() for UTF-8 chars,
# # will skip if any of the tokens are too long.
# # This is normally due to a poorly formatted triplet
# # extraction, in more serious KG building cases
# # we'll need NLP models to better extract triplets.
# continue
#
# subject, predicate, obj = map(str.strip, tokens)
# if not subject or not predicate or not obj:
# # skip partial triplets
# continue
# results.append((subject.lower(), predicate.lower(), obj.lower()))
return response
def parse_view_response(self, speak, data) -> str:
### tool out data to table view
return data

View File

@ -0,0 +1,40 @@
from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config
from pilot.scene.base import ChatScene
from pilot.common.schema import SeparatorStyle
from pilot.scene.chat_knowledge.refine_summary.out_parser import ExtractRefineSummaryParser
CFG = Config()
PROMPT_SCENE_DEFINE = """Your job is to produce a final summary."""
_DEFAULT_TEMPLATE = """
We have provided an existing summary up to a certain point: {existing_answer}\nWe have the opportunity to refine the existing summary (only if needed) with some more context below.\n------------\n{context}\n------------\nGiven the new context, refine the original summary.\nIf the context isn't useful, return the original summary.
please use original language.
"""
PROMPT_RESPONSE = """"""
RESPONSE_FORMAT = """"""
PROMPT_SEP = SeparatorStyle.SINGLE.value
PROMPT_NEED_NEED_STREAM_OUT = False
prompt = PromptTemplate(
template_scene=ChatScene.ExtractRefineSummary.value(),
input_variables=["existing_answer","context"],
response_format="",
template_define=PROMPT_SCENE_DEFINE,
template=_DEFAULT_TEMPLATE + PROMPT_RESPONSE,
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
output_parser=ExtractRefineSummaryParser(
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
),
)
CFG.prompt_template_registry.register(prompt, is_default=True)

View File

@ -0,0 +1,35 @@
from typing import Dict
from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene
from pilot.configs.config import Config
from pilot.scene.chat_knowledge.summary.prompt import prompt
CFG = Config()
class ExtractSummary(BaseChat):
chat_scene: str = ChatScene.ExtractSummary.value()
"""get summary by llm"""
def __init__(self, chat_param: Dict):
""" """
chat_param["chat_mode"] = ChatScene.ExtractSummary
super().__init__(
chat_param=chat_param,
)
self.user_input = chat_param["current_user_input"]
# self.extract_mode = chat_param["select_param"]
def generate_input_values(self):
input_values = {
"context": self.user_input,
}
return input_values
@property
def chat_type(self) -> str:
return ChatScene.ExtractSummary.value

View File

@ -0,0 +1,28 @@
import json
import logging
import re
from typing import List, Tuple
from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.config import Config
CFG = Config()
logger = logging.getLogger(__name__)
class ExtractSummaryParser(BaseOutputParser):
def __init__(self, sep: str, is_stream_out: bool):
super().__init__(sep=sep, is_stream_out=is_stream_out)
def parse_prompt_response(
self, response, max_length: int = 128
) -> List[Tuple[str, str, str]]:
# clean_str = super().parse_prompt_response(response)
print("clean prompt response:", response)
return response
def parse_view_response(self, speak, data) -> str:
### tool out data to table view
return data

View File

@ -0,0 +1,47 @@
from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config
from pilot.scene.base import ChatScene
from pilot.common.schema import SeparatorStyle
from pilot.scene.chat_knowledge.summary.out_parser import ExtractSummaryParser
CFG = Config()
# PROMPT_SCENE_DEFINE = """You are an expert Q&A system that is trusted around the world.\nAlways answer the query using the provided context information, and not prior knowledge.\nSome rules to follow:\n1. Never directly reference the given context in your answer.\n2. Avoid statements like 'Based on the context, ...' or 'The context information ...' or anything along those lines."""
PROMPT_SCENE_DEFINE = """Your job is to produce a final summary."""
# _DEFAULT_TEMPLATE = """
# Context information from multiple sources is below.\n---------------------\n
# {context}
# Given the information from multiple sources and not prior knowledge, answer the query.\nQuery: Describe what the provided text is about. Also describe some of the questions that this text can answer. \nAnswer: "
# """
_DEFAULT_TEMPLATE = """
Write a concise summary of the following context:
{context}
please use original language.
"""
PROMPT_RESPONSE = """"""
RESPONSE_FORMAT = """"""
PROMPT_SEP = SeparatorStyle.SINGLE.value
PROMPT_NEED_NEED_STREAM_OUT = False
prompt = PromptTemplate(
template_scene=ChatScene.ExtractSummary.value(),
input_variables=["context"],
response_format="",
template_define=PROMPT_SCENE_DEFINE,
template=_DEFAULT_TEMPLATE + PROMPT_RESPONSE,
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
output_parser=ExtractSummaryParser(
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
),
)
CFG.prompt_template_registry.register(prompt, is_default=True)

View File

@ -288,7 +288,7 @@ class KnowledgeService:
executor = CFG.SYSTEM_APP.get_component(
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create()
executor.submit(self.async_document_summary, chunk_docs, doc)
# executor.submit(self.async_document_summary, chunk_docs, doc)
executor.submit(self.async_doc_embedding, client, chunk_docs, doc)
logger.info(f"begin save document chunks, doc:{doc.doc_name}")
# save chunk details
@ -431,7 +431,8 @@ class KnowledgeService:
texts = [doc.page_content for doc in chunk_docs]
prompt_helper = PromptHelper()
texts = prompt_helper.repack(prompt=DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, text_chunks=texts)
summary = self._llm_extract_summary(chunk_docs[0])
summary = self._llm_extract_summary(texts[0])
# summaries = self._mapreduce_extract_summary(texts)
outputs, summary = self._refine_extract_summary(texts[1:], summary)
logger.info(
f"async_document_summary, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to graph store"
@ -452,6 +453,7 @@ class KnowledgeService:
)
try:
vector_ids = client.knowledge_embedding_batch(chunk_docs)
self.async_document_summary(chunk_docs, doc)
doc.status = SyncStatus.FINISHED.name
doc.result = "document embedding success"
if vector_ids is not None:
@ -512,9 +514,9 @@ class KnowledgeService:
chat_param = {
"chat_session_id": uuid.uuid1(),
"current_user_input": doc.page_content,
"current_user_input": doc,
"select_param": "summary",
"model_name": "proxyllm",
"model_name": CFG.LLM_MODEL,
}
from pilot.utils import utils
loop = utils.get_or_create_event_loop()
@ -535,7 +537,7 @@ class KnowledgeService:
"chat_session_id": uuid.uuid1(),
"current_user_input": doc,
"select_param": summary,
"model_name": "proxyllm",
"model_name": CFG.LLM_MODEL,
}
from pilot.utils import utils
loop = utils.get_or_create_event_loop()
@ -547,4 +549,33 @@ class KnowledgeService:
outputs.append(summary)
return outputs, summary
def _mapreduce_extract_summary(self, docs):
"""Extract mapreduce summary by llm"""
from pilot.scene.base import ChatScene
from pilot.common.chat_util import llm_chat_response_nostream
import uuid
outputs = []
tasks = []
for doc in docs:
chat_param = {
"chat_session_id": uuid.uuid1(),
"current_user_input": doc,
"select_param": "summary",
"model_name": CFG.LLM_MODEL,
}
tasks.append(llm_chat_response_nostream(
ChatScene.ExtractSummary.value(), **{"chat_param": chat_param}
))
from pilot.common.chat_util import run_async_tasks
summaries = run_async_tasks(tasks)
# from pilot.utils import utils
# loop = utils.get_or_create_event_loop()
# summary = loop.run_until_complete(
# llm_chat_response_nostream(
# ChatScene.ExtractRefineSummary.value(), **{"chat_param": chat_param}
# )
# )
# outputs.append(summary)
return summaries