refactor:merge tuyang_end_06

This commit is contained in:
aries_ckt 2023-07-03 18:59:37 +08:00
commit 07281d6bb5
11 changed files with 141 additions and 80 deletions

View File

@ -452,3 +452,5 @@ class Database:
return [ return [
(table_comment[0], table_comment[1]) for table_comment in table_comments (table_comment[0], table_comment[1]) for table_comment in table_comments
] ]

View File

@ -77,10 +77,18 @@ CFG = Config()
# print(__extract_json(ss)) # print(__extract_json(ss))
if __name__ == "__main__": if __name__ == "__main__":
test1 = Test1() # test1 = Test1()
test2 = Test2() # test2 = Test2()
test1.write() # test1.write()
test1.test() # test1.test()
test2.write() # test2.write()
test1.test() # test1.test()
test2.test() # test2.test()
# 定义包含元组的列表
data = [('key1', 'value1'), ('key2', 'value2'), ('key3', 'value3')]
# 使用字典解析将列表转换为字典
result = {k: v for k, v in data}
print(result)

View File

@ -228,6 +228,8 @@ async def chat_completions(dialogue: ConversationVo = Body()):
chat_param.update({"db_name": dialogue.select_param}) chat_param.update({"db_name": dialogue.select_param})
elif ChatScene.ChatDashboard.value == dialogue.chat_mode: elif ChatScene.ChatDashboard.value == dialogue.chat_mode:
chat_param.update({"db_name": dialogue.select_param}) chat_param.update({"db_name": dialogue.select_param})
## DEFAULT
chat_param.update({"report_name": "sales_report"})
elif ChatScene.ChatExecution.value == dialogue.chat_mode: elif ChatScene.ChatExecution.value == dialogue.chat_mode:
chat_param.update({"plugin_selector": dialogue.select_param}) chat_param.update({"plugin_selector": dialogue.select_param})
elif ChatScene.ChatKnowledge.value == dialogue.chat_mode: elif ChatScene.ChatKnowledge.value == dialogue.chat_mode:

View File

@ -120,17 +120,45 @@ class BaseOutputParser(ABC):
raise ValueError("Model server error!code=" + respObj_ex["error_code"]) raise ValueError("Model server error!code=" + respObj_ex["error_code"])
def __extract_json(self, s): def __extract_json(self, s):
i = s.index("{")
count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数 temp_json = self.__json_interception(s, True)
for j, c in enumerate(s[i + 1 :], start=i + 1): if not temp_json:
if c == "}": temp_json = self.__json_interception(s)
count -= 1 try:
elif c == "{": json.loads(temp_json)
count += 1 return temp_json
if count == 0: except Exception as e:
break raise ValueError("Failed to find a valid json response" + temp_json)
assert count == 0 # 检查是否找到最后一个'}'
return s[i : j + 1] def __json_interception(self, s, is_json_array: bool = False):
if is_json_array:
i = s.index("[")
if i <0:
return None
count = 1
for j, c in enumerate(s[i + 1:], start=i + 1):
if c == "]":
count -= 1
elif c == "[":
count += 1
if count == 0:
break
assert count == 0
return s[i: j + 1]
else:
i = s.index("{")
if i <0:
return None
count = 1
for j, c in enumerate(s[i + 1:], start=i + 1):
if c == "}":
count -= 1
elif c == "{":
count += 1
if count == 0:
break
assert count == 0
return s[i: j + 1]
def parse_prompt_response(self, model_out_text) -> T: def parse_prompt_response(self, model_out_text) -> T:
""" """
@ -147,9 +175,9 @@ class BaseOutputParser(ABC):
# if "```" in cleaned_output: # if "```" in cleaned_output:
# cleaned_output, _ = cleaned_output.split("```") # cleaned_output, _ = cleaned_output.split("```")
if cleaned_output.startswith("```json"): if cleaned_output.startswith("```json"):
cleaned_output = cleaned_output[len("```json") :] cleaned_output = cleaned_output[len("```json"):]
if cleaned_output.startswith("```"): if cleaned_output.startswith("```"):
cleaned_output = cleaned_output[len("```") :] cleaned_output = cleaned_output[len("```"):]
if cleaned_output.endswith("```"): if cleaned_output.endswith("```"):
cleaned_output = cleaned_output[: -len("```")] cleaned_output = cleaned_output[: -len("```")]
cleaned_output = cleaned_output.strip() cleaned_output = cleaned_output.strip()
@ -158,9 +186,9 @@ class BaseOutputParser(ABC):
cleaned_output = self.__extract_json(cleaned_output) cleaned_output = self.__extract_json(cleaned_output)
cleaned_output = ( cleaned_output = (
cleaned_output.strip() cleaned_output.strip()
.replace("\n", " ") .replace("\n", " ")
.replace("\\n", " ") .replace("\\n", " ")
.replace("\\", " ") .replace("\\", " ")
) )
return cleaned_output return cleaned_output

