多场景对架构一期0525

This commit is contained in:
yhjun1026
2023-05-25 10:04:25 +08:00
parent 2fc62c16ef
commit ff99723014
12 changed files with 190 additions and 63 deletions

View File

@@ -2,10 +2,15 @@ import markdown2
import pandas as pd
def datas_to_table_html(data):
df = pd.DataFrame(data)
table_style = """\n<style>\n table {\n border-collapse: collapse;\n width: 100%;\n }\n th, td {\n border: 1px solid blue;\n padding: 5px;\n text-align: left;\n }\n th {\n background-color: #f2f2f2;\n }\n</style>\n"""
html_table = df.to_html(index=False, header=False, border = True)
return table_style + html_table
df = pd.DataFrame(data[1:], columns=data[0])
table_style = """<style>
table{border-collapse:collapse;width:60%;height:80%;margin:0 auto;float:right;border: 1px solid #007bff; background-color:#CFE299}th,td{border:1px solid #ddd;padding:3px;text-align:center}th{background-color:#C9C3C7;color: #fff;font-weight: bold;}tr:nth-child(even){background-color:#7C9F4A}tr:hover{background-color:#333}
</style>"""
html_table = df.to_html(index=False, escape=False)
html = f"<html><head>{table_style}</head><body>{html_table}</body></html>"
return html.replace("\n", " ")
@@ -43,6 +48,9 @@ def generate_htm_table(data):
if __name__ == "__main__":
mk_text = "| user_name | phone | email | city | create_time | last_login_time | \n| --- | --- | --- | --- | --- | --- | \n| zhangsan | 123 | None | 上海 | 2023-05-13 09:09:09 | None | \n| hanmeimei | 123 | None | 上海 | 2023-05-13 09:09:09 | None | \n| wangwu | 123 | None | 上海 | 2023-05-13 09:09:09 | None | \n| test1 | 123 | None | 成都 | 2023-05-12 09:09:09 | None | \n| test2 | 123 | None | 成都 | 2023-05-11 09:09:09 | None | \n| test3 | 23 | None | 成都 | 2023-05-12 09:09:09 | None | \n| test4 | 23 | None | 成都 | 2023-05-09 09:09:09 | None | \n| test5 | 123 | None | 上海 | 2023-05-08 09:09:09 | None | \n| test6 | 123 | None | 成都 | 2023-05-08 09:09:09 | None | \n| test7 | 23 | None | 上海 | 2023-05-10 09:09:09 | None |\n"
# mk_text = "| user_name | phone | email | city | create_time | last_login_time | \n| --- | --- | --- | --- | --- | --- | \n| zhangsan | 123 | None | 上海 | 2023-05-13 09:09:09 | None | \n| hanmeimei | 123 | None | 上海 | 2023-05-13 09:09:09 | None | \n| wangwu | 123 | None | 上海 | 2023-05-13 09:09:09 | None | \n| test1 | 123 | None | 成都 | 2023-05-12 09:09:09 | None | \n| test2 | 123 | None | 成都 | 2023-05-11 09:09:09 | None | \n| test3 | 23 | None | 成都 | 2023-05-12 09:09:09 | None | \n| test4 | 23 | None | 成都 | 2023-05-09 09:09:09 | None | \n| test5 | 123 | None | 上海 | 2023-05-08 09:09:09 | None | \n| test6 | 123 | None | 成都 | 2023-05-08 09:09:09 | None | \n| test7 | 23 | None | 上海 | 2023-05-10 09:09:09 | None |\n"
# print(generate_htm_table(mk_text))
print(generate_htm_table(mk_text))
table_style = """<style>\n table {\n border-collapse: collapse;\n width: 100%;\n }\n th, td {\n border: 1px solid #ddd;\n padding: 8px;\n text-align: center;\n line-height: 150px; \n }\n th {\n background-color: #f2f2f2;\n color: #333;\n font-weight: bold;\n }\n tr:nth-child(even) {\n background-color: #f9f9f9;\n }\n tr:hover {\n background-color: #f2f2f2;\n }\n </style>"""
print(table_style.replace("\n", " "))

View File

