From a95d292bf96fce6756e1966e408b0d181de49e17 Mon Sep 17 00:00:00 2001 From: yhjun1026 <460342015@qq.com> Date: Thu, 10 Aug 2023 15:01:00 +0800 Subject: [PATCH] feat(editor): editor api devlop editor api devlop part 1 --- pilot/connections/rdbms/base.py | 32 +-- pilot/memory/chat_history/duckdb_history.py | 28 +++ pilot/openapi/api_v1/editor/api_editor_v1.py | 184 ++++++++++++++---- pilot/openapi/api_v1/editor/sql_editor.py | 26 +++ pilot/openapi/editor_view_model.py | 6 +- pilot/scene/base_chat.py | 89 +++++---- pilot/scene/base_message.py | 3 +- pilot/scene/chat_dashboard/chat.py | 51 +---- pilot/scene/chat_dashboard/data_loader.py | 57 ++++++ pilot/scene/chat_db/auto_execute/chat.py | 6 +- .../scene/chat_db/auto_execute/data_loader.py | 8 + pilot/scene/chat_db/professional_qa/chat.py | 1 + pilot/scene/chat_execution/chat.py | 1 + pilot/scene/chat_knowledge/custom/chat.py | 1 + .../chat_knowledge/inner_db_summary/chat.py | 1 + pilot/scene/message.py | 8 + pilot/utils.py | 28 +-- 17 files changed, 373 insertions(+), 157 deletions(-) create mode 100644 pilot/openapi/api_v1/editor/sql_editor.py create mode 100644 pilot/scene/chat_dashboard/data_loader.py create mode 100644 pilot/scene/chat_db/auto_execute/data_loader.py diff --git a/pilot/connections/rdbms/base.py b/pilot/connections/rdbms/base.py index d2a853589..c5b5d69cf 100644 --- a/pilot/connections/rdbms/base.py +++ b/pilot/connections/rdbms/base.py @@ -250,17 +250,17 @@ class RDBMSDatabase(BaseConnect): """Format the error message""" return f"Error: {e}" - def __write(self, session, write_sql): + def __write(self, write_sql): print(f"Write[{write_sql}]") db_cache = self._engine.url.database - result = session.execute(text(write_sql)) - session.commit() + result = self.session.execute(text(write_sql)) + self.session.commit() # TODO Subsequent optimization of dynamically specified database submission loss target problem - session.execute(text(f"use `{db_cache}`")) + self.session.execute(text(f"use `{db_cache}`")) print(f"SQL[{write_sql}], result:{result.rowcount}") return result.rowcount - def __query(self, session, query, fetch: str = "all"): + def __query(self,query, fetch: str = "all"): """ only for query Args: @@ -274,7 +274,7 @@ class RDBMSDatabase(BaseConnect): print(f"Query[{query}]") if not query: return [] - cursor = session.execute(text(query)) + cursor = self.session.execute(text(query)) if cursor.returns_rows: if fetch == "all": result = cursor.fetchall() @@ -288,7 +288,7 @@ class RDBMSDatabase(BaseConnect): result.insert(0, field_names) return result - def query_ex(self, session, query, fetch: str = "all"): + def query_ex(self, query, fetch: str = "all"): """ only for query Args: @@ -300,7 +300,7 @@ class RDBMSDatabase(BaseConnect): print(f"Query[{query}]") if not query: return [] - cursor = session.execute(text(query)) + cursor = self.session.execute(text(query)) if cursor.returns_rows: if fetch == "all": result = cursor.fetchall() @@ -313,7 +313,7 @@ class RDBMSDatabase(BaseConnect): result = list(result) return field_names, result - def run(self, session, command: str, fetch: str = "all") -> List: + def run(self, command: str, fetch: str = "all") -> List: """Execute a SQL command and return a string representing the results.""" print("SQL:" + command) if not command: @@ -321,17 +321,17 @@ class RDBMSDatabase(BaseConnect): parsed, ttype, sql_type, table_name = self.__sql_parse(command) if ttype == sqlparse.tokens.DML: if sql_type == "SELECT": - return self.__query(session, command, fetch) + return self.__query( command, fetch) else: - self.__write(session, command) + self.__write( command) select_sql = self.convert_sql_write_to_select(command) print(f"write result query:{select_sql}") - return self.__query(session, select_sql) + return self.__query( select_sql) else: print(f"DDL execution determines whether to enable through configuration ") - cursor = session.execute(text(command)) - session.commit() + cursor = self.session.execute(text(command)) + self.session.commit() if cursor.returns_rows: result = cursor.fetchall() field_names = tuple(i[0:] for i in cursor.keys()) @@ -339,10 +339,10 @@ class RDBMSDatabase(BaseConnect): result.insert(0, field_names) print("DDL Result:" + str(result)) if not result: - return self.__query(session, f"SHOW COLUMNS FROM {table_name}") + return self.__query( f"SHOW COLUMNS FROM {table_name}") return result else: - return self.__query(session, f"SHOW COLUMNS FROM {table_name}") + return self.__query( f"SHOW COLUMNS FROM {table_name}") def run_no_throw(self, session, command: str, fetch: str = "all") -> List: """Execute a SQL command and return a string representing the results. diff --git a/pilot/memory/chat_history/duckdb_history.py b/pilot/memory/chat_history/duckdb_history.py index 2de34976f..827107515 100644 --- a/pilot/memory/chat_history/duckdb_history.py +++ b/pilot/memory/chat_history/duckdb_history.py @@ -94,6 +94,16 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory): cursor.commit() self.connect.commit() + + def update(self, messages:List[OnceConversation]) -> None: + cursor = self.connect.cursor() + cursor.execute( + "UPDATE chat_history set messages=? where conv_uid=?", + [json.dumps(messages, ensure_ascii=False), self.chat_seesion_id], + ) + cursor.commit() + self.connect.commit() + def clear(self) -> None: cursor = self.connect.cursor() cursor.execute( @@ -134,6 +144,24 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory): return [] + def conv_info(self, conv_uid: str = None) -> None: + cursor = self.connect.cursor() + cursor.execute( + "SELECT * FROM chat_history where conv_uid=? ", + [conv_uid], + ) + # 获取查询结果字段名 + fields = [field[0] for field in cursor.description] + + for row in cursor.fetchone(): + row_dict = {} + for i, field in enumerate(fields): + row_dict[field] = row[i] + return row_dict + + return {} + + def get_messages(self) -> List[OnceConversation]: cursor = self.connect.cursor() cursor.execute( diff --git a/pilot/openapi/api_v1/editor/api_editor_v1.py b/pilot/openapi/api_v1/editor/api_editor_v1.py index fbb92353d..cbcd29abc 100644 --- a/pilot/openapi/api_v1/editor/api_editor_v1.py +++ b/pilot/openapi/api_v1/editor/api_editor_v1.py @@ -1,17 +1,12 @@ -import os - +import json from fastapi import ( APIRouter, - Request, Body, - BackgroundTasks, ) - from typing import List from pilot.configs.config import Config -from pilot.server.knowledge.service import KnowledgeService from pilot.scene.chat_factory import ChatFactory from pilot.configs.model_config import LOGDIR @@ -19,9 +14,6 @@ from pilot.utils import build_logger from pilot.openapi.api_view_model import ( Result, - ConversationVo, - MessageVo, - ChatSceneVo, ) from pilot.openapi.editor_view_model import ( ChatDbRounds, @@ -31,9 +23,11 @@ from pilot.openapi.editor_view_model import ( DbTable ) -from pilot.scene.chat_dashboard.data_preparation.report_schma import ChartData +from pilot.openapi.api_v1.editor.sql_editor import DataNode,ChartRunData,SqlRunData +from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory +from pilot.scene.message import OnceConversation +from pilot.scene.chat_dashboard.data_loader import DashboardDataLoader -from pilot.scene.chat_db.auto_execute.out_parser import SqlAction router = APIRouter() CFG = Config() @@ -42,47 +36,165 @@ logger = build_logger("api_editor_v1", LOGDIR + "api_editor_v1.log") @router.get("/v1/editor/db/tables", response_model=Result[DbTable]) -async def get_editor_tables(db_name: str, page_index: int, page_size: int, search_str: str= ""): - return Result.succ(None) - +async def get_editor_tables(db_name: str, page_index: int, page_size: int, search_str: str = ""): + logger.info("get_editor_tables:{},{},{},{}", db_name, page_index, page_size, search_str) + db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name) + tables = db_conn.get_table_names() + db_node: DataNode = DataNode(title=db_name, key=db_name, type="db") + for table in tables: + table_node: DataNode = DataNode(title=table, key=table, type="table") + db_node.children.append(table_node) + fields = db_conn.get_fields("transaction_order") + for field in fields: + table_node.children.append( + DataNode(title=field[0], key=field[0], type=field[1], default_value=field[2], can_null=field[3], + comment=field[-1])) + return Result.succ(db_node) @router.get("/v1/editor/sql/rounds", response_model=Result[ChatDbRounds]) async def get_editor_sql_rounds(con_uid: str): - return Result.succ(None) + logger.info("get_editor_sql_rounds:{}", con_uid) + history_mem = DuckdbHistoryMemory(con_uid) + history_messages: List[OnceConversation] = history_mem.get_messages() + if history_messages: + result: List = [] + for once in history_messages: + round_name: str = "" + for element in once["messages"]: + if element["type"] == "human": + round_name = element["data"]["content"] + if once.get("param_value"): + round: ChatDbRounds = ChatDbRounds(round=once["chat_order"], db_name=once["param_value"], + round_name=round_name) + result.append(round) + return Result.succ(result) -@router.get("/v1/editor/sql", response_model=Result[SqlAction]) +@router.get("/v1/editor/sql", response_model=Result[dict]) async def get_editor_sql(con_uid: str, round: int): - return Result.succ(None) - - -@router.get("/v1/editor/chart/details", response_model=Result[ChartDetail]) -async def get_editor_sql_rounds(con_uid: str): - return Result.succ(None) - - -@router.get("/v1/editor/chart", response_model=Result[ChartDetail]) -async def get_editor_chart(con_uid: str, chart_uid: str): - return Result.succ(None) + logger.info("get_editor_sql:{},{}", con_uid, round) + history_mem = DuckdbHistoryMemory(con_uid) + history_messages: List[OnceConversation] = history_mem.get_messages() + if history_messages: + for once in history_messages: + if int(once["chat_order"]) == round: + for element in once["messages"]: + if element["type"] == "ai": + return Result.succ(json.loads(element["data"]["content"])) + return Result.faild("没有获取到可用的SQL返回结构") @router.post("/v1/editor/sql/run", response_model=Result[List[dict]]) -async def get_editor_chart(db_name: str, sql: str): - return Result.succ(None) +async def editor_sql_run(db_name: str, sql: str): + logger.info("get_editor_sql_run:{},{}", db_name, sql) + conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name) + return Result.succ(conn.run(sql)) -@router.post("/v1/editor/chart/run", response_model=Result[ChartData]) -async def get_editor_chart(db_name: str, sql: str): - return Result.succ(None) +@router.post("/v1/sql/editor/submit", response_model=Result) +async def sql_editor_submit(sql_edit_context: ChatSqlEditContext = Body()): + logger.info(f"sql_editor_submit:{sql_edit_context.__dict__}") + history_mem = DuckdbHistoryMemory(sql_edit_context.conv_uid) + history_messages: List[OnceConversation] = history_mem.get_messages() + if history_messages: + edit_round = list(filter(lambda x: x['chat_order'] == sql_edit_context.conv_round, history_messages))[0] + if edit_round: + for element in edit_round["messages"]: + if element["type"] == "ai": + element["data"]["content"]="" + if element["type"] == "view": + element["data"]["content"]="" + history_mem.update(history_messages) + return Result.succ(None) + return Result.faild("Edit Faild!") + + +@router.get("/v1/editor/chart/list", response_model=Result[ChartDetail]) +async def get_editor_chart_list(con_uid: str): + logger.info("get_editor_sql_rounds:{}", con_uid) + history_mem = DuckdbHistoryMemory(con_uid) + history_messages: List[OnceConversation] = history_mem.get_messages() + if history_messages: + last_round = max(history_messages, key=lambda x: x['chat_order']) + for element in last_round["messages"]: + if element["type"] == "ai": + return Result.succ(json.loads(element["data"]["content"])) + + return Result.faild("没有获取到可用的SQL返回结构") + + +@router.get("/v1/editor/chart/info", response_model=Result[ChartDetail]) +async def get_editor_chart_info(con_uid: str, chart_uid: str): + logger.info(f"get_editor_sql_rounds:{con_uid}") + logger.info("get_editor_sql_rounds:{}", con_uid) + history_mem = DuckdbHistoryMemory(con_uid) + history_messages: List[OnceConversation] = history_mem.get_messages() + if history_messages: + last_round = max(history_messages, key=lambda x: x['chat_order']) + db_name = last_round["param_value"] + if not db_name: + logger.error("this dashboard dialogue version too old, can't support editor!") + return Result.faild("this dashboard dialogue version too old, can't support editor!") + for element in last_round["messages"]: + if element["type"] == "view": + view_data: dict = json.loads(element["data"]["content"]); + charts: List = view_data.get("charts") + find_chart = list(filter(lambda x: x['chart_name'] == chart_uid, charts))[0] + + conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name) + detail: ChartDetail = ChartDetail(chart_uid=find_chart['chart_uid'], + chart_type=find_chart['chart_type'], + chart_desc=find_chart['chart_desc'], + chart_sql=find_chart['chart_sql'], + db_name=db_name, + chart_name=find_chart['chart_name'], + chart_value=find_chart['values'], + table_value=conn.run(find_chart['chart_sql']) + ) + + return Result.succ(detail) + return Result.faild("Can't Find Chart Detail Info!") + + +@router.post("/v1/editor/chart/run", response_model=Result[ChartRunData]) +async def editor_chart_run(db_name: str, sql: str): + logger.info(f"editor_chart_run:{db_name},{sql}") + dashboard_data_loader:DashboardDataLoader = DashboardDataLoader() + db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name) + + field_names,chart_values = dashboard_data_loader.get_chart_values_by_db(db_conn, sql) + + sql_run_data:SqlRunData = SqlRunData(result_info="", + run_cost="", + colunms= field_names, + values= db_conn.query_ex(sql) + ) + return Result.succ(ChartRunData(sql_data=sql_run_data,chart_values=chart_values)) @router.post("/v1/chart/editor/submit", response_model=Result[bool]) async def chart_editor_submit(chart_edit_context: ChatChartEditContext = Body()): - return Result.succ(None) + logger.info(f"sql_editor_submit:{chart_edit_context.__dict__}") + history_mem = DuckdbHistoryMemory(chart_edit_context.conv_uid) + history_messages: List[OnceConversation] = history_mem.get_messages() + if history_messages: + edit_round = list(filter(lambda x: x['chat_order'] == chart_edit_context.conv_round, history_messages))[0] + if edit_round: + for element in edit_round["messages"]: + if element["type"] == "ai": + view_data: dict = json.loads(element["data"]["content"]); + charts: List = view_data.get("charts") + find_chart = list(filter(lambda x: x['chart_name'] == chart_edit_context.chart_uid, charts))[0] -@router.post("/v1/sql/editor/submit", response_model=Result[bool]) -async def chart_editor_submit(sql_edit_context: ChatSqlEditContext = Body()): - return Result.succ(None) + if element["type"] == "view": + view_data: dict = json.loads(element["data"]["content"]); + charts: List = view_data.get("charts") + find_chart = list(filter(lambda x: x['chart_name'] == chart_edit_context.chart_uid, charts))[0] + + + history_mem.update(history_messages) + return Result.succ(None) + return Result.faild("Edit Faild!") diff --git a/pilot/openapi/api_v1/editor/sql_editor.py b/pilot/openapi/api_v1/editor/sql_editor.py new file mode 100644 index 000000000..61262d60c --- /dev/null +++ b/pilot/openapi/api_v1/editor/sql_editor.py @@ -0,0 +1,26 @@ +from typing import List +from pydantic import BaseModel, Field, root_validator, validator, Extra +from pilot.scene.chat_dashboard.data_preparation.report_schma import ValueItem + + +class DataNode(BaseModel): + title: str + key: str + + type: str = "" + default_value: str = None + can_null: str = 'YES' + comment: str = None + children: List = [] + + +class SqlRunData(BaseModel): + result_info: str + run_cost: str + colunms: List[str] + values: List + + +class ChartRunData(BaseModel): + sql_data: SqlRunData + chart_values: List[ValueItem] diff --git a/pilot/openapi/editor_view_model.py b/pilot/openapi/editor_view_model.py index 242affede..c44b49041 100644 --- a/pilot/openapi/editor_view_model.py +++ b/pilot/openapi/editor_view_model.py @@ -24,10 +24,12 @@ class ChatDbRounds(BaseModel): class ChartDetail(BaseModel): chart_uid: str chart_type: str + chart_desc: str + chart_sql: str db_name: str chart_name: str - chart_value: str - chat_round: int # defualt last round + chart_value: Any + table_value: Any class ChatChartEditContext(BaseModel): diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index 13fbf2708..d342e76db 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -60,10 +60,11 @@ class BaseChat(ABC): arbitrary_types_allowed = True def __init__( - self, - chat_mode, - chat_session_id, - current_user_input, + self, + chat_mode, + chat_session_id, + current_user_input, + select_param: Any = None ): self.chat_session_id = chat_session_id self.chat_mode = chat_mode @@ -87,6 +88,9 @@ class BaseChat(ABC): ) self.history_message: List[OnceConversation] = self.memory.messages() self.current_message: OnceConversation = OnceConversation(chat_mode.value()) + if select_param: + self.current_message.param_type = chat_mode.param_types()[0] + self.current_message.param_value = select_param self.current_tokens_used: int = 0 class Config: @@ -111,6 +115,24 @@ class BaseChat(ABC): def do_action(self, prompt_response): return prompt_response + def get_llm_speak(self, prompt_define_response): + if hasattr(prompt_define_response, "thoughts"): + if isinstance(prompt_define_response.thoughts, dict): + if "speak" in prompt_define_response.thoughts: + speak_to_user = prompt_define_response.thoughts.get("speak") + else: + speak_to_user = str(prompt_define_response.thoughts) + else: + if hasattr(prompt_define_response.thoughts, "speak"): + speak_to_user = prompt_define_response.thoughts.get("speak") + elif hasattr(prompt_define_response.thoughts, "reasoning"): + speak_to_user = prompt_define_response.thoughts.get("reasoning") + else: + speak_to_user = prompt_define_response.thoughts + else: + speak_to_user = prompt_define_response + return speak_to_user + def __call_base(self): input_values = self.generate_input_values() ### Chat sequence advance @@ -209,26 +231,13 @@ class BaseChat(ABC): ai_response_text ) ) + ### sql run result = self.do_action(prompt_define_response) - if hasattr(prompt_define_response, "thoughts"): - if isinstance(prompt_define_response.thoughts, dict): - if "speak" in prompt_define_response.thoughts: - speak_to_user = prompt_define_response.thoughts.get("speak") - else: - speak_to_user = str(prompt_define_response.thoughts) - else: - if hasattr(prompt_define_response.thoughts, "speak"): - speak_to_user = prompt_define_response.thoughts.get("speak") - elif hasattr(prompt_define_response.thoughts, "reasoning"): - speak_to_user = prompt_define_response.thoughts.get("reasoning") - else: - speak_to_user = prompt_define_response.thoughts - else: - speak_to_user = prompt_define_response - view_message = self.prompt_template.output_parser.parse_view_response( - speak_to_user, result - ) + ### llm speaker + speak_to_user = self.get_llm_speak(prompt_define_response) + + view_message = self.prompt_template.output_parser.parse_view_response(speak_to_user, result) self.current_message.add_view_message(view_message) except Exception as e: print(traceback.format_exc()) @@ -297,7 +306,7 @@ class BaseChat(ABC): system_messages = [] for system_conv in system_convs: system_text += ( - system_conv.type + ":" + system_conv.content + self.prompt_template.sep + system_conv.type + ":" + system_conv.content + self.prompt_template.sep ) system_messages.append( ModelMessage(role=system_conv.type, content=system_conv.content) @@ -309,7 +318,7 @@ class BaseChat(ABC): user_messages = [] if user_conv: user_text = ( - user_conv.type + ":" + user_conv.content + self.prompt_template.sep + user_conv.type + ":" + user_conv.content + self.prompt_template.sep ) user_messages.append( ModelMessage(role=user_conv.type, content=user_conv.content) @@ -331,10 +340,10 @@ class BaseChat(ABC): message_type = round_message["type"] message_content = round_message["data"]["content"] example_text += ( - message_type - + ":" - + message_content - + self.prompt_template.sep + message_type + + ":" + + message_content + + self.prompt_template.sep ) example_messages.append( ModelMessage(role=message_type, content=message_content) @@ -358,10 +367,10 @@ class BaseChat(ABC): message_type = first_message["type"] message_content = first_message["data"]["content"] history_text += ( - message_type - + ":" - + message_content - + self.prompt_template.sep + message_type + + ":" + + message_content + + self.prompt_template.sep ) history_messages.append( ModelMessage(role=message_type, content=message_content) @@ -377,10 +386,10 @@ class BaseChat(ABC): message_type = round_message["type"] message_content = round_message["data"]["content"] history_text += ( - message_type - + ":" - + message_content - + self.prompt_template.sep + message_type + + ":" + + message_content + + self.prompt_template.sep ) history_messages.append( ModelMessage(role=message_type, content=message_content) @@ -398,10 +407,10 @@ class BaseChat(ABC): message_type = message["type"] message_content = message["data"]["content"] history_text += ( - message_type - + ":" - + message_content - + self.prompt_template.sep + message_type + + ":" + + message_content + + self.prompt_template.sep ) history_messages.append( ModelMessage(role=message_type, content=message_content) diff --git a/pilot/scene/base_message.py b/pilot/scene/base_message.py index 09ea9695d..75c0aad97 100644 --- a/pilot/scene/base_message.py +++ b/pilot/scene/base_message.py @@ -95,6 +95,7 @@ class ModelMessageRoleType: SYSTEM = "system" HUMAN = "human" AI = "ai" + VIEW = "view" class Generation(BaseModel): @@ -166,7 +167,7 @@ def messages_from_dict(messages: List[dict]) -> List[BaseMessage]: def _parse_model_messages( - messages: List[ModelMessage], + messages: List[ModelMessage], ) -> Tuple[str, List[str], List[List[str, str]]]: """ " Parameters: diff --git a/pilot/scene/chat_dashboard/chat.py b/pilot/scene/chat_dashboard/chat.py index 79728493c..ed1c22754 100644 --- a/pilot/scene/chat_dashboard/chat.py +++ b/pilot/scene/chat_dashboard/chat.py @@ -1,26 +1,16 @@ import json import os import uuid -from typing import Dict, NamedTuple, List -from decimal import Decimal +from typing import List -from pilot.scene.base_message import ( - HumanMessage, - ViewMessage, -) from pilot.scene.base_chat import BaseChat from pilot.scene.base import ChatScene -from pilot.common.sql_database import Database from pilot.configs.config import Config -from pilot.common.markdown_text import ( - generate_htm_table, -) -from pilot.scene.chat_dashboard.prompt import prompt from pilot.scene.chat_dashboard.data_preparation.report_schma import ( ChartData, ReportData, - ValueItem, ) +from pilot.scene.chat_dashboard.data_loader import DashboardDataLoader CFG = Config() @@ -85,42 +75,10 @@ class ChatDashboard(BaseChat): def do_action(self, prompt_response): ### TODO 记录整体信息,处理成功的,和未成功的分开记录处理 chart_datas: List[ChartData] = [] + dashboard_data_loader = DashboardDataLoader() for chart_item in prompt_response: try: - field_names, datas = self.database.query_ex( - self.db_connect, chart_item.sql - ) - values: List[ValueItem] = [] - data_map = {} - field_map = {} - index = 0 - for field_name in field_names: - data_map.update({f"{field_name}": [row[index] for row in datas]}) - index += 1 - if not data_map[field_name]: - field_map.update({f"{field_name}": False}) - else: - field_map.update( - { - f"{field_name}": all( - isinstance(item, (int, float, Decimal)) - for item in data_map[field_name] - ) - } - ) - - for field_name in field_names[1:]: - if not field_map[field_name]: - print("more than 2 non-numeric column") - else: - for data in datas: - value_item = ValueItem( - name=data[0], - type=field_name, - value=data[field_names.index(field_name)], - ) - values.append(value_item) - + field_names, values = dashboard_data_loader.get_chart_values_by_conn(self.db_connect, chart_item.sql) chart_datas.append( ChartData( chart_uid=str(uuid.uuid1()), @@ -135,7 +93,6 @@ class ChatDashboard(BaseChat): except Exception as e: # TODO 修复流程 print(str(e)) - return ReportData( conv_uid=self.chat_session_id, template_name=self.report_name, diff --git a/pilot/scene/chat_dashboard/data_loader.py b/pilot/scene/chat_dashboard/data_loader.py new file mode 100644 index 000000000..efb09a6ac --- /dev/null +++ b/pilot/scene/chat_dashboard/data_loader.py @@ -0,0 +1,57 @@ +from typing import List +from decimal import Decimal + +from pilot.configs.config import Config +from pilot.configs.model_config import LOGDIR +from pilot.utils import build_logger +from pilot.scene.chat_dashboard.data_preparation.report_schma import ValueItem + +CFG = Config() +logger = build_logger("dashboard_data", LOGDIR + "dashboard_data.log") + + +class DashboardDataLoader: + + def get_chart_values_by_conn(self, db_conn, chart_sql: str) : + logger.info(f"get_chart_values_by_conn:{chart_sql}") + try: + field_names, datas = db_conn.query_ex(chart_sql) + values: List[ValueItem] = [] + data_map = {} + field_map = {} + index = 0 + for field_name in field_names: + data_map.update({f"{field_name}": [row[index] for row in datas]}) + index += 1 + if not data_map[field_name]: + field_map.update({f"{field_name}": False}) + else: + field_map.update( + { + f"{field_name}": all( + isinstance(item, (int, float, Decimal)) + for item in data_map[field_name] + ) + } + ) + + for field_name in field_names[1:]: + if not field_map[field_name]: + logger.info("More than 2 non-numeric column:" + field_name) + else: + for data in datas: + value_item = ValueItem( + name=data[0], + type=field_name, + value=data[field_names.index(field_name)], + ) + values.append(value_item) + return field_names, values + except Exception as e: + logger.debug("Prepare Chart Data Faild!" + str(e)) + raise ValueError("Prepare Chart Data Faild!") + + def get_chart_values_by_db(self, db_name: str, chart_sql: str) : + logger.info(f"get_chart_values_by_db:{db_name},{chart_sql}") + db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name) + return self.get_chart_values_by_conn(db_conn, chart_sql) diff --git a/pilot/scene/chat_db/auto_execute/chat.py b/pilot/scene/chat_db/auto_execute/chat.py index 57867532b..ca0604f66 100644 --- a/pilot/scene/chat_db/auto_execute/chat.py +++ b/pilot/scene/chat_db/auto_execute/chat.py @@ -22,11 +22,13 @@ class ChatWithDbAutoExecute(BaseChat): """Number of results to return from the query""" def __init__(self, chat_session_id, db_name, user_input): + chat_mode = ChatScene.ChatWithDbExecute """ """ super().__init__( - chat_mode=ChatScene.ChatWithDbExecute, + chat_mode=chat_mode, chat_session_id=chat_session_id, current_user_input=user_input, + select_param=db_name, ) if not db_name: raise ValueError( @@ -35,7 +37,7 @@ class ChatWithDbAutoExecute(BaseChat): self.db_name = db_name self.database = CFG.LOCAL_DB_MANAGE.get_connect(db_name) self.db_connect = self.database.session - self.top_k: int = 5 + self.top_k: int = 200 def generate_input_values(self): try: diff --git a/pilot/scene/chat_db/auto_execute/data_loader.py b/pilot/scene/chat_db/auto_execute/data_loader.py new file mode 100644 index 000000000..e36289d9f --- /dev/null +++ b/pilot/scene/chat_db/auto_execute/data_loader.py @@ -0,0 +1,8 @@ +import json + + +class DbDataLoader: + + + def get_table_view_by_conn(self, db_conn, chart_sql: str): + pass \ No newline at end of file diff --git a/pilot/scene/chat_db/professional_qa/chat.py b/pilot/scene/chat_db/professional_qa/chat.py index 706cbee40..285629913 100644 --- a/pilot/scene/chat_db/professional_qa/chat.py +++ b/pilot/scene/chat_db/professional_qa/chat.py @@ -25,6 +25,7 @@ class ChatWithDbQA(BaseChat): chat_mode=ChatScene.ChatWithDbQA, chat_session_id=chat_session_id, current_user_input=user_input, + select_param=db_name, ) self.db_name = db_name if db_name: diff --git a/pilot/scene/chat_execution/chat.py b/pilot/scene/chat_execution/chat.py index da370bf18..78774bd48 100644 --- a/pilot/scene/chat_execution/chat.py +++ b/pilot/scene/chat_execution/chat.py @@ -30,6 +30,7 @@ class ChatWithPlugin(BaseChat): chat_mode=ChatScene.ChatExecution, chat_session_id=chat_session_id, current_user_input=user_input, + select_param=plugin_selector, ) self.plugins_prompt_generator = PluginPromptGenerator() self.plugins_prompt_generator.command_registry = CFG.command_registry diff --git a/pilot/scene/chat_knowledge/custom/chat.py b/pilot/scene/chat_knowledge/custom/chat.py index bc121a7c1..c1ea61f92 100644 --- a/pilot/scene/chat_knowledge/custom/chat.py +++ b/pilot/scene/chat_knowledge/custom/chat.py @@ -33,6 +33,7 @@ class ChatNewKnowledge(BaseChat): chat_mode=ChatScene.ChatNewKnowledge, chat_session_id=chat_session_id, current_user_input=user_input, + select_param=knowledge_name, ) self.knowledge_name = knowledge_name vector_store_config = { diff --git a/pilot/scene/chat_knowledge/inner_db_summary/chat.py b/pilot/scene/chat_knowledge/inner_db_summary/chat.py index 4a952e6cc..34c8260e3 100644 --- a/pilot/scene/chat_knowledge/inner_db_summary/chat.py +++ b/pilot/scene/chat_knowledge/inner_db_summary/chat.py @@ -24,6 +24,7 @@ class InnerChatDBSummary(BaseChat): chat_mode=ChatScene.InnerChatDBSummary, chat_session_id=chat_session_id, current_user_input=user_input, + select_param=db_select, ) self.db_input = db_select diff --git a/pilot/scene/message.py b/pilot/scene/message.py index 51ec2643e..ea74bae2f 100644 --- a/pilot/scene/message.py +++ b/pilot/scene/message.py @@ -30,6 +30,8 @@ class OnceConversation: self.messages: List[BaseMessage] = [] self.start_date: str = "" self.chat_order: int = 0 + self.param_type: str = "" + self.param_value: str = "" self.cost: int = 0 self.tokens: int = 0 @@ -114,9 +116,13 @@ def _conversation_to_dic(once: OnceConversation) -> dict: "cost": once.cost if once.cost else 0, "tokens": once.tokens if once.tokens else 0, "messages": messages_to_dict(once.messages), + "param_type": once.param_type, + "param_value": once.param_value } + + def conversations_to_dict(conversations: List[OnceConversation]) -> List[dict]: return [_conversation_to_dic(m) for m in conversations] @@ -128,6 +134,8 @@ def conversation_from_dict(once: dict) -> OnceConversation: conversation.tokens = once.get("tokens", 0) conversation.start_date = once.get("start_date", "") conversation.chat_order = int(once.get("chat_order")) + conversation.param_type = once.get("param_type", "") + conversation.param_value = once.get("param_value", "") print(once.get("messages")) conversation.messages = messages_from_dict(once.get("messages", [])) return conversation diff --git a/pilot/utils.py b/pilot/utils.py index b15d3af21..c44a4ea2d 100644 --- a/pilot/utils.py +++ b/pilot/utils.py @@ -51,19 +51,15 @@ def build_logger(logger_name, logger_filename): logging.getLogger().handlers[0].setFormatter(formatter) # Redirect stdout and stderr to loggers - stdout_logger = logging.getLogger("stdout") - stdout_logger.setLevel(logging.INFO) - sl = StreamToLogger(stdout_logger, logging.INFO) - sys.stdout = sl - - stderr_logger = logging.getLogger("stderr") - stderr_logger.setLevel(logging.ERROR) - sl = StreamToLogger(stderr_logger, logging.ERROR) - sys.stderr = sl - - # Get logger - logger = logging.getLogger(logger_name) - logger.setLevel(logging.INFO) + # stdout_logger = logging.getLogger("stdout") + # stdout_logger.setLevel(logging.INFO) + # sl_1 = StreamToLogger(stdout_logger, logging.INFO) + # sys.stdout = sl_1 + # + # stderr_logger = logging.getLogger("stderr") + # stderr_logger.setLevel(logging.ERROR) + # sl = StreamToLogger(stderr_logger, logging.ERROR) + # sys.stderr = sl # Add a file handler for all loggers if handler is None: @@ -78,6 +74,12 @@ def build_logger(logger_name, logger_filename): if isinstance(item, logging.Logger): item.addHandler(handler) logging.basicConfig(level=logging.INFO) + + + # Get logger + logger = logging.getLogger(logger_name) + logger.setLevel(logging.INFO) + return logger