View File

@ -1,4 +1,6 @@
import json import json
import os
import uuid
from typing import Dict, NamedTuple, List from typing import Dict, NamedTuple, List
from pilot.scene.base_message import ( from pilot.scene.base_message import (
HumanMessage, HumanMessage,
@ -11,7 +13,7 @@ from pilot.configs.config import Config
from pilot.common.markdown_text import ( from pilot.common.markdown_text import (
generate_htm_table, generate_htm_table,
) )
from pilot.scene.chat_db.auto_execute.prompt import prompt from pilot.scene.chat_dashboard.prompt import prompt
from pilot.scene.chat_dashboard.data_preparation.report_schma import ( from pilot.scene.chat_dashboard.data_preparation.report_schma import (
ChartData, ChartData,
ReportData, ReportData,
@ -28,19 +30,31 @@ class ChatDashboard(BaseChat):
def __init__(self, chat_session_id, db_name, user_input, report_name): def __init__(self, chat_session_id, db_name, user_input, report_name):
""" """ """ """
super().__init__( super().__init__(
chat_mode=ChatScene.ChatWithDbExecute, chat_mode=ChatScene.ChatDashboard,
chat_session_id=chat_session_id, chat_session_id=chat_session_id,
current_user_input=user_input, current_user_input=user_input,
) )
if not db_name: if not db_name:
raise ValueError( raise ValueError(
f"{ChatScene.ChatWithDbExecute.value} mode should chose db!" f"{ChatScene.ChatDashboard.value} mode should chose db!"
) )
self.db_name = db_name
self.report_name = report_name self.report_name = report_name
self.database = CFG.local_db self.database = CFG.local_db
# 准备DB信息(拿到指定库的链接) # 准备DB信息(拿到指定库的链接)
self.db_connect = self.database.get_session(self.db_name) self.db_connect = self.database.get_session(self.db_name)
self.top_k: int = 5 self.top_k: int = 5
self.dashboard_template = self.__load_dashboard_template(report_name)
def __load_dashboard_template(self, template_name):
current_dir = os.getcwd()
print(current_dir)
current_dir = os.path.dirname(os.path.abspath(__file__))
with open(f"{current_dir}/template/{template_name}/dashboard.json", 'r') as f:
data = f.read()
return json.loads(data)
def generate_input_values(self): def generate_input_values(self):
try: try:
@ -52,34 +66,28 @@ class ChatDashboard(BaseChat):
"input": self.current_user_input, "input": self.current_user_input,
"dialect": self.database.dialect, "dialect": self.database.dialect,
"table_info": self.database.table_simple_info(self.db_connect), "table_info": self.database.table_simple_info(self.db_connect),
"supported_chat_type": "" # TODO "supported_chat_type": self.dashboard_template['supported_chart_type']
# "table_info": client.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k) # "table_info": client.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k)
} }
return input_values return input_values
def do_action(self, prompt_response): def do_action(self, prompt_response):
### TODO 记录整体信息,处理成功的,和未成功的分开记录处理 ### TODO 记录整体信息,处理成功的,和未成功的分开记录处理
report_data: ReportData = ReportData()
chart_datas: List[ChartData] = [] chart_datas: List[ChartData] = []
for chart_item in prompt_response: for chart_item in prompt_response:
try: try:
datas = self.database.run(self.db_connect, chart_item.sql) datas = self.database.run(self.db_connect, chart_item.sql)
chart_data: ChartData = ChartData() chart_datas.append(ChartData(chart_uid=str(uuid.uuid1()),
chart_data.chart_sql = chart_item["sql"] chart_name=chart_item.title,
chart_data.chart_type = chart_item["showcase"] chart_type=chart_item.showcase,
chart_data.chart_name = chart_item["title"] chart_desc=chart_item.thoughts,
chart_data.chart_desc = chart_item["thoughts"] chart_sql=chart_item.sql,
chart_data.column_name = datas[0] column_name=datas[0],
chart_data.values = datas values=datas))
except Exception as e: except Exception as e:
# TODO 修复流程 # TODO 修复流程
print(str(e)) print(str(e))
chart_datas.append(chart_data) return ReportData(conv_uid=self.chat_session_id, template_name=self.report_name, template_introduce=None,
charts=chart_datas)
report_data.conv_uid = self.chat_session_id
report_data.template_name = self.report_name
report_data.template_introduce = None
report_data.charts = chart_datas
return report_data

View File

@ -1,6 +1,7 @@
import json
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import TypeVar, Union, List, Generic, Any from typing import TypeVar, Union, List, Generic, Any
from dataclasses import dataclass, asdict
class ChartData(BaseModel): class ChartData(BaseModel):
chart_uid: str chart_uid: str
@ -10,11 +11,30 @@ class ChartData(BaseModel):
chart_sql: str chart_sql: str
column_name: List column_name: List
values: List values: List
style: Any style: Any = None
def dict(self, *args, **kwargs):
return {
"chart_uid": self.chart_uid,
"chart_name": self.chart_name,
"chart_type": self.chart_type,
"chart_desc": self.chart_desc,
"chart_sql": self.chart_sql,
"column_name": [str(item) for item in self.column_name],
"values": [[str(item) for item in sublist] for sublist in self.values],
"style": self.style
}
class ReportData(BaseModel): class ReportData(BaseModel):
conv_uid: str conv_uid: str
template_name: str template_name: str
template_introduce: str template_introduce: str = None
charts: List[ChartData] charts: List[ChartData]
def prepare_dict(self):
return {
"conv_uid": self.conv_uid,
"template_name": self.template_name,
"template_introduce": self.template_introduce,
"charts": [chart.dict() for chart in self.charts]
}

View File

@ -1,12 +1,13 @@
import json import json
import re import re
from dataclasses import dataclass, asdict
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, NamedTuple, List from typing import Dict, NamedTuple, List
import pandas as pd import pandas as pd
from pilot.utils import build_logger from pilot.utils import build_logger
from pilot.out_parser.base import BaseOutputParser, T from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR from pilot.configs.model_config import LOGDIR
from pilot.scene.base import ChatScene
class ChartItem(NamedTuple): class ChartItem(NamedTuple):
sql: str sql: str
@ -26,7 +27,7 @@ class ChatDashboardOutputParser(BaseOutputParser):
clean_str = super().parse_prompt_response(model_out_text) clean_str = super().parse_prompt_response(model_out_text)
print("clean prompt response:", clean_str) print("clean prompt response:", clean_str)
response = json.loads(clean_str) response = json.loads(clean_str)
chart_items = List[ChartItem] chart_items: List[ChartItem] = []
for item in response: for item in response:
chart_items.append( chart_items.append(
ChartItem( ChartItem(
@ -36,10 +37,8 @@ class ChatDashboardOutputParser(BaseOutputParser):
return chart_items return chart_items
def parse_view_response(self, speak, data) -> str: def parse_view_response(self, speak, data) -> str:
### TODO return json.dumps(data.prepare_dict())
return data
@property @property
def _type(self) -> str: def _type(self) -> str:
return "chat_dashboard" return ChatScene.ChatDashboard.value

View File

@ -3,24 +3,26 @@ import importlib
from pilot.prompts.prompt_new import PromptTemplate from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.scene.base import ChatScene from pilot.scene.base import ChatScene
from pilot.scene.chat_db.auto_execute.out_parser import DbChatOutputParser, SqlAction from pilot.scene.chat_dashboard.out_parser import ChatDashboardOutputParser, ChartItem
from pilot.common.schema import SeparatorStyle from pilot.common.schema import SeparatorStyle
CFG = Config() CFG = Config()
PROMPT_SCENE_DEFINE = """You are a {dialect} data analysis expert, please provide a professional data analysis solution according to the following situations""" PROMPT_SCENE_DEFINE = """You are a {dialect} data analysis expert, please provide a professional data analysis solution according to the following situations"""
PROMPT_SCENE_DEFINE = None
_DEFAULT_TEMPLATE = """ _DEFAULT_TEMPLATE = """
According to the structure definition in the following tables: According to the structure definition in the following tables:
{table_info} {table_info}
Provide a professional data analysis with as few dimensions as possible, and the upper limit does not exceed 8 dimensions. Provide professional data analysis, use as few dimensions as possible, but no less than three, and no more than eight dimensions.
Used to support goal: {input} Used to support goal: {input}
Use the chart display method in the following range: Pay attention to the length of the output content of the analysis result, do not exceed 4000tokens
According to the characteristics of the analyzed data, choose the best one from the charts provided below to displaychart types:
{supported_chat_type} {supported_chat_type}
give {dialect} data analysis SQL, analysis title, display method and analytical thinking,respond in the following json format:
Give {dialect} data analysis SQL, analysis title, display method and analytical thinking,respond in the following json format:
{response} {response}
Do not use unprovided fields and do not use unprovided data in the where condition of sql.
Ensure the response is correct json and can be parsed by Python json.loads Ensure the response is correct json and can be parsed by Python json.loads
""" """
@ -38,13 +40,13 @@ PROMPT_SEP = SeparatorStyle.SINGLE.value
PROMPT_NEED_NEED_STREAM_OUT = False PROMPT_NEED_NEED_STREAM_OUT = False
prompt = PromptTemplate( prompt = PromptTemplate(
template_scene=ChatScene.ChatWithDbExecute.value, template_scene=ChatScene.ChatDashboard.value,
input_variables=["input", "table_info", "dialect", "supported_chat_type"], input_variables=["input", "table_info", "dialect", "supported_chat_type"],
response_format=json.dumps(RESPONSE_FORMAT, indent=4), response_format=json.dumps(RESPONSE_FORMAT, indent=4),
template_define=PROMPT_SCENE_DEFINE, template_define=PROMPT_SCENE_DEFINE,
template=_DEFAULT_TEMPLATE, template=_DEFAULT_TEMPLATE,
stream_out=PROMPT_NEED_NEED_STREAM_OUT, stream_out=PROMPT_NEED_NEED_STREAM_OUT,
output_parser=DbChatOutputParser( output_parser=ChatDashboardOutputParser(
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
), ),
) )

View File

@ -6,6 +6,7 @@ from pilot.scene.chat_execution.chat import ChatWithPlugin
from pilot.scene.chat_normal.chat import ChatNormal from pilot.scene.chat_normal.chat import ChatNormal
from pilot.scene.chat_db.professional_qa.chat import ChatWithDbQA from pilot.scene.chat_db.professional_qa.chat import ChatWithDbQA
from pilot.scene.chat_db.auto_execute.chat import ChatWithDbAutoExecute from pilot.scene.chat_db.auto_execute.chat import ChatWithDbAutoExecute
from pilot.scene.chat_dashboard.chat import ChatDashboard
from pilot.scene.chat_knowledge.url.chat import ChatUrlKnowledge from pilot.scene.chat_knowledge.url.chat import ChatUrlKnowledge
from pilot.scene.chat_knowledge.custom.chat import ChatNewKnowledge from pilot.scene.chat_knowledge.custom.chat import ChatNewKnowledge
from pilot.scene.chat_knowledge.default.chat import ChatDefaultKnowledge from pilot.scene.chat_knowledge.default.chat import ChatDefaultKnowledge

View File

@ -168,7 +168,7 @@ async def api_generate_stream(request: Request):
@app.post("/generate") @app.post("/generate")
def generate(prompt_request: PromptRequest): def generate(prompt_request: PromptRequest)->str:
params = { params = {
"prompt": prompt_request.prompt, "prompt": prompt_request.prompt,
"temperature": prompt_request.temperature, "temperature": prompt_request.temperature,

View File

@ -690,9 +690,6 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--model_list_mode", type=str, default="once", choices=["once", "reload"] "--model_list_mode", type=str, default="once", choices=["once", "reload"]
) )
parser.add_argument(
"-new", "--new", action="store_true", help="enable new http mode"
)
# old version server config # old version server config
parser.add_argument("--host", type=str, default="0.0.0.0") parser.add_argument("--host", type=str, default="0.0.0.0")
@ -704,20 +701,14 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
server_init(args) server_init(args)
dbs = CFG.local_db.get_database_list() dbs = CFG.local_db.get_database_list()
if args.new: demo = build_webdemo()
import uvicorn demo.queue(
concurrency_count=args.concurrency_count,
uvicorn.run(app, host="0.0.0.0", port=5000) status_update_rate=10,
else: api_open=False,
### Compatibility mode starts the old version server by default ).launch(
demo = build_webdemo() server_name=args.host,
demo.queue( server_port=args.port,
concurrency_count=args.concurrency_count, share=args.share,
status_update_rate=10, max_threads=200,
api_open=False, )
).launch(
server_name=args.host,
server_port=args.port,
share=args.share,
max_threads=200,
)