mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-15 06:53:12 +00:00
WEB API independent
This commit is contained in:
parent
9d3000fb26
commit
3f7cc02426
@ -1,4 +1,5 @@
|
|||||||
"""Utilities for formatting strings."""
|
"""Utilities for formatting strings."""
|
||||||
|
import json
|
||||||
from string import Formatter
|
from string import Formatter
|
||||||
from typing import Any, List, Mapping, Sequence, Union
|
from typing import Any, List, Mapping, Sequence, Union
|
||||||
|
|
||||||
@ -36,3 +37,13 @@ class StrictFormatter(Formatter):
|
|||||||
|
|
||||||
|
|
||||||
formatter = StrictFormatter()
|
formatter = StrictFormatter()
|
||||||
|
|
||||||
|
|
||||||
|
class MyEncoder(json.JSONEncoder):
|
||||||
|
def default(self, obj):
|
||||||
|
if isinstance(obj, set):
|
||||||
|
return list(obj)
|
||||||
|
elif hasattr(obj, '__dict__'):
|
||||||
|
return obj.__dict__
|
||||||
|
else:
|
||||||
|
return json.JSONEncoder.default(self, obj)
|
@ -6,6 +6,8 @@ import numpy as np
|
|||||||
from matplotlib.font_manager import FontProperties
|
from matplotlib.font_manager import FontProperties
|
||||||
from pyecharts.charts import Bar
|
from pyecharts.charts import Bar
|
||||||
from pyecharts import options as opts
|
from pyecharts import options as opts
|
||||||
|
from test_cls_1 import TestBase,Test1
|
||||||
|
from test_cls_2 import Test2
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
@ -56,20 +58,29 @@ CFG = Config()
|
|||||||
#
|
#
|
||||||
|
|
||||||
|
|
||||||
|
# if __name__ == "__main__":
|
||||||
|
|
||||||
|
# def __extract_json(s):
|
||||||
|
# i = s.index("{")
|
||||||
|
# count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
|
||||||
|
# for j, c in enumerate(s[i + 1 :], start=i + 1):
|
||||||
|
# if c == "}":
|
||||||
|
# count -= 1
|
||||||
|
# elif c == "{":
|
||||||
|
# count += 1
|
||||||
|
# if count == 0:
|
||||||
|
# break
|
||||||
|
# assert count == 0 # 检查是否找到最后一个'}'
|
||||||
|
# return s[i : j + 1]
|
||||||
|
#
|
||||||
|
# ss = """here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:select u.city, count(*) as order_countfrom tran_order oleft join user u on o.user_id = u.idgroup by u.city;this will return the number of orders for each city that has at least one order. we can use this data to generate a histogram that shows the distribution of orders across different cities.here's the response in the required format:{ "thoughts": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:\n\nselect u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;", "speak": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities.", "command": { "name": "histogram-executor", "args": { "title": "distribution of user orders in different cities", "sql": "select u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;" } }}"""
|
||||||
|
# print(__extract_json(ss))
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
test1 = Test1()
|
||||||
def __extract_json(s):
|
test2 = Test2()
|
||||||
i = s.index("{")
|
test1.write()
|
||||||
count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
|
test1.test()
|
||||||
for j, c in enumerate(s[i + 1 :], start=i + 1):
|
test2.write()
|
||||||
if c == "}":
|
test1.test()
|
||||||
count -= 1
|
test2.test()
|
||||||
elif c == "{":
|
|
||||||
count += 1
|
|
||||||
if count == 0:
|
|
||||||
break
|
|
||||||
assert count == 0 # 检查是否找到最后一个'}'
|
|
||||||
return s[i : j + 1]
|
|
||||||
|
|
||||||
ss = """here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:select u.city, count(*) as order_countfrom tran_order oleft join user u on o.user_id = u.idgroup by u.city;this will return the number of orders for each city that has at least one order. we can use this data to generate a histogram that shows the distribution of orders across different cities.here's the response in the required format:{ "thoughts": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:\n\nselect u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;", "speak": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities.", "command": { "name": "histogram-executor", "args": { "title": "distribution of user orders in different cities", "sql": "select u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;" } }}"""
|
|
||||||
print(__extract_json(ss))
|
|
@ -4,7 +4,7 @@ from test_cls_base import TestBase
|
|||||||
|
|
||||||
|
|
||||||
class Test1(TestBase):
|
class Test1(TestBase):
|
||||||
|
mode:str = "456"
|
||||||
def write(self):
|
def write(self):
|
||||||
self.test_values.append("x")
|
self.test_values.append("x")
|
||||||
self.test_values.append("y")
|
self.test_values.append("y")
|
||||||
|
@ -5,7 +5,7 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
|
|||||||
|
|
||||||
class Test2(TestBase):
|
class Test2(TestBase):
|
||||||
test_2_values:List = []
|
test_2_values:List = []
|
||||||
|
mode:str = "789"
|
||||||
def write(self):
|
def write(self):
|
||||||
self.test_values.append(1)
|
self.test_values.append(1)
|
||||||
self.test_values.append(2)
|
self.test_values.append(2)
|
||||||
|
@ -5,8 +5,9 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
|
|||||||
|
|
||||||
class TestBase(BaseModel, ABC):
|
class TestBase(BaseModel, ABC):
|
||||||
test_values: List = []
|
test_values: List = []
|
||||||
|
mode:str = "123"
|
||||||
|
|
||||||
def test(self):
|
def test(self):
|
||||||
print(self.__class__.__name__ + ":" )
|
print(self.__class__.__name__ + ":" )
|
||||||
print(self.test_values)
|
print(self.test_values)
|
||||||
|
print(self.mode)
|
13
pilot/connections/rdbms/py_study/test_duckdb.py
Normal file
13
pilot/connections/rdbms/py_study/test_duckdb.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import duckdb
|
||||||
|
|
||||||
|
default_db_path = os.path.join(os.getcwd(), "message")
|
||||||
|
duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/chat_history.db")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
if os.path.isfile(duckdb_path):
|
||||||
|
cursor = duckdb.connect(duckdb_path).cursor()
|
||||||
|
cursor.execute("SELECT * FROM chat_history limit 20")
|
||||||
|
data = cursor.fetchall()
|
||||||
|
print(data)
|
@ -32,3 +32,9 @@ class BaseChatHistoryMemory(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
"""Clear session memory from the local file"""
|
"""Clear session memory from the local file"""
|
||||||
|
|
||||||
|
|
||||||
|
def conv_list(self, user_name:str=None) -> None:
|
||||||
|
"""get user's conversation list"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
109
pilot/memory/chat_history/duckdb_history.py
Normal file
109
pilot/memory/chat_history/duckdb_history.py
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import duckdb
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
from pilot.memory.chat_history.base import BaseChatHistoryMemory
|
||||||
|
from pilot.scene.message import (
|
||||||
|
OnceConversation,
|
||||||
|
conversation_from_dict,
|
||||||
|
conversations_to_dict,
|
||||||
|
)
|
||||||
|
from pilot.common.formatting import MyEncoder
|
||||||
|
|
||||||
|
default_db_path = os.path.join(os.getcwd(), "message")
|
||||||
|
duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/chat_history.db")
|
||||||
|
table_name = 'chat_history'
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
|
class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
||||||
|
|
||||||
|
def __init__(self, chat_session_id: str):
|
||||||
|
self.chat_seesion_id = chat_session_id
|
||||||
|
os.makedirs(default_db_path, exist_ok=True)
|
||||||
|
self.connect = duckdb.connect(duckdb_path)
|
||||||
|
self.__init_chat_history_tables()
|
||||||
|
|
||||||
|
def __init_chat_history_tables(self):
|
||||||
|
|
||||||
|
# 检查表是否存在
|
||||||
|
result = self.connect.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?",
|
||||||
|
[table_name]).fetchall()
|
||||||
|
|
||||||
|
if not result:
|
||||||
|
# 如果表不存在,则创建新表
|
||||||
|
self.connect.execute(
|
||||||
|
"CREATE TABLE chat_history (conv_uid VARCHAR(100) PRIMARY KEY, user_name VARCHAR(100), messages TEXT)")
|
||||||
|
|
||||||
|
def __get_messages_by_conv_uid(self, conv_uid: str):
|
||||||
|
cursor = self.connect.cursor()
|
||||||
|
cursor.execute("SELECT messages FROM chat_history where conv_uid=?", [conv_uid])
|
||||||
|
return cursor.fetchone()
|
||||||
|
|
||||||
|
def messages(self) -> List[OnceConversation]:
|
||||||
|
context = self.__get_messages_by_conv_uid(self.chat_seesion_id)
|
||||||
|
if context:
|
||||||
|
conversations: List[OnceConversation] = json.loads(context[0])
|
||||||
|
return conversations
|
||||||
|
return []
|
||||||
|
|
||||||
|
def append(self, once_message: OnceConversation) -> None:
|
||||||
|
context = self.__get_messages_by_conv_uid(self.chat_seesion_id)
|
||||||
|
conversations: List[OnceConversation] = []
|
||||||
|
if context:
|
||||||
|
conversations = json.load(context)
|
||||||
|
conversations.append(once_message)
|
||||||
|
cursor = self.connect.cursor()
|
||||||
|
if context:
|
||||||
|
cursor.execute("UPDATE chat_history set messages=? where conv_uid=?",
|
||||||
|
[json.dumps(conversations_to_dict(conversations), ensure_ascii=False, indent=4), self.chat_seesion_id])
|
||||||
|
else:
|
||||||
|
cursor.execute("INSERT INTO chat_history(conv_uid, user_name, messages)VALUES(?,?,?)",
|
||||||
|
[self.chat_seesion_id, "", json.dumps(conversations_to_dict(conversations), ensure_ascii=False, indent=4)])
|
||||||
|
cursor.commit()
|
||||||
|
self.connect.commit()
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
cursor = self.connect.cursor()
|
||||||
|
cursor.execute("DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id])
|
||||||
|
cursor.commit()
|
||||||
|
self.connect.commit()
|
||||||
|
|
||||||
|
def delete(self) -> bool:
|
||||||
|
cursor = self.connect.cursor()
|
||||||
|
cursor.execute("DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id])
|
||||||
|
cursor.commit()
|
||||||
|
return True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def conv_list(cls, user_name: str = None) -> None:
|
||||||
|
if os.path.isfile(duckdb_path):
|
||||||
|
cursor = duckdb.connect(duckdb_path).cursor()
|
||||||
|
if user_name:
|
||||||
|
cursor.execute("SELECT * FROM chat_history where user_name=? limit 20", [user_name])
|
||||||
|
else:
|
||||||
|
cursor.execute("SELECT * FROM chat_history limit 20")
|
||||||
|
# 获取查询结果字段名
|
||||||
|
fields = [field[0] for field in cursor.description]
|
||||||
|
data = []
|
||||||
|
for row in cursor.fetchall():
|
||||||
|
row_dict = {}
|
||||||
|
for i, field in enumerate(fields):
|
||||||
|
row_dict[field] = row[i]
|
||||||
|
data.append(row_dict)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def get_messages(self)-> List[OnceConversation]:
|
||||||
|
cursor = self.connect.cursor()
|
||||||
|
cursor.execute("SELECT messages FROM chat_history where conv_uid=?", [self.chat_seesion_id])
|
||||||
|
context = cursor.fetchone()
|
||||||
|
if context:
|
||||||
|
return json.loads(context[0])
|
||||||
|
return None
|
@ -11,13 +11,14 @@ from pilot.scene.message import (
|
|||||||
conversation_from_dict,
|
conversation_from_dict,
|
||||||
conversations_to_dict,
|
conversations_to_dict,
|
||||||
)
|
)
|
||||||
|
from pilot.common.custom_data_structure import FixedSizeDict, FixedSizeList
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
class MemHistoryMemory(BaseChatHistoryMemory):
|
class MemHistoryMemory(BaseChatHistoryMemory):
|
||||||
histroies_map = {}
|
histroies_map = FixedSizeDict(100)
|
||||||
|
|
||||||
|
|
||||||
def __init__(self, chat_session_id: str):
|
def __init__(self, chat_session_id: str):
|
||||||
self.chat_seesion_id = chat_session_id
|
self.chat_seesion_id = chat_session_id
|
||||||
|
BIN
pilot/mock_datas/chat_history.db
Normal file
BIN
pilot/mock_datas/chat_history.db
Normal file
Binary file not shown.
0
pilot/mock_datas/chat_history.db.wal
Normal file
0
pilot/mock_datas/chat_history.db.wal
Normal file
@ -1,12 +1,27 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
|
class Scene:
|
||||||
|
def __init__(self, code, describe, is_inner):
|
||||||
|
self.code = code
|
||||||
|
self.describe = describe
|
||||||
|
self.is_inner = is_inner
|
||||||
|
|
||||||
class ChatScene(Enum):
|
class ChatScene(Enum):
|
||||||
ChatWithDbExecute = "chat_with_db_execute"
|
ChatWithDbExecute = "chat_with_db_execute"
|
||||||
ChatWithDbQA = "chat_with_db_qa"
|
ChatWithDbQA = "chat_with_db_qa"
|
||||||
ChatExecution = "chat_execution"
|
ChatExecution = "chat_execution"
|
||||||
ChatKnowledge = "chat_default_knowledge"
|
ChatDefaultKnowledge = "chat_default_knowledge"
|
||||||
ChatNewKnowledge = "chat_new_knowledge"
|
ChatNewKnowledge = "chat_new_knowledge"
|
||||||
ChatUrlKnowledge = "chat_url_knowledge"
|
ChatUrlKnowledge = "chat_url_knowledge"
|
||||||
InnerChatDBSummary = "inner_chat_db_summary"
|
InnerChatDBSummary = "inner_chat_db_summary"
|
||||||
|
|
||||||
ChatNormal = "chat_normal"
|
ChatNormal = "chat_normal"
|
||||||
|
ChatDashboard = "chat_dashboard"
|
||||||
|
ChatKnowledge = "chat_knowledge"
|
||||||
|
ChatDb = "chat_db"
|
||||||
|
ChatData= "chat_data"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_valid_mode(mode):
|
||||||
|
return any(mode == item.value for item in ChatScene)
|
||||||
|
|
||||||
|
@ -24,6 +24,7 @@ from pilot.prompts.prompt_new import PromptTemplate
|
|||||||
from pilot.memory.chat_history.base import BaseChatHistoryMemory
|
from pilot.memory.chat_history.base import BaseChatHistoryMemory
|
||||||
from pilot.memory.chat_history.file_history import FileHistoryMemory
|
from pilot.memory.chat_history.file_history import FileHistoryMemory
|
||||||
from pilot.memory.chat_history.mem_history import MemHistoryMemory
|
from pilot.memory.chat_history.mem_history import MemHistoryMemory
|
||||||
|
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
|
||||||
|
|
||||||
from pilot.configs.model_config import LOGDIR, DATASETS_DIR
|
from pilot.configs.model_config import LOGDIR, DATASETS_DIR
|
||||||
from pilot.utils import (
|
from pilot.utils import (
|
||||||
@ -59,8 +60,6 @@ class BaseChat(ABC):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
temperature,
|
|
||||||
max_new_tokens,
|
|
||||||
chat_mode,
|
chat_mode,
|
||||||
chat_session_id,
|
chat_session_id,
|
||||||
current_user_input,
|
current_user_input,
|
||||||
@ -70,17 +69,15 @@ class BaseChat(ABC):
|
|||||||
self.current_user_input: str = current_user_input
|
self.current_user_input: str = current_user_input
|
||||||
self.llm_model = CFG.LLM_MODEL
|
self.llm_model = CFG.LLM_MODEL
|
||||||
### can configurable storage methods
|
### can configurable storage methods
|
||||||
self.memory = MemHistoryMemory(chat_session_id)
|
self.memory = DuckdbHistoryMemory(chat_session_id)
|
||||||
|
|
||||||
### load prompt template
|
### load prompt template
|
||||||
self.prompt_template: PromptTemplate = CFG.prompt_templates[
|
self.prompt_template: PromptTemplate = CFG.prompt_templates[
|
||||||
self.chat_mode.value
|
self.chat_mode.value
|
||||||
]
|
]
|
||||||
self.history_message: List[OnceConversation] = []
|
self.history_message: List[OnceConversation] = []
|
||||||
self.current_message: OnceConversation = OnceConversation()
|
self.current_message: OnceConversation = OnceConversation(chat_mode.value)
|
||||||
self.current_tokens_used: int = 0
|
self.current_tokens_used: int = 0
|
||||||
self.temperature = temperature
|
|
||||||
self.max_new_tokens = max_new_tokens
|
|
||||||
### load chat_session_id's chat historys
|
### load chat_session_id's chat historys
|
||||||
self._load_history(self.chat_session_id)
|
self._load_history(self.chat_session_id)
|
||||||
|
|
||||||
@ -99,15 +96,15 @@ class BaseChat(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def do_with_prompt_response(self, prompt_response):
|
def do_action(self, prompt_response):
|
||||||
pass
|
return prompt_response
|
||||||
|
|
||||||
def __call_base(self):
|
def __call_base(self):
|
||||||
input_values = self.generate_input_values()
|
input_values = self.generate_input_values()
|
||||||
### Chat sequence advance
|
### Chat sequence advance
|
||||||
self.current_message.chat_order = len(self.history_message) + 1
|
self.current_message.chat_order = len(self.history_message) + 1
|
||||||
self.current_message.add_user_message(self.current_user_input)
|
self.current_message.add_user_message(self.current_user_input)
|
||||||
self.current_message.start_date = datetime.datetime.now()
|
self.current_message.start_date = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
# TODO
|
# TODO
|
||||||
self.current_message.tokens = 0
|
self.current_message.tokens = 0
|
||||||
current_prompt = None
|
current_prompt = None
|
||||||
@ -203,13 +200,10 @@ class BaseChat(ABC):
|
|||||||
# }"""
|
# }"""
|
||||||
|
|
||||||
self.current_message.add_ai_message(ai_response_text)
|
self.current_message.add_ai_message(ai_response_text)
|
||||||
prompt_define_response = (
|
prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text)
|
||||||
self.prompt_template.output_parser.parse_prompt_response(
|
|
||||||
ai_response_text
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
result = self.do_with_prompt_response(prompt_define_response)
|
|
||||||
|
result = self.do_action(prompt_define_response)
|
||||||
|
|
||||||
if hasattr(prompt_define_response, "thoughts"):
|
if hasattr(prompt_define_response, "thoughts"):
|
||||||
if isinstance(prompt_define_response.thoughts, dict):
|
if isinstance(prompt_define_response.thoughts, dict):
|
||||||
|
0
pilot/scene/chat_dashboard/__init__.py
Normal file
0
pilot/scene/chat_dashboard/__init__.py
Normal file
81
pilot/scene/chat_dashboard/chat.py
Normal file
81
pilot/scene/chat_dashboard/chat.py
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
import json
|
||||||
|
from typing import Dict, NamedTuple, List
|
||||||
|
from pilot.scene.base_message import (
|
||||||
|
HumanMessage,
|
||||||
|
ViewMessage,
|
||||||
|
)
|
||||||
|
from pilot.scene.base_chat import BaseChat
|
||||||
|
from pilot.scene.base import ChatScene
|
||||||
|
from pilot.common.sql_database import Database
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
from pilot.common.markdown_text import (
|
||||||
|
generate_htm_table,
|
||||||
|
)
|
||||||
|
from pilot.scene.chat_db.auto_execute.prompt import prompt
|
||||||
|
from pilot.scene.chat_dashboard.data_preparation.report_schma import ChartData, ReportData
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
|
class ChatDashboard(BaseChat):
|
||||||
|
chat_scene: str = ChatScene.ChatDashboard.value
|
||||||
|
report_name: str
|
||||||
|
"""Number of results to return from the query"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, chat_session_id, db_name, user_input, report_name
|
||||||
|
):
|
||||||
|
""" """
|
||||||
|
super().__init__(
|
||||||
|
chat_mode=ChatScene.ChatWithDbExecute,
|
||||||
|
chat_session_id=chat_session_id,
|
||||||
|
current_user_input=user_input,
|
||||||
|
)
|
||||||
|
if not db_name:
|
||||||
|
raise ValueError(
|
||||||
|
f"{ChatScene.ChatWithDbExecute.value} mode should chose db!"
|
||||||
|
)
|
||||||
|
self.report_name = report_name
|
||||||
|
self.database = CFG.local_db
|
||||||
|
# 准备DB信息(拿到指定库的链接)
|
||||||
|
self.db_connect = self.database.get_session(self.db_name)
|
||||||
|
self.top_k: int = 5
|
||||||
|
|
||||||
|
def generate_input_values(self):
|
||||||
|
try:
|
||||||
|
from pilot.summary.db_summary_client import DBSummaryClient
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError("Could not import DBSummaryClient. ")
|
||||||
|
client = DBSummaryClient()
|
||||||
|
input_values = {
|
||||||
|
"input": self.current_user_input,
|
||||||
|
"dialect": self.database.dialect,
|
||||||
|
"table_info": self.database.table_simple_info(self.db_connect),
|
||||||
|
"supported_chat_type": "" #TODO
|
||||||
|
# "table_info": client.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k)
|
||||||
|
}
|
||||||
|
return input_values
|
||||||
|
|
||||||
|
def do_action(self, prompt_response):
|
||||||
|
### TODO 记录整体信息,处理成功的,和未成功的分开记录处理
|
||||||
|
report_data: ReportData = ReportData()
|
||||||
|
chart_datas: List[ChartData] = []
|
||||||
|
for chart_item in prompt_response:
|
||||||
|
try:
|
||||||
|
datas = self.database.run(self.db_connect, chart_item.sql)
|
||||||
|
chart_data: ChartData = ChartData()
|
||||||
|
except Exception as e:
|
||||||
|
# TODO 修复流程
|
||||||
|
print(str(e))
|
||||||
|
|
||||||
|
|
||||||
|
chart_datas.append(chart_data)
|
||||||
|
|
||||||
|
report_data.conv_uid = self.chat_session_id
|
||||||
|
report_data.template_name = self.report_name
|
||||||
|
report_data.template_introduce = None
|
||||||
|
report_data.charts = chart_datas
|
||||||
|
|
||||||
|
return report_data
|
||||||
|
|
||||||
|
|
22
pilot/scene/chat_dashboard/data_preparation/report_schma.py
Normal file
22
pilot/scene/chat_dashboard/data_preparation/report_schma.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import TypeVar, Union, List, Generic, Any
|
||||||
|
|
||||||
|
|
||||||
|
class ChartData(BaseModel):
|
||||||
|
chart_uid: str
|
||||||
|
chart_type: str
|
||||||
|
chart_sql: str
|
||||||
|
column_name: List
|
||||||
|
values: List
|
||||||
|
style: Any
|
||||||
|
|
||||||
|
|
||||||
|
class ReportData(BaseModel):
|
||||||
|
conv_uid:str
|
||||||
|
template_name:str
|
||||||
|
template_introduce:str
|
||||||
|
charts: List[ChartData]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
41
pilot/scene/chat_dashboard/out_parser.py
Normal file
41
pilot/scene/chat_dashboard/out_parser.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
import json
|
||||||
|
import re
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Dict, NamedTuple, List
|
||||||
|
import pandas as pd
|
||||||
|
from pilot.utils import build_logger
|
||||||
|
from pilot.out_parser.base import BaseOutputParser, T
|
||||||
|
from pilot.configs.model_config import LOGDIR
|
||||||
|
|
||||||
|
|
||||||
|
class ChartItem(NamedTuple):
|
||||||
|
sql: str
|
||||||
|
title:str
|
||||||
|
thoughts: str
|
||||||
|
showcase:str
|
||||||
|
|
||||||
|
|
||||||
|
logger = build_logger("webserver", LOGDIR + "ChatDashboardOutputParser.log")
|
||||||
|
|
||||||
|
|
||||||
|
class ChatDashboardOutputParser(BaseOutputParser):
|
||||||
|
def __init__(self, sep: str, is_stream_out: bool):
|
||||||
|
super().__init__(sep=sep, is_stream_out=is_stream_out)
|
||||||
|
|
||||||
|
def parse_prompt_response(self, model_out_text):
|
||||||
|
clean_str = super().parse_prompt_response(model_out_text)
|
||||||
|
print("clean prompt response:", clean_str)
|
||||||
|
response = json.loads(clean_str)
|
||||||
|
chart_items = List[ChartItem]
|
||||||
|
for item in response:
|
||||||
|
chart_items.append(ChartItem(item["sql"], item["title"], item["thoughts"], item["showcase"]))
|
||||||
|
return chart_items
|
||||||
|
|
||||||
|
def parse_view_response(self, speak, data) -> str:
|
||||||
|
### TODO
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _type(self) -> str:
|
||||||
|
return "chat_dashboard"
|
49
pilot/scene/chat_dashboard/prompt.py
Normal file
49
pilot/scene/chat_dashboard/prompt.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
import json
|
||||||
|
import importlib
|
||||||
|
from pilot.prompts.prompt_new import PromptTemplate
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
from pilot.scene.base import ChatScene
|
||||||
|
from pilot.scene.chat_db.auto_execute.out_parser import DbChatOutputParser, SqlAction
|
||||||
|
from pilot.common.schema import SeparatorStyle
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
|
|
||||||
|
PROMPT_SCENE_DEFINE = """You are a {dialect} data analysis expert, please provide a professional data analysis solution according to the following situations"""
|
||||||
|
PROMPT_SCENE_DEFINE = None
|
||||||
|
|
||||||
|
_DEFAULT_TEMPLATE = """
|
||||||
|
According to the structure definition in the following tables:
|
||||||
|
{table_info}
|
||||||
|
Provide a professional data analysis with as few dimensions as possible, and the upper limit does not exceed 8 dimensions.
|
||||||
|
Used to support goal: {input}
|
||||||
|
|
||||||
|
Use the chart display method in the following range:
|
||||||
|
{supported_chat_type}
|
||||||
|
give {dialect} data analysis SQL, analysis title, display method and analytical thinking,respond in the following json format:
|
||||||
|
{response}
|
||||||
|
Ensure the response is correct json and can be parsed by Python json.loads
|
||||||
|
"""
|
||||||
|
|
||||||
|
RESPONSE_FORMAT = [{
|
||||||
|
"sql": "data analysis SQL",
|
||||||
|
"title": "Data Analysis Title",
|
||||||
|
"showcase": "What type of charts to show",
|
||||||
|
"thoughts": "Current thinking and value of data analysis"
|
||||||
|
}]
|
||||||
|
|
||||||
|
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||||
|
|
||||||
|
PROMPT_NEED_NEED_STREAM_OUT = False
|
||||||
|
|
||||||
|
prompt = PromptTemplate(
|
||||||
|
template_scene=ChatScene.ChatWithDbExecute.value,
|
||||||
|
input_variables=["input", "table_info", "dialect", "supported_chat_type"],
|
||||||
|
response_format=json.dumps(RESPONSE_FORMAT, indent=4),
|
||||||
|
template_define=PROMPT_SCENE_DEFINE,
|
||||||
|
template=_DEFAULT_TEMPLATE,
|
||||||
|
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
|
||||||
|
output_parser=DbChatOutputParser(
|
||||||
|
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
|
||||||
|
),
|
||||||
|
)
|
||||||
|
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
@ -0,0 +1,9 @@
|
|||||||
|
{
|
||||||
|
"title": "Sales Report",
|
||||||
|
"name": "sale_report",
|
||||||
|
"introduce": "",
|
||||||
|
"layout": "TODO",
|
||||||
|
"supported_chart_type":["HeatMap","sheet", "LineChart", "PieChart", "BarChart"],
|
||||||
|
"key_metrics":[],
|
||||||
|
"trends": []
|
||||||
|
}
|
@ -22,12 +22,10 @@ class ChatWithDbAutoExecute(BaseChat):
|
|||||||
"""Number of results to return from the query"""
|
"""Number of results to return from the query"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, temperature, max_new_tokens, chat_session_id, db_name, user_input
|
self, chat_session_id, db_name, user_input
|
||||||
):
|
):
|
||||||
""" """
|
""" """
|
||||||
super().__init__(
|
super().__init__(
|
||||||
temperature=temperature,
|
|
||||||
max_new_tokens=max_new_tokens,
|
|
||||||
chat_mode=ChatScene.ChatWithDbExecute,
|
chat_mode=ChatScene.ChatWithDbExecute,
|
||||||
chat_session_id=chat_session_id,
|
chat_session_id=chat_session_id,
|
||||||
current_user_input=user_input,
|
current_user_input=user_input,
|
||||||
@ -57,5 +55,5 @@ class ChatWithDbAutoExecute(BaseChat):
|
|||||||
}
|
}
|
||||||
return input_values
|
return input_values
|
||||||
|
|
||||||
def do_with_prompt_response(self, prompt_response):
|
def do_action(self, prompt_response):
|
||||||
return self.database.run(self.db_connect, prompt_response.sql)
|
return self.database.run(self.db_connect, prompt_response.sql)
|
||||||
|
@ -20,12 +20,10 @@ class ChatWithDbQA(BaseChat):
|
|||||||
"""Number of results to return from the query"""
|
"""Number of results to return from the query"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, temperature, max_new_tokens, chat_session_id, db_name, user_input
|
self, chat_session_id, db_name, user_input
|
||||||
):
|
):
|
||||||
""" """
|
""" """
|
||||||
super().__init__(
|
super().__init__(
|
||||||
temperature=temperature,
|
|
||||||
max_new_tokens=max_new_tokens,
|
|
||||||
chat_mode=ChatScene.ChatWithDbQA,
|
chat_mode=ChatScene.ChatWithDbQA,
|
||||||
chat_session_id=chat_session_id,
|
chat_session_id=chat_session_id,
|
||||||
current_user_input=user_input,
|
current_user_input=user_input,
|
||||||
@ -66,5 +64,4 @@ class ChatWithDbQA(BaseChat):
|
|||||||
}
|
}
|
||||||
return input_values
|
return input_values
|
||||||
|
|
||||||
def do_with_prompt_response(self, prompt_response):
|
|
||||||
return prompt_response
|
|
||||||
|
@ -22,15 +22,11 @@ class ChatWithPlugin(BaseChat):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
temperature,
|
|
||||||
max_new_tokens,
|
|
||||||
chat_session_id,
|
chat_session_id,
|
||||||
user_input,
|
user_input,
|
||||||
plugin_selector: str = None,
|
plugin_selector: str = None,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
temperature=temperature,
|
|
||||||
max_new_tokens=max_new_tokens,
|
|
||||||
chat_mode=ChatScene.ChatExecution,
|
chat_mode=ChatScene.ChatExecution,
|
||||||
chat_session_id=chat_session_id,
|
chat_session_id=chat_session_id,
|
||||||
current_user_input=user_input,
|
current_user_input=user_input,
|
||||||
@ -66,7 +62,7 @@ class ChatWithPlugin(BaseChat):
|
|||||||
}
|
}
|
||||||
return input_values
|
return input_values
|
||||||
|
|
||||||
def do_with_prompt_response(self, prompt_response):
|
def do_action(self, prompt_response):
|
||||||
## plugin command run
|
## plugin command run
|
||||||
return execute_command(
|
return execute_command(
|
||||||
str(prompt_response.command.get("name")),
|
str(prompt_response.command.get("name")),
|
||||||
|
@ -30,12 +30,10 @@ class ChatNewKnowledge(BaseChat):
|
|||||||
"""Number of results to return from the query"""
|
"""Number of results to return from the query"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, temperature, max_new_tokens, chat_session_id, user_input, knowledge_name
|
self, chat_session_id, user_input, knowledge_name
|
||||||
):
|
):
|
||||||
""" """
|
""" """
|
||||||
super().__init__(
|
super().__init__(
|
||||||
temperature=temperature,
|
|
||||||
max_new_tokens=max_new_tokens,
|
|
||||||
chat_mode=ChatScene.ChatNewKnowledge,
|
chat_mode=ChatScene.ChatNewKnowledge,
|
||||||
chat_session_id=chat_session_id,
|
chat_session_id=chat_session_id,
|
||||||
current_user_input=user_input,
|
current_user_input=user_input,
|
||||||
@ -67,8 +65,6 @@ class ChatNewKnowledge(BaseChat):
|
|||||||
|
|
||||||
return input_values
|
return input_values
|
||||||
|
|
||||||
def do_with_prompt_response(self, prompt_response):
|
|
||||||
return prompt_response
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def chat_type(self) -> str:
|
def chat_type(self) -> str:
|
||||||
|
@ -25,16 +25,14 @@ CFG = Config()
|
|||||||
|
|
||||||
|
|
||||||
class ChatDefaultKnowledge(BaseChat):
|
class ChatDefaultKnowledge(BaseChat):
|
||||||
chat_scene: str = ChatScene.ChatKnowledge.value
|
chat_scene: str = ChatScene.ChatDefaultKnowledge.value
|
||||||
|
|
||||||
"""Number of results to return from the query"""
|
"""Number of results to return from the query"""
|
||||||
|
|
||||||
def __init__(self, temperature, max_new_tokens, chat_session_id, user_input):
|
def __init__(self, chat_session_id, user_input):
|
||||||
""" """
|
""" """
|
||||||
super().__init__(
|
super().__init__(
|
||||||
temperature=temperature,
|
chat_mode=ChatScene.ChatDefaultKnowledge,
|
||||||
max_new_tokens=max_new_tokens,
|
|
||||||
chat_mode=ChatScene.ChatKnowledge,
|
|
||||||
chat_session_id=chat_session_id,
|
chat_session_id=chat_session_id,
|
||||||
current_user_input=user_input,
|
current_user_input=user_input,
|
||||||
)
|
)
|
||||||
@ -61,9 +59,8 @@ class ChatDefaultKnowledge(BaseChat):
|
|||||||
)
|
)
|
||||||
return input_values
|
return input_values
|
||||||
|
|
||||||
def do_with_prompt_response(self, prompt_response):
|
|
||||||
return prompt_response
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def chat_type(self) -> str:
|
def chat_type(self) -> str:
|
||||||
return ChatScene.ChatKnowledge.value
|
return ChatScene.ChatDefaultKnowledge.value
|
||||||
|
@ -39,7 +39,7 @@ PROMPT_SEP = SeparatorStyle.SINGLE.value
|
|||||||
PROMPT_NEED_NEED_STREAM_OUT = True
|
PROMPT_NEED_NEED_STREAM_OUT = True
|
||||||
|
|
||||||
prompt = PromptTemplate(
|
prompt = PromptTemplate(
|
||||||
template_scene=ChatScene.ChatKnowledge.value,
|
template_scene=ChatScene.ChatDefaultKnowledge.value,
|
||||||
input_variables=["context", "question"],
|
input_variables=["context", "question"],
|
||||||
response_format=None,
|
response_format=None,
|
||||||
template_define=PROMPT_SCENE_DEFINE,
|
template_define=PROMPT_SCENE_DEFINE,
|
||||||
|
@ -14,8 +14,6 @@ class InnerChatDBSummary(BaseChat):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
temperature,
|
|
||||||
max_new_tokens,
|
|
||||||
chat_session_id,
|
chat_session_id,
|
||||||
user_input,
|
user_input,
|
||||||
db_select,
|
db_select,
|
||||||
@ -23,8 +21,6 @@ class InnerChatDBSummary(BaseChat):
|
|||||||
):
|
):
|
||||||
""" """
|
""" """
|
||||||
super().__init__(
|
super().__init__(
|
||||||
temperature=temperature,
|
|
||||||
max_new_tokens=max_new_tokens,
|
|
||||||
chat_mode=ChatScene.InnerChatDBSummary,
|
chat_mode=ChatScene.InnerChatDBSummary,
|
||||||
chat_session_id=chat_session_id,
|
chat_session_id=chat_session_id,
|
||||||
current_user_input=user_input,
|
current_user_input=user_input,
|
||||||
@ -40,8 +36,6 @@ class InnerChatDBSummary(BaseChat):
|
|||||||
}
|
}
|
||||||
return input_values
|
return input_values
|
||||||
|
|
||||||
def do_with_prompt_response(self, prompt_response):
|
|
||||||
return prompt_response
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def chat_type(self) -> str:
|
def chat_type(self) -> str:
|
||||||
|
@ -27,11 +27,9 @@ class ChatUrlKnowledge(BaseChat):
|
|||||||
|
|
||||||
"""Number of results to return from the query"""
|
"""Number of results to return from the query"""
|
||||||
|
|
||||||
def __init__(self, temperature, max_new_tokens, chat_session_id, user_input, url):
|
def __init__(self, chat_session_id, user_input, url):
|
||||||
""" """
|
""" """
|
||||||
super().__init__(
|
super().__init__(
|
||||||
temperature=temperature,
|
|
||||||
max_new_tokens=max_new_tokens,
|
|
||||||
chat_mode=ChatScene.ChatUrlKnowledge,
|
chat_mode=ChatScene.ChatUrlKnowledge,
|
||||||
chat_session_id=chat_session_id,
|
chat_session_id=chat_session_id,
|
||||||
current_user_input=user_input,
|
current_user_input=user_input,
|
||||||
@ -62,8 +60,6 @@ class ChatUrlKnowledge(BaseChat):
|
|||||||
input_values = {"context": context, "question": self.current_user_input}
|
input_values = {"context": context, "question": self.current_user_input}
|
||||||
return input_values
|
return input_values
|
||||||
|
|
||||||
def do_with_prompt_response(self, prompt_response):
|
|
||||||
return prompt_response
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def chat_type(self) -> str:
|
def chat_type(self) -> str:
|
||||||
|
@ -18,11 +18,9 @@ class ChatNormal(BaseChat):
|
|||||||
|
|
||||||
"""Number of results to return from the query"""
|
"""Number of results to return from the query"""
|
||||||
|
|
||||||
def __init__(self, temperature, max_new_tokens, chat_session_id, user_input):
|
def __init__(self, chat_session_id, user_input):
|
||||||
""" """
|
""" """
|
||||||
super().__init__(
|
super().__init__(
|
||||||
temperature=temperature,
|
|
||||||
max_new_tokens=max_new_tokens,
|
|
||||||
chat_mode=ChatScene.ChatNormal,
|
chat_mode=ChatScene.ChatNormal,
|
||||||
chat_session_id=chat_session_id,
|
chat_session_id=chat_session_id,
|
||||||
current_user_input=user_input,
|
current_user_input=user_input,
|
||||||
@ -32,7 +30,7 @@ class ChatNormal(BaseChat):
|
|||||||
input_values = {"input": self.current_user_input}
|
input_values = {"input": self.current_user_input}
|
||||||
return input_values
|
return input_values
|
||||||
|
|
||||||
def do_with_prompt_response(self, prompt_response):
|
def do_action(self, prompt_response):
|
||||||
return prompt_response
|
return prompt_response
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -25,7 +25,8 @@ class OnceConversation:
|
|||||||
All the information of a conversation, the current single service in memory, can expand cache and database support distributed services
|
All the information of a conversation, the current single service in memory, can expand cache and database support distributed services
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, chat_mode):
|
||||||
|
self.chat_mode: str = chat_mode
|
||||||
self.messages: List[BaseMessage] = []
|
self.messages: List[BaseMessage] = []
|
||||||
self.start_date: str = ""
|
self.start_date: str = ""
|
||||||
self.chat_order: int = 0
|
self.chat_order: int = 0
|
||||||
@ -43,12 +44,28 @@ class OnceConversation:
|
|||||||
|
|
||||||
def add_ai_message(self, message: str) -> None:
|
def add_ai_message(self, message: str) -> None:
|
||||||
"""Add an AI message to the store"""
|
"""Add an AI message to the store"""
|
||||||
|
|
||||||
has_message = any(isinstance(instance, AIMessage) for instance in self.messages)
|
has_message = any(isinstance(instance, AIMessage) for instance in self.messages)
|
||||||
if has_message:
|
if has_message:
|
||||||
raise ValueError("Already Have Ai message")
|
self.__update_ai_message(message)
|
||||||
|
else:
|
||||||
self.messages.append(AIMessage(content=message))
|
self.messages.append(AIMessage(content=message))
|
||||||
""" """
|
""" """
|
||||||
|
|
||||||
|
def __update_ai_message(self, new_message: str) -> None:
|
||||||
|
"""
|
||||||
|
stream out message update
|
||||||
|
Args:
|
||||||
|
new_message:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
for item in self.messages:
|
||||||
|
if item.type == "ai":
|
||||||
|
item.content = new_message
|
||||||
|
|
||||||
def add_view_message(self, message: str) -> None:
|
def add_view_message(self, message: str) -> None:
|
||||||
"""Add an AI message to the store"""
|
"""Add an AI message to the store"""
|
||||||
|
|
||||||
@ -69,6 +86,13 @@ class OnceConversation:
|
|||||||
self.session_id = None
|
self.session_id = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_user_message(self):
|
||||||
|
for once in self.messages:
|
||||||
|
if isinstance(once, HumanMessage):
|
||||||
|
return once.content
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
def _conversation_to_dic(once: OnceConversation) -> dict:
|
def _conversation_to_dic(once: OnceConversation) -> dict:
|
||||||
start_str: str = ""
|
start_str: str = ""
|
||||||
if once.start_date:
|
if once.start_date:
|
||||||
@ -78,6 +102,7 @@ def _conversation_to_dic(once: OnceConversation) -> dict:
|
|||||||
start_str = once.start_date
|
start_str = once.start_date
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
"chat_mode": once.chat_mode,
|
||||||
"chat_order": once.chat_order,
|
"chat_order": once.chat_order,
|
||||||
"start_date": start_str,
|
"start_date": start_str,
|
||||||
"cost": once.cost if once.cost else 0,
|
"cost": once.cost if once.cost else 0,
|
||||||
@ -93,6 +118,7 @@ def conversations_to_dict(conversations: List[OnceConversation]) -> List[dict]:
|
|||||||
def conversation_from_dict(once: dict) -> OnceConversation:
|
def conversation_from_dict(once: dict) -> OnceConversation:
|
||||||
conversation = OnceConversation()
|
conversation = OnceConversation()
|
||||||
conversation.cost = once.get("cost", 0)
|
conversation.cost = once.get("cost", 0)
|
||||||
|
conversation.chat_mode = once.get("chat_mode", "chat_normal")
|
||||||
conversation.tokens = once.get("tokens", 0)
|
conversation.tokens = once.get("tokens", 0)
|
||||||
conversation.start_date = once.get("start_date", "")
|
conversation.start_date = once.get("start_date", "")
|
||||||
conversation.chat_order = int(once.get("chat_order"))
|
conversation.chat_order = int(once.get("chat_order"))
|
||||||
|
@ -1,15 +1,18 @@
|
|||||||
import uuid
|
import uuid
|
||||||
|
import json
|
||||||
from fastapi import APIRouter, Request, Body, status
|
import asyncio
|
||||||
|
import time
|
||||||
|
from fastapi import APIRouter, Request, Body, status, HTTPException, Response
|
||||||
|
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from fastapi.encoders import jsonable_encoder
|
from fastapi.encoders import jsonable_encoder
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
|
from sse_starlette.sse import EventSourceResponse
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from pilot.server.api_v1.api_view_model import Result, ConversationVo, MessageVo
|
from pilot.server.api_v1.api_view_model import Result, ConversationVo, MessageVo, ChatSceneVo
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.scene.base_chat import BaseChat
|
from pilot.scene.base_chat import BaseChat
|
||||||
from pilot.scene.base import ChatScene
|
from pilot.scene.base import ChatScene
|
||||||
@ -17,6 +20,8 @@ from pilot.scene.chat_factory import ChatFactory
|
|||||||
from pilot.configs.model_config import (LOGDIR)
|
from pilot.configs.model_config import (LOGDIR)
|
||||||
from pilot.utils import build_logger
|
from pilot.utils import build_logger
|
||||||
from pilot.scene.base_message import (BaseMessage)
|
from pilot.scene.base_message import (BaseMessage)
|
||||||
|
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
|
||||||
|
from pilot.scene.message import OnceConversation
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
@ -28,32 +33,117 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE
|
|||||||
message = ""
|
message = ""
|
||||||
for error in exc.errors():
|
for error in exc.errors():
|
||||||
message += ".".join(error.get("loc")) + ":" + error.get("msg") + ";"
|
message += ".".join(error.get("loc")) + ":" + error.get("msg") + ";"
|
||||||
return Result.faild(message)
|
return Result.faild(msg=message)
|
||||||
|
|
||||||
|
|
||||||
@router.get('/v1/chat/dialogue/list', response_model=Result[List[ConversationVo]])
|
def __get_conv_user_message(conversations: dict):
|
||||||
async def dialogue_list(user_id: str):
|
messages = conversations['messages']
|
||||||
#### TODO
|
for item in messages:
|
||||||
|
if item['type'] == "human":
|
||||||
conversations = [ConversationVo(conv_uid="123", chat_mode="user", select_param="test1", user_input="message[0]"),
|
return item['data']['content']
|
||||||
ConversationVo(conv_uid="123", chat_mode="user", select_param="test1", user_input="message[0]")]
|
return ""
|
||||||
|
|
||||||
return Result[ConversationVo].succ(conversations)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post('/v1/chat/dialogue/new', response_model=Result[str])
|
@router.get('/v1/chat/dialogue/list', response_model=Result[ConversationVo])
|
||||||
async def dialogue_new(user_id: str):
|
async def dialogue_list(response: Response, user_id: str = None):
|
||||||
|
# 设置CORS头部信息
|
||||||
|
response.headers['Access-Control-Allow-Origin'] = '*'
|
||||||
|
response.headers['Access-Control-Allow-Methods'] = 'GET'
|
||||||
|
response.headers['Access-Control-Request-Headers'] = 'content-type'
|
||||||
|
|
||||||
|
dialogues: List = []
|
||||||
|
datas = DuckdbHistoryMemory.conv_list(user_id)
|
||||||
|
|
||||||
|
for item in datas:
|
||||||
|
conv_uid = item.get("conv_uid")
|
||||||
|
messages = item.get("messages")
|
||||||
|
conversations = json.loads(messages)
|
||||||
|
|
||||||
|
first_conv: OnceConversation = conversations[0]
|
||||||
|
conv_vo: ConversationVo = ConversationVo(conv_uid=conv_uid, user_input=__get_conv_user_message(first_conv),
|
||||||
|
chat_mode=first_conv['chat_mode'])
|
||||||
|
dialogues.append(conv_vo)
|
||||||
|
|
||||||
|
return Result[ConversationVo].succ(dialogues)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post('/v1/chat/dialogue/scenes', response_model=Result[List[ChatSceneVo]])
|
||||||
|
async def dialogue_scenes():
|
||||||
|
scene_vos: List[ChatSceneVo] = []
|
||||||
|
new_modes:List[ChatScene] = [ChatScene.ChatDb, ChatScene.ChatData, ChatScene.ChatDashboard, ChatScene.ChatKnowledge, ChatScene.ChatExecution]
|
||||||
|
for scene in new_modes:
|
||||||
|
if not scene.value in [ChatScene.ChatNormal.value, ChatScene.InnerChatDBSummary.value]:
|
||||||
|
scene_vo = ChatSceneVo(chat_scene=scene.value, scene_name=scene.name, param_title="Selection Param")
|
||||||
|
scene_vos.append(scene_vo)
|
||||||
|
return Result.succ(scene_vos)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post('/v1/chat/dialogue/new', response_model=Result[ConversationVo])
|
||||||
|
async def dialogue_new(chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None):
|
||||||
unique_id = uuid.uuid1()
|
unique_id = uuid.uuid1()
|
||||||
return Result.succ(unique_id)
|
return Result.succ(ConversationVo(conv_uid=str(unique_id), chat_mode=chat_mode))
|
||||||
|
|
||||||
|
|
||||||
@router.post('/v1/chat/dialogue/delete')
|
def get_db_list():
|
||||||
async def dialogue_delete(con_uid: str, user_id: str):
|
db = CFG.local_db
|
||||||
#### TODO
|
dbs = db.get_database_list()
|
||||||
|
params:dict = {}
|
||||||
|
for name in dbs:
|
||||||
|
params.update({name: name})
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def plugins_select_info():
|
||||||
|
plugins_infos: dict = {}
|
||||||
|
for plugin in CFG.plugins:
|
||||||
|
plugins_infos.update({f"【{plugin._name}】=>{plugin._description}": plugin._name})
|
||||||
|
return plugins_infos
|
||||||
|
|
||||||
|
|
||||||
|
def knowledge_list():
|
||||||
|
knowledge: dict = {}
|
||||||
|
### TODO
|
||||||
|
return knowledge
|
||||||
|
|
||||||
|
|
||||||
|
@router.post('/v1/chat/mode/params/list', response_model=Result[dict])
|
||||||
|
async def params_list(chat_mode: str = ChatScene.ChatNormal.value):
|
||||||
|
if ChatScene.ChatDb.value == chat_mode:
|
||||||
|
return Result.succ(get_db_list())
|
||||||
|
elif ChatScene.ChatData.value == chat_mode:
|
||||||
|
return Result.succ(get_db_list())
|
||||||
|
elif ChatScene.ChatDashboard.value == chat_mode:
|
||||||
|
return Result.succ(get_db_list())
|
||||||
|
elif ChatScene.ChatExecution.value == chat_mode:
|
||||||
|
return Result.succ(plugins_select_info())
|
||||||
|
elif ChatScene.ChatKnowledge.value == chat_mode:
|
||||||
|
return Result.succ(knowledge_list())
|
||||||
|
else:
|
||||||
return Result.succ(None)
|
return Result.succ(None)
|
||||||
|
|
||||||
|
|
||||||
@router.post('/v1/chat/completions', response_model=Result[MessageVo])
|
@router.post('/v1/chat/dialogue/delete')
|
||||||
|
async def dialogue_delete(con_uid: str):
|
||||||
|
history_mem = DuckdbHistoryMemory(con_uid)
|
||||||
|
history_mem.delete()
|
||||||
|
return Result.succ(None)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get('/v1/chat/dialogue/messages/history', response_model=Result[MessageVo])
|
||||||
|
async def dialogue_history_messages(con_uid: str):
|
||||||
|
print(f"dialogue_history_messages:{con_uid}")
|
||||||
|
message_vos: List[MessageVo] = []
|
||||||
|
|
||||||
|
history_mem = DuckdbHistoryMemory(con_uid)
|
||||||
|
history_messages: List[OnceConversation] = history_mem.get_messages()
|
||||||
|
if history_messages:
|
||||||
|
for once in history_messages:
|
||||||
|
once_message_vos = [message2Vo(element, once['chat_order']) for element in once['messages']]
|
||||||
|
message_vos.extend(once_message_vos)
|
||||||
|
return Result.succ(message_vos)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post('/v1/chat/completions')
|
||||||
async def chat_completions(dialogue: ConversationVo = Body()):
|
async def chat_completions(dialogue: ConversationVo = Body()):
|
||||||
print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param}")
|
print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param}")
|
||||||
|
|
||||||
@ -65,22 +155,31 @@ async def chat_completions(dialogue: ConversationVo = Body()):
|
|||||||
"user_input": dialogue.user_input,
|
"user_input": dialogue.user_input,
|
||||||
}
|
}
|
||||||
|
|
||||||
if ChatScene.ChatWithDbExecute == dialogue.chat_mode:
|
if ChatScene.ChatDb == dialogue.chat_mode:
|
||||||
chat_param.update("db_name", dialogue.select_param)
|
chat_param.update("db_name", dialogue.select_param)
|
||||||
elif ChatScene.ChatWithDbQA == dialogue.chat_mode:
|
elif ChatScene.ChatData == dialogue.chat_mode:
|
||||||
|
chat_param.update("db_name", dialogue.select_param)
|
||||||
|
elif ChatScene.ChatDashboard == dialogue.chat_mode:
|
||||||
chat_param.update("db_name", dialogue.select_param)
|
chat_param.update("db_name", dialogue.select_param)
|
||||||
elif ChatScene.ChatExecution == dialogue.chat_mode:
|
elif ChatScene.ChatExecution == dialogue.chat_mode:
|
||||||
chat_param.update("plugin_selector", dialogue.select_param)
|
chat_param.update("plugin_selector", dialogue.select_param)
|
||||||
elif ChatScene.ChatNewKnowledge == dialogue.chat_mode:
|
elif ChatScene.ChatKnowledge == dialogue.chat_mode:
|
||||||
chat_param.update("knowledge_name", dialogue.select_param)
|
chat_param.update("knowledge_name", dialogue.select_param)
|
||||||
elif ChatScene.ChatUrlKnowledge == dialogue.chat_mode:
|
|
||||||
chat_param.update("url", dialogue.select_param)
|
|
||||||
|
|
||||||
chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param)
|
chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param)
|
||||||
if not chat.prompt_template.stream_out:
|
if not chat.prompt_template.stream_out:
|
||||||
return non_stream_response(chat)
|
return non_stream_response(chat)
|
||||||
else:
|
else:
|
||||||
return stream_response(chat)
|
# generator = stream_generator(chat)
|
||||||
|
# result = Result.succ(data=StreamingResponse(stream_test(), media_type='text/plain'))
|
||||||
|
# return result
|
||||||
|
return StreamingResponse(stream_generator(chat), media_type="text/plain")
|
||||||
|
|
||||||
|
|
||||||
|
def stream_test():
|
||||||
|
for message in ["Hello", "world", "how", "are", "you"]:
|
||||||
|
yield message
|
||||||
|
# yield json.dumps(Result.succ(message).__dict__).encode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
def stream_generator(chat):
|
def stream_generator(chat):
|
||||||
@ -89,24 +188,28 @@ def stream_generator(chat):
|
|||||||
if chunk:
|
if chunk:
|
||||||
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
|
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
|
||||||
chat.current_message.add_ai_message(msg)
|
chat.current_message.add_ai_message(msg)
|
||||||
messageVos = [message2Vo(element) for element in chat.current_message.messages]
|
yield msg
|
||||||
yield Result.succ(messageVos)
|
# chat.current_message.add_ai_message(msg)
|
||||||
def stream_response(chat):
|
# vo = MessageVo(role="view", context=msg, order=chat.current_message.chat_order)
|
||||||
logger.info("stream out start!")
|
# json_text = json.dumps(vo.__dict__)
|
||||||
api_response = StreamingResponse(stream_generator(chat), media_type="application/json")
|
# yield json_text.encode('utf-8')
|
||||||
return api_response
|
chat.memory.append(chat.current_message)
|
||||||
|
|
||||||
|
|
||||||
|
# def stream_response(chat):
|
||||||
|
# logger.info("stream out start!")
|
||||||
|
# api_response = StreamingResponse(stream_generator(chat), media_type="application/json")
|
||||||
|
# return api_response
|
||||||
|
|
||||||
|
|
||||||
|
def message2Vo(message: dict, order) -> MessageVo:
|
||||||
|
# message.additional_kwargs['time_stamp'] if message.additional_kwargs["time_stamp"] else 0
|
||||||
|
return MessageVo(role=message['type'], context=message['data']['content'], order=order)
|
||||||
|
|
||||||
def message2Vo(message:BaseMessage)->MessageVo:
|
|
||||||
vo:MessageVo = MessageVo()
|
|
||||||
vo.role = message.type
|
|
||||||
vo.role = message.content
|
|
||||||
vo.time_stamp = message.additional_kwargs.time_stamp if message.additional_kwargs["time_stamp"] else 0
|
|
||||||
|
|
||||||
def non_stream_response(chat):
|
def non_stream_response(chat):
|
||||||
logger.info("not stream out, wait model response!")
|
logger.info("not stream out, wait model response!")
|
||||||
chat.nostream_call()
|
return chat.nostream_call()
|
||||||
messageVos = [message2Vo(element) for element in chat.current_message.messages]
|
|
||||||
return Result.succ(messageVos)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get('/v1/db/types', response_model=Result[str])
|
@router.get('/v1/db/types', response_model=Result[str])
|
||||||
|
@ -1,28 +1,33 @@
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing import TypeVar, Union, List, Generic
|
from typing import TypeVar, Union, List, Generic, Any
|
||||||
|
|
||||||
T = TypeVar('T')
|
T = TypeVar('T')
|
||||||
|
|
||||||
|
|
||||||
class Result(Generic[T], BaseModel):
|
class Result(Generic[T], BaseModel):
|
||||||
success: bool
|
success: bool
|
||||||
err_code: str
|
err_code: str = None
|
||||||
err_msg: str
|
err_msg: str = None
|
||||||
data: List[T]
|
data: T = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def succ(cls, data: List[T]):
|
def succ(cls, data: T):
|
||||||
return Result(True, None, None, data)
|
return Result(success=True, err_code=None, err_msg=None, data=data)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def faild(cls, msg):
|
def faild(cls, msg):
|
||||||
return Result(True, "E000X", msg, None)
|
return Result(success=False, err_code="E000X", err_msg=msg, data=None)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def faild(cls, code, msg):
|
def faild(cls, code, msg):
|
||||||
return Result(True, code, msg, None)
|
return Result(success=False, err_code=code, err_msg=msg, data=None)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatSceneVo(BaseModel):
|
||||||
|
chat_scene: str = Field(..., description="chat_scene")
|
||||||
|
scene_name: str = Field(..., description="chat_scene name show for user")
|
||||||
|
param_title: str = Field(..., description="chat_scene required parameter title")
|
||||||
|
|
||||||
class ConversationVo(BaseModel):
|
class ConversationVo(BaseModel):
|
||||||
"""
|
"""
|
||||||
dialogue_uid
|
dialogue_uid
|
||||||
@ -31,15 +36,21 @@ class ConversationVo(BaseModel):
|
|||||||
"""
|
"""
|
||||||
user input
|
user input
|
||||||
"""
|
"""
|
||||||
user_input: str
|
user_input: str = ""
|
||||||
|
"""
|
||||||
|
user
|
||||||
|
"""
|
||||||
|
user_name: str = ""
|
||||||
"""
|
"""
|
||||||
the scene of chat
|
the scene of chat
|
||||||
"""
|
"""
|
||||||
chat_mode: str = Field(..., description="the scene of chat ")
|
chat_mode: str = Field(..., description="the scene of chat ")
|
||||||
|
|
||||||
"""
|
"""
|
||||||
chat scene select param
|
chat scene select param
|
||||||
"""
|
"""
|
||||||
select_param: str
|
select_param: str = None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class MessageVo(BaseModel):
|
class MessageVo(BaseModel):
|
||||||
@ -51,7 +62,12 @@ class MessageVo(BaseModel):
|
|||||||
current message
|
current message
|
||||||
"""
|
"""
|
||||||
context: str
|
context: str
|
||||||
|
|
||||||
|
""" message postion order """
|
||||||
|
order: int
|
||||||
|
|
||||||
"""
|
"""
|
||||||
time the current message was sent
|
time the current message was sent
|
||||||
"""
|
"""
|
||||||
time_stamp: float
|
time_stamp: Any = None
|
||||||
|
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import signal
|
|
||||||
import threading
|
import threading
|
||||||
import traceback
|
import traceback
|
||||||
import argparse
|
import argparse
|
||||||
@ -12,12 +11,10 @@ import uuid
|
|||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
|
|
||||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
sys.path.append(ROOT_PATH)
|
sys.path.append(ROOT_PATH)
|
||||||
|
|
||||||
from pilot.summary.db_summary_client import DBSummaryClient
|
from pilot.summary.db_summary_client import DBSummaryClient
|
||||||
from pilot.commands.command_mange import CommandRegistry
|
|
||||||
|
|
||||||
from pilot.scene.base_chat import BaseChat
|
from pilot.scene.base_chat import BaseChat
|
||||||
|
|
||||||
@ -25,8 +22,8 @@ from pilot.configs.config import Config
|
|||||||
from pilot.configs.model_config import (
|
from pilot.configs.model_config import (
|
||||||
DATASETS_DIR,
|
DATASETS_DIR,
|
||||||
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||||
LOGDIR,
|
|
||||||
LLM_MODEL_CONFIG,
|
LLM_MODEL_CONFIG,
|
||||||
|
LOGDIR,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pilot.conversation import (
|
from pilot.conversation import (
|
||||||
@ -35,11 +32,10 @@ from pilot.conversation import (
|
|||||||
chat_mode_title,
|
chat_mode_title,
|
||||||
default_conversation,
|
default_conversation,
|
||||||
)
|
)
|
||||||
from pilot.common.plugins import scan_plugins, load_native_plugins
|
|
||||||
|
|
||||||
from pilot.server.gradio_css import code_highlight_css
|
from pilot.server.gradio_css import code_highlight_css
|
||||||
from pilot.server.gradio_patch import Chatbot as grChatbot
|
from pilot.server.gradio_patch import Chatbot as grChatbot
|
||||||
from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding
|
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
|
||||||
from pilot.utils import build_logger
|
from pilot.utils import build_logger
|
||||||
from pilot.vector_store.extract_tovec import (
|
from pilot.vector_store.extract_tovec import (
|
||||||
get_vector_storelist,
|
get_vector_storelist,
|
||||||
@ -49,6 +45,20 @@ from pilot.vector_store.extract_tovec import (
|
|||||||
from pilot.scene.base import ChatScene
|
from pilot.scene.base import ChatScene
|
||||||
from pilot.scene.chat_factory import ChatFactory
|
from pilot.scene.chat_factory import ChatFactory
|
||||||
from pilot.language.translation_handler import get_lang_text
|
from pilot.language.translation_handler import get_lang_text
|
||||||
|
from pilot.server.webserver_base import server_init
|
||||||
|
|
||||||
|
|
||||||
|
import uvicorn
|
||||||
|
from fastapi import BackgroundTasks, Request
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from fastapi import FastAPI, applications
|
||||||
|
from fastapi.openapi.docs import get_swagger_ui_html
|
||||||
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
|
||||||
|
from pilot.server.api_v1.api_v1 import router as api_v1, validation_exception_handler
|
||||||
|
|
||||||
# 加载插件
|
# 加载插件
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
@ -95,6 +105,30 @@ knowledge_qa_type_list = [
|
|||||||
add_knowledge_base_dialogue,
|
add_knowledge_base_dialogue,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def swagger_monkey_patch(*args, **kwargs):
|
||||||
|
return get_swagger_ui_html(
|
||||||
|
*args, **kwargs,
|
||||||
|
swagger_js_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js',
|
||||||
|
swagger_css_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css'
|
||||||
|
)
|
||||||
|
applications.get_swagger_ui_html = swagger_monkey_patch
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
origins = ["*"]
|
||||||
|
|
||||||
|
# 添加跨域中间件
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=origins,
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# app.mount("static", StaticFiles(directory="static"), name="static")
|
||||||
|
app.include_router(api_v1)
|
||||||
|
app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
||||||
|
|
||||||
|
|
||||||
def get_simlar(q):
|
def get_simlar(q):
|
||||||
docsearch = knownledge_tovec_st(os.path.join(DATASETS_DIR, "plan.md"))
|
docsearch = knownledge_tovec_st(os.path.join(DATASETS_DIR, "plan.md"))
|
||||||
@ -216,7 +250,7 @@ def get_chat_mode(selected, param=None) -> ChatScene:
|
|||||||
else:
|
else:
|
||||||
mode = param
|
mode = param
|
||||||
if mode == conversation_types["default_knownledge"]:
|
if mode == conversation_types["default_knownledge"]:
|
||||||
return ChatScene.ChatKnowledge
|
return ChatScene.ChatDefaultKnowledge
|
||||||
elif mode == conversation_types["custome"]:
|
elif mode == conversation_types["custome"]:
|
||||||
return ChatScene.ChatNewKnowledge
|
return ChatScene.ChatNewKnowledge
|
||||||
elif mode == conversation_types["url"]:
|
elif mode == conversation_types["url"]:
|
||||||
@ -286,7 +320,7 @@ def http_bot(
|
|||||||
"chat_session_id": state.conv_id,
|
"chat_session_id": state.conv_id,
|
||||||
"user_input": state.last_user_input,
|
"user_input": state.last_user_input,
|
||||||
}
|
}
|
||||||
elif ChatScene.ChatKnowledge == scene:
|
elif ChatScene.ChatDefaultKnowledge == scene:
|
||||||
chat_param = {
|
chat_param = {
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
"max_new_tokens": max_new_tokens,
|
"max_new_tokens": max_new_tokens,
|
||||||
@ -324,15 +358,14 @@ def http_bot(
|
|||||||
response = chat.stream_call()
|
response = chat.stream_call()
|
||||||
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||||
if chunk:
|
if chunk:
|
||||||
state.messages[-1][
|
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
|
||||||
-1
|
state.messages[-1][-1] =msg
|
||||||
] = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
|
chat.current_message.add_ai_message(msg)
|
||||||
chunk, chat.skip_echo_len
|
|
||||||
)
|
|
||||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||||
|
chat.memory.append(chat.current_message)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
state.messages[-1][-1] = "Error:" + str(e)
|
state.messages[-1][-1] = f"""<span style=\"color:red\">ERROR!</span>{str(e)} """
|
||||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||||
|
|
||||||
|
|
||||||
@ -632,7 +665,7 @@ def knowledge_embedding_store(vs_id, files):
|
|||||||
)
|
)
|
||||||
knowledge_embedding_client = KnowledgeEmbedding(
|
knowledge_embedding_client = KnowledgeEmbedding(
|
||||||
file_path=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename),
|
file_path=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename),
|
||||||
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
model_name=LLM_MODEL_CONFIG["text2vec"],
|
||||||
vector_store_config={
|
vector_store_config={
|
||||||
"vector_store_name": vector_store_name["vs_name"],
|
"vector_store_name": vector_store_name["vs_name"],
|
||||||
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||||
@ -657,42 +690,25 @@ def signal_handler(sig, frame):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--model_list_mode", type=str, default="once", choices=["once", "reload"])
|
||||||
|
parser.add_argument('-new', '--new', action='store_true', help='enable new http mode')
|
||||||
|
|
||||||
|
# old version server config
|
||||||
parser.add_argument("--host", type=str, default="0.0.0.0")
|
parser.add_argument("--host", type=str, default="0.0.0.0")
|
||||||
parser.add_argument("--port", type=int, default=CFG.WEB_SERVER_PORT)
|
parser.add_argument("--port", type=int, default=CFG.WEB_SERVER_PORT)
|
||||||
parser.add_argument("--concurrency-count", type=int, default=10)
|
parser.add_argument("--concurrency-count", type=int, default=10)
|
||||||
parser.add_argument(
|
|
||||||
"--model-list-mode", type=str, default="once", choices=["once", "reload"]
|
|
||||||
)
|
|
||||||
parser.add_argument("--share", default=False, action="store_true")
|
parser.add_argument("--share", default=False, action="store_true")
|
||||||
|
|
||||||
|
|
||||||
|
# init server config
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
logger.info(f"args: {args}")
|
server_init(args)
|
||||||
|
|
||||||
# init config
|
if args.new:
|
||||||
cfg = Config()
|
import uvicorn
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=5000)
|
||||||
load_native_plugins(cfg)
|
else:
|
||||||
dbs = cfg.local_db.get_database_list()
|
### Compatibility mode starts the old version server by default
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
|
||||||
async_db_summery()
|
|
||||||
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
|
|
||||||
|
|
||||||
# Loader plugins and commands
|
|
||||||
command_categories = [
|
|
||||||
"pilot.commands.built_in.audio_text",
|
|
||||||
"pilot.commands.built_in.image_gen",
|
|
||||||
]
|
|
||||||
# exclude commands
|
|
||||||
command_categories = [
|
|
||||||
x for x in command_categories if x not in cfg.disabled_command_categories
|
|
||||||
]
|
|
||||||
command_registry = CommandRegistry()
|
|
||||||
for command_category in command_categories:
|
|
||||||
command_registry.import_commands(command_category)
|
|
||||||
|
|
||||||
cfg.command_registry = command_registry
|
|
||||||
|
|
||||||
logger.info(args)
|
|
||||||
demo = build_webdemo()
|
demo = build_webdemo()
|
||||||
demo.queue(
|
demo.queue(
|
||||||
concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False
|
concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False
|
||||||
@ -702,3 +718,8 @@ if __name__ == "__main__":
|
|||||||
share=args.share,
|
share=args.share,
|
||||||
max_threads=200,
|
max_threads=200,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user