WEB API independent

This commit is contained in:
tuyang.yhj 2023-07-03 16:17:32 +08:00
parent 1605ce53bf
commit a6097c4cb4
9 changed files with 124 additions and 57 deletions

View File

@ -452,3 +452,5 @@ class Database:
return [
(table_comment[0], table_comment[1]) for table_comment in table_comments
]

View File

@ -77,10 +77,18 @@ CFG = Config()
# print(__extract_json(ss))
if __name__ == "__main__":
test1 = Test1()
test2 = Test2()
test1.write()
test1.test()
test2.write()
test1.test()
test2.test()
# test1 = Test1()
# test2 = Test2()
# test1.write()
# test1.test()
# test2.write()
# test1.test()
# test2.test()
# 定义包含元组的列表
data = [('key1', 'value1'), ('key2', 'value2'), ('key3', 'value3')]
# 使用字典解析将列表转换为字典
result = {k: v for k, v in data}
print(result)

View File

@ -228,6 +228,8 @@ async def chat_completions(dialogue: ConversationVo = Body()):
chat_param.update({"db_name": dialogue.select_param})
elif ChatScene.ChatDashboard.value == dialogue.chat_mode:
chat_param.update({"db_name": dialogue.select_param})
## DEFAULT
chat_param.update({"report_name": "sales_report"})
elif ChatScene.ChatExecution.value == dialogue.chat_mode:
chat_param.update({"plugin_selector": dialogue.select_param})
elif ChatScene.ChatKnowledge.value == dialogue.chat_mode:

View File

