mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-15 06:53:12 +00:00
WEB API independent
This commit is contained in:
parent
9d3000fb26
commit
3f7cc02426
@ -1,4 +1,5 @@
|
||||
"""Utilities for formatting strings."""
|
||||
import json
|
||||
from string import Formatter
|
||||
from typing import Any, List, Mapping, Sequence, Union
|
||||
|
||||
@ -36,3 +37,13 @@ class StrictFormatter(Formatter):
|
||||
|
||||
|
||||
formatter = StrictFormatter()
|
||||
|
||||
|
||||
class MyEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, set):
|
||||
return list(obj)
|
||||
elif hasattr(obj, '__dict__'):
|
||||
return obj.__dict__
|
||||
else:
|
||||
return json.JSONEncoder.default(self, obj)
|
@ -6,6 +6,8 @@ import numpy as np
|
||||
from matplotlib.font_manager import FontProperties
|
||||
from pyecharts.charts import Bar
|
||||
from pyecharts import options as opts
|
||||
from test_cls_1 import TestBase,Test1
|
||||
from test_cls_2 import Test2
|
||||
|
||||
CFG = Config()
|
||||
|
||||
@ -56,20 +58,29 @@ CFG = Config()
|
||||
#
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
|
||||
# def __extract_json(s):
|
||||
# i = s.index("{")
|
||||
# count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
|
||||
# for j, c in enumerate(s[i + 1 :], start=i + 1):
|
||||
# if c == "}":
|
||||
# count -= 1
|
||||
# elif c == "{":
|
||||
# count += 1
|
||||
# if count == 0:
|
||||
# break
|
||||
# assert count == 0 # 检查是否找到最后一个'}'
|
||||
# return s[i : j + 1]
|
||||
#
|
||||
# ss = """here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:select u.city, count(*) as order_countfrom tran_order oleft join user u on o.user_id = u.idgroup by u.city;this will return the number of orders for each city that has at least one order. we can use this data to generate a histogram that shows the distribution of orders across different cities.here's the response in the required format:{ "thoughts": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:\n\nselect u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;", "speak": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities.", "command": { "name": "histogram-executor", "args": { "title": "distribution of user orders in different cities", "sql": "select u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;" } }}"""
|
||||
# print(__extract_json(ss))
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def __extract_json(s):
|
||||
i = s.index("{")
|
||||
count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
|
||||
for j, c in enumerate(s[i + 1 :], start=i + 1):
|
||||
if c == "}":
|
||||
count -= 1
|
||||
elif c == "{":
|
||||
count += 1
|
||||
if count == 0:
|
||||
break
|
||||
assert count == 0 # 检查是否找到最后一个'}'
|
||||
return s[i : j + 1]
|
||||
|
||||
ss = """here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:select u.city, count(*) as order_countfrom tran_order oleft join user u on o.user_id = u.idgroup by u.city;this will return the number of orders for each city that has at least one order. we can use this data to generate a histogram that shows the distribution of orders across different cities.here's the response in the required format:{ "thoughts": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:\n\nselect u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;", "speak": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities.", "command": { "name": "histogram-executor", "args": { "title": "distribution of user orders in different cities", "sql": "select u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;" } }}"""
|
||||
print(__extract_json(ss))
|
||||
test1 = Test1()
|
||||
test2 = Test2()
|
||||
test1.write()
|
||||
test1.test()
|
||||
test2.write()
|
||||
test1.test()
|
||||
test2.test()
|
@ -4,7 +4,7 @@ from test_cls_base import TestBase
|
||||
|
||||
|
||||
class Test1(TestBase):
|
||||
|
||||
mode:str = "456"
|
||||
def write(self):
|
||||
self.test_values.append("x")
|
||||
self.test_values.append("y")
|
||||
|
@ -5,7 +5,7 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
|
||||
|
||||
class Test2(TestBase):
|
||||
test_2_values:List = []
|
||||
|
||||
mode:str = "789"
|
||||
def write(self):
|
||||
self.test_values.append(1)
|
||||
self.test_values.append(2)
|
||||
|
@ -5,8 +5,9 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
|
||||
|
||||
class TestBase(BaseModel, ABC):
|
||||
test_values: List = []
|
||||
|
||||
mode:str = "123"
|
||||
|
||||
def test(self):
|
||||
print(self.__class__.__name__ + ":" )
|
||||
print(self.test_values)
|
||||
print(self.mode)
|
13
pilot/connections/rdbms/py_study/test_duckdb.py
Normal file
13
pilot/connections/rdbms/py_study/test_duckdb.py
Normal file
@ -0,0 +1,13 @@
|
||||
import json
|
||||
import os
|
||||
import duckdb
|
||||
|
||||
default_db_path = os.path.join(os.getcwd(), "message")
|
||||
duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/chat_history.db")
|
||||
|
||||
if __name__ == "__main__":
|
||||
if os.path.isfile(duckdb_path):
|
||||
cursor = duckdb.connect(duckdb_path).cursor()
|
||||
cursor.execute("SELECT * FROM chat_history limit 20")
|
||||
data = cursor.fetchall()
|
||||
print(data)
|
@ -32,3 +32,9 @@ class BaseChatHistoryMemory(ABC):
|
||||
@abstractmethod
|
||||
def clear(self) -> None:
|
||||
"""Clear session memory from the local file"""
|
||||
|
||||
|
||||
def conv_list(self, user_name:str=None) -> None:
|
||||
"""get user's conversation list"""
|
||||
pass
|
||||
|
||||
|
109
pilot/memory/chat_history/duckdb_history.py
Normal file
109
pilot/memory/chat_history/duckdb_history.py
Normal file
@ -0,0 +1,109 @@
|
||||
import json
|
||||
import os
|
||||
import duckdb
|
||||
from typing import List
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.memory.chat_history.base import BaseChatHistoryMemory
|
||||
from pilot.scene.message import (
|
||||
OnceConversation,
|
||||
conversation_from_dict,
|
||||
conversations_to_dict,
|
||||
)
|
||||
from pilot.common.formatting import MyEncoder
|
||||
|
||||
default_db_path = os.path.join(os.getcwd(), "message")
|
||||
duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/chat_history.db")
|
||||
table_name = 'chat_history'
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
||||
|
||||
def __init__(self, chat_session_id: str):
|
||||
self.chat_seesion_id = chat_session_id
|
||||
os.makedirs(default_db_path, exist_ok=True)
|
||||
self.connect = duckdb.connect(duckdb_path)
|
||||
self.__init_chat_history_tables()
|
||||
|
||||
def __init_chat_history_tables(self):
|
||||
|
||||
# 检查表是否存在
|
||||
result = self.connect.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?",
|
||||
[table_name]).fetchall()
|
||||
|
||||
if not result:
|
||||
# 如果表不存在,则创建新表
|
||||
self.connect.execute(
|
||||
"CREATE TABLE chat_history (conv_uid VARCHAR(100) PRIMARY KEY, user_name VARCHAR(100), messages TEXT)")
|
||||
|
||||
def __get_messages_by_conv_uid(self, conv_uid: str):
|
||||
cursor = self.connect.cursor()
|
||||
cursor.execute("SELECT messages FROM chat_history where conv_uid=?", [conv_uid])
|
||||
return cursor.fetchone()
|
||||
|
||||
def messages(self) -> List[OnceConversation]:
|
||||
context = self.__get_messages_by_conv_uid(self.chat_seesion_id)
|
||||
if context:
|
||||
conversations: List[OnceConversation] = json.loads(context[0])
|
||||
return conversations
|
||||
return []
|
||||
|
||||
def append(self, once_message: OnceConversation) -> None:
|
||||
context = self.__get_messages_by_conv_uid(self.chat_seesion_id)
|
||||
conversations: List[OnceConversation] = []
|
||||
if context:
|
||||
conversations = json.load(context)
|
||||
conversations.append(once_message)
|
||||
cursor = self.connect.cursor()
|
||||
if context:
|
||||
cursor.execute("UPDATE chat_history set messages=? where conv_uid=?",
|
||||
[json.dumps(conversations_to_dict(conversations), ensure_ascii=False, indent=4), self.chat_seesion_id])
|
||||
else:
|
||||
cursor.execute("INSERT INTO chat_history(conv_uid, user_name, messages)VALUES(?,?,?)",
|
||||
[self.chat_seesion_id, "", json.dumps(conversations_to_dict(conversations), ensure_ascii=False, indent=4)])
|
||||
cursor.commit()
|
||||
self.connect.commit()
|
||||
|
||||
def clear(self) -> None:
|
||||
cursor = self.connect.cursor()
|
||||
cursor.execute("DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id])
|
||||
cursor.commit()
|
||||
self.connect.commit()
|
||||
|
||||
def delete(self) -> bool:
|
||||
cursor = self.connect.cursor()
|
||||
cursor.execute("DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id])
|
||||
cursor.commit()
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def conv_list(cls, user_name: str = None) -> None:
|
||||
if os.path.isfile(duckdb_path):
|
||||
cursor = duckdb.connect(duckdb_path).cursor()
|
||||
if user_name:
|
||||
cursor.execute("SELECT * FROM chat_history where user_name=? limit 20", [user_name])
|
||||
else:
|
||||
cursor.execute("SELECT * FROM chat_history limit 20")
|
||||
# 获取查询结果字段名
|
||||
fields = [field[0] for field in cursor.description]
|
||||
data = []
|
||||
for row in cursor.fetchall():
|
||||
row_dict = {}
|
||||
for i, field in enumerate(fields):
|
||||
row_dict[field] = row[i]
|
||||
data.append(row_dict)
|
||||
|
||||
return data
|
||||
|
||||
return []
|
||||
|
||||
|
||||
def get_messages(self)-> List[OnceConversation]:
|
||||
cursor = self.connect.cursor()
|
||||
cursor.execute("SELECT messages FROM chat_history where conv_uid=?", [self.chat_seesion_id])
|
||||
context = cursor.fetchone()
|
||||
if context:
|
||||
return json.loads(context[0])
|
||||
return None
|
@ -11,13 +11,14 @@ from pilot.scene.message import (
|
||||
conversation_from_dict,
|
||||
conversations_to_dict,
|
||||
)
|
||||
|
||||
from pilot.common.custom_data_structure import FixedSizeDict, FixedSizeList
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class MemHistoryMemory(BaseChatHistoryMemory):
|
||||
histroies_map = {}
|
||||
histroies_map = FixedSizeDict(100)
|
||||
|
||||
|
||||
def __init__(self, chat_session_id: str):
|
||||
self.chat_seesion_id = chat_session_id
|
||||
|
BIN
pilot/mock_datas/chat_history.db
Normal file
BIN
pilot/mock_datas/chat_history.db
Normal file
Binary file not shown.
0
pilot/mock_datas/chat_history.db.wal
Normal file
0
pilot/mock_datas/chat_history.db.wal
Normal file
@ -1,12 +1,27 @@
|
||||
from enum import Enum
|
||||
|
||||
class Scene:
|
||||
def __init__(self, code, describe, is_inner):
|
||||
self.code = code
|
||||
self.describe = describe
|
||||
self.is_inner = is_inner
|
||||
|
||||
class ChatScene(Enum):
|
||||
ChatWithDbExecute = "chat_with_db_execute"
|
||||
ChatWithDbQA = "chat_with_db_qa"
|
||||
ChatExecution = "chat_execution"
|
||||
ChatKnowledge = "chat_default_knowledge"
|
||||
ChatDefaultKnowledge = "chat_default_knowledge"
|
||||
ChatNewKnowledge = "chat_new_knowledge"
|
||||
ChatUrlKnowledge = "chat_url_knowledge"
|
||||
InnerChatDBSummary = "inner_chat_db_summary"
|
||||
|
||||
ChatNormal = "chat_normal"
|
||||
ChatDashboard = "chat_dashboard"
|
||||
ChatKnowledge = "chat_knowledge"
|
||||
ChatDb = "chat_db"
|
||||
ChatData= "chat_data"
|
||||
|
||||
@staticmethod
|
||||
def is_valid_mode(mode):
|
||||
return any(mode == item.value for item in ChatScene)
|
||||
|
||||
|
@ -24,6 +24,7 @@ from pilot.prompts.prompt_new import PromptTemplate
|
||||
from pilot.memory.chat_history.base import BaseChatHistoryMemory
|
||||
from pilot.memory.chat_history.file_history import FileHistoryMemory
|
||||
from pilot.memory.chat_history.mem_history import MemHistoryMemory
|
||||
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
|
||||
|
||||
from pilot.configs.model_config import LOGDIR, DATASETS_DIR
|
||||
from pilot.utils import (
|
||||
@ -59,8 +60,6 @@ class BaseChat(ABC):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
temperature,
|
||||
max_new_tokens,
|
||||
chat_mode,
|
||||
chat_session_id,
|
||||
current_user_input,
|
||||
@ -70,17 +69,15 @@ class BaseChat(ABC):
|
||||
self.current_user_input: str = current_user_input
|
||||
self.llm_model = CFG.LLM_MODEL
|
||||
### can configurable storage methods
|
||||
self.memory = MemHistoryMemory(chat_session_id)
|
||||
self.memory = DuckdbHistoryMemory(chat_session_id)
|
||||
|
||||
### load prompt template
|
||||
self.prompt_template: PromptTemplate = CFG.prompt_templates[
|
||||
self.chat_mode.value
|
||||
]
|
||||
self.history_message: List[OnceConversation] = []
|
||||
self.current_message: OnceConversation = OnceConversation()
|
||||
self.current_message: OnceConversation = OnceConversation(chat_mode.value)
|
||||
self.current_tokens_used: int = 0
|
||||
self.temperature = temperature
|
||||
self.max_new_tokens = max_new_tokens
|
||||
### load chat_session_id's chat historys
|
||||
self._load_history(self.chat_session_id)
|
||||
|
||||
@ -99,15 +96,15 @@ class BaseChat(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def do_with_prompt_response(self, prompt_response):
|
||||
pass
|
||||
def do_action(self, prompt_response):
|
||||
return prompt_response
|
||||
|
||||
def __call_base(self):
|
||||
input_values = self.generate_input_values()
|
||||
### Chat sequence advance
|
||||
self.current_message.chat_order = len(self.history_message) + 1
|
||||
self.current_message.add_user_message(self.current_user_input)
|
||||
self.current_message.start_date = datetime.datetime.now()
|
||||
self.current_message.start_date = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
# TODO
|
||||
self.current_message.tokens = 0
|
||||
current_prompt = None
|
||||
@ -203,13 +200,10 @@ class BaseChat(ABC):
|
||||
# }"""
|
||||
|
||||
self.current_message.add_ai_message(ai_response_text)
|
||||
prompt_define_response = (
|
||||
self.prompt_template.output_parser.parse_prompt_response(
|
||||
ai_response_text
|
||||
)
|
||||
)
|
||||
prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text)
|
||||
|
||||
result = self.do_with_prompt_response(prompt_define_response)
|
||||
|
||||
result = self.do_action(prompt_define_response)
|
||||
|
||||
if hasattr(prompt_define_response, "thoughts"):
|
||||
if isinstance(prompt_define_response.thoughts, dict):
|
||||
|
0
pilot/scene/chat_dashboard/__init__.py
Normal file
0
pilot/scene/chat_dashboard/__init__.py
Normal file
81
pilot/scene/chat_dashboard/chat.py
Normal file
81
pilot/scene/chat_dashboard/chat.py
Normal file
@ -0,0 +1,81 @@
|
||||
import json
|
||||
from typing import Dict, NamedTuple, List
|
||||
from pilot.scene.base_message import (
|
||||
HumanMessage,
|
||||
ViewMessage,
|
||||
)
|
||||
from pilot.scene.base_chat import BaseChat
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.common.sql_database import Database
|
||||
from pilot.configs.config import Config
|
||||
from pilot.common.markdown_text import (
|
||||
generate_htm_table,
|
||||
)
|
||||
from pilot.scene.chat_db.auto_execute.prompt import prompt
|
||||
from pilot.scene.chat_dashboard.data_preparation.report_schma import ChartData, ReportData
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class ChatDashboard(BaseChat):
|
||||
chat_scene: str = ChatScene.ChatDashboard.value
|
||||
report_name: str
|
||||
"""Number of results to return from the query"""
|
||||
|
||||
def __init__(
|
||||
self, chat_session_id, db_name, user_input, report_name
|
||||
):
|
||||
""" """
|
||||
super().__init__(
|
||||
chat_mode=ChatScene.ChatWithDbExecute,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input,
|
||||
)
|
||||
if not db_name:
|
||||
raise ValueError(
|
||||
f"{ChatScene.ChatWithDbExecute.value} mode should chose db!"
|
||||
)
|
||||
self.report_name = report_name
|
||||
self.database = CFG.local_db
|
||||
# 准备DB信息(拿到指定库的链接)
|
||||
self.db_connect = self.database.get_session(self.db_name)
|
||||
self.top_k: int = 5
|
||||
|
||||
def generate_input_values(self):
|
||||
try:
|
||||
from pilot.summary.db_summary_client import DBSummaryClient
|
||||
except ImportError:
|
||||
raise ValueError("Could not import DBSummaryClient. ")
|
||||
client = DBSummaryClient()
|
||||
input_values = {
|
||||
"input": self.current_user_input,
|
||||
"dialect": self.database.dialect,
|
||||
"table_info": self.database.table_simple_info(self.db_connect),
|
||||
"supported_chat_type": "" #TODO
|
||||
# "table_info": client.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k)
|
||||
}
|
||||
return input_values
|
||||
|
||||
def do_action(self, prompt_response):
|
||||
### TODO 记录整体信息,处理成功的,和未成功的分开记录处理
|
||||
report_data: ReportData = ReportData()
|
||||
chart_datas: List[ChartData] = []
|
||||
for chart_item in prompt_response:
|
||||
try:
|
||||
datas = self.database.run(self.db_connect, chart_item.sql)
|
||||
chart_data: ChartData = ChartData()
|
||||
except Exception as e:
|
||||
# TODO 修复流程
|
||||
print(str(e))
|
||||
|
||||
|
||||
chart_datas.append(chart_data)
|
||||
|
||||
report_data.conv_uid = self.chat_session_id
|
||||
report_data.template_name = self.report_name
|
||||
report_data.template_introduce = None
|
||||
report_data.charts = chart_datas
|
||||
|
||||
return report_data
|
||||
|
||||
|
22
pilot/scene/chat_dashboard/data_preparation/report_schma.py
Normal file
22
pilot/scene/chat_dashboard/data_preparation/report_schma.py
Normal file
@ -0,0 +1,22 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import TypeVar, Union, List, Generic, Any
|
||||
|
||||
|
||||
class ChartData(BaseModel):
|
||||
chart_uid: str
|
||||
chart_type: str
|
||||
chart_sql: str
|
||||
column_name: List
|
||||
values: List
|
||||
style: Any
|
||||
|
||||
|
||||
class ReportData(BaseModel):
|
||||
conv_uid:str
|
||||
template_name:str
|
||||
template_introduce:str
|
||||
charts: List[ChartData]
|
||||
|
||||
|
||||
|
||||
|
41
pilot/scene/chat_dashboard/out_parser.py
Normal file
41
pilot/scene/chat_dashboard/out_parser.py
Normal file
@ -0,0 +1,41 @@
|
||||
import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, NamedTuple, List
|
||||
import pandas as pd
|
||||
from pilot.utils import build_logger
|
||||
from pilot.out_parser.base import BaseOutputParser, T
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
|
||||
|
||||
class ChartItem(NamedTuple):
|
||||
sql: str
|
||||
title:str
|
||||
thoughts: str
|
||||
showcase:str
|
||||
|
||||
|
||||
logger = build_logger("webserver", LOGDIR + "ChatDashboardOutputParser.log")
|
||||
|
||||
|
||||
class ChatDashboardOutputParser(BaseOutputParser):
|
||||
def __init__(self, sep: str, is_stream_out: bool):
|
||||
super().__init__(sep=sep, is_stream_out=is_stream_out)
|
||||
|
||||
def parse_prompt_response(self, model_out_text):
|
||||
clean_str = super().parse_prompt_response(model_out_text)
|
||||
print("clean prompt response:", clean_str)
|
||||
response = json.loads(clean_str)
|
||||
chart_items = List[ChartItem]
|
||||
for item in response:
|
||||
chart_items.append(ChartItem(item["sql"], item["title"], item["thoughts"], item["showcase"]))
|
||||
return chart_items
|
||||
|
||||
def parse_view_response(self, speak, data) -> str:
|
||||
### TODO
|
||||
|
||||
return data
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "chat_dashboard"
|
49
pilot/scene/chat_dashboard/prompt.py
Normal file
49
pilot/scene/chat_dashboard/prompt.py
Normal file
@ -0,0 +1,49 @@
|
||||
import json
|
||||
import importlib
|
||||
from pilot.prompts.prompt_new import PromptTemplate
|
||||
from pilot.configs.config import Config
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.scene.chat_db.auto_execute.out_parser import DbChatOutputParser, SqlAction
|
||||
from pilot.common.schema import SeparatorStyle
|
||||
|
||||
CFG = Config()
|
||||
|
||||
PROMPT_SCENE_DEFINE = """You are a {dialect} data analysis expert, please provide a professional data analysis solution according to the following situations"""
|
||||
PROMPT_SCENE_DEFINE = None
|
||||
|
||||
_DEFAULT_TEMPLATE = """
|
||||
According to the structure definition in the following tables:
|
||||
{table_info}
|
||||
Provide a professional data analysis with as few dimensions as possible, and the upper limit does not exceed 8 dimensions.
|
||||
Used to support goal: {input}
|
||||
|
||||
Use the chart display method in the following range:
|
||||
{supported_chat_type}
|
||||
give {dialect} data analysis SQL, analysis title, display method and analytical thinking,respond in the following json format:
|
||||
{response}
|
||||
Ensure the response is correct json and can be parsed by Python json.loads
|
||||
"""
|
||||
|
||||
RESPONSE_FORMAT = [{
|
||||
"sql": "data analysis SQL",
|
||||
"title": "Data Analysis Title",
|
||||
"showcase": "What type of charts to show",
|
||||
"thoughts": "Current thinking and value of data analysis"
|
||||
}]
|
||||
|
||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||
|
||||
PROMPT_NEED_NEED_STREAM_OUT = False
|
||||
|
||||
prompt = PromptTemplate(
|
||||
template_scene=ChatScene.ChatWithDbExecute.value,
|
||||
input_variables=["input", "table_info", "dialect", "supported_chat_type"],
|
||||
response_format=json.dumps(RESPONSE_FORMAT, indent=4),
|
||||
template_define=PROMPT_SCENE_DEFINE,
|
||||
template=_DEFAULT_TEMPLATE,
|
||||
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
|
||||
output_parser=DbChatOutputParser(
|
||||
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
|
||||
),
|
||||
)
|
||||
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
@ -0,0 +1,9 @@
|
||||
{
|
||||
"title": "Sales Report",
|
||||
"name": "sale_report",
|
||||
"introduce": "",
|
||||
"layout": "TODO",
|
||||
"supported_chart_type":["HeatMap","sheet", "LineChart", "PieChart", "BarChart"],
|
||||
"key_metrics":[],
|
||||
"trends": []
|
||||
}
|
@ -22,12 +22,10 @@ class ChatWithDbAutoExecute(BaseChat):
|
||||
"""Number of results to return from the query"""
|
||||
|
||||
def __init__(
|
||||
self, temperature, max_new_tokens, chat_session_id, db_name, user_input
|
||||
self, chat_session_id, db_name, user_input
|
||||
):
|
||||
""" """
|
||||
super().__init__(
|
||||
temperature=temperature,
|
||||
max_new_tokens=max_new_tokens,
|
||||
chat_mode=ChatScene.ChatWithDbExecute,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input,
|
||||
@ -57,5 +55,5 @@ class ChatWithDbAutoExecute(BaseChat):
|
||||
}
|
||||
return input_values
|
||||
|
||||
def do_with_prompt_response(self, prompt_response):
|
||||
def do_action(self, prompt_response):
|
||||
return self.database.run(self.db_connect, prompt_response.sql)
|
||||
|
@ -20,12 +20,10 @@ class ChatWithDbQA(BaseChat):
|
||||
"""Number of results to return from the query"""
|
||||
|
||||
def __init__(
|
||||
self, temperature, max_new_tokens, chat_session_id, db_name, user_input
|
||||
self, chat_session_id, db_name, user_input
|
||||
):
|
||||
""" """
|
||||
super().__init__(
|
||||
temperature=temperature,
|
||||
max_new_tokens=max_new_tokens,
|
||||
chat_mode=ChatScene.ChatWithDbQA,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input,
|
||||
@ -66,5 +64,4 @@ class ChatWithDbQA(BaseChat):
|
||||
}
|
||||
return input_values
|
||||
|
||||
def do_with_prompt_response(self, prompt_response):
|
||||
return prompt_response
|
||||
|
||||
|
@ -22,15 +22,11 @@ class ChatWithPlugin(BaseChat):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
temperature,
|
||||
max_new_tokens,
|
||||
chat_session_id,
|
||||
user_input,
|
||||
plugin_selector: str = None,
|
||||
):
|
||||
super().__init__(
|
||||
temperature=temperature,
|
||||
max_new_tokens=max_new_tokens,
|
||||
chat_mode=ChatScene.ChatExecution,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input,
|
||||
@ -66,7 +62,7 @@ class ChatWithPlugin(BaseChat):
|
||||
}
|
||||
return input_values
|
||||
|
||||
def do_with_prompt_response(self, prompt_response):
|
||||
def do_action(self, prompt_response):
|
||||
## plugin command run
|
||||
return execute_command(
|
||||
str(prompt_response.command.get("name")),
|
||||
|
@ -30,12 +30,10 @@ class ChatNewKnowledge(BaseChat):
|
||||
"""Number of results to return from the query"""
|
||||
|
||||
def __init__(
|
||||
self, temperature, max_new_tokens, chat_session_id, user_input, knowledge_name
|
||||
self, chat_session_id, user_input, knowledge_name
|
||||
):
|
||||
""" """
|
||||
super().__init__(
|
||||
temperature=temperature,
|
||||
max_new_tokens=max_new_tokens,
|
||||
chat_mode=ChatScene.ChatNewKnowledge,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input,
|
||||
@ -67,8 +65,6 @@ class ChatNewKnowledge(BaseChat):
|
||||
|
||||
return input_values
|
||||
|
||||
def do_with_prompt_response(self, prompt_response):
|
||||
return prompt_response
|
||||
|
||||
@property
|
||||
def chat_type(self) -> str:
|
||||
|
@ -25,16 +25,14 @@ CFG = Config()
|
||||
|
||||
|
||||
class ChatDefaultKnowledge(BaseChat):
|
||||
chat_scene: str = ChatScene.ChatKnowledge.value
|
||||
chat_scene: str = ChatScene.ChatDefaultKnowledge.value
|
||||
|
||||
"""Number of results to return from the query"""
|
||||
|
||||
def __init__(self, temperature, max_new_tokens, chat_session_id, user_input):
|
||||
def __init__(self, chat_session_id, user_input):
|
||||
""" """
|
||||
super().__init__(
|
||||
temperature=temperature,
|
||||
max_new_tokens=max_new_tokens,
|
||||
chat_mode=ChatScene.ChatKnowledge,
|
||||
chat_mode=ChatScene.ChatDefaultKnowledge,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input,
|
||||
)
|
||||
@ -61,9 +59,8 @@ class ChatDefaultKnowledge(BaseChat):
|
||||
)
|
||||
return input_values
|
||||
|
||||
def do_with_prompt_response(self, prompt_response):
|
||||
return prompt_response
|
||||
|
||||
|
||||
@property
|
||||
def chat_type(self) -> str:
|
||||
return ChatScene.ChatKnowledge.value
|
||||
return ChatScene.ChatDefaultKnowledge.value
|
||||
|
@ -39,7 +39,7 @@ PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||
PROMPT_NEED_NEED_STREAM_OUT = True
|
||||
|
||||
prompt = PromptTemplate(
|
||||
template_scene=ChatScene.ChatKnowledge.value,
|
||||
template_scene=ChatScene.ChatDefaultKnowledge.value,
|
||||
input_variables=["context", "question"],
|
||||
response_format=None,
|
||||
template_define=PROMPT_SCENE_DEFINE,
|
||||
|
@ -14,8 +14,6 @@ class InnerChatDBSummary(BaseChat):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
temperature,
|
||||
max_new_tokens,
|
||||
chat_session_id,
|
||||
user_input,
|
||||
db_select,
|
||||
@ -23,8 +21,6 @@ class InnerChatDBSummary(BaseChat):
|
||||
):
|
||||
""" """
|
||||
super().__init__(
|
||||
temperature=temperature,
|
||||
max_new_tokens=max_new_tokens,
|
||||
chat_mode=ChatScene.InnerChatDBSummary,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input,
|
||||
@ -40,8 +36,6 @@ class InnerChatDBSummary(BaseChat):
|
||||
}
|
||||
return input_values
|
||||
|
||||
def do_with_prompt_response(self, prompt_response):
|
||||
return prompt_response
|
||||
|
||||
@property
|
||||
def chat_type(self) -> str:
|
||||
|
@ -27,11 +27,9 @@ class ChatUrlKnowledge(BaseChat):
|
||||
|
||||
"""Number of results to return from the query"""
|
||||
|
||||
def __init__(self, temperature, max_new_tokens, chat_session_id, user_input, url):
|
||||
def __init__(self, chat_session_id, user_input, url):
|
||||
""" """
|
||||
super().__init__(
|
||||
temperature=temperature,
|
||||
max_new_tokens=max_new_tokens,
|
||||
chat_mode=ChatScene.ChatUrlKnowledge,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input,
|
||||
@ -62,8 +60,6 @@ class ChatUrlKnowledge(BaseChat):
|
||||
input_values = {"context": context, "question": self.current_user_input}
|
||||
return input_values
|
||||
|
||||
def do_with_prompt_response(self, prompt_response):
|
||||
return prompt_response
|
||||
|
||||
@property
|
||||
def chat_type(self) -> str:
|
||||
|
@ -18,11 +18,9 @@ class ChatNormal(BaseChat):
|
||||
|
||||
"""Number of results to return from the query"""
|
||||
|
||||
def __init__(self, temperature, max_new_tokens, chat_session_id, user_input):
|
||||
def __init__(self, chat_session_id, user_input):
|
||||
""" """
|
||||
super().__init__(
|
||||
temperature=temperature,
|
||||
max_new_tokens=max_new_tokens,
|
||||
chat_mode=ChatScene.ChatNormal,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input,
|
||||
@ -32,7 +30,7 @@ class ChatNormal(BaseChat):
|
||||
input_values = {"input": self.current_user_input}
|
||||
return input_values
|
||||
|
||||
def do_with_prompt_response(self, prompt_response):
|
||||
def do_action(self, prompt_response):
|
||||
return prompt_response
|
||||
|
||||
@property
|
||||
|
@ -25,7 +25,8 @@ class OnceConversation:
|
||||
All the information of a conversation, the current single service in memory, can expand cache and database support distributed services
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, chat_mode):
|
||||
self.chat_mode: str = chat_mode
|
||||
self.messages: List[BaseMessage] = []
|
||||
self.start_date: str = ""
|
||||
self.chat_order: int = 0
|
||||
@ -43,12 +44,28 @@ class OnceConversation:
|
||||
|
||||
def add_ai_message(self, message: str) -> None:
|
||||
"""Add an AI message to the store"""
|
||||
|
||||
has_message = any(isinstance(instance, AIMessage) for instance in self.messages)
|
||||
if has_message:
|
||||
raise ValueError("Already Have Ai message")
|
||||
self.messages.append(AIMessage(content=message))
|
||||
self.__update_ai_message(message)
|
||||
else:
|
||||
self.messages.append(AIMessage(content=message))
|
||||
""" """
|
||||
|
||||
def __update_ai_message(self, new_message: str) -> None:
|
||||
"""
|
||||
stream out message update
|
||||
Args:
|
||||
new_message:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
|
||||
for item in self.messages:
|
||||
if item.type == "ai":
|
||||
item.content = new_message
|
||||
|
||||
def add_view_message(self, message: str) -> None:
|
||||
"""Add an AI message to the store"""
|
||||
|
||||
@ -69,6 +86,13 @@ class OnceConversation:
|
||||
self.session_id = None
|
||||
|
||||
|
||||
def get_user_message(self):
|
||||
for once in self.messages:
|
||||
if isinstance(once, HumanMessage):
|
||||
return once.content
|
||||
return ""
|
||||
|
||||
|
||||
def _conversation_to_dic(once: OnceConversation) -> dict:
|
||||
start_str: str = ""
|
||||
if once.start_date:
|
||||
@ -78,6 +102,7 @@ def _conversation_to_dic(once: OnceConversation) -> dict:
|
||||
start_str = once.start_date
|
||||
|
||||
return {
|
||||
"chat_mode": once.chat_mode,
|
||||
"chat_order": once.chat_order,
|
||||
"start_date": start_str,
|
||||
"cost": once.cost if once.cost else 0,
|
||||
@ -93,6 +118,7 @@ def conversations_to_dict(conversations: List[OnceConversation]) -> List[dict]:
|
||||
def conversation_from_dict(once: dict) -> OnceConversation:
|
||||
conversation = OnceConversation()
|
||||
conversation.cost = once.get("cost", 0)
|
||||
conversation.chat_mode = once.get("chat_mode", "chat_normal")
|
||||
conversation.tokens = once.get("tokens", 0)
|
||||
conversation.start_date = once.get("start_date", "")
|
||||
conversation.chat_order = int(once.get("chat_order"))
|
||||
|
@ -1,15 +1,18 @@
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Request, Body, status
|
||||
import json
|
||||
import asyncio
|
||||
import time
|
||||
from fastapi import APIRouter, Request, Body, status, HTTPException, Response
|
||||
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
from typing import List
|
||||
|
||||
from pilot.server.api_v1.api_view_model import Result, ConversationVo, MessageVo
|
||||
from pilot.server.api_v1.api_view_model import Result, ConversationVo, MessageVo, ChatSceneVo
|
||||
from pilot.configs.config import Config
|
||||
from pilot.scene.base_chat import BaseChat
|
||||
from pilot.scene.base import ChatScene
|
||||
@ -17,6 +20,8 @@ from pilot.scene.chat_factory import ChatFactory
|
||||
from pilot.configs.model_config import (LOGDIR)
|
||||
from pilot.utils import build_logger
|
||||
from pilot.scene.base_message import (BaseMessage)
|
||||
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
|
||||
from pilot.scene.message import OnceConversation
|
||||
|
||||
router = APIRouter()
|
||||
CFG = Config()
|
||||
@ -28,32 +33,117 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE
|
||||
message = ""
|
||||
for error in exc.errors():
|
||||
message += ".".join(error.get("loc")) + ":" + error.get("msg") + ";"
|
||||
return Result.faild(message)
|
||||
return Result.faild(msg=message)
|
||||
|
||||
|
||||
@router.get('/v1/chat/dialogue/list', response_model=Result[List[ConversationVo]])
|
||||
async def dialogue_list(user_id: str):
|
||||
#### TODO
|
||||
|
||||
conversations = [ConversationVo(conv_uid="123", chat_mode="user", select_param="test1", user_input="message[0]"),
|
||||
ConversationVo(conv_uid="123", chat_mode="user", select_param="test1", user_input="message[0]")]
|
||||
|
||||
return Result[ConversationVo].succ(conversations)
|
||||
def __get_conv_user_message(conversations: dict):
|
||||
messages = conversations['messages']
|
||||
for item in messages:
|
||||
if item['type'] == "human":
|
||||
return item['data']['content']
|
||||
return ""
|
||||
|
||||
|
||||
@router.post('/v1/chat/dialogue/new', response_model=Result[str])
|
||||
async def dialogue_new(user_id: str):
|
||||
@router.get('/v1/chat/dialogue/list', response_model=Result[ConversationVo])
|
||||
async def dialogue_list(response: Response, user_id: str = None):
|
||||
# 设置CORS头部信息
|
||||
response.headers['Access-Control-Allow-Origin'] = '*'
|
||||
response.headers['Access-Control-Allow-Methods'] = 'GET'
|
||||
response.headers['Access-Control-Request-Headers'] = 'content-type'
|
||||
|
||||
dialogues: List = []
|
||||
datas = DuckdbHistoryMemory.conv_list(user_id)
|
||||
|
||||
for item in datas:
|
||||
conv_uid = item.get("conv_uid")
|
||||
messages = item.get("messages")
|
||||
conversations = json.loads(messages)
|
||||
|
||||
first_conv: OnceConversation = conversations[0]
|
||||
conv_vo: ConversationVo = ConversationVo(conv_uid=conv_uid, user_input=__get_conv_user_message(first_conv),
|
||||
chat_mode=first_conv['chat_mode'])
|
||||
dialogues.append(conv_vo)
|
||||
|
||||
return Result[ConversationVo].succ(dialogues)
|
||||
|
||||
|
||||
@router.post('/v1/chat/dialogue/scenes', response_model=Result[List[ChatSceneVo]])
|
||||
async def dialogue_scenes():
|
||||
scene_vos: List[ChatSceneVo] = []
|
||||
new_modes:List[ChatScene] = [ChatScene.ChatDb, ChatScene.ChatData, ChatScene.ChatDashboard, ChatScene.ChatKnowledge, ChatScene.ChatExecution]
|
||||
for scene in new_modes:
|
||||
if not scene.value in [ChatScene.ChatNormal.value, ChatScene.InnerChatDBSummary.value]:
|
||||
scene_vo = ChatSceneVo(chat_scene=scene.value, scene_name=scene.name, param_title="Selection Param")
|
||||
scene_vos.append(scene_vo)
|
||||
return Result.succ(scene_vos)
|
||||
|
||||
|
||||
@router.post('/v1/chat/dialogue/new', response_model=Result[ConversationVo])
|
||||
async def dialogue_new(chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None):
|
||||
unique_id = uuid.uuid1()
|
||||
return Result.succ(unique_id)
|
||||
return Result.succ(ConversationVo(conv_uid=str(unique_id), chat_mode=chat_mode))
|
||||
|
||||
|
||||
def get_db_list():
|
||||
db = CFG.local_db
|
||||
dbs = db.get_database_list()
|
||||
params:dict = {}
|
||||
for name in dbs:
|
||||
params.update({name: name})
|
||||
return params
|
||||
|
||||
|
||||
def plugins_select_info():
|
||||
plugins_infos: dict = {}
|
||||
for plugin in CFG.plugins:
|
||||
plugins_infos.update({f"【{plugin._name}】=>{plugin._description}": plugin._name})
|
||||
return plugins_infos
|
||||
|
||||
|
||||
def knowledge_list():
|
||||
knowledge: dict = {}
|
||||
### TODO
|
||||
return knowledge
|
||||
|
||||
|
||||
@router.post('/v1/chat/mode/params/list', response_model=Result[dict])
|
||||
async def params_list(chat_mode: str = ChatScene.ChatNormal.value):
|
||||
if ChatScene.ChatDb.value == chat_mode:
|
||||
return Result.succ(get_db_list())
|
||||
elif ChatScene.ChatData.value == chat_mode:
|
||||
return Result.succ(get_db_list())
|
||||
elif ChatScene.ChatDashboard.value == chat_mode:
|
||||
return Result.succ(get_db_list())
|
||||
elif ChatScene.ChatExecution.value == chat_mode:
|
||||
return Result.succ(plugins_select_info())
|
||||
elif ChatScene.ChatKnowledge.value == chat_mode:
|
||||
return Result.succ(knowledge_list())
|
||||
else:
|
||||
return Result.succ(None)
|
||||
|
||||
|
||||
@router.post('/v1/chat/dialogue/delete')
|
||||
async def dialogue_delete(con_uid: str, user_id: str):
|
||||
#### TODO
|
||||
async def dialogue_delete(con_uid: str):
|
||||
history_mem = DuckdbHistoryMemory(con_uid)
|
||||
history_mem.delete()
|
||||
return Result.succ(None)
|
||||
|
||||
|
||||
@router.post('/v1/chat/completions', response_model=Result[MessageVo])
|
||||
@router.get('/v1/chat/dialogue/messages/history', response_model=Result[MessageVo])
|
||||
async def dialogue_history_messages(con_uid: str):
|
||||
print(f"dialogue_history_messages:{con_uid}")
|
||||
message_vos: List[MessageVo] = []
|
||||
|
||||
history_mem = DuckdbHistoryMemory(con_uid)
|
||||
history_messages: List[OnceConversation] = history_mem.get_messages()
|
||||
if history_messages:
|
||||
for once in history_messages:
|
||||
once_message_vos = [message2Vo(element, once['chat_order']) for element in once['messages']]
|
||||
message_vos.extend(once_message_vos)
|
||||
return Result.succ(message_vos)
|
||||
|
||||
|
||||
@router.post('/v1/chat/completions')
|
||||
async def chat_completions(dialogue: ConversationVo = Body()):
|
||||
print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param}")
|
||||
|
||||
@ -65,22 +155,31 @@ async def chat_completions(dialogue: ConversationVo = Body()):
|
||||
"user_input": dialogue.user_input,
|
||||
}
|
||||
|
||||
if ChatScene.ChatWithDbExecute == dialogue.chat_mode:
|
||||
if ChatScene.ChatDb == dialogue.chat_mode:
|
||||
chat_param.update("db_name", dialogue.select_param)
|
||||
elif ChatScene.ChatWithDbQA == dialogue.chat_mode:
|
||||
elif ChatScene.ChatData == dialogue.chat_mode:
|
||||
chat_param.update("db_name", dialogue.select_param)
|
||||
elif ChatScene.ChatDashboard == dialogue.chat_mode:
|
||||
chat_param.update("db_name", dialogue.select_param)
|
||||
elif ChatScene.ChatExecution == dialogue.chat_mode:
|
||||
chat_param.update("plugin_selector", dialogue.select_param)
|
||||
elif ChatScene.ChatNewKnowledge == dialogue.chat_mode:
|
||||
elif ChatScene.ChatKnowledge == dialogue.chat_mode:
|
||||
chat_param.update("knowledge_name", dialogue.select_param)
|
||||
elif ChatScene.ChatUrlKnowledge == dialogue.chat_mode:
|
||||
chat_param.update("url", dialogue.select_param)
|
||||
|
||||
chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param)
|
||||
if not chat.prompt_template.stream_out:
|
||||
return non_stream_response(chat)
|
||||
else:
|
||||
return stream_response(chat)
|
||||
# generator = stream_generator(chat)
|
||||
# result = Result.succ(data=StreamingResponse(stream_test(), media_type='text/plain'))
|
||||
# return result
|
||||
return StreamingResponse(stream_generator(chat), media_type="text/plain")
|
||||
|
||||
|
||||
def stream_test():
|
||||
for message in ["Hello", "world", "how", "are", "you"]:
|
||||
yield message
|
||||
# yield json.dumps(Result.succ(message).__dict__).encode("utf-8")
|
||||
|
||||
|
||||
def stream_generator(chat):
|
||||
@ -89,24 +188,28 @@ def stream_generator(chat):
|
||||
if chunk:
|
||||
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
|
||||
chat.current_message.add_ai_message(msg)
|
||||
messageVos = [message2Vo(element) for element in chat.current_message.messages]
|
||||
yield Result.succ(messageVos)
|
||||
def stream_response(chat):
|
||||
logger.info("stream out start!")
|
||||
api_response = StreamingResponse(stream_generator(chat), media_type="application/json")
|
||||
return api_response
|
||||
yield msg
|
||||
# chat.current_message.add_ai_message(msg)
|
||||
# vo = MessageVo(role="view", context=msg, order=chat.current_message.chat_order)
|
||||
# json_text = json.dumps(vo.__dict__)
|
||||
# yield json_text.encode('utf-8')
|
||||
chat.memory.append(chat.current_message)
|
||||
|
||||
|
||||
# def stream_response(chat):
|
||||
# logger.info("stream out start!")
|
||||
# api_response = StreamingResponse(stream_generator(chat), media_type="application/json")
|
||||
# return api_response
|
||||
|
||||
|
||||
def message2Vo(message: dict, order) -> MessageVo:
|
||||
# message.additional_kwargs['time_stamp'] if message.additional_kwargs["time_stamp"] else 0
|
||||
return MessageVo(role=message['type'], context=message['data']['content'], order=order)
|
||||
|
||||
def message2Vo(message:BaseMessage)->MessageVo:
|
||||
vo:MessageVo = MessageVo()
|
||||
vo.role = message.type
|
||||
vo.role = message.content
|
||||
vo.time_stamp = message.additional_kwargs.time_stamp if message.additional_kwargs["time_stamp"] else 0
|
||||
|
||||
def non_stream_response(chat):
|
||||
logger.info("not stream out, wait model response!")
|
||||
chat.nostream_call()
|
||||
messageVos = [message2Vo(element) for element in chat.current_message.messages]
|
||||
return Result.succ(messageVos)
|
||||
return chat.nostream_call()
|
||||
|
||||
|
||||
@router.get('/v1/db/types', response_model=Result[str])
|
||||
|
@ -1,28 +1,33 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import TypeVar, Union, List, Generic
|
||||
from typing import TypeVar, Union, List, Generic, Any
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
class Result(Generic[T], BaseModel):
|
||||
success: bool
|
||||
err_code: str
|
||||
err_msg: str
|
||||
data: List[T]
|
||||
err_code: str = None
|
||||
err_msg: str = None
|
||||
data: T = None
|
||||
|
||||
@classmethod
|
||||
def succ(cls, data: List[T]):
|
||||
return Result(True, None, None, data)
|
||||
def succ(cls, data: T):
|
||||
return Result(success=True, err_code=None, err_msg=None, data=data)
|
||||
|
||||
@classmethod
|
||||
def faild(cls, msg):
|
||||
return Result(True, "E000X", msg, None)
|
||||
return Result(success=False, err_code="E000X", err_msg=msg, data=None)
|
||||
|
||||
@classmethod
|
||||
def faild(cls, code, msg):
|
||||
return Result(True, code, msg, None)
|
||||
return Result(success=False, err_code=code, err_msg=msg, data=None)
|
||||
|
||||
|
||||
class ChatSceneVo(BaseModel):
|
||||
chat_scene: str = Field(..., description="chat_scene")
|
||||
scene_name: str = Field(..., description="chat_scene name show for user")
|
||||
param_title: str = Field(..., description="chat_scene required parameter title")
|
||||
|
||||
class ConversationVo(BaseModel):
|
||||
"""
|
||||
dialogue_uid
|
||||
@ -31,15 +36,21 @@ class ConversationVo(BaseModel):
|
||||
"""
|
||||
user input
|
||||
"""
|
||||
user_input: str
|
||||
user_input: str = ""
|
||||
"""
|
||||
user
|
||||
"""
|
||||
user_name: str = ""
|
||||
"""
|
||||
the scene of chat
|
||||
"""
|
||||
chat_mode: str = Field(..., description="the scene of chat ")
|
||||
|
||||
"""
|
||||
chat scene select param
|
||||
"""
|
||||
select_param: str
|
||||
select_param: str = None
|
||||
|
||||
|
||||
|
||||
class MessageVo(BaseModel):
|
||||
@ -51,7 +62,12 @@ class MessageVo(BaseModel):
|
||||
current message
|
||||
"""
|
||||
context: str
|
||||
|
||||
""" message postion order """
|
||||
order: int
|
||||
|
||||
"""
|
||||
time the current message was sent
|
||||
"""
|
||||
time_stamp: float
|
||||
time_stamp: Any = None
|
||||
|
||||
|
@ -1,6 +1,5 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
import signal
|
||||
import threading
|
||||
import traceback
|
||||
import argparse
|
||||
@ -12,12 +11,10 @@ import uuid
|
||||
|
||||
import gradio as gr
|
||||
|
||||
|
||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(ROOT_PATH)
|
||||
|
||||
from pilot.summary.db_summary_client import DBSummaryClient
|
||||
from pilot.commands.command_mange import CommandRegistry
|
||||
|
||||
from pilot.scene.base_chat import BaseChat
|
||||
|
||||
@ -25,8 +22,8 @@ from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import (
|
||||
DATASETS_DIR,
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
LOGDIR,
|
||||
LLM_MODEL_CONFIG,
|
||||
LOGDIR,
|
||||
)
|
||||
|
||||
from pilot.conversation import (
|
||||
@ -35,11 +32,10 @@ from pilot.conversation import (
|
||||
chat_mode_title,
|
||||
default_conversation,
|
||||
)
|
||||
from pilot.common.plugins import scan_plugins, load_native_plugins
|
||||
|
||||
from pilot.server.gradio_css import code_highlight_css
|
||||
from pilot.server.gradio_patch import Chatbot as grChatbot
|
||||
from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding
|
||||
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
|
||||
from pilot.utils import build_logger
|
||||
from pilot.vector_store.extract_tovec import (
|
||||
get_vector_storelist,
|
||||
@ -49,6 +45,20 @@ from pilot.vector_store.extract_tovec import (
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.scene.chat_factory import ChatFactory
|
||||
from pilot.language.translation_handler import get_lang_text
|
||||
from pilot.server.webserver_base import server_init
|
||||
|
||||
|
||||
import uvicorn
|
||||
from fastapi import BackgroundTasks, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from fastapi import FastAPI, applications
|
||||
from fastapi.openapi.docs import get_swagger_ui_html
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from pilot.server.api_v1.api_v1 import router as api_v1, validation_exception_handler
|
||||
|
||||
# 加载插件
|
||||
CFG = Config()
|
||||
@ -95,6 +105,30 @@ knowledge_qa_type_list = [
|
||||
add_knowledge_base_dialogue,
|
||||
]
|
||||
|
||||
def swagger_monkey_patch(*args, **kwargs):
|
||||
return get_swagger_ui_html(
|
||||
*args, **kwargs,
|
||||
swagger_js_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js',
|
||||
swagger_css_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css'
|
||||
)
|
||||
applications.get_swagger_ui_html = swagger_monkey_patch
|
||||
|
||||
app = FastAPI()
|
||||
origins = ["*"]
|
||||
|
||||
# 添加跨域中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# app.mount("static", StaticFiles(directory="static"), name="static")
|
||||
app.include_router(api_v1)
|
||||
app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
||||
|
||||
|
||||
def get_simlar(q):
|
||||
docsearch = knownledge_tovec_st(os.path.join(DATASETS_DIR, "plan.md"))
|
||||
@ -216,7 +250,7 @@ def get_chat_mode(selected, param=None) -> ChatScene:
|
||||
else:
|
||||
mode = param
|
||||
if mode == conversation_types["default_knownledge"]:
|
||||
return ChatScene.ChatKnowledge
|
||||
return ChatScene.ChatDefaultKnowledge
|
||||
elif mode == conversation_types["custome"]:
|
||||
return ChatScene.ChatNewKnowledge
|
||||
elif mode == conversation_types["url"]:
|
||||
@ -286,7 +320,7 @@ def http_bot(
|
||||
"chat_session_id": state.conv_id,
|
||||
"user_input": state.last_user_input,
|
||||
}
|
||||
elif ChatScene.ChatKnowledge == scene:
|
||||
elif ChatScene.ChatDefaultKnowledge == scene:
|
||||
chat_param = {
|
||||
"temperature": temperature,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
@ -324,15 +358,14 @@ def http_bot(
|
||||
response = chat.stream_call()
|
||||
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||
if chunk:
|
||||
state.messages[-1][
|
||||
-1
|
||||
] = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
|
||||
chunk, chat.skip_echo_len
|
||||
)
|
||||
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
|
||||
state.messages[-1][-1] =msg
|
||||
chat.current_message.add_ai_message(msg)
|
||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||
chat.memory.append(chat.current_message)
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
state.messages[-1][-1] = "Error:" + str(e)
|
||||
state.messages[-1][-1] = f"""<span style=\"color:red\">ERROR!</span>{str(e)} """
|
||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||
|
||||
|
||||
@ -632,7 +665,7 @@ def knowledge_embedding_store(vs_id, files):
|
||||
)
|
||||
knowledge_embedding_client = KnowledgeEmbedding(
|
||||
file_path=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename),
|
||||
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
model_name=LLM_MODEL_CONFIG["text2vec"],
|
||||
vector_store_config={
|
||||
"vector_store_name": vector_store_name["vs_name"],
|
||||
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
@ -657,48 +690,36 @@ def signal_handler(sig, frame):
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_list_mode", type=str, default="once", choices=["once", "reload"])
|
||||
parser.add_argument('-new', '--new', action='store_true', help='enable new http mode')
|
||||
|
||||
# old version server config
|
||||
parser.add_argument("--host", type=str, default="0.0.0.0")
|
||||
parser.add_argument("--port", type=int, default=CFG.WEB_SERVER_PORT)
|
||||
parser.add_argument("--concurrency-count", type=int, default=10)
|
||||
parser.add_argument(
|
||||
"--model-list-mode", type=str, default="once", choices=["once", "reload"]
|
||||
)
|
||||
parser.add_argument("--share", default=False, action="store_true")
|
||||
|
||||
|
||||
# init server config
|
||||
args = parser.parse_args()
|
||||
logger.info(f"args: {args}")
|
||||
server_init(args)
|
||||
|
||||
if args.new:
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=5000)
|
||||
else:
|
||||
### Compatibility mode starts the old version server by default
|
||||
demo = build_webdemo()
|
||||
demo.queue(
|
||||
concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False
|
||||
).launch(
|
||||
server_name=args.host,
|
||||
server_port=args.port,
|
||||
share=args.share,
|
||||
max_threads=200,
|
||||
)
|
||||
|
||||
# init config
|
||||
cfg = Config()
|
||||
|
||||
load_native_plugins(cfg)
|
||||
dbs = cfg.local_db.get_database_list()
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
async_db_summery()
|
||||
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
|
||||
|
||||
# Loader plugins and commands
|
||||
command_categories = [
|
||||
"pilot.commands.built_in.audio_text",
|
||||
"pilot.commands.built_in.image_gen",
|
||||
]
|
||||
# exclude commands
|
||||
command_categories = [
|
||||
x for x in command_categories if x not in cfg.disabled_command_categories
|
||||
]
|
||||
command_registry = CommandRegistry()
|
||||
for command_category in command_categories:
|
||||
command_registry.import_commands(command_category)
|
||||
|
||||
cfg.command_registry = command_registry
|
||||
|
||||
logger.info(args)
|
||||
demo = build_webdemo()
|
||||
demo.queue(
|
||||
concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False
|
||||
).launch(
|
||||
server_name=args.host,
|
||||
server_port=args.port,
|
||||
share=args.share,
|
||||
max_threads=200,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user