feat(awel): New AWEL RAG example

This commit is contained in:
FangYin Cheng 2023-11-21 14:33:56 +08:00
parent e67d62a785
commit 1801138b62
16 changed files with 548 additions and 179 deletions

View File

@ -0,0 +1,70 @@
"""AWEL: Simple rag example
Example:
.. code-block:: shell
curl -X POST http://127.0.0.1:5000/api/v1/awel/trigger/examples/simple_rag \
-H "Content-Type: application/json" -d '{
"conv_uid": "36f0e992-8825-11ee-8638-0242ac150003",
"model_name": "proxyllm",
"chat_mode": "chat_knowledge",
"user_input": "What is DB-GPT?",
"select_param": "default"
}'
"""
from pilot.awel import HttpTrigger, DAG, MapOperator
from pilot.scene.operator._experimental import (
ChatContext,
PromptManagerOperator,
ChatHistoryStorageOperator,
ChatHistoryOperator,
EmbeddingEngingOperator,
BaseChatOperator,
)
from pilot.scene.base import ChatScene
from pilot.openapi.api_view_model import ConversationVo
from pilot.model.base import ModelOutput
from pilot.model.operator.model_operator import ModelOperator
class RequestParseOperator(MapOperator[ConversationVo, ChatContext]):
def __init__(self, **kwargs):
super().__init__(**kwargs)
async def map(self, input_value: ConversationVo) -> ChatContext:
return ChatContext(
current_user_input=input_value.user_input,
model_name=input_value.model_name,
chat_session_id=input_value.conv_uid,
select_param=input_value.select_param,
chat_scene=ChatScene.ChatKnowledge,
)
with DAG("simple_rag_example") as dag:
trigger_task = HttpTrigger(
"/examples/simple_rag", methods="POST", request_body=ConversationVo
)
req_parse_task = RequestParseOperator()
prompt_task = PromptManagerOperator()
history_storage_task = ChatHistoryStorageOperator()
history_task = ChatHistoryOperator()
embedding_task = EmbeddingEngingOperator()
chat_task = BaseChatOperator()
model_task = ModelOperator()
output_parser_task = MapOperator(lambda out: out.to_dict()["text"])
(
trigger_task
>> req_parse_task
>> prompt_task
>> history_storage_task
>> history_task
>> embedding_task
>> chat_task
>> model_task
>> output_parser_task
)

View File

