refactor:merge tuyang_end_06

This commit is contained in:
aries_ckt 2023-07-03 18:59:37 +08:00
commit 07281d6bb5
11 changed files with 141 additions and 80 deletions

View File

@ -452,3 +452,5 @@ class Database:
return [
(table_comment[0], table_comment[1]) for table_comment in table_comments
]

View File

@ -77,10 +77,18 @@ CFG = Config()
# print(__extract_json(ss))
if __name__ == "__main__":
test1 = Test1()
test2 = Test2()
test1.write()
test1.test()
test2.write()
test1.test()
test2.test()
# test1 = Test1()
# test2 = Test2()
# test1.write()
# test1.test()
# test2.write()
# test1.test()
# test2.test()
# 定义包含元组的列表
data = [('key1', 'value1'), ('key2', 'value2'), ('key3', 'value3')]
# 使用字典解析将列表转换为字典
result = {k: v for k, v in data}
print(result)

View File

@ -228,6 +228,8 @@ async def chat_completions(dialogue: ConversationVo = Body()):
chat_param.update({"db_name": dialogue.select_param})
elif ChatScene.ChatDashboard.value == dialogue.chat_mode:
chat_param.update({"db_name": dialogue.select_param})
## DEFAULT
chat_param.update({"report_name": "sales_report"})
elif ChatScene.ChatExecution.value == dialogue.chat_mode:
chat_param.update({"plugin_selector": dialogue.select_param})
elif ChatScene.ChatKnowledge.value == dialogue.chat_mode:

View File

