From 3f7cc024268a64ef1306a9701e66e4db3f7fdc58 Mon Sep 17 00:00:00 2001 From: "tuyang.yhj" Date: Tue, 27 Jun 2023 15:35:18 +0800 Subject: [PATCH] WEB API independent --- pilot/common/formatting.py | 11 ++ pilot/connections/rdbms/py_study/pd_study.py | 43 +++-- .../connections/rdbms/py_study/test_cls_1.py | 2 +- .../connections/rdbms/py_study/test_cls_2.py | 2 +- .../rdbms/py_study/test_cls_base.py | 5 +- .../connections/rdbms/py_study/test_duckdb.py | 13 ++ pilot/memory/chat_history/base.py | 6 + pilot/memory/chat_history/duckdb_history.py | 109 +++++++++++ pilot/memory/chat_history/mem_history.py | 5 +- pilot/mock_datas/chat_history.db | Bin 0 -> 12288 bytes pilot/mock_datas/chat_history.db.wal | 0 pilot/scene/base.py | 17 +- pilot/scene/base_chat.py | 24 +-- pilot/scene/chat_dashboard/__init__.py | 0 pilot/scene/chat_dashboard/chat.py | 81 ++++++++ .../data_preparation/__init__.py | 0 .../data_preparation/report_schma.py | 22 +++ pilot/scene/chat_dashboard/out_parser.py | 41 ++++ pilot/scene/chat_dashboard/prompt.py | 49 +++++ .../template/sales_report/dashboard.json | 9 + pilot/scene/chat_db/auto_execute/chat.py | 6 +- pilot/scene/chat_db/professional_qa/chat.py | 7 +- pilot/scene/chat_execution/chat.py | 6 +- pilot/scene/chat_knowledge/custom/chat.py | 6 +- pilot/scene/chat_knowledge/default/chat.py | 13 +- pilot/scene/chat_knowledge/default/prompt.py | 2 +- .../chat_knowledge/inner_db_summary/chat.py | 6 - pilot/scene/chat_knowledge/url/chat.py | 6 +- pilot/scene/chat_normal/chat.py | 6 +- pilot/scene/message.py | 32 +++- pilot/server/api_v1/api_v1.py | 179 ++++++++++++++---- pilot/server/api_v1/api_view_model.py | 38 ++-- pilot/server/webserver.py | 119 +++++++----- 33 files changed, 683 insertions(+), 182 deletions(-) create mode 100644 pilot/connections/rdbms/py_study/test_duckdb.py create mode 100644 pilot/memory/chat_history/duckdb_history.py create mode 100644 pilot/mock_datas/chat_history.db create mode 100644 pilot/mock_datas/chat_history.db.wal create mode 100644 pilot/scene/chat_dashboard/__init__.py create mode 100644 pilot/scene/chat_dashboard/chat.py create mode 100644 pilot/scene/chat_dashboard/data_preparation/__init__.py create mode 100644 pilot/scene/chat_dashboard/data_preparation/report_schma.py create mode 100644 pilot/scene/chat_dashboard/out_parser.py create mode 100644 pilot/scene/chat_dashboard/prompt.py create mode 100644 pilot/scene/chat_dashboard/template/sales_report/dashboard.json diff --git a/pilot/common/formatting.py b/pilot/common/formatting.py index 3b3b597b0..c2db4126a 100644 --- a/pilot/common/formatting.py +++ b/pilot/common/formatting.py @@ -1,4 +1,5 @@ """Utilities for formatting strings.""" +import json from string import Formatter from typing import Any, List, Mapping, Sequence, Union @@ -36,3 +37,13 @@ class StrictFormatter(Formatter): formatter = StrictFormatter() + + +class MyEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, set): + return list(obj) + elif hasattr(obj, '__dict__'): + return obj.__dict__ + else: + return json.JSONEncoder.default(self, obj) \ No newline at end of file diff --git a/pilot/connections/rdbms/py_study/pd_study.py b/pilot/connections/rdbms/py_study/pd_study.py index 5a2b3edae..5ad5be08f 100644 --- a/pilot/connections/rdbms/py_study/pd_study.py +++ b/pilot/connections/rdbms/py_study/pd_study.py @@ -6,6 +6,8 @@ import numpy as np from matplotlib.font_manager import FontProperties from pyecharts.charts import Bar from pyecharts import options as opts +from test_cls_1 import TestBase,Test1 +from test_cls_2 import Test2 CFG = Config() @@ -56,20 +58,29 @@ CFG = Config() # +# if __name__ == "__main__": + + # def __extract_json(s): + # i = s.index("{") + # count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数 + # for j, c in enumerate(s[i + 1 :], start=i + 1): + # if c == "}": + # count -= 1 + # elif c == "{": + # count += 1 + # if count == 0: + # break + # assert count == 0 # 检查是否找到最后一个'}' + # return s[i : j + 1] + # + # ss = """here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:select u.city, count(*) as order_countfrom tran_order oleft join user u on o.user_id = u.idgroup by u.city;this will return the number of orders for each city that has at least one order. we can use this data to generate a histogram that shows the distribution of orders across different cities.here's the response in the required format:{ "thoughts": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:\n\nselect u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;", "speak": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities.", "command": { "name": "histogram-executor", "args": { "title": "distribution of user orders in different cities", "sql": "select u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;" } }}""" + # print(__extract_json(ss)) + if __name__ == "__main__": - - def __extract_json(s): - i = s.index("{") - count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数 - for j, c in enumerate(s[i + 1 :], start=i + 1): - if c == "}": - count -= 1 - elif c == "{": - count += 1 - if count == 0: - break - assert count == 0 # 检查是否找到最后一个'}' - return s[i : j + 1] - - ss = """here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:select u.city, count(*) as order_countfrom tran_order oleft join user u on o.user_id = u.idgroup by u.city;this will return the number of orders for each city that has at least one order. we can use this data to generate a histogram that shows the distribution of orders across different cities.here's the response in the required format:{ "thoughts": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:\n\nselect u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;", "speak": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities.", "command": { "name": "histogram-executor", "args": { "title": "distribution of user orders in different cities", "sql": "select u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;" } }}""" - print(__extract_json(ss)) + test1 = Test1() + test2 = Test2() + test1.write() + test1.test() + test2.write() + test1.test() + test2.test() \ No newline at end of file diff --git a/pilot/connections/rdbms/py_study/test_cls_1.py b/pilot/connections/rdbms/py_study/test_cls_1.py index 66c07de78..c7d26a674 100644 --- a/pilot/connections/rdbms/py_study/test_cls_1.py +++ b/pilot/connections/rdbms/py_study/test_cls_1.py @@ -4,7 +4,7 @@ from test_cls_base import TestBase class Test1(TestBase): - + mode:str = "456" def write(self): self.test_values.append("x") self.test_values.append("y") diff --git a/pilot/connections/rdbms/py_study/test_cls_2.py b/pilot/connections/rdbms/py_study/test_cls_2.py index c0fdbb305..e911f0542 100644 --- a/pilot/connections/rdbms/py_study/test_cls_2.py +++ b/pilot/connections/rdbms/py_study/test_cls_2.py @@ -5,7 +5,7 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union class Test2(TestBase): test_2_values:List = [] - + mode:str = "789" def write(self): self.test_values.append(1) self.test_values.append(2) diff --git a/pilot/connections/rdbms/py_study/test_cls_base.py b/pilot/connections/rdbms/py_study/test_cls_base.py index 9a04a48b3..b8377c73d 100644 --- a/pilot/connections/rdbms/py_study/test_cls_base.py +++ b/pilot/connections/rdbms/py_study/test_cls_base.py @@ -5,8 +5,9 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union class TestBase(BaseModel, ABC): test_values: List = [] - + mode:str = "123" def test(self): print(self.__class__.__name__ + ":" ) - print(self.test_values) \ No newline at end of file + print(self.test_values) + print(self.mode) \ No newline at end of file diff --git a/pilot/connections/rdbms/py_study/test_duckdb.py b/pilot/connections/rdbms/py_study/test_duckdb.py new file mode 100644 index 000000000..dbcf2ecb7 --- /dev/null +++ b/pilot/connections/rdbms/py_study/test_duckdb.py @@ -0,0 +1,13 @@ +import json +import os +import duckdb + +default_db_path = os.path.join(os.getcwd(), "message") +duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/chat_history.db") + +if __name__ == "__main__": + if os.path.isfile(duckdb_path): + cursor = duckdb.connect(duckdb_path).cursor() + cursor.execute("SELECT * FROM chat_history limit 20") + data = cursor.fetchall() + print(data) diff --git a/pilot/memory/chat_history/base.py b/pilot/memory/chat_history/base.py index 8d60eafe7..649afa4ab 100644 --- a/pilot/memory/chat_history/base.py +++ b/pilot/memory/chat_history/base.py @@ -32,3 +32,9 @@ class BaseChatHistoryMemory(ABC): @abstractmethod def clear(self) -> None: """Clear session memory from the local file""" + + + def conv_list(self, user_name:str=None) -> None: + """get user's conversation list""" + pass + diff --git a/pilot/memory/chat_history/duckdb_history.py b/pilot/memory/chat_history/duckdb_history.py new file mode 100644 index 000000000..b24546d19 --- /dev/null +++ b/pilot/memory/chat_history/duckdb_history.py @@ -0,0 +1,109 @@ +import json +import os +import duckdb +from typing import List + +from pilot.configs.config import Config +from pilot.memory.chat_history.base import BaseChatHistoryMemory +from pilot.scene.message import ( + OnceConversation, + conversation_from_dict, + conversations_to_dict, +) +from pilot.common.formatting import MyEncoder + +default_db_path = os.path.join(os.getcwd(), "message") +duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/chat_history.db") +table_name = 'chat_history' + +CFG = Config() + + +class DuckdbHistoryMemory(BaseChatHistoryMemory): + + def __init__(self, chat_session_id: str): + self.chat_seesion_id = chat_session_id + os.makedirs(default_db_path, exist_ok=True) + self.connect = duckdb.connect(duckdb_path) + self.__init_chat_history_tables() + + def __init_chat_history_tables(self): + + # 检查表是否存在 + result = self.connect.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?", + [table_name]).fetchall() + + if not result: + # 如果表不存在,则创建新表 + self.connect.execute( + "CREATE TABLE chat_history (conv_uid VARCHAR(100) PRIMARY KEY, user_name VARCHAR(100), messages TEXT)") + + def __get_messages_by_conv_uid(self, conv_uid: str): + cursor = self.connect.cursor() + cursor.execute("SELECT messages FROM chat_history where conv_uid=?", [conv_uid]) + return cursor.fetchone() + + def messages(self) -> List[OnceConversation]: + context = self.__get_messages_by_conv_uid(self.chat_seesion_id) + if context: + conversations: List[OnceConversation] = json.loads(context[0]) + return conversations + return [] + + def append(self, once_message: OnceConversation) -> None: + context = self.__get_messages_by_conv_uid(self.chat_seesion_id) + conversations: List[OnceConversation] = [] + if context: + conversations = json.load(context) + conversations.append(once_message) + cursor = self.connect.cursor() + if context: + cursor.execute("UPDATE chat_history set messages=? where conv_uid=?", + [json.dumps(conversations_to_dict(conversations), ensure_ascii=False, indent=4), self.chat_seesion_id]) + else: + cursor.execute("INSERT INTO chat_history(conv_uid, user_name, messages)VALUES(?,?,?)", + [self.chat_seesion_id, "", json.dumps(conversations_to_dict(conversations), ensure_ascii=False, indent=4)]) + cursor.commit() + self.connect.commit() + + def clear(self) -> None: + cursor = self.connect.cursor() + cursor.execute("DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id]) + cursor.commit() + self.connect.commit() + + def delete(self) -> bool: + cursor = self.connect.cursor() + cursor.execute("DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id]) + cursor.commit() + return True + + @staticmethod + def conv_list(cls, user_name: str = None) -> None: + if os.path.isfile(duckdb_path): + cursor = duckdb.connect(duckdb_path).cursor() + if user_name: + cursor.execute("SELECT * FROM chat_history where user_name=? limit 20", [user_name]) + else: + cursor.execute("SELECT * FROM chat_history limit 20") + # 获取查询结果字段名 + fields = [field[0] for field in cursor.description] + data = [] + for row in cursor.fetchall(): + row_dict = {} + for i, field in enumerate(fields): + row_dict[field] = row[i] + data.append(row_dict) + + return data + + return [] + + + def get_messages(self)-> List[OnceConversation]: + cursor = self.connect.cursor() + cursor.execute("SELECT messages FROM chat_history where conv_uid=?", [self.chat_seesion_id]) + context = cursor.fetchone() + if context: + return json.loads(context[0]) + return None diff --git a/pilot/memory/chat_history/mem_history.py b/pilot/memory/chat_history/mem_history.py index 8ff0d08eb..a46428e75 100644 --- a/pilot/memory/chat_history/mem_history.py +++ b/pilot/memory/chat_history/mem_history.py @@ -11,13 +11,14 @@ from pilot.scene.message import ( conversation_from_dict, conversations_to_dict, ) - +from pilot.common.custom_data_structure import FixedSizeDict, FixedSizeList CFG = Config() class MemHistoryMemory(BaseChatHistoryMemory): - histroies_map = {} + histroies_map = FixedSizeDict(100) + def __init__(self, chat_session_id: str): self.chat_seesion_id = chat_session_id diff --git a/pilot/mock_datas/chat_history.db b/pilot/mock_datas/chat_history.db new file mode 100644 index 0000000000000000000000000000000000000000..929805035e1c69ca980b5e0750d0d4b7cc3c7f3e GIT binary patch literal 12288 zcmeI#p$&jA5Cu?3V1#;hU<7bD5+-0_W+1i903=|FLtgUE{JYfjURCRDU-L1iaT%t* zdaAjjdwW5E009C72oNAZfB*pk1PH_zXj8ev`Kj{MM1TMR0t5&UAV7cs0RjXFL=^D< ZkN9ftOn?9Z0t5&UAV7cs0Rja630#2eCE)-7 literal 0 HcmV?d00001 diff --git a/pilot/mock_datas/chat_history.db.wal b/pilot/mock_datas/chat_history.db.wal new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/scene/base.py b/pilot/scene/base.py index e301a14de..cec443beb 100644 --- a/pilot/scene/base.py +++ b/pilot/scene/base.py @@ -1,12 +1,27 @@ from enum import Enum +class Scene: + def __init__(self, code, describe, is_inner): + self.code = code + self.describe = describe + self.is_inner = is_inner class ChatScene(Enum): ChatWithDbExecute = "chat_with_db_execute" ChatWithDbQA = "chat_with_db_qa" ChatExecution = "chat_execution" - ChatKnowledge = "chat_default_knowledge" + ChatDefaultKnowledge = "chat_default_knowledge" ChatNewKnowledge = "chat_new_knowledge" ChatUrlKnowledge = "chat_url_knowledge" InnerChatDBSummary = "inner_chat_db_summary" + ChatNormal = "chat_normal" + ChatDashboard = "chat_dashboard" + ChatKnowledge = "chat_knowledge" + ChatDb = "chat_db" + ChatData= "chat_data" + + @staticmethod + def is_valid_mode(mode): + return any(mode == item.value for item in ChatScene) + diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index 0120b9e86..d6628a19a 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -24,6 +24,7 @@ from pilot.prompts.prompt_new import PromptTemplate from pilot.memory.chat_history.base import BaseChatHistoryMemory from pilot.memory.chat_history.file_history import FileHistoryMemory from pilot.memory.chat_history.mem_history import MemHistoryMemory +from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory from pilot.configs.model_config import LOGDIR, DATASETS_DIR from pilot.utils import ( @@ -59,8 +60,6 @@ class BaseChat(ABC): def __init__( self, - temperature, - max_new_tokens, chat_mode, chat_session_id, current_user_input, @@ -70,17 +69,15 @@ class BaseChat(ABC): self.current_user_input: str = current_user_input self.llm_model = CFG.LLM_MODEL ### can configurable storage methods - self.memory = MemHistoryMemory(chat_session_id) + self.memory = DuckdbHistoryMemory(chat_session_id) ### load prompt template self.prompt_template: PromptTemplate = CFG.prompt_templates[ self.chat_mode.value ] self.history_message: List[OnceConversation] = [] - self.current_message: OnceConversation = OnceConversation() + self.current_message: OnceConversation = OnceConversation(chat_mode.value) self.current_tokens_used: int = 0 - self.temperature = temperature - self.max_new_tokens = max_new_tokens ### load chat_session_id's chat historys self._load_history(self.chat_session_id) @@ -99,15 +96,15 @@ class BaseChat(ABC): pass @abstractmethod - def do_with_prompt_response(self, prompt_response): - pass + def do_action(self, prompt_response): + return prompt_response def __call_base(self): input_values = self.generate_input_values() ### Chat sequence advance self.current_message.chat_order = len(self.history_message) + 1 self.current_message.add_user_message(self.current_user_input) - self.current_message.start_date = datetime.datetime.now() + self.current_message.start_date = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") # TODO self.current_message.tokens = 0 current_prompt = None @@ -203,13 +200,10 @@ class BaseChat(ABC): # }""" self.current_message.add_ai_message(ai_response_text) - prompt_define_response = ( - self.prompt_template.output_parser.parse_prompt_response( - ai_response_text - ) - ) + prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text) - result = self.do_with_prompt_response(prompt_define_response) + + result = self.do_action(prompt_define_response) if hasattr(prompt_define_response, "thoughts"): if isinstance(prompt_define_response.thoughts, dict): diff --git a/pilot/scene/chat_dashboard/__init__.py b/pilot/scene/chat_dashboard/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/scene/chat_dashboard/chat.py b/pilot/scene/chat_dashboard/chat.py new file mode 100644 index 000000000..19a59c39a --- /dev/null +++ b/pilot/scene/chat_dashboard/chat.py @@ -0,0 +1,81 @@ +import json +from typing import Dict, NamedTuple, List +from pilot.scene.base_message import ( + HumanMessage, + ViewMessage, +) +from pilot.scene.base_chat import BaseChat +from pilot.scene.base import ChatScene +from pilot.common.sql_database import Database +from pilot.configs.config import Config +from pilot.common.markdown_text import ( + generate_htm_table, +) +from pilot.scene.chat_db.auto_execute.prompt import prompt +from pilot.scene.chat_dashboard.data_preparation.report_schma import ChartData, ReportData + +CFG = Config() + + +class ChatDashboard(BaseChat): + chat_scene: str = ChatScene.ChatDashboard.value + report_name: str + """Number of results to return from the query""" + + def __init__( + self, chat_session_id, db_name, user_input, report_name + ): + """ """ + super().__init__( + chat_mode=ChatScene.ChatWithDbExecute, + chat_session_id=chat_session_id, + current_user_input=user_input, + ) + if not db_name: + raise ValueError( + f"{ChatScene.ChatWithDbExecute.value} mode should chose db!" + ) + self.report_name = report_name + self.database = CFG.local_db + # 准备DB信息(拿到指定库的链接) + self.db_connect = self.database.get_session(self.db_name) + self.top_k: int = 5 + + def generate_input_values(self): + try: + from pilot.summary.db_summary_client import DBSummaryClient + except ImportError: + raise ValueError("Could not import DBSummaryClient. ") + client = DBSummaryClient() + input_values = { + "input": self.current_user_input, + "dialect": self.database.dialect, + "table_info": self.database.table_simple_info(self.db_connect), + "supported_chat_type": "" #TODO + # "table_info": client.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k) + } + return input_values + + def do_action(self, prompt_response): + ### TODO 记录整体信息,处理成功的,和未成功的分开记录处理 + report_data: ReportData = ReportData() + chart_datas: List[ChartData] = [] + for chart_item in prompt_response: + try: + datas = self.database.run(self.db_connect, chart_item.sql) + chart_data: ChartData = ChartData() + except Exception as e: + # TODO 修复流程 + print(str(e)) + + + chart_datas.append(chart_data) + + report_data.conv_uid = self.chat_session_id + report_data.template_name = self.report_name + report_data.template_introduce = None + report_data.charts = chart_datas + + return report_data + + diff --git a/pilot/scene/chat_dashboard/data_preparation/__init__.py b/pilot/scene/chat_dashboard/data_preparation/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/scene/chat_dashboard/data_preparation/report_schma.py b/pilot/scene/chat_dashboard/data_preparation/report_schma.py new file mode 100644 index 000000000..9323aacc3 --- /dev/null +++ b/pilot/scene/chat_dashboard/data_preparation/report_schma.py @@ -0,0 +1,22 @@ +from pydantic import BaseModel, Field +from typing import TypeVar, Union, List, Generic, Any + + +class ChartData(BaseModel): + chart_uid: str + chart_type: str + chart_sql: str + column_name: List + values: List + style: Any + + +class ReportData(BaseModel): + conv_uid:str + template_name:str + template_introduce:str + charts: List[ChartData] + + + + diff --git a/pilot/scene/chat_dashboard/out_parser.py b/pilot/scene/chat_dashboard/out_parser.py new file mode 100644 index 000000000..975196978 --- /dev/null +++ b/pilot/scene/chat_dashboard/out_parser.py @@ -0,0 +1,41 @@ +import json +import re +from abc import ABC, abstractmethod +from typing import Dict, NamedTuple, List +import pandas as pd +from pilot.utils import build_logger +from pilot.out_parser.base import BaseOutputParser, T +from pilot.configs.model_config import LOGDIR + + +class ChartItem(NamedTuple): + sql: str + title:str + thoughts: str + showcase:str + + +logger = build_logger("webserver", LOGDIR + "ChatDashboardOutputParser.log") + + +class ChatDashboardOutputParser(BaseOutputParser): + def __init__(self, sep: str, is_stream_out: bool): + super().__init__(sep=sep, is_stream_out=is_stream_out) + + def parse_prompt_response(self, model_out_text): + clean_str = super().parse_prompt_response(model_out_text) + print("clean prompt response:", clean_str) + response = json.loads(clean_str) + chart_items = List[ChartItem] + for item in response: + chart_items.append(ChartItem(item["sql"], item["title"], item["thoughts"], item["showcase"])) + return chart_items + + def parse_view_response(self, speak, data) -> str: + ### TODO + + return data + + @property + def _type(self) -> str: + return "chat_dashboard" diff --git a/pilot/scene/chat_dashboard/prompt.py b/pilot/scene/chat_dashboard/prompt.py new file mode 100644 index 000000000..e44144d4d --- /dev/null +++ b/pilot/scene/chat_dashboard/prompt.py @@ -0,0 +1,49 @@ +import json +import importlib +from pilot.prompts.prompt_new import PromptTemplate +from pilot.configs.config import Config +from pilot.scene.base import ChatScene +from pilot.scene.chat_db.auto_execute.out_parser import DbChatOutputParser, SqlAction +from pilot.common.schema import SeparatorStyle + +CFG = Config() + +PROMPT_SCENE_DEFINE = """You are a {dialect} data analysis expert, please provide a professional data analysis solution according to the following situations""" +PROMPT_SCENE_DEFINE = None + +_DEFAULT_TEMPLATE = """ +According to the structure definition in the following tables: +{table_info} +Provide a professional data analysis with as few dimensions as possible, and the upper limit does not exceed 8 dimensions. +Used to support goal: {input} + +Use the chart display method in the following range: +{supported_chat_type} +give {dialect} data analysis SQL, analysis title, display method and analytical thinking,respond in the following json format: +{response} +Ensure the response is correct json and can be parsed by Python json.loads +""" + +RESPONSE_FORMAT = [{ + "sql": "data analysis SQL", + "title": "Data Analysis Title", + "showcase": "What type of charts to show", + "thoughts": "Current thinking and value of data analysis" +}] + +PROMPT_SEP = SeparatorStyle.SINGLE.value + +PROMPT_NEED_NEED_STREAM_OUT = False + +prompt = PromptTemplate( + template_scene=ChatScene.ChatWithDbExecute.value, + input_variables=["input", "table_info", "dialect", "supported_chat_type"], + response_format=json.dumps(RESPONSE_FORMAT, indent=4), + template_define=PROMPT_SCENE_DEFINE, + template=_DEFAULT_TEMPLATE, + stream_out=PROMPT_NEED_NEED_STREAM_OUT, + output_parser=DbChatOutputParser( + sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT + ), +) +CFG.prompt_templates.update({prompt.template_scene: prompt}) diff --git a/pilot/scene/chat_dashboard/template/sales_report/dashboard.json b/pilot/scene/chat_dashboard/template/sales_report/dashboard.json new file mode 100644 index 000000000..f08142122 --- /dev/null +++ b/pilot/scene/chat_dashboard/template/sales_report/dashboard.json @@ -0,0 +1,9 @@ +{ + "title": "Sales Report", + "name": "sale_report", + "introduce": "", + "layout": "TODO", + "supported_chart_type":["HeatMap","sheet", "LineChart", "PieChart", "BarChart"], + "key_metrics":[], + "trends": [] +} \ No newline at end of file diff --git a/pilot/scene/chat_db/auto_execute/chat.py b/pilot/scene/chat_db/auto_execute/chat.py index 73c732713..0a87f51f6 100644 --- a/pilot/scene/chat_db/auto_execute/chat.py +++ b/pilot/scene/chat_db/auto_execute/chat.py @@ -22,12 +22,10 @@ class ChatWithDbAutoExecute(BaseChat): """Number of results to return from the query""" def __init__( - self, temperature, max_new_tokens, chat_session_id, db_name, user_input + self, chat_session_id, db_name, user_input ): """ """ super().__init__( - temperature=temperature, - max_new_tokens=max_new_tokens, chat_mode=ChatScene.ChatWithDbExecute, chat_session_id=chat_session_id, current_user_input=user_input, @@ -57,5 +55,5 @@ class ChatWithDbAutoExecute(BaseChat): } return input_values - def do_with_prompt_response(self, prompt_response): + def do_action(self, prompt_response): return self.database.run(self.db_connect, prompt_response.sql) diff --git a/pilot/scene/chat_db/professional_qa/chat.py b/pilot/scene/chat_db/professional_qa/chat.py index e956bdc8b..3d6cd0db4 100644 --- a/pilot/scene/chat_db/professional_qa/chat.py +++ b/pilot/scene/chat_db/professional_qa/chat.py @@ -20,12 +20,10 @@ class ChatWithDbQA(BaseChat): """Number of results to return from the query""" def __init__( - self, temperature, max_new_tokens, chat_session_id, db_name, user_input + self, chat_session_id, db_name, user_input ): """ """ super().__init__( - temperature=temperature, - max_new_tokens=max_new_tokens, chat_mode=ChatScene.ChatWithDbQA, chat_session_id=chat_session_id, current_user_input=user_input, @@ -66,5 +64,4 @@ class ChatWithDbQA(BaseChat): } return input_values - def do_with_prompt_response(self, prompt_response): - return prompt_response + diff --git a/pilot/scene/chat_execution/chat.py b/pilot/scene/chat_execution/chat.py index 1dcb4c6ed..97646c299 100644 --- a/pilot/scene/chat_execution/chat.py +++ b/pilot/scene/chat_execution/chat.py @@ -22,15 +22,11 @@ class ChatWithPlugin(BaseChat): def __init__( self, - temperature, - max_new_tokens, chat_session_id, user_input, plugin_selector: str = None, ): super().__init__( - temperature=temperature, - max_new_tokens=max_new_tokens, chat_mode=ChatScene.ChatExecution, chat_session_id=chat_session_id, current_user_input=user_input, @@ -66,7 +62,7 @@ class ChatWithPlugin(BaseChat): } return input_values - def do_with_prompt_response(self, prompt_response): + def do_action(self, prompt_response): ## plugin command run return execute_command( str(prompt_response.command.get("name")), diff --git a/pilot/scene/chat_knowledge/custom/chat.py b/pilot/scene/chat_knowledge/custom/chat.py index 85d48a657..9283f277b 100644 --- a/pilot/scene/chat_knowledge/custom/chat.py +++ b/pilot/scene/chat_knowledge/custom/chat.py @@ -30,12 +30,10 @@ class ChatNewKnowledge(BaseChat): """Number of results to return from the query""" def __init__( - self, temperature, max_new_tokens, chat_session_id, user_input, knowledge_name + self, chat_session_id, user_input, knowledge_name ): """ """ super().__init__( - temperature=temperature, - max_new_tokens=max_new_tokens, chat_mode=ChatScene.ChatNewKnowledge, chat_session_id=chat_session_id, current_user_input=user_input, @@ -67,8 +65,6 @@ class ChatNewKnowledge(BaseChat): return input_values - def do_with_prompt_response(self, prompt_response): - return prompt_response @property def chat_type(self) -> str: diff --git a/pilot/scene/chat_knowledge/default/chat.py b/pilot/scene/chat_knowledge/default/chat.py index 838ff834c..8052de910 100644 --- a/pilot/scene/chat_knowledge/default/chat.py +++ b/pilot/scene/chat_knowledge/default/chat.py @@ -25,16 +25,14 @@ CFG = Config() class ChatDefaultKnowledge(BaseChat): - chat_scene: str = ChatScene.ChatKnowledge.value + chat_scene: str = ChatScene.ChatDefaultKnowledge.value """Number of results to return from the query""" - def __init__(self, temperature, max_new_tokens, chat_session_id, user_input): + def __init__(self, chat_session_id, user_input): """ """ super().__init__( - temperature=temperature, - max_new_tokens=max_new_tokens, - chat_mode=ChatScene.ChatKnowledge, + chat_mode=ChatScene.ChatDefaultKnowledge, chat_session_id=chat_session_id, current_user_input=user_input, ) @@ -61,9 +59,8 @@ class ChatDefaultKnowledge(BaseChat): ) return input_values - def do_with_prompt_response(self, prompt_response): - return prompt_response + @property def chat_type(self) -> str: - return ChatScene.ChatKnowledge.value + return ChatScene.ChatDefaultKnowledge.value diff --git a/pilot/scene/chat_knowledge/default/prompt.py b/pilot/scene/chat_knowledge/default/prompt.py index 0fd9f9ff3..760686366 100644 --- a/pilot/scene/chat_knowledge/default/prompt.py +++ b/pilot/scene/chat_knowledge/default/prompt.py @@ -39,7 +39,7 @@ PROMPT_SEP = SeparatorStyle.SINGLE.value PROMPT_NEED_NEED_STREAM_OUT = True prompt = PromptTemplate( - template_scene=ChatScene.ChatKnowledge.value, + template_scene=ChatScene.ChatDefaultKnowledge.value, input_variables=["context", "question"], response_format=None, template_define=PROMPT_SCENE_DEFINE, diff --git a/pilot/scene/chat_knowledge/inner_db_summary/chat.py b/pilot/scene/chat_knowledge/inner_db_summary/chat.py index e149f4a1b..b4dcc536f 100644 --- a/pilot/scene/chat_knowledge/inner_db_summary/chat.py +++ b/pilot/scene/chat_knowledge/inner_db_summary/chat.py @@ -14,8 +14,6 @@ class InnerChatDBSummary(BaseChat): def __init__( self, - temperature, - max_new_tokens, chat_session_id, user_input, db_select, @@ -23,8 +21,6 @@ class InnerChatDBSummary(BaseChat): ): """ """ super().__init__( - temperature=temperature, - max_new_tokens=max_new_tokens, chat_mode=ChatScene.InnerChatDBSummary, chat_session_id=chat_session_id, current_user_input=user_input, @@ -40,8 +36,6 @@ class InnerChatDBSummary(BaseChat): } return input_values - def do_with_prompt_response(self, prompt_response): - return prompt_response @property def chat_type(self) -> str: diff --git a/pilot/scene/chat_knowledge/url/chat.py b/pilot/scene/chat_knowledge/url/chat.py index 57fb8b618..433a64bd8 100644 --- a/pilot/scene/chat_knowledge/url/chat.py +++ b/pilot/scene/chat_knowledge/url/chat.py @@ -27,11 +27,9 @@ class ChatUrlKnowledge(BaseChat): """Number of results to return from the query""" - def __init__(self, temperature, max_new_tokens, chat_session_id, user_input, url): + def __init__(self, chat_session_id, user_input, url): """ """ super().__init__( - temperature=temperature, - max_new_tokens=max_new_tokens, chat_mode=ChatScene.ChatUrlKnowledge, chat_session_id=chat_session_id, current_user_input=user_input, @@ -62,8 +60,6 @@ class ChatUrlKnowledge(BaseChat): input_values = {"context": context, "question": self.current_user_input} return input_values - def do_with_prompt_response(self, prompt_response): - return prompt_response @property def chat_type(self) -> str: diff --git a/pilot/scene/chat_normal/chat.py b/pilot/scene/chat_normal/chat.py index f4ba94320..a183020c5 100644 --- a/pilot/scene/chat_normal/chat.py +++ b/pilot/scene/chat_normal/chat.py @@ -18,11 +18,9 @@ class ChatNormal(BaseChat): """Number of results to return from the query""" - def __init__(self, temperature, max_new_tokens, chat_session_id, user_input): + def __init__(self, chat_session_id, user_input): """ """ super().__init__( - temperature=temperature, - max_new_tokens=max_new_tokens, chat_mode=ChatScene.ChatNormal, chat_session_id=chat_session_id, current_user_input=user_input, @@ -32,7 +30,7 @@ class ChatNormal(BaseChat): input_values = {"input": self.current_user_input} return input_values - def do_with_prompt_response(self, prompt_response): + def do_action(self, prompt_response): return prompt_response @property diff --git a/pilot/scene/message.py b/pilot/scene/message.py index 0203ec68c..a2a894fe8 100644 --- a/pilot/scene/message.py +++ b/pilot/scene/message.py @@ -25,7 +25,8 @@ class OnceConversation: All the information of a conversation, the current single service in memory, can expand cache and database support distributed services """ - def __init__(self): + def __init__(self, chat_mode): + self.chat_mode: str = chat_mode self.messages: List[BaseMessage] = [] self.start_date: str = "" self.chat_order: int = 0 @@ -43,12 +44,28 @@ class OnceConversation: def add_ai_message(self, message: str) -> None: """Add an AI message to the store""" + has_message = any(isinstance(instance, AIMessage) for instance in self.messages) if has_message: - raise ValueError("Already Have Ai message") - self.messages.append(AIMessage(content=message)) + self.__update_ai_message(message) + else: + self.messages.append(AIMessage(content=message)) """ """ + def __update_ai_message(self, new_message: str) -> None: + """ + stream out message update + Args: + new_message: + + Returns: + + """ + + for item in self.messages: + if item.type == "ai": + item.content = new_message + def add_view_message(self, message: str) -> None: """Add an AI message to the store""" @@ -69,6 +86,13 @@ class OnceConversation: self.session_id = None + def get_user_message(self): + for once in self.messages: + if isinstance(once, HumanMessage): + return once.content + return "" + + def _conversation_to_dic(once: OnceConversation) -> dict: start_str: str = "" if once.start_date: @@ -78,6 +102,7 @@ def _conversation_to_dic(once: OnceConversation) -> dict: start_str = once.start_date return { + "chat_mode": once.chat_mode, "chat_order": once.chat_order, "start_date": start_str, "cost": once.cost if once.cost else 0, @@ -93,6 +118,7 @@ def conversations_to_dict(conversations: List[OnceConversation]) -> List[dict]: def conversation_from_dict(once: dict) -> OnceConversation: conversation = OnceConversation() conversation.cost = once.get("cost", 0) + conversation.chat_mode = once.get("chat_mode", "chat_normal") conversation.tokens = once.get("tokens", 0) conversation.start_date = once.get("start_date", "") conversation.chat_order = int(once.get("chat_order")) diff --git a/pilot/server/api_v1/api_v1.py b/pilot/server/api_v1/api_v1.py index 19f4e765c..b75179eaf 100644 --- a/pilot/server/api_v1/api_v1.py +++ b/pilot/server/api_v1/api_v1.py @@ -1,15 +1,18 @@ import uuid - -from fastapi import APIRouter, Request, Body, status +import json +import asyncio +import time +from fastapi import APIRouter, Request, Body, status, HTTPException, Response from fastapi.responses import JSONResponse from fastapi.responses import StreamingResponse from fastapi.encoders import jsonable_encoder from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse +from sse_starlette.sse import EventSourceResponse from typing import List -from pilot.server.api_v1.api_view_model import Result, ConversationVo, MessageVo +from pilot.server.api_v1.api_view_model import Result, ConversationVo, MessageVo, ChatSceneVo from pilot.configs.config import Config from pilot.scene.base_chat import BaseChat from pilot.scene.base import ChatScene @@ -17,6 +20,8 @@ from pilot.scene.chat_factory import ChatFactory from pilot.configs.model_config import (LOGDIR) from pilot.utils import build_logger from pilot.scene.base_message import (BaseMessage) +from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory +from pilot.scene.message import OnceConversation router = APIRouter() CFG = Config() @@ -28,32 +33,117 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE message = "" for error in exc.errors(): message += ".".join(error.get("loc")) + ":" + error.get("msg") + ";" - return Result.faild(message) + return Result.faild(msg=message) -@router.get('/v1/chat/dialogue/list', response_model=Result[List[ConversationVo]]) -async def dialogue_list(user_id: str): - #### TODO - - conversations = [ConversationVo(conv_uid="123", chat_mode="user", select_param="test1", user_input="message[0]"), - ConversationVo(conv_uid="123", chat_mode="user", select_param="test1", user_input="message[0]")] - - return Result[ConversationVo].succ(conversations) +def __get_conv_user_message(conversations: dict): + messages = conversations['messages'] + for item in messages: + if item['type'] == "human": + return item['data']['content'] + return "" -@router.post('/v1/chat/dialogue/new', response_model=Result[str]) -async def dialogue_new(user_id: str): +@router.get('/v1/chat/dialogue/list', response_model=Result[ConversationVo]) +async def dialogue_list(response: Response, user_id: str = None): + # 设置CORS头部信息 + response.headers['Access-Control-Allow-Origin'] = '*' + response.headers['Access-Control-Allow-Methods'] = 'GET' + response.headers['Access-Control-Request-Headers'] = 'content-type' + + dialogues: List = [] + datas = DuckdbHistoryMemory.conv_list(user_id) + + for item in datas: + conv_uid = item.get("conv_uid") + messages = item.get("messages") + conversations = json.loads(messages) + + first_conv: OnceConversation = conversations[0] + conv_vo: ConversationVo = ConversationVo(conv_uid=conv_uid, user_input=__get_conv_user_message(first_conv), + chat_mode=first_conv['chat_mode']) + dialogues.append(conv_vo) + + return Result[ConversationVo].succ(dialogues) + + +@router.post('/v1/chat/dialogue/scenes', response_model=Result[List[ChatSceneVo]]) +async def dialogue_scenes(): + scene_vos: List[ChatSceneVo] = [] + new_modes:List[ChatScene] = [ChatScene.ChatDb, ChatScene.ChatData, ChatScene.ChatDashboard, ChatScene.ChatKnowledge, ChatScene.ChatExecution] + for scene in new_modes: + if not scene.value in [ChatScene.ChatNormal.value, ChatScene.InnerChatDBSummary.value]: + scene_vo = ChatSceneVo(chat_scene=scene.value, scene_name=scene.name, param_title="Selection Param") + scene_vos.append(scene_vo) + return Result.succ(scene_vos) + + +@router.post('/v1/chat/dialogue/new', response_model=Result[ConversationVo]) +async def dialogue_new(chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None): unique_id = uuid.uuid1() - return Result.succ(unique_id) + return Result.succ(ConversationVo(conv_uid=str(unique_id), chat_mode=chat_mode)) + + +def get_db_list(): + db = CFG.local_db + dbs = db.get_database_list() + params:dict = {} + for name in dbs: + params.update({name: name}) + return params + + +def plugins_select_info(): + plugins_infos: dict = {} + for plugin in CFG.plugins: + plugins_infos.update({f"【{plugin._name}】=>{plugin._description}": plugin._name}) + return plugins_infos + + +def knowledge_list(): + knowledge: dict = {} + ### TODO + return knowledge + + +@router.post('/v1/chat/mode/params/list', response_model=Result[dict]) +async def params_list(chat_mode: str = ChatScene.ChatNormal.value): + if ChatScene.ChatDb.value == chat_mode: + return Result.succ(get_db_list()) + elif ChatScene.ChatData.value == chat_mode: + return Result.succ(get_db_list()) + elif ChatScene.ChatDashboard.value == chat_mode: + return Result.succ(get_db_list()) + elif ChatScene.ChatExecution.value == chat_mode: + return Result.succ(plugins_select_info()) + elif ChatScene.ChatKnowledge.value == chat_mode: + return Result.succ(knowledge_list()) + else: + return Result.succ(None) @router.post('/v1/chat/dialogue/delete') -async def dialogue_delete(con_uid: str, user_id: str): - #### TODO +async def dialogue_delete(con_uid: str): + history_mem = DuckdbHistoryMemory(con_uid) + history_mem.delete() return Result.succ(None) -@router.post('/v1/chat/completions', response_model=Result[MessageVo]) +@router.get('/v1/chat/dialogue/messages/history', response_model=Result[MessageVo]) +async def dialogue_history_messages(con_uid: str): + print(f"dialogue_history_messages:{con_uid}") + message_vos: List[MessageVo] = [] + + history_mem = DuckdbHistoryMemory(con_uid) + history_messages: List[OnceConversation] = history_mem.get_messages() + if history_messages: + for once in history_messages: + once_message_vos = [message2Vo(element, once['chat_order']) for element in once['messages']] + message_vos.extend(once_message_vos) + return Result.succ(message_vos) + + +@router.post('/v1/chat/completions') async def chat_completions(dialogue: ConversationVo = Body()): print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param}") @@ -65,22 +155,31 @@ async def chat_completions(dialogue: ConversationVo = Body()): "user_input": dialogue.user_input, } - if ChatScene.ChatWithDbExecute == dialogue.chat_mode: + if ChatScene.ChatDb == dialogue.chat_mode: chat_param.update("db_name", dialogue.select_param) - elif ChatScene.ChatWithDbQA == dialogue.chat_mode: + elif ChatScene.ChatData == dialogue.chat_mode: + chat_param.update("db_name", dialogue.select_param) + elif ChatScene.ChatDashboard == dialogue.chat_mode: chat_param.update("db_name", dialogue.select_param) elif ChatScene.ChatExecution == dialogue.chat_mode: chat_param.update("plugin_selector", dialogue.select_param) - elif ChatScene.ChatNewKnowledge == dialogue.chat_mode: + elif ChatScene.ChatKnowledge == dialogue.chat_mode: chat_param.update("knowledge_name", dialogue.select_param) - elif ChatScene.ChatUrlKnowledge == dialogue.chat_mode: - chat_param.update("url", dialogue.select_param) chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param) if not chat.prompt_template.stream_out: return non_stream_response(chat) else: - return stream_response(chat) + # generator = stream_generator(chat) + # result = Result.succ(data=StreamingResponse(stream_test(), media_type='text/plain')) + # return result + return StreamingResponse(stream_generator(chat), media_type="text/plain") + + +def stream_test(): + for message in ["Hello", "world", "how", "are", "you"]: + yield message + # yield json.dumps(Result.succ(message).__dict__).encode("utf-8") def stream_generator(chat): @@ -89,24 +188,28 @@ def stream_generator(chat): if chunk: msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len) chat.current_message.add_ai_message(msg) - messageVos = [message2Vo(element) for element in chat.current_message.messages] - yield Result.succ(messageVos) -def stream_response(chat): - logger.info("stream out start!") - api_response = StreamingResponse(stream_generator(chat), media_type="application/json") - return api_response + yield msg + # chat.current_message.add_ai_message(msg) + # vo = MessageVo(role="view", context=msg, order=chat.current_message.chat_order) + # json_text = json.dumps(vo.__dict__) + # yield json_text.encode('utf-8') + chat.memory.append(chat.current_message) + + +# def stream_response(chat): +# logger.info("stream out start!") +# api_response = StreamingResponse(stream_generator(chat), media_type="application/json") +# return api_response + + +def message2Vo(message: dict, order) -> MessageVo: + # message.additional_kwargs['time_stamp'] if message.additional_kwargs["time_stamp"] else 0 + return MessageVo(role=message['type'], context=message['data']['content'], order=order) -def message2Vo(message:BaseMessage)->MessageVo: - vo:MessageVo = MessageVo() - vo.role = message.type - vo.role = message.content - vo.time_stamp = message.additional_kwargs.time_stamp if message.additional_kwargs["time_stamp"] else 0 def non_stream_response(chat): logger.info("not stream out, wait model response!") - chat.nostream_call() - messageVos = [message2Vo(element) for element in chat.current_message.messages] - return Result.succ(messageVos) + return chat.nostream_call() @router.get('/v1/db/types', response_model=Result[str]) diff --git a/pilot/server/api_v1/api_view_model.py b/pilot/server/api_v1/api_view_model.py index 938ce22ec..f2d58599b 100644 --- a/pilot/server/api_v1/api_view_model.py +++ b/pilot/server/api_v1/api_view_model.py @@ -1,28 +1,33 @@ from pydantic import BaseModel, Field -from typing import TypeVar, Union, List, Generic +from typing import TypeVar, Union, List, Generic, Any T = TypeVar('T') class Result(Generic[T], BaseModel): success: bool - err_code: str - err_msg: str - data: List[T] + err_code: str = None + err_msg: str = None + data: T = None @classmethod - def succ(cls, data: List[T]): - return Result(True, None, None, data) + def succ(cls, data: T): + return Result(success=True, err_code=None, err_msg=None, data=data) @classmethod def faild(cls, msg): - return Result(True, "E000X", msg, None) + return Result(success=False, err_code="E000X", err_msg=msg, data=None) @classmethod def faild(cls, code, msg): - return Result(True, code, msg, None) + return Result(success=False, err_code=code, err_msg=msg, data=None) +class ChatSceneVo(BaseModel): + chat_scene: str = Field(..., description="chat_scene") + scene_name: str = Field(..., description="chat_scene name show for user") + param_title: str = Field(..., description="chat_scene required parameter title") + class ConversationVo(BaseModel): """ dialogue_uid @@ -31,15 +36,21 @@ class ConversationVo(BaseModel): """ user input """ - user_input: str + user_input: str = "" + """ + user + """ + user_name: str = "" """ the scene of chat """ chat_mode: str = Field(..., description="the scene of chat ") + """ chat scene select param """ - select_param: str + select_param: str = None + class MessageVo(BaseModel): @@ -51,7 +62,12 @@ class MessageVo(BaseModel): current message """ context: str + + """ message postion order """ + order: int + """ time the current message was sent """ - time_stamp: float + time_stamp: Any = None + diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index 7cc32bbad..c2cd9d434 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -import signal import threading import traceback import argparse @@ -12,12 +11,10 @@ import uuid import gradio as gr - ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(ROOT_PATH) from pilot.summary.db_summary_client import DBSummaryClient -from pilot.commands.command_mange import CommandRegistry from pilot.scene.base_chat import BaseChat @@ -25,8 +22,8 @@ from pilot.configs.config import Config from pilot.configs.model_config import ( DATASETS_DIR, KNOWLEDGE_UPLOAD_ROOT_PATH, - LOGDIR, LLM_MODEL_CONFIG, + LOGDIR, ) from pilot.conversation import ( @@ -35,11 +32,10 @@ from pilot.conversation import ( chat_mode_title, default_conversation, ) -from pilot.common.plugins import scan_plugins, load_native_plugins from pilot.server.gradio_css import code_highlight_css from pilot.server.gradio_patch import Chatbot as grChatbot -from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding +from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding from pilot.utils import build_logger from pilot.vector_store.extract_tovec import ( get_vector_storelist, @@ -49,6 +45,20 @@ from pilot.vector_store.extract_tovec import ( from pilot.scene.base import ChatScene from pilot.scene.chat_factory import ChatFactory from pilot.language.translation_handler import get_lang_text +from pilot.server.webserver_base import server_init + + +import uvicorn +from fastapi import BackgroundTasks, Request +from fastapi.responses import StreamingResponse +from pydantic import BaseModel +from fastapi import FastAPI, applications +from fastapi.openapi.docs import get_swagger_ui_html +from fastapi.exceptions import RequestValidationError +from fastapi.middleware.cors import CORSMiddleware +from fastapi.staticfiles import StaticFiles + +from pilot.server.api_v1.api_v1 import router as api_v1, validation_exception_handler # 加载插件 CFG = Config() @@ -95,6 +105,30 @@ knowledge_qa_type_list = [ add_knowledge_base_dialogue, ] +def swagger_monkey_patch(*args, **kwargs): + return get_swagger_ui_html( + *args, **kwargs, + swagger_js_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js', + swagger_css_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css' + ) +applications.get_swagger_ui_html = swagger_monkey_patch + +app = FastAPI() +origins = ["*"] + +# 添加跨域中间件 +app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"], + allow_headers=["*"], +) + +# app.mount("static", StaticFiles(directory="static"), name="static") +app.include_router(api_v1) +app.add_exception_handler(RequestValidationError, validation_exception_handler) + def get_simlar(q): docsearch = knownledge_tovec_st(os.path.join(DATASETS_DIR, "plan.md")) @@ -216,7 +250,7 @@ def get_chat_mode(selected, param=None) -> ChatScene: else: mode = param if mode == conversation_types["default_knownledge"]: - return ChatScene.ChatKnowledge + return ChatScene.ChatDefaultKnowledge elif mode == conversation_types["custome"]: return ChatScene.ChatNewKnowledge elif mode == conversation_types["url"]: @@ -286,7 +320,7 @@ def http_bot( "chat_session_id": state.conv_id, "user_input": state.last_user_input, } - elif ChatScene.ChatKnowledge == scene: + elif ChatScene.ChatDefaultKnowledge == scene: chat_param = { "temperature": temperature, "max_new_tokens": max_new_tokens, @@ -324,15 +358,14 @@ def http_bot( response = chat.stream_call() for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): if chunk: - state.messages[-1][ - -1 - ] = chat.prompt_template.output_parser.parse_model_stream_resp_ex( - chunk, chat.skip_echo_len - ) + msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len) + state.messages[-1][-1] =msg + chat.current_message.add_ai_message(msg) yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 + chat.memory.append(chat.current_message) except Exception as e: print(traceback.format_exc()) - state.messages[-1][-1] = "Error:" + str(e) + state.messages[-1][-1] = f"""ERROR!{str(e)} """ yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 @@ -632,7 +665,7 @@ def knowledge_embedding_store(vs_id, files): ) knowledge_embedding_client = KnowledgeEmbedding( file_path=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename), - model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], + model_name=LLM_MODEL_CONFIG["text2vec"], vector_store_config={ "vector_store_name": vector_store_name["vs_name"], "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, @@ -657,48 +690,36 @@ def signal_handler(sig, frame): if __name__ == "__main__": parser = argparse.ArgumentParser() + parser.add_argument("--model_list_mode", type=str, default="once", choices=["once", "reload"]) + parser.add_argument('-new', '--new', action='store_true', help='enable new http mode') + + # old version server config parser.add_argument("--host", type=str, default="0.0.0.0") parser.add_argument("--port", type=int, default=CFG.WEB_SERVER_PORT) parser.add_argument("--concurrency-count", type=int, default=10) - parser.add_argument( - "--model-list-mode", type=str, default="once", choices=["once", "reload"] - ) parser.add_argument("--share", default=False, action="store_true") + + # init server config args = parser.parse_args() - logger.info(f"args: {args}") + server_init(args) + + if args.new: + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=5000) + else: + ### Compatibility mode starts the old version server by default + demo = build_webdemo() + demo.queue( + concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False + ).launch( + server_name=args.host, + server_port=args.port, + share=args.share, + max_threads=200, + ) - # init config - cfg = Config() - load_native_plugins(cfg) - dbs = cfg.local_db.get_database_list() - signal.signal(signal.SIGINT, signal_handler) - async_db_summery() - cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode)) - # Loader plugins and commands - command_categories = [ - "pilot.commands.built_in.audio_text", - "pilot.commands.built_in.image_gen", - ] - # exclude commands - command_categories = [ - x for x in command_categories if x not in cfg.disabled_command_categories - ] - command_registry = CommandRegistry() - for command_category in command_categories: - command_registry.import_commands(command_category) - cfg.command_registry = command_registry - logger.info(args) - demo = build_webdemo() - demo.queue( - concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False - ).launch( - server_name=args.host, - server_port=args.port, - share=args.share, - max_threads=200, - )