mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-06 02:46:40 +00:00
refactor:merge tuyang_end_06
This commit is contained in:
commit
07281d6bb5
@ -452,3 +452,5 @@ class Database:
|
||||
return [
|
||||
(table_comment[0], table_comment[1]) for table_comment in table_comments
|
||||
]
|
||||
|
||||
|
||||
|
@ -77,10 +77,18 @@ CFG = Config()
|
||||
# print(__extract_json(ss))
|
||||
|
||||
if __name__ == "__main__":
|
||||
test1 = Test1()
|
||||
test2 = Test2()
|
||||
test1.write()
|
||||
test1.test()
|
||||
test2.write()
|
||||
test1.test()
|
||||
test2.test()
|
||||
# test1 = Test1()
|
||||
# test2 = Test2()
|
||||
# test1.write()
|
||||
# test1.test()
|
||||
# test2.write()
|
||||
# test1.test()
|
||||
# test2.test()
|
||||
|
||||
# 定义包含元组的列表
|
||||
data = [('key1', 'value1'), ('key2', 'value2'), ('key3', 'value3')]
|
||||
|
||||
# 使用字典解析将列表转换为字典
|
||||
result = {k: v for k, v in data}
|
||||
|
||||
print(result)
|
||||
|
@ -228,6 +228,8 @@ async def chat_completions(dialogue: ConversationVo = Body()):
|
||||
chat_param.update({"db_name": dialogue.select_param})
|
||||
elif ChatScene.ChatDashboard.value == dialogue.chat_mode:
|
||||
chat_param.update({"db_name": dialogue.select_param})
|
||||
## DEFAULT
|
||||
chat_param.update({"report_name": "sales_report"})
|
||||
elif ChatScene.ChatExecution.value == dialogue.chat_mode:
|
||||
chat_param.update({"plugin_selector": dialogue.select_param})
|
||||
elif ChatScene.ChatKnowledge.value == dialogue.chat_mode:
|
||||
|
@ -120,17 +120,45 @@ class BaseOutputParser(ABC):
|
||||
raise ValueError("Model server error!code=" + respObj_ex["error_code"])
|
||||
|
||||
def __extract_json(self, s):
|
||||
i = s.index("{")
|
||||
count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
|
||||
for j, c in enumerate(s[i + 1 :], start=i + 1):
|
||||
if c == "}":
|
||||
count -= 1
|
||||
elif c == "{":
|
||||
count += 1
|
||||
if count == 0:
|
||||
break
|
||||
assert count == 0 # 检查是否找到最后一个'}'
|
||||
return s[i : j + 1]
|
||||
|
||||
temp_json = self.__json_interception(s, True)
|
||||
if not temp_json:
|
||||
temp_json = self.__json_interception(s)
|
||||
try:
|
||||
json.loads(temp_json)
|
||||
return temp_json
|
||||
except Exception as e:
|
||||
raise ValueError("Failed to find a valid json response!" + temp_json)
|
||||
|
||||
def __json_interception(self, s, is_json_array: bool = False):
|
||||
if is_json_array:
|
||||
i = s.index("[")
|
||||
if i <0:
|
||||
return None
|
||||
count = 1
|
||||
for j, c in enumerate(s[i + 1:], start=i + 1):
|
||||
if c == "]":
|
||||
count -= 1
|
||||
elif c == "[":
|
||||
count += 1
|
||||
if count == 0:
|
||||
break
|
||||
assert count == 0
|
||||
return s[i: j + 1]
|
||||
else:
|
||||
i = s.index("{")
|
||||
if i <0:
|
||||
return None
|
||||
count = 1
|
||||
for j, c in enumerate(s[i + 1:], start=i + 1):
|
||||
if c == "}":
|
||||
count -= 1
|
||||
elif c == "{":
|
||||
count += 1
|
||||
if count == 0:
|
||||
break
|
||||
assert count == 0
|
||||
return s[i: j + 1]
|
||||
|
||||
def parse_prompt_response(self, model_out_text) -> T:
|
||||
"""
|
||||
@ -147,9 +175,9 @@ class BaseOutputParser(ABC):
|
||||
# if "```" in cleaned_output:
|
||||
# cleaned_output, _ = cleaned_output.split("```")
|
||||
if cleaned_output.startswith("```json"):
|
||||
cleaned_output = cleaned_output[len("```json") :]
|
||||
cleaned_output = cleaned_output[len("```json"):]
|
||||
if cleaned_output.startswith("```"):
|
||||
cleaned_output = cleaned_output[len("```") :]
|
||||
cleaned_output = cleaned_output[len("```"):]
|
||||
if cleaned_output.endswith("```"):
|
||||
cleaned_output = cleaned_output[: -len("```")]
|
||||
cleaned_output = cleaned_output.strip()
|
||||
@ -158,9 +186,9 @@ class BaseOutputParser(ABC):
|
||||
cleaned_output = self.__extract_json(cleaned_output)
|
||||
cleaned_output = (
|
||||
cleaned_output.strip()
|
||||
.replace("\n", " ")
|
||||
.replace("\\n", " ")
|
||||
.replace("\\", " ")
|
||||
.replace("\n", " ")
|
||||
.replace("\\n", " ")
|
||||
.replace("\\", " ")
|
||||
)
|
||||
return cleaned_output
|
||||
|
||||
|
@ -1,4 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from typing import Dict, NamedTuple, List
|
||||
from pilot.scene.base_message import (
|
||||
HumanMessage,
|
||||
@ -11,7 +13,7 @@ from pilot.configs.config import Config
|
||||
from pilot.common.markdown_text import (
|
||||
generate_htm_table,
|
||||
)
|
||||
from pilot.scene.chat_db.auto_execute.prompt import prompt
|
||||
from pilot.scene.chat_dashboard.prompt import prompt
|
||||
from pilot.scene.chat_dashboard.data_preparation.report_schma import (
|
||||
ChartData,
|
||||
ReportData,
|
||||
@ -28,19 +30,31 @@ class ChatDashboard(BaseChat):
|
||||
def __init__(self, chat_session_id, db_name, user_input, report_name):
|
||||
""" """
|
||||
super().__init__(
|
||||
chat_mode=ChatScene.ChatWithDbExecute,
|
||||
chat_mode=ChatScene.ChatDashboard,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input,
|
||||
)
|
||||
if not db_name:
|
||||
raise ValueError(
|
||||
f"{ChatScene.ChatWithDbExecute.value} mode should chose db!"
|
||||
f"{ChatScene.ChatDashboard.value} mode should chose db!"
|
||||
)
|
||||
self.db_name = db_name
|
||||
self.report_name = report_name
|
||||
self.database = CFG.local_db
|
||||
# 准备DB信息(拿到指定库的链接)
|
||||
self.db_connect = self.database.get_session(self.db_name)
|
||||
self.top_k: int = 5
|
||||
self.dashboard_template = self.__load_dashboard_template(report_name)
|
||||
|
||||
def __load_dashboard_template(self, template_name):
|
||||
|
||||
current_dir = os.getcwd()
|
||||
print(current_dir)
|
||||
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
with open(f"{current_dir}/template/{template_name}/dashboard.json", 'r') as f:
|
||||
data = f.read()
|
||||
return json.loads(data)
|
||||
|
||||
def generate_input_values(self):
|
||||
try:
|
||||
@ -52,34 +66,28 @@ class ChatDashboard(BaseChat):
|
||||
"input": self.current_user_input,
|
||||
"dialect": self.database.dialect,
|
||||
"table_info": self.database.table_simple_info(self.db_connect),
|
||||
"supported_chat_type": "" # TODO
|
||||
"supported_chat_type": self.dashboard_template['supported_chart_type']
|
||||
# "table_info": client.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k)
|
||||
}
|
||||
|
||||
return input_values
|
||||
|
||||
def do_action(self, prompt_response):
|
||||
### TODO 记录整体信息,处理成功的,和未成功的分开记录处理
|
||||
report_data: ReportData = ReportData()
|
||||
chart_datas: List[ChartData] = []
|
||||
for chart_item in prompt_response:
|
||||
try:
|
||||
datas = self.database.run(self.db_connect, chart_item.sql)
|
||||
chart_data: ChartData = ChartData()
|
||||
chart_data.chart_sql = chart_item["sql"]
|
||||
chart_data.chart_type = chart_item["showcase"]
|
||||
chart_data.chart_name = chart_item["title"]
|
||||
chart_data.chart_desc = chart_item["thoughts"]
|
||||
chart_data.column_name = datas[0]
|
||||
chart_data.values = datas
|
||||
chart_datas.append(ChartData(chart_uid=str(uuid.uuid1()),
|
||||
chart_name=chart_item.title,
|
||||
chart_type=chart_item.showcase,
|
||||
chart_desc=chart_item.thoughts,
|
||||
chart_sql=chart_item.sql,
|
||||
column_name=datas[0],
|
||||
values=datas))
|
||||
except Exception as e:
|
||||
# TODO 修复流程
|
||||
print(str(e))
|
||||
|
||||
chart_datas.append(chart_data)
|
||||
|
||||
report_data.conv_uid = self.chat_session_id
|
||||
report_data.template_name = self.report_name
|
||||
report_data.template_introduce = None
|
||||
report_data.charts = chart_datas
|
||||
|
||||
return report_data
|
||||
return ReportData(conv_uid=self.chat_session_id, template_name=self.report_name, template_introduce=None,
|
||||
charts=chart_datas)
|
||||
|
@ -1,6 +1,7 @@
|
||||
import json
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import TypeVar, Union, List, Generic, Any
|
||||
|
||||
from dataclasses import dataclass, asdict
|
||||
|
||||
class ChartData(BaseModel):
|
||||
chart_uid: str
|
||||
@ -10,11 +11,30 @@ class ChartData(BaseModel):
|
||||
chart_sql: str
|
||||
column_name: List
|
||||
values: List
|
||||
style: Any
|
||||
style: Any = None
|
||||
|
||||
def dict(self, *args, **kwargs):
|
||||
return {
|
||||
"chart_uid": self.chart_uid,
|
||||
"chart_name": self.chart_name,
|
||||
"chart_type": self.chart_type,
|
||||
"chart_desc": self.chart_desc,
|
||||
"chart_sql": self.chart_sql,
|
||||
"column_name": [str(item) for item in self.column_name],
|
||||
"values": [[str(item) for item in sublist] for sublist in self.values],
|
||||
"style": self.style
|
||||
}
|
||||
|
||||
class ReportData(BaseModel):
|
||||
conv_uid: str
|
||||
template_name: str
|
||||
template_introduce: str
|
||||
template_introduce: str = None
|
||||
charts: List[ChartData]
|
||||
|
||||
def prepare_dict(self):
|
||||
return {
|
||||
"conv_uid": self.conv_uid,
|
||||
"template_name": self.template_name,
|
||||
"template_introduce": self.template_introduce,
|
||||
"charts": [chart.dict() for chart in self.charts]
|
||||
}
|
@ -1,12 +1,13 @@
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass, asdict
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, NamedTuple, List
|
||||
import pandas as pd
|
||||
from pilot.utils import build_logger
|
||||
from pilot.out_parser.base import BaseOutputParser, T
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
|
||||
from pilot.scene.base import ChatScene
|
||||
|
||||
class ChartItem(NamedTuple):
|
||||
sql: str
|
||||
@ -26,7 +27,7 @@ class ChatDashboardOutputParser(BaseOutputParser):
|
||||
clean_str = super().parse_prompt_response(model_out_text)
|
||||
print("clean prompt response:", clean_str)
|
||||
response = json.loads(clean_str)
|
||||
chart_items = List[ChartItem]
|
||||
chart_items: List[ChartItem] = []
|
||||
for item in response:
|
||||
chart_items.append(
|
||||
ChartItem(
|
||||
@ -36,10 +37,8 @@ class ChatDashboardOutputParser(BaseOutputParser):
|
||||
return chart_items
|
||||
|
||||
def parse_view_response(self, speak, data) -> str:
|
||||
### TODO
|
||||
|
||||
return data
|
||||
return json.dumps(data.prepare_dict())
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "chat_dashboard"
|
||||
return ChatScene.ChatDashboard.value
|
||||
|
@ -3,24 +3,26 @@ import importlib
|
||||
from pilot.prompts.prompt_new import PromptTemplate
|
||||
from pilot.configs.config import Config
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.scene.chat_db.auto_execute.out_parser import DbChatOutputParser, SqlAction
|
||||
from pilot.scene.chat_dashboard.out_parser import ChatDashboardOutputParser, ChartItem
|
||||
from pilot.common.schema import SeparatorStyle
|
||||
|
||||
CFG = Config()
|
||||
|
||||
PROMPT_SCENE_DEFINE = """You are a {dialect} data analysis expert, please provide a professional data analysis solution according to the following situations"""
|
||||
PROMPT_SCENE_DEFINE = None
|
||||
|
||||
_DEFAULT_TEMPLATE = """
|
||||
According to the structure definition in the following tables:
|
||||
{table_info}
|
||||
Provide a professional data analysis with as few dimensions as possible, and the upper limit does not exceed 8 dimensions.
|
||||
Provide professional data analysis, use as few dimensions as possible, but no less than three, and no more than eight dimensions.
|
||||
Used to support goal: {input}
|
||||
|
||||
Use the chart display method in the following range:
|
||||
Pay attention to the length of the output content of the analysis result, do not exceed 4000tokens
|
||||
According to the characteristics of the analyzed data, choose the best one from the charts provided below to display,chart types:
|
||||
{supported_chat_type}
|
||||
give {dialect} data analysis SQL, analysis title, display method and analytical thinking,respond in the following json format:
|
||||
|
||||
Give {dialect} data analysis SQL, analysis title, display method and analytical thinking,respond in the following json format:
|
||||
{response}
|
||||
Do not use unprovided fields and do not use unprovided data in the where condition of sql.
|
||||
Ensure the response is correct json and can be parsed by Python json.loads
|
||||
"""
|
||||
|
||||
@ -38,13 +40,13 @@ PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||
PROMPT_NEED_NEED_STREAM_OUT = False
|
||||
|
||||
prompt = PromptTemplate(
|
||||
template_scene=ChatScene.ChatWithDbExecute.value,
|
||||
template_scene=ChatScene.ChatDashboard.value,
|
||||
input_variables=["input", "table_info", "dialect", "supported_chat_type"],
|
||||
response_format=json.dumps(RESPONSE_FORMAT, indent=4),
|
||||
template_define=PROMPT_SCENE_DEFINE,
|
||||
template=_DEFAULT_TEMPLATE,
|
||||
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
|
||||
output_parser=DbChatOutputParser(
|
||||
output_parser=ChatDashboardOutputParser(
|
||||
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
|
||||
),
|
||||
)
|
||||
|
@ -6,6 +6,7 @@ from pilot.scene.chat_execution.chat import ChatWithPlugin
|
||||
from pilot.scene.chat_normal.chat import ChatNormal
|
||||
from pilot.scene.chat_db.professional_qa.chat import ChatWithDbQA
|
||||
from pilot.scene.chat_db.auto_execute.chat import ChatWithDbAutoExecute
|
||||
from pilot.scene.chat_dashboard.chat import ChatDashboard
|
||||
from pilot.scene.chat_knowledge.url.chat import ChatUrlKnowledge
|
||||
from pilot.scene.chat_knowledge.custom.chat import ChatNewKnowledge
|
||||
from pilot.scene.chat_knowledge.default.chat import ChatDefaultKnowledge
|
||||
|
@ -168,7 +168,7 @@ async def api_generate_stream(request: Request):
|
||||
|
||||
|
||||
@app.post("/generate")
|
||||
def generate(prompt_request: PromptRequest):
|
||||
def generate(prompt_request: PromptRequest)->str:
|
||||
params = {
|
||||
"prompt": prompt_request.prompt,
|
||||
"temperature": prompt_request.temperature,
|
||||
|
@ -690,9 +690,6 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--model_list_mode", type=str, default="once", choices=["once", "reload"]
|
||||
)
|
||||
parser.add_argument(
|
||||
"-new", "--new", action="store_true", help="enable new http mode"
|
||||
)
|
||||
|
||||
# old version server config
|
||||
parser.add_argument("--host", type=str, default="0.0.0.0")
|
||||
@ -704,20 +701,14 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
server_init(args)
|
||||
dbs = CFG.local_db.get_database_list()
|
||||
if args.new:
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=5000)
|
||||
else:
|
||||
### Compatibility mode starts the old version server by default
|
||||
demo = build_webdemo()
|
||||
demo.queue(
|
||||
concurrency_count=args.concurrency_count,
|
||||
status_update_rate=10,
|
||||
api_open=False,
|
||||
).launch(
|
||||
server_name=args.host,
|
||||
server_port=args.port,
|
||||
share=args.share,
|
||||
max_threads=200,
|
||||
)
|
||||
demo = build_webdemo()
|
||||
demo.queue(
|
||||
concurrency_count=args.concurrency_count,
|
||||
status_update_rate=10,
|
||||
api_open=False,
|
||||
).launch(
|
||||
server_name=args.host,
|
||||
server_port=args.port,
|
||||
share=args.share,
|
||||
max_threads=200,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user