@ -7,6 +7,7 @@ import asyncio
import logging import logging
from collections import deque from collections import deque
from functools import cache from functools import cache
from concurrent.futures import Executor
from pilot.component import SystemApp from pilot.component import SystemApp
from ..resource.base import ResourceGroup from ..resource.base import ResourceGroup
@ -102,6 +103,7 @@ class DAGVar:
_thread_local = threading.local() _thread_local = threading.local()
_async_local = contextvars.ContextVar("current_dag_stack", default=deque()) _async_local = contextvars.ContextVar("current_dag_stack", default=deque())
_system_app: SystemApp = None _system_app: SystemApp = None
_executor: Executor = None
@classmethod @classmethod
def enter_dag(cls, dag) -> None: def enter_dag(cls, dag) -> None:
@ -157,6 +159,14 @@ class DAGVar:
else: else:
cls._system_app = system_app cls._system_app = system_app
@classmethod
def get_executor(cls) -> Executor:
return cls._executor
@classmethod
def set_executor(cls, executor: Executor) -> None:
cls._executor = executor
class DAGNode(DependencyMixin, ABC): class DAGNode(DependencyMixin, ABC):
resource_group: Optional[ResourceGroup] = None resource_group: Optional[ResourceGroup] = None
@ -165,9 +175,10 @@ class DAGNode(DependencyMixin, ABC):
def __init__( def __init__(
self, self,
dag: Optional["DAG"] = None, dag: Optional["DAG"] = None,
node_id: str = None, node_id: Optional[str] = None,
node_name: str = None, node_name: Optional[str] = None,
system_app: SystemApp = None, system_app: Optional[SystemApp] = None,
executor: Optional[Executor] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self._upstream: List["DAGNode"] = [] self._upstream: List["DAGNode"] = []
@ -176,6 +187,7 @@ class DAGNode(DependencyMixin, ABC):
self._system_app: Optional[SystemApp] = ( self._system_app: Optional[SystemApp] = (
system_app or DAGVar.get_current_system_app() system_app or DAGVar.get_current_system_app()
) )
self._executor: Optional[Executor] = executor or DAGVar.get_executor()
if not node_id and self._dag: if not node_id and self._dag:
node_id = self._dag._new_node_id() node_id = self._dag._new_node_id()
self._node_id: str = node_id self._node_id: str = node_id

View File

@ -14,7 +14,13 @@ from typing import (
) )
import functools import functools
from inspect import signature from inspect import signature
from pilot.component import SystemApp from pilot.component import SystemApp, ComponentType
from pilot.utils.executor_utils import (
ExecutorFactory,
DefaultExecutorFactory,
blocking_func_to_async,
BlockingFunction,
)
from ..dag.base import DAGNode, DAGContext, DAGVar, DAG from ..dag.base import DAGNode, DAGContext, DAGVar, DAG
from ..task.base import ( from ..task.base import (
@ -71,6 +77,16 @@ class BaseOperatorMeta(ABCMeta):
system_app: Optional[SystemApp] = ( system_app: Optional[SystemApp] = (
kwargs.get("system_app") or DAGVar.get_current_system_app() kwargs.get("system_app") or DAGVar.get_current_system_app()
) )
executor = kwargs.get("executor") or DAGVar.get_executor()
if not executor:
if system_app:
executor = system_app.get_component(
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create()
else:
executor = DefaultExecutorFactory().create()
DAGVar.set_executor(executor)
if not task_id and dag: if not task_id and dag:
task_id = dag._new_node_id() task_id = dag._new_node_id()
runner: Optional[WorkflowRunner] = kwargs.get("runner") or default_runner runner: Optional[WorkflowRunner] = kwargs.get("runner") or default_runner
@ -86,6 +102,8 @@ class BaseOperatorMeta(ABCMeta):
kwargs["runner"] = runner kwargs["runner"] = runner
if not kwargs.get("system_app"): if not kwargs.get("system_app"):
kwargs["system_app"] = system_app kwargs["system_app"] = system_app
if not kwargs.get("executor"):
kwargs["executor"] = executor
real_obj = func(self, *args, **kwargs) real_obj = func(self, *args, **kwargs)
return real_obj return real_obj
@ -177,6 +195,11 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
out_ctx = await self._runner.execute_workflow(self, call_data) out_ctx = await self._runner.execute_workflow(self, call_data)
return out_ctx.current_task_context.task_output.output_stream return out_ctx.current_task_context.task_output.output_stream
async def blocking_func_to_async(
self, func: BlockingFunction, *args, **kwargs
) -> Any:
return await blocking_func_to_async(self._executor, func, *args, **kwargs)
def initialize_runner(runner: WorkflowRunner): def initialize_runner(runner: WorkflowRunner):
global default_runner global default_runner

View File

@ -67,7 +67,7 @@ class DefaultWorkflowRunner(WorkflowRunner):
node_outputs[node.node_id] = task_ctx node_outputs[node.node_id] = task_ctx
return return
try: try:
logger.info( logger.debug(
f"Begin run operator, node id: {node.node_id}, node name: {node.node_name}, cls: {node}" f"Begin run operator, node id: {node.node_id}, node name: {node.node_name}, cls: {node}"
) )
await node._run(dag_ctx) await node._run(dag_ctx)
@ -76,7 +76,7 @@ class DefaultWorkflowRunner(WorkflowRunner):
if isinstance(node, BranchOperator): if isinstance(node, BranchOperator):
skip_nodes = task_ctx.metadata.get("skip_node_names", []) skip_nodes = task_ctx.metadata.get("skip_node_names", [])
logger.info( logger.debug(
f"Current is branch operator, skip node names: {skip_nodes}" f"Current is branch operator, skip node names: {skip_nodes}"
) )
_skip_current_downstream_by_node_name(node, skip_nodes, skip_node_ids) _skip_current_downstream_by_node_name(node, skip_nodes, skip_node_ids)

View File

@ -47,7 +47,7 @@ class DbHistoryMemory(BaseChatHistoryMemory):
logger.error("init create conversation log error" + str(e)) logger.error("init create conversation log error" + str(e))
def append(self, once_message: OnceConversation) -> None: def append(self, once_message: OnceConversation) -> None:
logger.info(f"db history append: {once_message}") logger.debug(f"db history append: {once_message}")
chat_history: ChatHistoryEntity = self.chat_history_dao.get_by_uid( chat_history: ChatHistoryEntity = self.chat_history_dao.get_by_uid(
self.chat_seesion_id self.chat_seesion_id
) )

View File

@ -143,9 +143,7 @@ def _build_request(model: ProxyModel, params):
proxyllm_backend = proxyllm_backend or "gpt-3.5-turbo" proxyllm_backend = proxyllm_backend or "gpt-3.5-turbo"
payloads["model"] = proxyllm_backend payloads["model"] = proxyllm_backend
logger.info( logger.info(f"Send request to real model {proxyllm_backend}")
f"Send request to real model {proxyllm_backend}, openai_params: {openai_params}"
)
return history, payloads return history, payloads

View File

@ -68,7 +68,7 @@ class BaseChat(ABC):
CFG.prompt_template_registry.get_prompt_template( CFG.prompt_template_registry.get_prompt_template(
self.chat_mode.value(), self.chat_mode.value(),
language=CFG.LANGUAGE, language=CFG.LANGUAGE,
model_name=CFG.LLM_MODEL, model_name=self.llm_model,
proxyllm_backend=CFG.PROXYLLM_BACKEND, proxyllm_backend=CFG.PROXYLLM_BACKEND,
) )
) )
@ -141,13 +141,7 @@ class BaseChat(ABC):
return speak_to_user return speak_to_user
async def __call_base(self): async def __call_base(self):
import inspect input_values = await self.generate_input_values()
input_values = (
await self.generate_input_values()
if inspect.isawaitable(self.generate_input_values())
else self.generate_input_values()
)
### Chat sequence advance ### Chat sequence advance
self.current_message.chat_order = len(self.history_message) + 1 self.current_message.chat_order = len(self.history_message) + 1
self.current_message.add_user_message(self.current_user_input) self.current_message.add_user_message(self.current_user_input)
@ -379,16 +373,18 @@ class BaseChat(ABC):
if self.prompt_template.template_define: if self.prompt_template.template_define:
text += self.prompt_template.template_define + self.prompt_template.sep text += self.prompt_template.template_define + self.prompt_template.sep
### Load prompt ### Load prompt
text += self.__load_system_message() text += _load_system_message(self.current_message, self.prompt_template)
### Load examples ### Load examples
text += self.__load_example_messages() text += _load_example_messages(self.prompt_template)
### Load History ### Load History
text += self.__load_history_messages() text += _load_history_messages(
self.prompt_template, self.history_message, self.chat_retention_rounds
)
### Load User Input ### Load User Input
text += self.__load_user_message() text += _load_user_message(self.current_message, self.prompt_template)
return text return text
def generate_llm_messages(self) -> List[ModelMessage]: def generate_llm_messages(self) -> List[ModelMessage]:
@ -406,137 +402,26 @@ class BaseChat(ABC):
) )
) )
### Load prompt ### Load prompt
messages += self.__load_system_message(str_message=False) messages += _load_system_message(
self.current_message, self.prompt_template, str_message=False
)
### Load examples ### Load examples
messages += self.__load_example_messages(str_message=False) messages += _load_example_messages(self.prompt_template, str_message=False)
### Load History ### Load History
messages += self.__load_history_messages(str_message=False) messages += _load_history_messages(
self.prompt_template,
self.history_message,
self.chat_retention_rounds,
str_message=False,
)
### Load User Input ### Load User Input
messages += self.__load_user_message(str_message=False) messages += _load_user_message(
self.current_message, self.prompt_template, str_message=False
)
return messages return messages
def __load_system_message(self, str_message: bool = True):
system_convs = self.current_message.get_system_conv()
system_text = ""
system_messages = []
for system_conv in system_convs:
system_text += (
system_conv.type + ":" + system_conv.content + self.prompt_template.sep
)
system_messages.append(
ModelMessage(role=system_conv.type, content=system_conv.content)
)
return system_text if str_message else system_messages
def __load_user_message(self, str_message: bool = True):
user_conv = self.current_message.get_user_conv()
user_messages = []
if user_conv:
user_text = (
user_conv.type + ":" + user_conv.content + self.prompt_template.sep
)
user_messages.append(
ModelMessage(role=user_conv.type, content=user_conv.content)
)
return user_text if str_message else user_messages
else:
raise ValueError("Hi! What do you want to talk about")
def __load_example_messages(self, str_message: bool = True):
example_text = ""
example_messages = []
if self.prompt_template.example_selector:
for round_conv in self.prompt_template.example_selector.examples():
for round_message in round_conv["messages"]:
if not round_message["type"] in [
ModelMessageRoleType.VIEW,
ModelMessageRoleType.SYSTEM,
]:
message_type = round_message["type"]
message_content = round_message["data"]["content"]
example_text += (
message_type
+ ":"
+ message_content
+ self.prompt_template.sep
)
example_messages.append(
ModelMessage(role=message_type, content=message_content)
)
return example_text if str_message else example_messages
def __load_history_messages(self, str_message: bool = True):
history_text = ""
history_messages = []
if self.prompt_template.need_historical_messages:
if self.history_message:
logger.info(
f"There are already {len(self.history_message)} rounds of conversations! Will use {self.chat_retention_rounds} rounds of content as history!"
)
if len(self.history_message) > self.chat_retention_rounds:
for first_message in self.history_message[0]["messages"]:
if not first_message["type"] in [
ModelMessageRoleType.VIEW,
ModelMessageRoleType.SYSTEM,
]:
message_type = first_message["type"]
message_content = first_message["data"]["content"]
history_text += (
message_type
+ ":"
+ message_content
+ self.prompt_template.sep
)
history_messages.append(
ModelMessage(role=message_type, content=message_content)
)
if self.chat_retention_rounds > 1:
index = self.chat_retention_rounds - 1
for round_conv in self.history_message[-index:]:
for round_message in round_conv["messages"]:
if not round_message["type"] in [
ModelMessageRoleType.VIEW,
ModelMessageRoleType.SYSTEM,
]:
message_type = round_message["type"]
message_content = round_message["data"]["content"]
history_text += (
message_type
+ ":"
+ message_content
+ self.prompt_template.sep
)
history_messages.append(
ModelMessage(
role=message_type, content=message_content
)
)
else:
### user all history
for conversation in self.history_message:
for message in conversation["messages"]:
### histroy message not have promot and view info
if not message["type"] in [
ModelMessageRoleType.VIEW,
ModelMessageRoleType.SYSTEM,
]:
message_type = message["type"]
message_content = message["data"]["content"]
history_text += (
message_type
+ ":"
+ message_content
+ self.prompt_template.sep
)
history_messages.append(
ModelMessage(role=message_type, content=message_content)
)
return history_text if str_message else history_messages
def current_ai_response(self) -> str: def current_ai_response(self) -> str:
for message in self.current_message.messages: for message in self.current_message.messages:
if message.type == "view": if message.type == "view":
@ -656,3 +541,127 @@ def _build_model_operator(
cache_check_branch_node >> cached_node >> join_node cache_check_branch_node >> cached_node >> join_node
return join_node return join_node
def _load_system_message(
current_message: OnceConversation,
prompt_template: PromptTemplate,
str_message: bool = True,
):
system_convs = current_message.get_system_conv()
system_text = ""
system_messages = []
for system_conv in system_convs:
system_text += (
system_conv.type + ":" + system_conv.content + prompt_template.sep
)
system_messages.append(
ModelMessage(role=system_conv.type, content=system_conv.content)
)
return system_text if str_message else system_messages
def _load_user_message(
current_message: OnceConversation,
prompt_template: PromptTemplate,
str_message: bool = True,
):
user_conv = current_message.get_user_conv()
user_messages = []
if user_conv:
user_text = user_conv.type + ":" + user_conv.content + prompt_template.sep
user_messages.append(
ModelMessage(role=user_conv.type, content=user_conv.content)
)
return user_text if str_message else user_messages
else:
raise ValueError("Hi! What do you want to talk about")
def _load_example_messages(prompt_template: PromptTemplate, str_message: bool = True):
example_text = ""
example_messages = []
if prompt_template.example_selector:
for round_conv in prompt_template.example_selector.examples():
for round_message in round_conv["messages"]:
if not round_message["type"] in [
ModelMessageRoleType.VIEW,
ModelMessageRoleType.SYSTEM,
]:
message_type = round_message["type"]
message_content = round_message["data"]["content"]
example_text += (
message_type + ":" + message_content + prompt_template.sep
)
example_messages.append(
ModelMessage(role=message_type, content=message_content)
)
return example_text if str_message else example_messages
def _load_history_messages(
prompt_template: PromptTemplate,
history_message: List[OnceConversation],
chat_retention_rounds: int,
str_message: bool = True,
):
history_text = ""
history_messages = []
if prompt_template.need_historical_messages:
if history_message:
logger.info(
f"There are already {len(history_message)} rounds of conversations! Will use {chat_retention_rounds} rounds of content as history!"
)
if len(history_message) > chat_retention_rounds:
for first_message in history_message[0]["messages"]:
if not first_message["type"] in [
ModelMessageRoleType.VIEW,
ModelMessageRoleType.SYSTEM,
]:
message_type = first_message["type"]
message_content = first_message["data"]["content"]
history_text += (
message_type + ":" + message_content + prompt_template.sep
)
history_messages.append(
ModelMessage(role=message_type, content=message_content)
)
if chat_retention_rounds > 1:
index = chat_retention_rounds - 1
for round_conv in history_message[-index:]:
for round_message in round_conv["messages"]:
if not round_message["type"] in [
ModelMessageRoleType.VIEW,
ModelMessageRoleType.SYSTEM,
]:
message_type = round_message["type"]
message_content = round_message["data"]["content"]
history_text += (
message_type
+ ":"
+ message_content
+ prompt_template.sep
)
history_messages.append(
ModelMessage(role=message_type, content=message_content)
)
else:
### user all history
for conversation in history_message:
for message in conversation["messages"]:
### histroy message not have promot and view info
if not message["type"] in [
ModelMessageRoleType.VIEW,
ModelMessageRoleType.SYSTEM,
]:
message_type = message["type"]
message_content = message["data"]["content"]
history_text += (
message_type + ":" + message_content + prompt_template.sep
)
history_messages.append(
ModelMessage(role=message_type, content=message_content)
)
return history_text if str_message else history_messages

View File

@ -6,7 +6,6 @@ import re
import sqlparse import sqlparse
import pandas as pd import pandas as pd
import chardet import chardet
import pandas as pd
import numpy as np import numpy as np
from pyparsing import ( from pyparsing import (
CaselessKeyword, CaselessKeyword,
@ -27,6 +26,8 @@ from pyparsing import (
from pilot.common.pd_utils import csv_colunm_foramt from pilot.common.pd_utils import csv_colunm_foramt
from pilot.common.string_utils import is_chinese_include_number from pilot.common.string_utils import is_chinese_include_number
logger = logging.getLogger(__name__)
def excel_colunm_format(old_name: str) -> str: def excel_colunm_format(old_name: str) -> str:
new_column = old_name.strip() new_column = old_name.strip()
@ -263,7 +264,7 @@ class ExcelReader:
file_name = os.path.basename(file_path) file_name = os.path.basename(file_path)
self.file_name_without_extension = os.path.splitext(file_name)[0] self.file_name_without_extension = os.path.splitext(file_name)[0]
encoding, confidence = detect_encoding(file_path) encoding, confidence = detect_encoding(file_path)
logging.error(f"Detected Encoding: {encoding} (Confidence: {confidence})") logger.error(f"Detected Encoding: {encoding} (Confidence: {confidence})")
self.excel_file_name = file_name self.excel_file_name = file_name
self.extension = os.path.splitext(file_name)[1] self.extension = os.path.splitext(file_name)[1]
# read excel file # read excel file
@ -323,7 +324,7 @@ class ExcelReader:
colunms.append(descrip[0]) colunms.append(descrip[0])
return colunms, results.fetchall() return colunms, results.fetchall()
except Exception as e: except Exception as e:
logging.error("excel sql run error!", e) logger.error(f"excel sql run error!, {str(e)}")
raise ValueError(f"Data Query Exception!\\nSQL[{sql}].\\nError:{str(e)}") raise ValueError(f"Data Query Exception!\\nSQL[{sql}].\\nError:{str(e)}")
def get_df_by_sql_ex(self, sql): def get_df_by_sql_ex(self, sql):

View File

@ -37,7 +37,7 @@ class DbChatOutputParser(BaseOutputParser):
def parse_prompt_response(self, model_out_text): def parse_prompt_response(self, model_out_text):
clean_str = super().parse_prompt_response(model_out_text) clean_str = super().parse_prompt_response(model_out_text)
logging.info("clean prompt response:", clean_str) logger.info(f"clean prompt response: {clean_str}")
# Compatible with community pure sql output model # Compatible with community pure sql output model
if self.is_sql_statement(clean_str): if self.is_sql_statement(clean_str):
return SqlAction(clean_str, "") return SqlAction(clean_str, "")
@ -51,7 +51,7 @@ class DbChatOutputParser(BaseOutputParser):
thoughts = response[key] thoughts = response[key]
return SqlAction(sql, thoughts) return SqlAction(sql, thoughts)
except Exception as e: except Exception as e:
logging.error("json load faild") logger.error("json load faild")
return SqlAction("", clean_str) return SqlAction("", clean_str)
def parse_view_response(self, speak, data, prompt_response) -> str: def parse_view_response(self, speak, data, prompt_response) -> str:

View File

@ -24,7 +24,7 @@ class ExtractEntity(BaseChat):
self.user_input = chat_param["current_user_input"] self.user_input = chat_param["current_user_input"]
self.extract_mode = chat_param["select_param"] self.extract_mode = chat_param["select_param"]
def generate_input_values(self): async def generate_input_values(self):
input_values = { input_values = {
"text": self.user_input, "text": self.user_input,
} }

View File

@ -24,7 +24,7 @@ class ExtractTriplet(BaseChat):
self.user_input = chat_param["current_user_input"] self.user_input = chat_param["current_user_input"]
self.extract_mode = chat_param["select_param"] self.extract_mode = chat_param["select_param"]
def generate_input_values(self): async def generate_input_values(self):
input_values = { input_values = {
"text": self.user_input, "text": self.user_input,
} }

View File

@ -23,7 +23,7 @@ class ExtractRefineSummary(BaseChat):
self.existing_answer = chat_param["select_param"] self.existing_answer = chat_param["select_param"]
def generate_input_values(self): async def generate_input_values(self):
input_values = { input_values = {
# "context": self.user_input, # "context": self.user_input,
"existing_answer": self.existing_answer, "existing_answer": self.existing_answer,

View File

@ -23,7 +23,7 @@ class ExtractSummary(BaseChat):
self.user_input = chat_param["select_param"] self.user_input = chat_param["select_param"]
def generate_input_values(self): async def generate_input_values(self):
input_values = { input_values = {
"context": self.user_input, "context": self.user_input,
} }

View File

@ -104,7 +104,7 @@ class ChatKnowledge(BaseChat):
self.current_user_input, self.current_user_input,
self.top_k, self.top_k,
) )
self.sources = self.merge_by_key( self.sources = _merge_by_key(
list(map(lambda doc: doc.metadata, docs)), "source" list(map(lambda doc: doc.metadata, docs)), "source"
) )
@ -149,29 +149,6 @@ class ChatKnowledge(BaseChat):
) )
return html return html
def merge_by_key(self, data, key):
result = {}
for item in data:
if item.get(key):
item_key = os.path.basename(item.get(key))
if item_key in result:
if "pages" in result[item_key] and "page" in item:
result[item_key]["pages"].append(str(item["page"]))
elif "page" in item:
result[item_key]["pages"] = [
result[item_key]["pages"],
str(item["page"]),
]
else:
if "page" in item:
result[item_key] = {
"source": item_key,
"pages": [str(item["page"])],
}
else:
result[item_key] = {"source": item_key}
return list(result.values())
@property @property
def chat_type(self) -> str: def chat_type(self) -> str:
return ChatScene.ChatKnowledge.value() return ChatScene.ChatKnowledge.value()
@ -179,3 +156,27 @@ class ChatKnowledge(BaseChat):
def get_space_context(self, space_name): def get_space_context(self, space_name):
service = KnowledgeService() service = KnowledgeService()
return service.get_space_context(space_name) return service.get_space_context(space_name)
def _merge_by_key(data, key):
result = {}
for item in data:
if item.get(key):
item_key = os.path.basename(item.get(key))
if item_key in result:
if "pages" in result[item_key] and "page" in item:
result[item_key]["pages"].append(str(item["page"]))
elif "page" in item:
result[item_key]["pages"] = [
result[item_key]["pages"],
str(item["page"]),
]
else:
if "page" in item:
result[item_key] = {
"source": item_key,
"pages": [str(item["page"])],
}
else:
result[item_key] = {"source": item_key}
return list(result.values())

View File

@ -0,0 +1,255 @@
from typing import Dict, Optional, List, Any
from dataclasses import dataclass
import datetime
import os
from pilot.awel import MapOperator
from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config
from pilot.scene.base import ChatScene
from pilot.scene.message import OnceConversation
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
from pilot.memory.chat_history.base import BaseChatHistoryMemory
from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory
# TODO move global config
CFG = Config()
@dataclass
class ChatContext:
current_user_input: str
model_name: Optional[str]
chat_session_id: Optional[str] = None
select_param: Optional[str] = None
chat_scene: Optional[ChatScene] = ChatScene.ChatNormal
prompt_template: Optional[PromptTemplate] = None
chat_retention_rounds: Optional[int] = 0
history_storage: Optional[BaseChatHistoryMemory] = None
history_manager: Optional["ChatHistoryManager"] = None
# The input values for prompt template
input_values: Optional[Dict] = None
echo: Optional[bool] = False
def build_model_payload(self) -> Dict:
if not self.input_values:
raise ValueError("The input value can't be empty")
llm_messages = self.history_manager._new_chat(self.input_values)
return {
"model": self.model_name,
"prompt": "",
"messages": llm_messages,
"temperature": float(self.prompt_template.temperature),
"max_new_tokens": int(self.prompt_template.max_new_tokens),
"echo": self.echo,
}
class ChatHistoryManager:
def __init__(
self,
chat_ctx: ChatContext,
prompt_template: PromptTemplate,
history_storage: BaseChatHistoryMemory,
chat_retention_rounds: Optional[int] = 0,
) -> None:
self._chat_ctx = chat_ctx
self.chat_retention_rounds = chat_retention_rounds
self.current_message: OnceConversation = OnceConversation(
chat_ctx.chat_scene.value()
)
self.prompt_template = prompt_template
self.history_storage: BaseChatHistoryMemory = history_storage
self.history_message: List[OnceConversation] = history_storage.messages()
self.current_message.model_name = chat_ctx.model_name
if chat_ctx.select_param:
if len(chat_ctx.chat_scene.param_types()) > 0:
self.current_message.param_type = chat_ctx.chat_scene.param_types()[0]
self.current_message.param_value = chat_ctx.select_param
def _new_chat(self, input_values: Dict) -> List[ModelMessage]:
self.current_message.chat_order = len(self.history_message) + 1
self.current_message.add_user_message(self._chat_ctx.current_user_input)
self.current_message.start_date = datetime.datetime.now().strftime(
"%Y-%m-%d %H:%M:%S"
)
self.current_message.tokens = 0
if self.prompt_template.template:
current_prompt = self.prompt_template.format(**input_values)
self.current_message.add_system_message(current_prompt)
return self._generate_llm_messages()
def _generate_llm_messages(self) -> List[ModelMessage]:
from pilot.scene.base_chat import (
_load_system_message,
_load_example_messages,
_load_history_messages,
_load_user_message,
)
messages = []
### Load scene setting or character definition as system message
if self.prompt_template.template_define:
messages.append(
ModelMessage(
role=ModelMessageRoleType.SYSTEM,
content=self.prompt_template.template_define,
)
)
### Load prompt
messages += _load_system_message(
self.current_message, self.prompt_template, str_message=False
)
### Load examples
messages += _load_example_messages(self.prompt_template, str_message=False)
### Load History
messages += _load_history_messages(
self.prompt_template,
self.history_message,
self.chat_retention_rounds,
str_message=False,
)
### Load User Input
messages += _load_user_message(
self.current_message, self.prompt_template, str_message=False
)
return messages
class PromptManagerOperator(MapOperator[ChatContext, ChatContext]):
def __init__(self, prompt_template: PromptTemplate = None, **kwargs):
super().__init__(**kwargs)
self._prompt_template = prompt_template
async def map(self, input_value: ChatContext) -> ChatContext:
if not self._prompt_template:
self._prompt_template: PromptTemplate = (
CFG.prompt_template_registry.get_prompt_template(
input_value.chat_scene.value(),
language=CFG.LANGUAGE,
model_name=input_value.model_name,
proxyllm_backend=CFG.PROXYLLM_BACKEND,
)
)
input_value.prompt_template = self._prompt_template
return input_value
class ChatHistoryStorageOperator(MapOperator[ChatContext, ChatContext]):
def __init__(self, history: BaseChatHistoryMemory = None, **kwargs):
super().__init__(**kwargs)
self._history = history
async def map(self, input_value: ChatContext) -> ChatContext:
if self._history:
return self._history
chat_history_fac = ChatHistory()
input_value.history_storage = chat_history_fac.get_store_instance(
input_value.chat_session_id
)
return input_value
class ChatHistoryOperator(MapOperator[ChatContext, ChatContext]):
def __init__(self, history: BaseChatHistoryMemory = None, **kwargs):
super().__init__(**kwargs)
self._history = history
async def map(self, input_value: ChatContext) -> ChatContext:
history_storage = self._history or input_value.history_storage
if not history_storage:
from pilot.memory.chat_history.store_type.mem_history import (
MemHistoryMemory,
)
history_storage = MemHistoryMemory(input_value.chat_session_id)
input_value.history_storage = history_storage
input_value.history_manager = ChatHistoryManager(
input_value,
input_value.prompt_template,
history_storage,
input_value.chat_retention_rounds,
)
return input_value
class EmbeddingEngingOperator(MapOperator[ChatContext, ChatContext]):
def __init__(self, **kwargs):
super().__init__(**kwargs)
async def map(self, input_value: ChatContext) -> ChatContext:
from pilot.configs.model_config import EMBEDDING_MODEL_CONFIG
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
from pilot.scene.chat_knowledge.v1.chat import _merge_by_key
# TODO, decompose the current operator into some atomic operators
knowledge_space = input_value.select_param
vector_store_config = {
"vector_store_name": knowledge_space,
"vector_store_type": CFG.VECTOR_STORE_TYPE,
}
embedding_factory = self.system_app.get_component(
"embedding_factory", EmbeddingFactory
)
knowledge_embedding_client = EmbeddingEngine(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config=vector_store_config,
embedding_factory=embedding_factory,
)
space_context = await self._get_space_context(knowledge_space)
top_k = (
CFG.KNOWLEDGE_SEARCH_TOP_SIZE
if space_context is None
else int(space_context["embedding"]["topk"])
)
max_token = (
CFG.KNOWLEDGE_SEARCH_MAX_TOKEN
if space_context is None or space_context.get("prompt") is None
else int(space_context["prompt"]["max_token"])
)
input_value.prompt_template.template_is_strict = False
if space_context and space_context.get("prompt"):
input_value.prompt_template.template_define = space_context["prompt"][
"scene"
]
input_value.prompt_template.template = space_context["prompt"]["template"]
docs = await self.blocking_func_to_async(
knowledge_embedding_client.similar_search,
input_value.current_user_input,
top_k,
)
sources = _merge_by_key(list(map(lambda doc: doc.metadata, docs)), "source")
if not docs or len(docs) == 0:
print("no relevant docs to retrieve")
context = "no relevant docs to retrieve"
else:
context = [d.page_content for d in docs]
context = context[:max_token]
relations = list(
set([os.path.basename(str(d.metadata.get("source", ""))) for d in docs])
)
input_value.input_values = {
"context": context,
"question": input_value.current_user_input,
"relations": relations,
}
return input_value
async def _get_space_context(self, space_name):
from pilot.server.knowledge.service import KnowledgeService
service = KnowledgeService()
return await self.blocking_func_to_async(service.get_space_context, space_name)
class BaseChatOperator(MapOperator[ChatContext, Dict]):
def __init__(self, **kwargs):
super().__init__(**kwargs)
async def map(self, input_value: ChatContext) -> Dict:
return input_value.build_model_payload()