@ -120,17 +120,45 @@ class BaseOutputParser(ABC):
raise ValueError("Model server error!code=" + respObj_ex["error_code"])
def __extract_json(self, s):
i = s.index("{")
count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
for j, c in enumerate(s[i + 1 :], start=i + 1):
if c == "}":
count -= 1
elif c == "{":
count += 1
if count == 0:
break
assert count == 0 # 检查是否找到最后一个'}'
return s[i : j + 1]
temp_json = self.__json_interception(s, True)
if not temp_json:
temp_json = self.__json_interception(s)
try:
json.loads(temp_json)
return temp_json
except Exception as e:
raise ValueError("Failed to find a valid json response" + temp_json)
def __json_interception(self, s, is_json_array: bool = False):
if is_json_array:
i = s.index("[")
if i <0:
return None
count = 1
for j, c in enumerate(s[i + 1:], start=i + 1):
if c == "]":
count -= 1
elif c == "[":
count += 1
if count == 0:
break
assert count == 0
return s[i: j + 1]
else:
i = s.index("{")
if i <0:
return None
count = 1
for j, c in enumerate(s[i + 1:], start=i + 1):
if c == "}":
count -= 1
elif c == "{":
count += 1
if count == 0:
break
assert count == 0
return s[i: j + 1]
def parse_prompt_response(self, model_out_text) -> T:
"""
@ -147,9 +175,9 @@ class BaseOutputParser(ABC):
# if "```" in cleaned_output:
# cleaned_output, _ = cleaned_output.split("```")
if cleaned_output.startswith("```json"):
cleaned_output = cleaned_output[len("```json") :]
cleaned_output = cleaned_output[len("```json"):]
if cleaned_output.startswith("```"):
cleaned_output = cleaned_output[len("```") :]
cleaned_output = cleaned_output[len("```"):]
if cleaned_output.endswith("```"):
cleaned_output = cleaned_output[: -len("```")]
cleaned_output = cleaned_output.strip()
@ -158,9 +186,9 @@ class BaseOutputParser(ABC):
cleaned_output = self.__extract_json(cleaned_output)
cleaned_output = (
cleaned_output.strip()
.replace("\n", " ")
.replace("\\n", " ")
.replace("\\", " ")
.replace("\n", " ")
.replace("\\n", " ")
.replace("\\", " ")
)
return cleaned_output

View File

@ -1,4 +1,6 @@
import json
import os
import uuid
from typing import Dict, NamedTuple, List
from pilot.scene.base_message import (
HumanMessage,
@ -11,7 +13,7 @@ from pilot.configs.config import Config
from pilot.common.markdown_text import (
generate_htm_table,
)
from pilot.scene.chat_db.auto_execute.prompt import prompt
from pilot.scene.chat_dashboard.prompt import prompt
from pilot.scene.chat_dashboard.data_preparation.report_schma import (
ChartData,
ReportData,
@ -28,19 +30,31 @@ class ChatDashboard(BaseChat):
def __init__(self, chat_session_id, db_name, user_input, report_name):
""" """
super().__init__(
chat_mode=ChatScene.ChatWithDbExecute,
chat_mode=ChatScene.ChatDashboard,
chat_session_id=chat_session_id,
current_user_input=user_input,
)
if not db_name:
raise ValueError(
f"{ChatScene.ChatWithDbExecute.value} mode should chose db!"
f"{ChatScene.ChatDashboard.value} mode should chose db!"
)
self.db_name = db_name
self.report_name = report_name
self.database = CFG.local_db
# 准备DB信息(拿到指定库的链接)
self.db_connect = self.database.get_session(self.db_name)
self.top_k: int = 5
self.dashboard_template = self.__load_dashboard_template(report_name)
def __load_dashboard_template(self, template_name):
current_dir = os.getcwd()
print(current_dir)
current_dir = os.path.dirname(os.path.abspath(__file__))
with open(f"{current_dir}/template/{template_name}/dashboard.json", 'r') as f:
data = f.read()
return json.loads(data)
def generate_input_values(self):
try:
@ -52,34 +66,28 @@ class ChatDashboard(BaseChat):
"input": self.current_user_input,
"dialect": self.database.dialect,
"table_info": self.database.table_simple_info(self.db_connect),
"supported_chat_type": "" # TODO
"supported_chat_type": self.dashboard_template['supported_chart_type']
# "table_info": client.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k)
}
return input_values
def do_action(self, prompt_response):
### TODO 记录整体信息,处理成功的,和未成功的分开记录处理
report_data: ReportData = ReportData()
chart_datas: List[ChartData] = []
for chart_item in prompt_response:
try:
datas = self.database.run(self.db_connect, chart_item.sql)
chart_data: ChartData = ChartData()
chart_data.chart_sql = chart_item["sql"]
chart_data.chart_type = chart_item["showcase"]
chart_data.chart_name = chart_item["title"]
chart_data.chart_desc = chart_item["thoughts"]
chart_data.column_name = datas[0]
chart_data.values = datas
chart_datas.append(ChartData(chart_uid=str(uuid.uuid1()),
chart_name=chart_item.title,
chart_type=chart_item.showcase,
chart_desc=chart_item.thoughts,
chart_sql=chart_item.sql,
column_name=datas[0],
values=datas))
except Exception as e:
# TODO 修复流程
print(str(e))
chart_datas.append(chart_data)
report_data.conv_uid = self.chat_session_id
report_data.template_name = self.report_name
report_data.template_introduce = None
report_data.charts = chart_datas
return report_data
return ReportData(conv_uid=self.chat_session_id, template_name=self.report_name, template_introduce=None,
charts=chart_datas)

View File

@ -1,6 +1,7 @@
import json
from pydantic import BaseModel, Field
from typing import TypeVar, Union, List, Generic, Any
from dataclasses import dataclass, asdict
class ChartData(BaseModel):
chart_uid: str
@ -10,11 +11,30 @@ class ChartData(BaseModel):
chart_sql: str
column_name: List
values: List
style: Any
style: Any = None
def dict(self, *args, **kwargs):
return {
"chart_uid": self.chart_uid,
"chart_name": self.chart_name,
"chart_type": self.chart_type,
"chart_desc": self.chart_desc,
"chart_sql": self.chart_sql,
"column_name": [str(item) for item in self.column_name],
"values": [[str(item) for item in sublist] for sublist in self.values],
"style": self.style
}
class ReportData(BaseModel):
conv_uid: str
template_name: str
template_introduce: str
template_introduce: str = None
charts: List[ChartData]
def prepare_dict(self):
return {
"conv_uid": self.conv_uid,
"template_name": self.template_name,
"template_introduce": self.template_introduce,
"charts": [chart.dict() for chart in self.charts]
}

View File

@ -1,12 +1,13 @@
import json
import re
from dataclasses import dataclass, asdict
from abc import ABC, abstractmethod
from typing import Dict, NamedTuple, List
import pandas as pd
from pilot.utils import build_logger
from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR
from pilot.scene.base import ChatScene
class ChartItem(NamedTuple):
sql: str
@ -26,7 +27,7 @@ class ChatDashboardOutputParser(BaseOutputParser):
clean_str = super().parse_prompt_response(model_out_text)
print("clean prompt response:", clean_str)
response = json.loads(clean_str)
chart_items = List[ChartItem]
chart_items: List[ChartItem] = []
for item in response:
chart_items.append(
ChartItem(
@ -36,10 +37,8 @@ class ChatDashboardOutputParser(BaseOutputParser):
return chart_items
def parse_view_response(self, speak, data) -> str:
### TODO
return data
return json.dumps(data.prepare_dict())
@property
def _type(self) -> str:
return "chat_dashboard"
return ChatScene.ChatDashboard.value

View File

@ -3,24 +3,26 @@ import importlib
from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config
from pilot.scene.base import ChatScene
from pilot.scene.chat_db.auto_execute.out_parser import DbChatOutputParser, SqlAction
from pilot.scene.chat_dashboard.out_parser import ChatDashboardOutputParser, ChartItem
from pilot.common.schema import SeparatorStyle
CFG = Config()
PROMPT_SCENE_DEFINE = """You are a {dialect} data analysis expert, please provide a professional data analysis solution according to the following situations"""
PROMPT_SCENE_DEFINE = None
_DEFAULT_TEMPLATE = """
According to the structure definition in the following tables:
{table_info}
Provide a professional data analysis with as few dimensions as possible, and the upper limit does not exceed 8 dimensions.
Provide professional data analysis, use as few dimensions as possible, but no less than three, and no more than eight dimensions.
Used to support goal: {input}
Use the chart display method in the following range:
Pay attention to the length of the output content of the analysis result, do not exceed 4000tokens
According to the characteristics of the analyzed data, choose the best one from the charts provided below to displaychart types:
{supported_chat_type}
give {dialect} data analysis SQL, analysis title, display method and analytical thinking,respond in the following json format:
Give {dialect} data analysis SQL, analysis title, display method and analytical thinking,respond in the following json format:
{response}
Do not use unprovided fields and do not use unprovided data in the where condition of sql.
Ensure the response is correct json and can be parsed by Python json.loads
"""
@ -38,13 +40,13 @@ PROMPT_SEP = SeparatorStyle.SINGLE.value
PROMPT_NEED_NEED_STREAM_OUT = False
prompt = PromptTemplate(
template_scene=ChatScene.ChatWithDbExecute.value,
template_scene=ChatScene.ChatDashboard.value,
input_variables=["input", "table_info", "dialect", "supported_chat_type"],
response_format=json.dumps(RESPONSE_FORMAT, indent=4),
template_define=PROMPT_SCENE_DEFINE,
template=_DEFAULT_TEMPLATE,
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
output_parser=DbChatOutputParser(
output_parser=ChatDashboardOutputParser(
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
),
)

View File

@ -6,6 +6,7 @@ from pilot.scene.chat_execution.chat import ChatWithPlugin
from pilot.scene.chat_normal.chat import ChatNormal
from pilot.scene.chat_db.professional_qa.chat import ChatWithDbQA
from pilot.scene.chat_db.auto_execute.chat import ChatWithDbAutoExecute
from pilot.scene.chat_dashboard.chat import ChatDashboard
from pilot.scene.chat_knowledge.url.chat import ChatUrlKnowledge
from pilot.scene.chat_knowledge.custom.chat import ChatNewKnowledge
from pilot.scene.chat_knowledge.default.chat import ChatDefaultKnowledge

View File

@ -168,7 +168,7 @@ async def api_generate_stream(request: Request):
@app.post("/generate")
def generate(prompt_request: PromptRequest):
def generate(prompt_request: PromptRequest)->str:
params = {
"prompt": prompt_request.prompt,
"temperature": prompt_request.temperature,

View File

@ -690,9 +690,6 @@ if __name__ == "__main__":
parser.add_argument(
"--model_list_mode", type=str, default="once", choices=["once", "reload"]
)
parser.add_argument(
"-new", "--new", action="store_true", help="enable new http mode"
)
# old version server config
parser.add_argument("--host", type=str, default="0.0.0.0")
@ -704,20 +701,14 @@ if __name__ == "__main__":
args = parser.parse_args()
server_init(args)
dbs = CFG.local_db.get_database_list()
if args.new:
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=5000)
else:
### Compatibility mode starts the old version server by default
demo = build_webdemo()
demo.queue(
concurrency_count=args.concurrency_count,
status_update_rate=10,
api_open=False,
).launch(
server_name=args.host,
server_port=args.port,
share=args.share,
max_threads=200,
)
demo = build_webdemo()
demo.queue(
concurrency_count=args.concurrency_count,
status_update_rate=10,
api_open=False,
).launch(
server_name=args.host,
server_port=args.port,
share=args.share,
max_threads=200,
)