@ -120,17 +120,45 @@ class BaseOutputParser(ABC):
raise ValueError("Model server error!code=" + respObj_ex["error_code"])
def __extract_json(self, s):
i = s.index("{")
count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
for j, c in enumerate(s[i + 1 :], start=i + 1):
if c == "}":
count -= 1
elif c == "{":
count += 1
if count == 0:
break
assert count == 0 # 检查是否找到最后一个'}'
return s[i : j + 1]
temp_json = self.__json_interception(s, True)
if not temp_json:
temp_json = self.__json_interception(s)
try:
json.loads(temp_json)
return temp_json
except Exception as e:
raise ValueError("Failed to find a valid json response" + temp_json)
def __json_interception(self, s, is_json_array: bool = False):
if is_json_array:
i = s.index("[")
if i <0:
return None
count = 1
for j, c in enumerate(s[i + 1:], start=i + 1):
if c == "]":
count -= 1
elif c == "[":
count += 1
if count == 0:
break
assert count == 0
return s[i: j + 1]
else:
i = s.index("{")
if i <0:
return None
count = 1
for j, c in enumerate(s[i + 1:], start=i + 1):
if c == "}":
count -= 1
elif c == "{":
count += 1
if count == 0:
break
assert count == 0
return s[i: j + 1]
def parse_prompt_response(self, model_out_text) -> T:
"""
@ -147,9 +175,9 @@ class BaseOutputParser(ABC):
# if "```" in cleaned_output:
# cleaned_output, _ = cleaned_output.split("```")
if cleaned_output.startswith("```json"):
cleaned_output = cleaned_output[len("```json") :]
cleaned_output = cleaned_output[len("```json"):]
if cleaned_output.startswith("```"):
cleaned_output = cleaned_output[len("```") :]
cleaned_output = cleaned_output[len("```"):]
if cleaned_output.endswith("```"):
cleaned_output = cleaned_output[: -len("```")]
cleaned_output = cleaned_output.strip()
@ -158,9 +186,9 @@ class BaseOutputParser(ABC):
cleaned_output = self.__extract_json(cleaned_output)
cleaned_output = (
cleaned_output.strip()
.replace("\n", " ")
.replace("\\n", " ")
.replace("\\", " ")
.replace("\n", " ")
.replace("\\n", " ")
.replace("\\", " ")
)
return cleaned_output

View File

@ -1,4 +1,6 @@
import json
import os
import uuid
from typing import Dict, NamedTuple, List
from pilot.scene.base_message import (
HumanMessage,
@ -11,7 +13,7 @@ from pilot.configs.config import Config
from pilot.common.markdown_text import (
generate_htm_table,
)
from pilot.scene.chat_db.auto_execute.prompt import prompt
from pilot.scene.chat_dashboard.prompt import prompt
from pilot.scene.chat_dashboard.data_preparation.report_schma import (
ChartData,
ReportData,
@ -28,19 +30,31 @@ class ChatDashboard(BaseChat):
def __init__(self, chat_session_id, db_name, user_input, report_name):
""" """
super().__init__(
chat_mode=ChatScene.ChatWithDbExecute,
chat_mode=ChatScene.ChatDashboard,
chat_session_id=chat_session_id,
current_user_input=user_input,
)
if not db_name:
raise ValueError(
f"{ChatScene.ChatWithDbExecute.value} mode should chose db!"
f"{ChatScene.ChatDashboard.value} mode should chose db!"
)
self.db_name = db_name
self.report_name = report_name
self.database = CFG.local_db
# 准备DB信息(拿到指定库的链接)
self.db_connect = self.database.get_session(self.db_name)
self.top_k: int = 5
self.dashboard_template = self.__load_dashboard_template(report_name)
def __load_dashboard_template(self, template_name):
current_dir = os.getcwd()
print(current_dir)
current_dir = os.path.dirname(os.path.abspath(__file__))
with open(f"{current_dir}/template/{template_name}/dashboard.json", 'r') as f:
data = f.read()
return json.loads(data)
def generate_input_values(self):
try:
@ -52,34 +66,28 @@ class ChatDashboard(BaseChat):
"input": self.current_user_input,
"dialect": self.database.dialect,
"table_info": self.database.table_simple_info(self.db_connect),
"supported_chat_type": "" # TODO
"supported_chat_type": self.dashboard_template['supported_chart_type']
# "table_info": client.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k)
}
return input_values
def do_action(self, prompt_response):
### TODO 记录整体信息,处理成功的,和未成功的分开记录处理
report_data: ReportData = ReportData()
chart_datas: List[ChartData] = []
for chart_item in prompt_response:
try:
datas = self.database.run(self.db_connect, chart_item.sql)
chart_data: ChartData = ChartData()
chart_data.chart_sql = chart_item["sql"]
chart_data.chart_type = chart_item["showcase"]
chart_data.chart_name = chart_item["title"]
chart_data.chart_desc = chart_item["thoughts"]
chart_data.column_name = datas[0]
chart_data.values = datas
chart_datas.append(ChartData(chart_uid=str(uuid.uuid1()),
chart_name=chart_item.title,
chart_type=chart_item.showcase,
chart_desc=chart_item.thoughts,
chart_sql=chart_item.sql,
column_name=datas[0],
values=datas))
except Exception as e:
# TODO 修复流程
print(str(e))
chart_datas.append(chart_data)
report_data.conv_uid = self.chat_session_id
report_data.template_name = self.report_name
report_data.template_introduce = None
report_data.charts = chart_datas
return report_data
return ReportData(conv_uid=self.chat_session_id, template_name=self.report_name, template_introduce=None,
charts=chart_datas)

View File

@ -1,6 +1,7 @@
import json
from pydantic import BaseModel, Field
from typing import TypeVar, Union, List, Generic, Any
from dataclasses import dataclass, asdict
class ChartData(BaseModel):
chart_uid: str
@ -10,11 +11,30 @@ class ChartData(BaseModel):
chart_sql: str
column_name: List
values: List
style: Any
style: Any = None
def dict(self, *args, **kwargs):
return {
"chart_uid": self.chart_uid,
"chart_name": self.chart_name,
"chart_type": self.chart_type,
"chart_desc": self.chart_desc,
"chart_sql": self.chart_sql,
"column_name": [str(item) for item in self.column_name],
"values": [[str(item) for item in sublist] for sublist in self.values],
"style": self.style
}
class ReportData(BaseModel):
conv_uid: str
template_name: str
template_introduce: str
template_introduce: str = None
charts: List[ChartData]
def prepare_dict(self):
return {
"conv_uid": self.conv_uid,
"template_name": self.template_name,
"template_introduce": self.template_introduce,
"charts": [chart.dict() for chart in self.charts]
}

View File

@ -1,12 +1,13 @@
import json
import re
from dataclasses import dataclass, asdict
from abc import ABC, abstractmethod
from typing import Dict, NamedTuple, List
import pandas as pd
from pilot.utils import build_logger
from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR
from pilot.scene.base import ChatScene
class ChartItem(NamedTuple):
sql: str
@ -26,7 +27,7 @@ class ChatDashboardOutputParser(BaseOutputParser):
clean_str = super().parse_prompt_response(model_out_text)
print("clean prompt response:", clean_str)
response = json.loads(clean_str)
chart_items = List[ChartItem]
chart_items: List[ChartItem] = []
for item in response:
chart_items.append(
ChartItem(
@ -36,10 +37,8 @@ class ChatDashboardOutputParser(BaseOutputParser):
return chart_items
def parse_view_response(self, speak, data) -> str:
### TODO
return data
return json.dumps(data.prepare_dict())
@property
def _type(self) -> str:
return "chat_dashboard"
return ChatScene.ChatDashboard.value

View File

@ -3,18 +3,17 @@ import importlib
from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config
from pilot.scene.base import ChatScene
from pilot.scene.chat_db.auto_execute.out_parser import DbChatOutputParser, SqlAction
from pilot.scene.chat_dashboard.out_parser import ChatDashboardOutputParser, ChartItem
from pilot.common.schema import SeparatorStyle
CFG = Config()
PROMPT_SCENE_DEFINE = """You are a {dialect} data analysis expert, please provide a professional data analysis solution according to the following situations"""
PROMPT_SCENE_DEFINE = None
_DEFAULT_TEMPLATE = """
According to the structure definition in the following tables:
{table_info}
Provide a professional data analysis with as few dimensions as possible, and the upper limit does not exceed 8 dimensions.
Provide a professional data analysis with as few dimensions as possible, and the upper limit does not exceed 5 dimensions.
Used to support goal: {input}
Use the chart display method in the following range:
@ -38,13 +37,13 @@ PROMPT_SEP = SeparatorStyle.SINGLE.value
PROMPT_NEED_NEED_STREAM_OUT = False
prompt = PromptTemplate(
template_scene=ChatScene.ChatWithDbExecute.value,
template_scene=ChatScene.ChatDashboard.value,
input_variables=["input", "table_info", "dialect", "supported_chat_type"],
response_format=json.dumps(RESPONSE_FORMAT, indent=4),
template_define=PROMPT_SCENE_DEFINE,
template=_DEFAULT_TEMPLATE,
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
output_parser=DbChatOutputParser(
output_parser=ChatDashboardOutputParser(
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
),
)

View File

@ -6,6 +6,7 @@ from pilot.scene.chat_execution.chat import ChatWithPlugin
from pilot.scene.chat_normal.chat import ChatNormal
from pilot.scene.chat_db.professional_qa.chat import ChatWithDbQA
from pilot.scene.chat_db.auto_execute.chat import ChatWithDbAutoExecute
from pilot.scene.chat_dashboard.chat import ChatDashboard
from pilot.scene.chat_knowledge.url.chat import ChatUrlKnowledge
from pilot.scene.chat_knowledge.custom.chat import ChatNewKnowledge
from pilot.scene.chat_knowledge.default.chat import ChatDefaultKnowledge