@@ -21,13 +21,23 @@ from pilot.prompts.base import PromptValue
T = TypeVar("T")
class BaseOutputParser(BaseModel, ABC, Generic[T]):
class BaseOutputParser(ABC):
"""Class to parse the output of an LLM call.
Output parsers help structure language model responses.
"""
def parse_model_nostream_resp(self, response, sep: str):
def __init__(self,sep:str, is_stream_out:bool):
self.sep = sep
self.is_stream_out = is_stream_out
# TODO 后续和模型绑定
def _parse_model_stream_resp(self, response, sep: str):
pass
def _parse_model_nostream_resp(self, response, sep: str):
text = response.text.strip()
text = text.rstrip()
respObj = json.loads(text)
@@ -52,35 +62,44 @@ class BaseOutputParser(BaseModel, ABC, Generic[T]):
else:
raise ValueError("Model server error!code=" + respObj_ex['error_code']);
@abstractmethod
def parse(self, text: str) -> T:
"""Parse the output of an LLM call.
A method which takes in a string (assumed output of language model )
and parses it into some structure.
def parse_model_server_out(self, response)->str:
"""
parse the model server http response
Args:
text: output of language model
response:
Returns:
structured output
"""
if self.is_stream_out:
self._parse_model_nostream_resp(response, self.sep)
else:
### TODO
self._parse_model_stream_resp(response, self.sep)
def parse_with_prompt(self, completion: str, prompt: PromptValue) -> Any:
"""Optional method to parse the output of an LLM call with a prompt.
The prompt is largely provided in the event the OutputParser wants
to retry or fix the output in some way, and needs information from
the prompt to do so.
def parse_prompt_response(self, model_out_text)->T:
"""
parse model out text to prompt define response
Args:
completion: output of language model
prompt: prompt value
model_out_text:
Returns:
structured output
"""
return self.parse(completion)
pass
def parse_view_response(self, ai_text)->str:
"""
parse the ai response info to user view
Args:
text:
Returns:
"""
pass
def get_format_instructions(self) -> str:
"""Instructions on how the LLM output should be formatted."""

View File

@@ -6,7 +6,7 @@ from pydantic import BaseModel, Extra, Field, root_validator
from pilot.common.formatting import formatter
from pilot.out_parser.base import BaseOutputParser
from pilot.common.schema import SeparatorStyle
def jinja2_formatter(template: str, **kwargs: Any) -> str:
"""Format a template using jinja2."""
@@ -36,7 +36,17 @@ class PromptTemplate(BaseModel, ABC):
template_format: str = "f-string"
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
response_format:str
"""default use stream out"""
stream_out: bool = True
""""""
output_parser: BaseOutputParser = None
""""""
sep:str = SeparatorStyle.SINGLE.value
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@property
def _prompt_type(self) -> str:
"""Return the prompt type key."""

View File

@@ -32,16 +32,14 @@ from pilot.configs.config import Config
logger = build_logger("BaseChat", LOGDIR + "BaseChat.log")
headers = {"User-Agent": "dbgpt Client"}
CFG = Config()
class BaseChat(ABC):
class BaseChat( ABC):
chat_scene: str = None
memory: BaseChatHistoryMemory
llm_model: Any = None
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
temperature: float = 0.6
max_new_tokens: int = 1024
# By default, keep the last two rounds of conversation records as the context
chat_retention_rounds: int = 2
sep = SeparatorStyle.SINGLE.value
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True

View File

@@ -61,6 +61,17 @@ class AIMessage(BaseMessage):
return "ai"
class ViewMessage(BaseMessage):
"""Type of message that is spoken by the AI."""
example: bool = False
@property
def type(self) -> str:
"""Type of the message, used for serialization."""
return "view"
class SystemMessage(BaseMessage):
"""Type of message that is a system message."""
@@ -132,6 +143,8 @@ def _message_from_dict(message: dict) -> BaseMessage:
return AIMessage(**message["data"])
elif _type == "system":
return SystemMessage(**message["data"])
elif _type == "view":
return ViewMessage(**message["data"])
else:
raise ValueError(f"Got unexpected type: {_type}")

View File

