From 30c7803451a25b3ed13a4ebe56ee48f720e1dd4e Mon Sep 17 00:00:00 2001 From: "tuyang.yhj" Date: Mon, 3 Jul 2023 16:17:32 +0800 Subject: [PATCH 1/2] WEB API independent --- pilot/common/sql_database.py | 2 + pilot/connections/rdbms/py_study/pd_study.py | 22 ++++--- pilot/openapi/api_v1/api_v1.py | 2 + pilot/out_parser/base.py | 60 ++++++++++++++----- pilot/scene/chat_dashboard/chat.py | 48 ++++++++------- .../data_preparation/report_schma.py | 26 +++++++- pilot/scene/chat_dashboard/out_parser.py | 11 ++-- pilot/scene/chat_dashboard/prompt.py | 9 ++- pilot/scene/chat_factory.py | 1 + 9 files changed, 124 insertions(+), 57 deletions(-) diff --git a/pilot/common/sql_database.py b/pilot/common/sql_database.py index d59a9d33f..501227873 100644 --- a/pilot/common/sql_database.py +++ b/pilot/common/sql_database.py @@ -452,3 +452,5 @@ class Database: return [ (table_comment[0], table_comment[1]) for table_comment in table_comments ] + + diff --git a/pilot/connections/rdbms/py_study/pd_study.py b/pilot/connections/rdbms/py_study/pd_study.py index 31b060ef1..411ce3935 100644 --- a/pilot/connections/rdbms/py_study/pd_study.py +++ b/pilot/connections/rdbms/py_study/pd_study.py @@ -77,10 +77,18 @@ CFG = Config() # print(__extract_json(ss)) if __name__ == "__main__": - test1 = Test1() - test2 = Test2() - test1.write() - test1.test() - test2.write() - test1.test() - test2.test() + # test1 = Test1() + # test2 = Test2() + # test1.write() + # test1.test() + # test2.write() + # test1.test() + # test2.test() + + # 定义包含元组的列表 + data = [('key1', 'value1'), ('key2', 'value2'), ('key3', 'value3')] + + # 使用字典解析将列表转换为字典 + result = {k: v for k, v in data} + + print(result) diff --git a/pilot/openapi/api_v1/api_v1.py b/pilot/openapi/api_v1/api_v1.py index 0d82c0e46..7f6f66953 100644 --- a/pilot/openapi/api_v1/api_v1.py +++ b/pilot/openapi/api_v1/api_v1.py @@ -228,6 +228,8 @@ async def chat_completions(dialogue: ConversationVo = Body()): chat_param.update({"db_name": dialogue.select_param}) elif ChatScene.ChatDashboard.value == dialogue.chat_mode: chat_param.update({"db_name": dialogue.select_param}) + ## DEFAULT + chat_param.update({"report_name": "sales_report"}) elif ChatScene.ChatExecution.value == dialogue.chat_mode: chat_param.update({"plugin_selector": dialogue.select_param}) elif ChatScene.ChatKnowledge.value == dialogue.chat_mode: diff --git a/pilot/out_parser/base.py b/pilot/out_parser/base.py index 208bb148f..2b3dd9b5f 100644 --- a/pilot/out_parser/base.py +++ b/pilot/out_parser/base.py @@ -120,17 +120,45 @@ class BaseOutputParser(ABC): raise ValueError("Model server error!code=" + respObj_ex["error_code"]) def __extract_json(self, s): - i = s.index("{") - count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数 - for j, c in enumerate(s[i + 1 :], start=i + 1): - if c == "}": - count -= 1 - elif c == "{": - count += 1 - if count == 0: - break - assert count == 0 # 检查是否找到最后一个'}' - return s[i : j + 1] + + temp_json = self.__json_interception(s, True) + if not temp_json: + temp_json = self.__json_interception(s) + try: + json.loads(temp_json) + return temp_json + except Exception as e: + raise ValueError("Failed to find a valid json response!" + temp_json) + + def __json_interception(self, s, is_json_array: bool = False): + if is_json_array: + i = s.index("[") + if i <0: + return None + count = 1 + for j, c in enumerate(s[i + 1:], start=i + 1): + if c == "]": + count -= 1 + elif c == "[": + count += 1 + if count == 0: + break + assert count == 0 + return s[i: j + 1] + else: + i = s.index("{") + if i <0: + return None + count = 1 + for j, c in enumerate(s[i + 1:], start=i + 1): + if c == "}": + count -= 1 + elif c == "{": + count += 1 + if count == 0: + break + assert count == 0 + return s[i: j + 1] def parse_prompt_response(self, model_out_text) -> T: """ @@ -147,9 +175,9 @@ class BaseOutputParser(ABC): # if "```" in cleaned_output: # cleaned_output, _ = cleaned_output.split("```") if cleaned_output.startswith("```json"): - cleaned_output = cleaned_output[len("```json") :] + cleaned_output = cleaned_output[len("```json"):] if cleaned_output.startswith("```"): - cleaned_output = cleaned_output[len("```") :] + cleaned_output = cleaned_output[len("```"):] if cleaned_output.endswith("```"): cleaned_output = cleaned_output[: -len("```")] cleaned_output = cleaned_output.strip() @@ -158,9 +186,9 @@ class BaseOutputParser(ABC): cleaned_output = self.__extract_json(cleaned_output) cleaned_output = ( cleaned_output.strip() - .replace("\n", " ") - .replace("\\n", " ") - .replace("\\", " ") + .replace("\n", " ") + .replace("\\n", " ") + .replace("\\", " ") ) return cleaned_output diff --git a/pilot/scene/chat_dashboard/chat.py b/pilot/scene/chat_dashboard/chat.py index 805f13eaf..84a87c5b6 100644 --- a/pilot/scene/chat_dashboard/chat.py +++ b/pilot/scene/chat_dashboard/chat.py @@ -1,4 +1,6 @@ import json +import os +import uuid from typing import Dict, NamedTuple, List from pilot.scene.base_message import ( HumanMessage, @@ -11,7 +13,7 @@ from pilot.configs.config import Config from pilot.common.markdown_text import ( generate_htm_table, ) -from pilot.scene.chat_db.auto_execute.prompt import prompt +from pilot.scene.chat_dashboard.prompt import prompt from pilot.scene.chat_dashboard.data_preparation.report_schma import ( ChartData, ReportData, @@ -28,19 +30,31 @@ class ChatDashboard(BaseChat): def __init__(self, chat_session_id, db_name, user_input, report_name): """ """ super().__init__( - chat_mode=ChatScene.ChatWithDbExecute, + chat_mode=ChatScene.ChatDashboard, chat_session_id=chat_session_id, current_user_input=user_input, ) if not db_name: raise ValueError( - f"{ChatScene.ChatWithDbExecute.value} mode should chose db!" + f"{ChatScene.ChatDashboard.value} mode should chose db!" ) + self.db_name = db_name self.report_name = report_name self.database = CFG.local_db # 准备DB信息(拿到指定库的链接) self.db_connect = self.database.get_session(self.db_name) self.top_k: int = 5 + self.dashboard_template = self.__load_dashboard_template(report_name) + + def __load_dashboard_template(self, template_name): + + current_dir = os.getcwd() + print(current_dir) + + current_dir = os.path.dirname(os.path.abspath(__file__)) + with open(f"{current_dir}/template/{template_name}/dashboard.json", 'r') as f: + data = f.read() + return json.loads(data) def generate_input_values(self): try: @@ -52,34 +66,28 @@ class ChatDashboard(BaseChat): "input": self.current_user_input, "dialect": self.database.dialect, "table_info": self.database.table_simple_info(self.db_connect), - "supported_chat_type": "" # TODO + "supported_chat_type": self.dashboard_template['supported_chart_type'] # "table_info": client.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k) } + return input_values def do_action(self, prompt_response): ### TODO 记录整体信息,处理成功的,和未成功的分开记录处理 - report_data: ReportData = ReportData() chart_datas: List[ChartData] = [] for chart_item in prompt_response: try: datas = self.database.run(self.db_connect, chart_item.sql) - chart_data: ChartData = ChartData() - chart_data.chart_sql = chart_item["sql"] - chart_data.chart_type = chart_item["showcase"] - chart_data.chart_name = chart_item["title"] - chart_data.chart_desc = chart_item["thoughts"] - chart_data.column_name = datas[0] - chart_data.values = datas + chart_datas.append(ChartData(chart_uid=str(uuid.uuid1()), + chart_name=chart_item.title, + chart_type=chart_item.showcase, + chart_desc=chart_item.thoughts, + chart_sql=chart_item.sql, + column_name=datas[0], + values=datas)) except Exception as e: # TODO 修复流程 print(str(e)) - chart_datas.append(chart_data) - - report_data.conv_uid = self.chat_session_id - report_data.template_name = self.report_name - report_data.template_introduce = None - report_data.charts = chart_datas - - return report_data + return ReportData(conv_uid=self.chat_session_id, template_name=self.report_name, template_introduce=None, + charts=chart_datas) diff --git a/pilot/scene/chat_dashboard/data_preparation/report_schma.py b/pilot/scene/chat_dashboard/data_preparation/report_schma.py index 9327fc4d5..4b9ca9f58 100644 --- a/pilot/scene/chat_dashboard/data_preparation/report_schma.py +++ b/pilot/scene/chat_dashboard/data_preparation/report_schma.py @@ -1,6 +1,7 @@ +import json from pydantic import BaseModel, Field from typing import TypeVar, Union, List, Generic, Any - +from dataclasses import dataclass, asdict class ChartData(BaseModel): chart_uid: str @@ -10,11 +11,30 @@ class ChartData(BaseModel): chart_sql: str column_name: List values: List - style: Any + style: Any = None + def dict(self, *args, **kwargs): + return { + "chart_uid": self.chart_uid, + "chart_name": self.chart_name, + "chart_type": self.chart_type, + "chart_desc": self.chart_desc, + "chart_sql": self.chart_sql, + "column_name": [str(item) for item in self.column_name], + "values": [[str(item) for item in sublist] for sublist in self.values], + "style": self.style + } class ReportData(BaseModel): conv_uid: str template_name: str - template_introduce: str + template_introduce: str = None charts: List[ChartData] + + def prepare_dict(self): + return { + "conv_uid": self.conv_uid, + "template_name": self.template_name, + "template_introduce": self.template_introduce, + "charts": [chart.dict() for chart in self.charts] + } \ No newline at end of file diff --git a/pilot/scene/chat_dashboard/out_parser.py b/pilot/scene/chat_dashboard/out_parser.py index 079b5f59a..56f6fe89e 100644 --- a/pilot/scene/chat_dashboard/out_parser.py +++ b/pilot/scene/chat_dashboard/out_parser.py @@ -1,12 +1,13 @@ import json import re +from dataclasses import dataclass, asdict from abc import ABC, abstractmethod from typing import Dict, NamedTuple, List import pandas as pd from pilot.utils import build_logger from pilot.out_parser.base import BaseOutputParser, T from pilot.configs.model_config import LOGDIR - +from pilot.scene.base import ChatScene class ChartItem(NamedTuple): sql: str @@ -26,7 +27,7 @@ class ChatDashboardOutputParser(BaseOutputParser): clean_str = super().parse_prompt_response(model_out_text) print("clean prompt response:", clean_str) response = json.loads(clean_str) - chart_items = List[ChartItem] + chart_items: List[ChartItem] = [] for item in response: chart_items.append( ChartItem( @@ -36,10 +37,8 @@ class ChatDashboardOutputParser(BaseOutputParser): return chart_items def parse_view_response(self, speak, data) -> str: - ### TODO - - return data + return json.dumps(data.prepare_dict()) @property def _type(self) -> str: - return "chat_dashboard" + return ChatScene.ChatDashboard.value diff --git a/pilot/scene/chat_dashboard/prompt.py b/pilot/scene/chat_dashboard/prompt.py index 481ecac22..f0b231116 100644 --- a/pilot/scene/chat_dashboard/prompt.py +++ b/pilot/scene/chat_dashboard/prompt.py @@ -3,18 +3,17 @@ import importlib from pilot.prompts.prompt_new import PromptTemplate from pilot.configs.config import Config from pilot.scene.base import ChatScene -from pilot.scene.chat_db.auto_execute.out_parser import DbChatOutputParser, SqlAction +from pilot.scene.chat_dashboard.out_parser import ChatDashboardOutputParser, ChartItem from pilot.common.schema import SeparatorStyle CFG = Config() PROMPT_SCENE_DEFINE = """You are a {dialect} data analysis expert, please provide a professional data analysis solution according to the following situations""" -PROMPT_SCENE_DEFINE = None _DEFAULT_TEMPLATE = """ According to the structure definition in the following tables: {table_info} -Provide a professional data analysis with as few dimensions as possible, and the upper limit does not exceed 8 dimensions. +Provide a professional data analysis with as few dimensions as possible, and the upper limit does not exceed 5 dimensions. Used to support goal: {input} Use the chart display method in the following range: @@ -38,13 +37,13 @@ PROMPT_SEP = SeparatorStyle.SINGLE.value PROMPT_NEED_NEED_STREAM_OUT = False prompt = PromptTemplate( - template_scene=ChatScene.ChatWithDbExecute.value, + template_scene=ChatScene.ChatDashboard.value, input_variables=["input", "table_info", "dialect", "supported_chat_type"], response_format=json.dumps(RESPONSE_FORMAT, indent=4), template_define=PROMPT_SCENE_DEFINE, template=_DEFAULT_TEMPLATE, stream_out=PROMPT_NEED_NEED_STREAM_OUT, - output_parser=DbChatOutputParser( + output_parser=ChatDashboardOutputParser( sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT ), ) diff --git a/pilot/scene/chat_factory.py b/pilot/scene/chat_factory.py index 436860d7f..42edf87da 100644 --- a/pilot/scene/chat_factory.py +++ b/pilot/scene/chat_factory.py @@ -6,6 +6,7 @@ from pilot.scene.chat_execution.chat import ChatWithPlugin from pilot.scene.chat_normal.chat import ChatNormal from pilot.scene.chat_db.professional_qa.chat import ChatWithDbQA from pilot.scene.chat_db.auto_execute.chat import ChatWithDbAutoExecute +from pilot.scene.chat_dashboard.chat import ChatDashboard from pilot.scene.chat_knowledge.url.chat import ChatUrlKnowledge from pilot.scene.chat_knowledge.custom.chat import ChatNewKnowledge from pilot.scene.chat_knowledge.default.chat import ChatDefaultKnowledge From b3dde34ec4d425d544c110f28cb387316893e5f3 Mon Sep 17 00:00:00 2001 From: "tuyang.yhj" Date: Mon, 3 Jul 2023 18:37:25 +0800 Subject: [PATCH 2/2] WEB API independent --- pilot/scene/chat_dashboard/prompt.py | 9 +++++--- pilot/server/llmserver.py | 2 +- pilot/server/webserver.py | 31 ++++++++++------------------ 3 files changed, 18 insertions(+), 24 deletions(-) diff --git a/pilot/scene/chat_dashboard/prompt.py b/pilot/scene/chat_dashboard/prompt.py index f0b231116..0053e4a51 100644 --- a/pilot/scene/chat_dashboard/prompt.py +++ b/pilot/scene/chat_dashboard/prompt.py @@ -13,13 +13,16 @@ PROMPT_SCENE_DEFINE = """You are a {dialect} data analysis expert, please provid _DEFAULT_TEMPLATE = """ According to the structure definition in the following tables: {table_info} -Provide a professional data analysis with as few dimensions as possible, and the upper limit does not exceed 5 dimensions. +Provide professional data analysis, use as few dimensions as possible, but no less than three, and no more than eight dimensions. Used to support goal: {input} -Use the chart display method in the following range: +Pay attention to the length of the output content of the analysis result, do not exceed 4000tokens +According to the characteristics of the analyzed data, choose the best one from the charts provided below to display,chart types: {supported_chat_type} -give {dialect} data analysis SQL, analysis title, display method and analytical thinking,respond in the following json format: + +Give {dialect} data analysis SQL, analysis title, display method and analytical thinking,respond in the following json format: {response} +Do not use unprovided fields and do not use unprovided data in the where condition of sql. Ensure the response is correct json and can be parsed by Python json.loads """ diff --git a/pilot/server/llmserver.py b/pilot/server/llmserver.py index d87540a8e..e7b5c877f 100644 --- a/pilot/server/llmserver.py +++ b/pilot/server/llmserver.py @@ -168,7 +168,7 @@ async def api_generate_stream(request: Request): @app.post("/generate") -def generate(prompt_request: PromptRequest): +def generate(prompt_request: PromptRequest)->str: params = { "prompt": prompt_request.prompt, "temperature": prompt_request.temperature, diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index b6c1e2cc3..bc8632579 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -690,9 +690,6 @@ if __name__ == "__main__": parser.add_argument( "--model_list_mode", type=str, default="once", choices=["once", "reload"] ) - parser.add_argument( - "-new", "--new", action="store_true", help="enable new http mode" - ) # old version server config parser.add_argument("--host", type=str, default="0.0.0.0") @@ -704,20 +701,14 @@ if __name__ == "__main__": args = parser.parse_args() server_init(args) - if args.new: - import uvicorn - - uvicorn.run(app, host="0.0.0.0", port=5000) - else: - ### Compatibility mode starts the old version server by default - demo = build_webdemo() - demo.queue( - concurrency_count=args.concurrency_count, - status_update_rate=10, - api_open=False, - ).launch( - server_name=args.host, - server_port=args.port, - share=args.share, - max_threads=200, - ) + demo = build_webdemo() + demo.queue( + concurrency_count=args.concurrency_count, + status_update_rate=10, + api_open=False, + ).launch( + server_name=args.host, + server_port=args.port, + share=args.share, + max_threads=200, + )