diff --git a/pilot/common/sql_database.py b/pilot/common/sql_database.py index d59a9d33f..501227873 100644 --- a/pilot/common/sql_database.py +++ b/pilot/common/sql_database.py @@ -452,3 +452,5 @@ class Database: return [ (table_comment[0], table_comment[1]) for table_comment in table_comments ] + + diff --git a/pilot/connections/rdbms/py_study/pd_study.py b/pilot/connections/rdbms/py_study/pd_study.py index 31b060ef1..411ce3935 100644 --- a/pilot/connections/rdbms/py_study/pd_study.py +++ b/pilot/connections/rdbms/py_study/pd_study.py @@ -77,10 +77,18 @@ CFG = Config() # print(__extract_json(ss)) if __name__ == "__main__": - test1 = Test1() - test2 = Test2() - test1.write() - test1.test() - test2.write() - test1.test() - test2.test() + # test1 = Test1() + # test2 = Test2() + # test1.write() + # test1.test() + # test2.write() + # test1.test() + # test2.test() + + # 定义包含元组的列表 + data = [('key1', 'value1'), ('key2', 'value2'), ('key3', 'value3')] + + # 使用字典解析将列表转换为字典 + result = {k: v for k, v in data} + + print(result) diff --git a/pilot/openapi/api_v1/api_v1.py b/pilot/openapi/api_v1/api_v1.py index 0d82c0e46..7f6f66953 100644 --- a/pilot/openapi/api_v1/api_v1.py +++ b/pilot/openapi/api_v1/api_v1.py @@ -228,6 +228,8 @@ async def chat_completions(dialogue: ConversationVo = Body()): chat_param.update({"db_name": dialogue.select_param}) elif ChatScene.ChatDashboard.value == dialogue.chat_mode: chat_param.update({"db_name": dialogue.select_param}) + ## DEFAULT + chat_param.update({"report_name": "sales_report"}) elif ChatScene.ChatExecution.value == dialogue.chat_mode: chat_param.update({"plugin_selector": dialogue.select_param}) elif ChatScene.ChatKnowledge.value == dialogue.chat_mode: diff --git a/pilot/out_parser/base.py b/pilot/out_parser/base.py index 208bb148f..2b3dd9b5f 100644 --- a/pilot/out_parser/base.py +++ b/pilot/out_parser/base.py @@ -120,17 +120,45 @@ class BaseOutputParser(ABC): raise ValueError("Model server error!code=" + respObj_ex["error_code"]) def __extract_json(self, s): - i = s.index("{") - count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数 - for j, c in enumerate(s[i + 1 :], start=i + 1): - if c == "}": - count -= 1 - elif c == "{": - count += 1 - if count == 0: - break - assert count == 0 # 检查是否找到最后一个'}' - return s[i : j + 1] + + temp_json = self.__json_interception(s, True) + if not temp_json: + temp_json = self.__json_interception(s) + try: + json.loads(temp_json) + return temp_json + except Exception as e: + raise ValueError("Failed to find a valid json response!" + temp_json) + + def __json_interception(self, s, is_json_array: bool = False): + if is_json_array: + i = s.index("[") + if i <0: + return None + count = 1 + for j, c in enumerate(s[i + 1:], start=i + 1): + if c == "]": + count -= 1 + elif c == "[": + count += 1 + if count == 0: + break + assert count == 0 + return s[i: j + 1] + else: + i = s.index("{") + if i <0: + return None + count = 1 + for j, c in enumerate(s[i + 1:], start=i + 1): + if c == "}": + count -= 1 + elif c == "{": + count += 1 + if count == 0: + break + assert count == 0 + return s[i: j + 1] def parse_prompt_response(self, model_out_text) -> T: """ @@ -147,9 +175,9 @@ class BaseOutputParser(ABC): # if "```" in cleaned_output: # cleaned_output, _ = cleaned_output.split("```") if cleaned_output.startswith("```json"): - cleaned_output = cleaned_output[len("```json") :] + cleaned_output = cleaned_output[len("```json"):] if cleaned_output.startswith("```"): - cleaned_output = cleaned_output[len("```") :] + cleaned_output = cleaned_output[len("```"):] if cleaned_output.endswith("```"): cleaned_output = cleaned_output[: -len("```")] cleaned_output = cleaned_output.strip() @@ -158,9 +186,9 @@ class BaseOutputParser(ABC): cleaned_output = self.__extract_json(cleaned_output) cleaned_output = ( cleaned_output.strip() - .replace("\n", " ") - .replace("\\n", " ") - .replace("\\", " ") + .replace("\n", " ") + .replace("\\n", " ") + .replace("\\", " ") ) return cleaned_output diff --git a/pilot/scene/chat_dashboard/chat.py b/pilot/scene/chat_dashboard/chat.py index 805f13eaf..84a87c5b6 100644 --- a/pilot/scene/chat_dashboard/chat.py +++ b/pilot/scene/chat_dashboard/chat.py @@ -1,4 +1,6 @@ import json +import os +import uuid from typing import Dict, NamedTuple, List from pilot.scene.base_message import ( HumanMessage, @@ -11,7 +13,7 @@ from pilot.configs.config import Config from pilot.common.markdown_text import ( generate_htm_table, ) -from pilot.scene.chat_db.auto_execute.prompt import prompt +from pilot.scene.chat_dashboard.prompt import prompt from pilot.scene.chat_dashboard.data_preparation.report_schma import ( ChartData, ReportData, @@ -28,19 +30,31 @@ class ChatDashboard(BaseChat): def __init__(self, chat_session_id, db_name, user_input, report_name): """ """ super().__init__( - chat_mode=ChatScene.ChatWithDbExecute, + chat_mode=ChatScene.ChatDashboard, chat_session_id=chat_session_id, current_user_input=user_input, ) if not db_name: raise ValueError( - f"{ChatScene.ChatWithDbExecute.value} mode should chose db!" + f"{ChatScene.ChatDashboard.value} mode should chose db!" ) + self.db_name = db_name self.report_name = report_name self.database = CFG.local_db # 准备DB信息(拿到指定库的链接) self.db_connect = self.database.get_session(self.db_name) self.top_k: int = 5 + self.dashboard_template = self.__load_dashboard_template(report_name) + + def __load_dashboard_template(self, template_name): + + current_dir = os.getcwd() + print(current_dir) + + current_dir = os.path.dirname(os.path.abspath(__file__)) + with open(f"{current_dir}/template/{template_name}/dashboard.json", 'r') as f: + data = f.read() + return json.loads(data) def generate_input_values(self): try: @@ -52,34 +66,28 @@ class ChatDashboard(BaseChat): "input": self.current_user_input, "dialect": self.database.dialect, "table_info": self.database.table_simple_info(self.db_connect), - "supported_chat_type": "" # TODO + "supported_chat_type": self.dashboard_template['supported_chart_type'] # "table_info": client.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k) } + return input_values def do_action(self, prompt_response): ### TODO 记录整体信息,处理成功的,和未成功的分开记录处理 - report_data: ReportData = ReportData() chart_datas: List[ChartData] = [] for chart_item in prompt_response: try: datas = self.database.run(self.db_connect, chart_item.sql) - chart_data: ChartData = ChartData() - chart_data.chart_sql = chart_item["sql"] - chart_data.chart_type = chart_item["showcase"] - chart_data.chart_name = chart_item["title"] - chart_data.chart_desc = chart_item["thoughts"] - chart_data.column_name = datas[0] - chart_data.values = datas + chart_datas.append(ChartData(chart_uid=str(uuid.uuid1()), + chart_name=chart_item.title, + chart_type=chart_item.showcase, + chart_desc=chart_item.thoughts, + chart_sql=chart_item.sql, + column_name=datas[0], + values=datas)) except Exception as e: # TODO 修复流程 print(str(e)) - chart_datas.append(chart_data) - - report_data.conv_uid = self.chat_session_id - report_data.template_name = self.report_name - report_data.template_introduce = None - report_data.charts = chart_datas - - return report_data + return ReportData(conv_uid=self.chat_session_id, template_name=self.report_name, template_introduce=None, + charts=chart_datas) diff --git a/pilot/scene/chat_dashboard/data_preparation/report_schma.py b/pilot/scene/chat_dashboard/data_preparation/report_schma.py index 9327fc4d5..4b9ca9f58 100644 --- a/pilot/scene/chat_dashboard/data_preparation/report_schma.py +++ b/pilot/scene/chat_dashboard/data_preparation/report_schma.py @@ -1,6 +1,7 @@ +import json from pydantic import BaseModel, Field from typing import TypeVar, Union, List, Generic, Any - +from dataclasses import dataclass, asdict class ChartData(BaseModel): chart_uid: str @@ -10,11 +11,30 @@ class ChartData(BaseModel): chart_sql: str column_name: List values: List - style: Any + style: Any = None + def dict(self, *args, **kwargs): + return { + "chart_uid": self.chart_uid, + "chart_name": self.chart_name, + "chart_type": self.chart_type, + "chart_desc": self.chart_desc, + "chart_sql": self.chart_sql, + "column_name": [str(item) for item in self.column_name], + "values": [[str(item) for item in sublist] for sublist in self.values], + "style": self.style + } class ReportData(BaseModel): conv_uid: str template_name: str - template_introduce: str + template_introduce: str = None charts: List[ChartData] + + def prepare_dict(self): + return { + "conv_uid": self.conv_uid, + "template_name": self.template_name, + "template_introduce": self.template_introduce, + "charts": [chart.dict() for chart in self.charts] + } \ No newline at end of file diff --git a/pilot/scene/chat_dashboard/out_parser.py b/pilot/scene/chat_dashboard/out_parser.py index 079b5f59a..56f6fe89e 100644 --- a/pilot/scene/chat_dashboard/out_parser.py +++ b/pilot/scene/chat_dashboard/out_parser.py @@ -1,12 +1,13 @@ import json import re +from dataclasses import dataclass, asdict from abc import ABC, abstractmethod from typing import Dict, NamedTuple, List import pandas as pd from pilot.utils import build_logger from pilot.out_parser.base import BaseOutputParser, T from pilot.configs.model_config import LOGDIR - +from pilot.scene.base import ChatScene class ChartItem(NamedTuple): sql: str @@ -26,7 +27,7 @@ class ChatDashboardOutputParser(BaseOutputParser): clean_str = super().parse_prompt_response(model_out_text) print("clean prompt response:", clean_str) response = json.loads(clean_str) - chart_items = List[ChartItem] + chart_items: List[ChartItem] = [] for item in response: chart_items.append( ChartItem( @@ -36,10 +37,8 @@ class ChatDashboardOutputParser(BaseOutputParser): return chart_items def parse_view_response(self, speak, data) -> str: - ### TODO - - return data + return json.dumps(data.prepare_dict()) @property def _type(self) -> str: - return "chat_dashboard" + return ChatScene.ChatDashboard.value diff --git a/pilot/scene/chat_dashboard/prompt.py b/pilot/scene/chat_dashboard/prompt.py index 481ecac22..0053e4a51 100644 --- a/pilot/scene/chat_dashboard/prompt.py +++ b/pilot/scene/chat_dashboard/prompt.py @@ -3,24 +3,26 @@ import importlib from pilot.prompts.prompt_new import PromptTemplate from pilot.configs.config import Config from pilot.scene.base import ChatScene -from pilot.scene.chat_db.auto_execute.out_parser import DbChatOutputParser, SqlAction +from pilot.scene.chat_dashboard.out_parser import ChatDashboardOutputParser, ChartItem from pilot.common.schema import SeparatorStyle CFG = Config() PROMPT_SCENE_DEFINE = """You are a {dialect} data analysis expert, please provide a professional data analysis solution according to the following situations""" -PROMPT_SCENE_DEFINE = None _DEFAULT_TEMPLATE = """ According to the structure definition in the following tables: {table_info} -Provide a professional data analysis with as few dimensions as possible, and the upper limit does not exceed 8 dimensions. +Provide professional data analysis, use as few dimensions as possible, but no less than three, and no more than eight dimensions. Used to support goal: {input} -Use the chart display method in the following range: +Pay attention to the length of the output content of the analysis result, do not exceed 4000tokens +According to the characteristics of the analyzed data, choose the best one from the charts provided below to display,chart types: {supported_chat_type} -give {dialect} data analysis SQL, analysis title, display method and analytical thinking,respond in the following json format: + +Give {dialect} data analysis SQL, analysis title, display method and analytical thinking,respond in the following json format: {response} +Do not use unprovided fields and do not use unprovided data in the where condition of sql. Ensure the response is correct json and can be parsed by Python json.loads """ @@ -38,13 +40,13 @@ PROMPT_SEP = SeparatorStyle.SINGLE.value PROMPT_NEED_NEED_STREAM_OUT = False prompt = PromptTemplate( - template_scene=ChatScene.ChatWithDbExecute.value, + template_scene=ChatScene.ChatDashboard.value, input_variables=["input", "table_info", "dialect", "supported_chat_type"], response_format=json.dumps(RESPONSE_FORMAT, indent=4), template_define=PROMPT_SCENE_DEFINE, template=_DEFAULT_TEMPLATE, stream_out=PROMPT_NEED_NEED_STREAM_OUT, - output_parser=DbChatOutputParser( + output_parser=ChatDashboardOutputParser( sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT ), ) diff --git a/pilot/scene/chat_factory.py b/pilot/scene/chat_factory.py index 436860d7f..42edf87da 100644 --- a/pilot/scene/chat_factory.py +++ b/pilot/scene/chat_factory.py @@ -6,6 +6,7 @@ from pilot.scene.chat_execution.chat import ChatWithPlugin from pilot.scene.chat_normal.chat import ChatNormal from pilot.scene.chat_db.professional_qa.chat import ChatWithDbQA from pilot.scene.chat_db.auto_execute.chat import ChatWithDbAutoExecute +from pilot.scene.chat_dashboard.chat import ChatDashboard from pilot.scene.chat_knowledge.url.chat import ChatUrlKnowledge from pilot.scene.chat_knowledge.custom.chat import ChatNewKnowledge from pilot.scene.chat_knowledge.default.chat import ChatDefaultKnowledge diff --git a/pilot/server/llmserver.py b/pilot/server/llmserver.py index d87540a8e..e7b5c877f 100644 --- a/pilot/server/llmserver.py +++ b/pilot/server/llmserver.py @@ -168,7 +168,7 @@ async def api_generate_stream(request: Request): @app.post("/generate") -def generate(prompt_request: PromptRequest): +def generate(prompt_request: PromptRequest)->str: params = { "prompt": prompt_request.prompt, "temperature": prompt_request.temperature, diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index d39637d89..973093ebc 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -690,9 +690,6 @@ if __name__ == "__main__": parser.add_argument( "--model_list_mode", type=str, default="once", choices=["once", "reload"] ) - parser.add_argument( - "-new", "--new", action="store_true", help="enable new http mode" - ) # old version server config parser.add_argument("--host", type=str, default="0.0.0.0") @@ -704,20 +701,14 @@ if __name__ == "__main__": args = parser.parse_args() server_init(args) dbs = CFG.local_db.get_database_list() - if args.new: - import uvicorn - - uvicorn.run(app, host="0.0.0.0", port=5000) - else: - ### Compatibility mode starts the old version server by default - demo = build_webdemo() - demo.queue( - concurrency_count=args.concurrency_count, - status_update_rate=10, - api_open=False, - ).launch( - server_name=args.host, - server_port=args.port, - share=args.share, - max_threads=200, - ) + demo = build_webdemo() + demo.queue( + concurrency_count=args.concurrency_count, + status_update_rate=10, + api_open=False, + ).launch( + server_name=args.host, + server_port=args.port, + share=args.share, + max_threads=200, + )