@@ -13,7 +13,7 @@ from sqlalchemy import (
)
from typing import Any, Iterable, List, Optional
from pilot.scene.base_message import BaseMessage, SystemMessage, HumanMessage, AIMessage
from pilot.scene.base_message import BaseMessage, SystemMessage, HumanMessage, AIMessage, ViewMessage
from pilot.scene.base_chat import BaseChat, logger, headers
from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database
@@ -26,13 +26,15 @@ from pilot.utils import (
)
from pilot.common.markdown_text import generate_markdown_table,generate_htm_table,datas_to_table_html
from pilot.scene.chat_db.prompt import chat_db_prompt
from pilot.out_parser.base import BaseOutputParser
from pilot.scene.chat_db.out_parser import DbChatOutputParser
CFG = Config()
class ChatWithDb(BaseChat):
chat_scene: str = ChatScene.ChatWithDb.value
"""Number of results to return from the query"""
def __init__(self, chat_session_id, db_name, user_input):
@@ -77,17 +79,22 @@ class ChatWithDb(BaseChat):
"prompt": self.generate_llm_text(),
"temperature": float(self.temperature),
"max_new_tokens": int(self.max_new_tokens),
"stop": self.sep_style.value,
"stop": self.prompt_template.sep,
}
logger.info(f"Requert: \n{payload}")
try:
### 走非流式的模型服务接口
# TODO - TEST
# # TODO - TEST
# response = requests.post(urljoin(CFG.MODEL_SERVER, "generate"), headers=headers, json=payload, timeout=120)
# clear_response = self.prompt_template.output_parser.parse_model_nostream_resp(response, self.sep_style)
# sql_action = self.prompt_template.output_parser.parse(clear_response)
# ai_response_text = self.prompt_template.output_parser.parse_model_server_out(response)
#
# prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text)
# self.current_message.add_ai_message(json.dumps(prompt_define_response._asdict()))
# result = self.database.run(self.db_connect, prompt_define_response.SQL)
resp_test = {
"SQL": "select * from users",
"thoughts": {
@@ -100,12 +107,10 @@ class ChatWithDb(BaseChat):
}
sql_action = SqlAction(**resp_test)
# self.current_message.add_ai_message(json.dumps(sql_action._asdict()))
self.current_message.add_ai_message(json.dumps(sql_action._asdict()))
result = self.database.run(self.db_connect, sql_action.SQL)
self.current_message.add_ai_message(f"{datas_to_table_html(result)}")
self.current_message.add_view_message(self.prompt_template.output_parser.parse_view_response(result))
except Exception as e:
logger.error("model response parase faild" + str(e))
@@ -118,17 +123,19 @@ class ChatWithDb(BaseChat):
ret = []
# 单论对话只能有一次User 记录 和一次 AI 记录
# TODO 推理过程前端展示。。。
for message in enumerate(self.current_message.messages):
for message in self.current_message.messages:
if (isinstance(message, HumanMessage)):
ret[-1][-2] = message.content
if (isinstance(message, AIMessage)):
# 是否展示推理过程
if (isinstance(message, ViewMessage)):
ret[-1][-1] = message.content
return ret
# 暂时为了兼容前端
def current_ai_response(self)->str:
for message in self.current_message.messages:
if message.type == 'ai':
if message.type == 'view':
return message.content
return None
@@ -137,28 +144,31 @@ class ChatWithDb(BaseChat):
text = ""
### 线处理历史信息
if (len(self.history_message) > self.chat_retention_rounds):
### 使用历史信息的第一轮和最后一轮数据合并成历史对话记录
### 使用历史信息的第一轮和最后一轮数据合并成历史对话记录, 做上下文提示时,用户展示消息需要过滤掉
for first_message in self.history_message[0].messages:
text += first_message.type + ":" + first_message.content + self.sep
if not isinstance(first_message, ViewMessage):
text += first_message.type + ":" + first_message.content + self.prompt_template.sep
index = self.chat_retention_rounds - 1
for last_message in self.history_message[-index:].messages:
text += last_message.type + ":" + last_message.content + self.sep
if not isinstance(last_message, ViewMessage):
text += last_message.type + ":" + last_message.content + self.prompt_template.sep
else:
### 直接历史记录拼接
for conversation in self.history_message:
for message in conversation.messages:
text += message.type + ":" + message.content + self.sep
if not isinstance(message, ViewMessage):
text += message.type + ":" + message.content + self.prompt_template.sep
### current conversation
for now_message in self.current_message.messages:
text += now_message.type + ":" + now_message.content + self.sep
text += now_message.type + ":" + now_message.content + self.prompt_template.sep
return text
@property
@classmethod
def chat_type(self) -> str:
return ChatScene.ChatWithDb.value

View File

View File

@@ -11,8 +11,9 @@ from typing import (
TypeVar,
Union,
)
import pandas as pd
from pilot.out_parser.base import BaseOutputParser
from pilot.out_parser.base import BaseOutputParser, T
class SqlAction(NamedTuple):
@@ -22,8 +23,15 @@ class SqlAction(NamedTuple):
class DbChatOutputParser(BaseOutputParser):
def parse(self, text: str) -> SqlAction:
cleaned_output = text.rstrip()
def __init__(self, sep:str, is_stream_out: bool):
super().__init__(sep=sep, is_stream_out=is_stream_out )
def parse_model_server_out(self, response) -> str:
return super().parse_model_server_out(response)
def parse_prompt_response(self, model_out_text):
cleaned_output = model_out_text.rstrip()
if "```json" in cleaned_output:
_, cleaned_output = cleaned_output.split("```json")
if "```" in cleaned_output:
@@ -40,6 +48,16 @@ class DbChatOutputParser(BaseOutputParser):
return SqlAction(sql, thoughts)
def parse_view_response(self, data) -> str:
### tool out data to table view
df = pd.DataFrame(data[1:], columns=data[0])
table_style = """<style>
table{border-collapse:collapse;width:60%;height:80%;margin:0 auto;float:right;border: 1px solid #007bff; background-color:#CFE299}th,td{border:1px solid #ddd;padding:3px;text-align:center}th{background-color:#C9C3C7;color: #fff;font-weight: bold;}tr:nth-child(even){background-color:#7C9F4A}tr:hover{background-color:#333}
</style>"""
html_table = df.to_html(index=False, escape=False)
html = f"<html><head>{table_style}</head><body>{html_table}</body></html>"
return html.replace("\n", " ")
@property
def _type(self) -> str:
return "sql_chat"

View File

@@ -2,7 +2,8 @@ import json
from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config
from pilot.scene.base import ChatScene
from pilot.scene.chat_db.out_parser import DbChatOutputParser
from pilot.scene.chat_db.out_parser import DbChatOutputParser,SqlAction
from pilot.common.schema import SeparatorStyle
CFG = Config()
@@ -21,6 +22,15 @@ You can order the results by a relevant column to return the most interesting ex
Never query for all the columns from a specific table, only ask for a the few relevant columns given the question.
Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
"""
_mysql_prompt = """You are a MySQL expert. Given an input question, first create a syntactically correct MySQL query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per MySQL. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CURDATE() function to get the current date, if the question involves "today".
"""
PROMPT_RESPONSE = """You should only respond in JSON format as following format:
@@ -37,17 +47,18 @@ RESPONSE_FORMAT = {
"SQL": "SQL Query to run"
}
PROMPT_SEP = SeparatorStyle.SINGLE.value
PROMPT_NEED_NEED_STREAM_OUT = False
chat_db_prompt = PromptTemplate(
template_scene=ChatScene.ChatWithDb.value,
input_variables=["input", "table_info", "dialect", "top_k", "response"],
response_format=json.dumps(RESPONSE_FORMAT, indent=4),
template=_DEFAULT_TEMPLATE + PROMPT_SUFFIX + PROMPT_RESPONSE,
output_parser=DbChatOutputParser()
template=_DEFAULT_TEMPLATE + PROMPT_RESPONSE + PROMPT_SUFFIX,
output_parser=DbChatOutputParser(sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT),
)
CFG.prompt_templates.update({chat_db_prompt.template_scene: chat_db_prompt})
if __name__ == "__main__":
resp = chat_db_prompt.format(input="查询用户信息", table_info="user(a,b,c,d)", dialect="mysql", top_k=10)
print(resp)

View File

@@ -0,0 +1,31 @@
import builtins
def stream_write_and_read(lst):
# 对lst使用yield from进行可迭代对象的扁平化
yield from lst
while True:
val = yield
lst.append(val)
if __name__ == "__main__":
# 创建一个空列表
my_list = []
# 使用生成器写入数据
stream_writer = stream_write_and_read(my_list)
next(stream_writer)
stream_writer.send(10)
print(1)
stream_writer.send(20)
print(2)
stream_writer.send(30)
print(3)
# 使用生成器读取数据
stream_reader = stream_write_and_read(my_list)
next(stream_reader)
print(stream_reader.send(None))
print(stream_reader.send(None))
print(stream_reader.send(None))

View File

@@ -9,8 +9,7 @@ from typing import (
List,
)
from pilot.scene.base_message import BaseMessage, AIMessage, HumanMessage, SystemMessage, messages_to_dict, \
messages_from_dict
from pilot.scene.base_message import BaseMessage, AIMessage, HumanMessage, SystemMessage, ViewMessage, messages_to_dict, messages_from_dict
class OnceConversation:
@@ -27,12 +26,23 @@ class OnceConversation:
def add_user_message(self, message: str) -> None:
"""Add a user message to the store"""
has_message = any(isinstance(instance, HumanMessage) for instance in self.messages)
if has_message:
raise ValueError("Already Have Human message")
self.messages.append(HumanMessage(content=message))
def add_ai_message(self, message: str) -> None:
"""Add an AI message to the store"""
has_message = any(isinstance(instance, AIMessage) for instance in self.messages)
if has_message:
raise ValueError("Already Have Ai message")
self.messages.append(AIMessage(content=message))
""" """
def add_view_message(self, message: str) -> None:
"""Add an AI message to the store"""
self.messages.append(ViewMessage(content=message))
""" """
def add_system_message(self, message: str) -> None:
"""Add an AI message to the store"""

View File

@@ -218,7 +218,6 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
}
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
chat.call()
# state.append_message(state.roles[1], chat.current_ai_response())
state.messages[-1][-1] = f"{chat.current_ai_response()}"
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5