WEB API independent

This commit is contained in:
tuyang.yhj 2023-06-27 15:35:18 +08:00
parent 9d3000fb26
commit 3f7cc02426
33 changed files with 683 additions and 182 deletions

View File

@ -1,4 +1,5 @@
"""Utilities for formatting strings.""" """Utilities for formatting strings."""
import json
from string import Formatter from string import Formatter
from typing import Any, List, Mapping, Sequence, Union from typing import Any, List, Mapping, Sequence, Union
@ -36,3 +37,13 @@ class StrictFormatter(Formatter):
formatter = StrictFormatter() formatter = StrictFormatter()
class MyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, set):
return list(obj)
elif hasattr(obj, '__dict__'):
return obj.__dict__
else:
return json.JSONEncoder.default(self, obj)

View File

@ -6,6 +6,8 @@ import numpy as np
from matplotlib.font_manager import FontProperties from matplotlib.font_manager import FontProperties
from pyecharts.charts import Bar from pyecharts.charts import Bar
from pyecharts import options as opts from pyecharts import options as opts
from test_cls_1 import TestBase,Test1
from test_cls_2 import Test2
CFG = Config() CFG = Config()
@ -56,20 +58,29 @@ CFG = Config()
# #
# if __name__ == "__main__":
# def __extract_json(s):
# i = s.index("{")
# count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
# for j, c in enumerate(s[i + 1 :], start=i + 1):
# if c == "}":
# count -= 1
# elif c == "{":
# count += 1
# if count == 0:
# break
# assert count == 0 # 检查是否找到最后一个'}'
# return s[i : j + 1]
#
# ss = """here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:select u.city, count(*) as order_countfrom tran_order oleft join user u on o.user_id = u.idgroup by u.city;this will return the number of orders for each city that has at least one order. we can use this data to generate a histogram that shows the distribution of orders across different cities.here's the response in the required format:{ "thoughts": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:\n\nselect u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;", "speak": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities.", "command": { "name": "histogram-executor", "args": { "title": "distribution of user orders in different cities", "sql": "select u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;" } }}"""
# print(__extract_json(ss))
if __name__ == "__main__": if __name__ == "__main__":
test1 = Test1()
def __extract_json(s): test2 = Test2()
i = s.index("{") test1.write()
count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数 test1.test()
for j, c in enumerate(s[i + 1 :], start=i + 1): test2.write()
if c == "}": test1.test()
count -= 1 test2.test()
elif c == "{":
count += 1
if count == 0:
break
assert count == 0 # 检查是否找到最后一个'}'
return s[i : j + 1]
ss = """here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:select u.city, count(*) as order_countfrom tran_order oleft join user u on o.user_id = u.idgroup by u.city;this will return the number of orders for each city that has at least one order. we can use this data to generate a histogram that shows the distribution of orders across different cities.here's the response in the required format:{ "thoughts": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:\n\nselect u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;", "speak": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities.", "command": { "name": "histogram-executor", "args": { "title": "distribution of user orders in different cities", "sql": "select u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;" } }}"""
print(__extract_json(ss))

View File

@ -4,7 +4,7 @@ from test_cls_base import TestBase
class Test1(TestBase): class Test1(TestBase):
mode:str = "456"
def write(self): def write(self):
self.test_values.append("x") self.test_values.append("x")
self.test_values.append("y") self.test_values.append("y")

View File

@ -5,7 +5,7 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
class Test2(TestBase): class Test2(TestBase):
test_2_values:List = [] test_2_values:List = []
mode:str = "789"
def write(self): def write(self):
self.test_values.append(1) self.test_values.append(1)
self.test_values.append(2) self.test_values.append(2)

View File

@ -5,8 +5,9 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
class TestBase(BaseModel, ABC): class TestBase(BaseModel, ABC):
test_values: List = [] test_values: List = []
mode:str = "123"
def test(self): def test(self):
print(self.__class__.__name__ + ":" ) print(self.__class__.__name__ + ":" )
print(self.test_values) print(self.test_values)
print(self.mode)

View File

@ -0,0 +1,13 @@
import json
import os
import duckdb
default_db_path = os.path.join(os.getcwd(), "message")
duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/chat_history.db")
if __name__ == "__main__":
if os.path.isfile(duckdb_path):
cursor = duckdb.connect(duckdb_path).cursor()
cursor.execute("SELECT * FROM chat_history limit 20")
data = cursor.fetchall()
print(data)

View File

@ -32,3 +32,9 @@ class BaseChatHistoryMemory(ABC):
@abstractmethod @abstractmethod
def clear(self) -> None: def clear(self) -> None:
"""Clear session memory from the local file""" """Clear session memory from the local file"""
def conv_list(self, user_name:str=None) -> None:
"""get user's conversation list"""
pass

View File

@ -0,0 +1,109 @@
import json
import os
import duckdb
from typing import List
from pilot.configs.config import Config
from pilot.memory.chat_history.base import BaseChatHistoryMemory
from pilot.scene.message import (
OnceConversation,
conversation_from_dict,
conversations_to_dict,
)
from pilot.common.formatting import MyEncoder
default_db_path = os.path.join(os.getcwd(), "message")
duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/chat_history.db")
table_name = 'chat_history'
CFG = Config()
class DuckdbHistoryMemory(BaseChatHistoryMemory):
def __init__(self, chat_session_id: str):
self.chat_seesion_id = chat_session_id
os.makedirs(default_db_path, exist_ok=True)
self.connect = duckdb.connect(duckdb_path)
self.__init_chat_history_tables()
def __init_chat_history_tables(self):
# 检查表是否存在
result = self.connect.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?",
[table_name]).fetchall()
if not result:
# 如果表不存在,则创建新表
self.connect.execute(
"CREATE TABLE chat_history (conv_uid VARCHAR(100) PRIMARY KEY, user_name VARCHAR(100), messages TEXT)")
def __get_messages_by_conv_uid(self, conv_uid: str):
cursor = self.connect.cursor()
cursor.execute("SELECT messages FROM chat_history where conv_uid=?", [conv_uid])
return cursor.fetchone()
def messages(self) -> List[OnceConversation]:
context = self.__get_messages_by_conv_uid(self.chat_seesion_id)
if context:
conversations: List[OnceConversation] = json.loads(context[0])
return conversations
return []
def append(self, once_message: OnceConversation) -> None:
context = self.__get_messages_by_conv_uid(self.chat_seesion_id)
conversations: List[OnceConversation] = []
if context:
conversations = json.load(context)
conversations.append(once_message)
cursor = self.connect.cursor()
if context:
cursor.execute("UPDATE chat_history set messages=? where conv_uid=?",
[json.dumps(conversations_to_dict(conversations), ensure_ascii=False, indent=4), self.chat_seesion_id])
else:
cursor.execute("INSERT INTO chat_history(conv_uid, user_name, messages)VALUES(?,?,?)",
[self.chat_seesion_id, "", json.dumps(conversations_to_dict(conversations), ensure_ascii=False, indent=4)])
cursor.commit()
self.connect.commit()
def clear(self) -> None:
cursor = self.connect.cursor()
cursor.execute("DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id])
cursor.commit()
self.connect.commit()
def delete(self) -> bool:
cursor = self.connect.cursor()
cursor.execute("DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id])
cursor.commit()
return True
@staticmethod
def conv_list(cls, user_name: str = None) -> None:
if os.path.isfile(duckdb_path):
cursor = duckdb.connect(duckdb_path).cursor()
if user_name:
cursor.execute("SELECT * FROM chat_history where user_name=? limit 20", [user_name])
else:
cursor.execute("SELECT * FROM chat_history limit 20")
# 获取查询结果字段名
fields = [field[0] for field in cursor.description]
data = []
for row in cursor.fetchall():
row_dict = {}
for i, field in enumerate(fields):
row_dict[field] = row[i]
data.append(row_dict)
return data
return []
def get_messages(self)-> List[OnceConversation]:
cursor = self.connect.cursor()
cursor.execute("SELECT messages FROM chat_history where conv_uid=?", [self.chat_seesion_id])
context = cursor.fetchone()
if context:
return json.loads(context[0])
return None

View File

@ -11,13 +11,14 @@ from pilot.scene.message import (
conversation_from_dict, conversation_from_dict,
conversations_to_dict, conversations_to_dict,
) )
from pilot.common.custom_data_structure import FixedSizeDict, FixedSizeList
CFG = Config() CFG = Config()
class MemHistoryMemory(BaseChatHistoryMemory): class MemHistoryMemory(BaseChatHistoryMemory):
histroies_map = {} histroies_map = FixedSizeDict(100)
def __init__(self, chat_session_id: str): def __init__(self, chat_session_id: str):
self.chat_seesion_id = chat_session_id self.chat_seesion_id = chat_session_id

Binary file not shown.

View File

View File

@ -1,12 +1,27 @@
from enum import Enum from enum import Enum
class Scene:
def __init__(self, code, describe, is_inner):
self.code = code
self.describe = describe
self.is_inner = is_inner
class ChatScene(Enum): class ChatScene(Enum):
ChatWithDbExecute = "chat_with_db_execute" ChatWithDbExecute = "chat_with_db_execute"
ChatWithDbQA = "chat_with_db_qa" ChatWithDbQA = "chat_with_db_qa"
ChatExecution = "chat_execution" ChatExecution = "chat_execution"
ChatKnowledge = "chat_default_knowledge" ChatDefaultKnowledge = "chat_default_knowledge"
ChatNewKnowledge = "chat_new_knowledge" ChatNewKnowledge = "chat_new_knowledge"
ChatUrlKnowledge = "chat_url_knowledge" ChatUrlKnowledge = "chat_url_knowledge"
InnerChatDBSummary = "inner_chat_db_summary" InnerChatDBSummary = "inner_chat_db_summary"
ChatNormal = "chat_normal" ChatNormal = "chat_normal"
ChatDashboard = "chat_dashboard"
ChatKnowledge = "chat_knowledge"
ChatDb = "chat_db"
ChatData= "chat_data"
@staticmethod
def is_valid_mode(mode):
return any(mode == item.value for item in ChatScene)

View File

@ -24,6 +24,7 @@ from pilot.prompts.prompt_new import PromptTemplate
from pilot.memory.chat_history.base import BaseChatHistoryMemory from pilot.memory.chat_history.base import BaseChatHistoryMemory
from pilot.memory.chat_history.file_history import FileHistoryMemory from pilot.memory.chat_history.file_history import FileHistoryMemory
from pilot.memory.chat_history.mem_history import MemHistoryMemory from pilot.memory.chat_history.mem_history import MemHistoryMemory
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
from pilot.configs.model_config import LOGDIR, DATASETS_DIR from pilot.configs.model_config import LOGDIR, DATASETS_DIR
from pilot.utils import ( from pilot.utils import (
@ -59,8 +60,6 @@ class BaseChat(ABC):
def __init__( def __init__(
self, self,
temperature,
max_new_tokens,
chat_mode, chat_mode,
chat_session_id, chat_session_id,
current_user_input, current_user_input,
@ -70,17 +69,15 @@ class BaseChat(ABC):
self.current_user_input: str = current_user_input self.current_user_input: str = current_user_input
self.llm_model = CFG.LLM_MODEL self.llm_model = CFG.LLM_MODEL
### can configurable storage methods ### can configurable storage methods
self.memory = MemHistoryMemory(chat_session_id) self.memory = DuckdbHistoryMemory(chat_session_id)
### load prompt template ### load prompt template
self.prompt_template: PromptTemplate = CFG.prompt_templates[ self.prompt_template: PromptTemplate = CFG.prompt_templates[
self.chat_mode.value self.chat_mode.value
] ]
self.history_message: List[OnceConversation] = [] self.history_message: List[OnceConversation] = []
self.current_message: OnceConversation = OnceConversation() self.current_message: OnceConversation = OnceConversation(chat_mode.value)
self.current_tokens_used: int = 0 self.current_tokens_used: int = 0
self.temperature = temperature
self.max_new_tokens = max_new_tokens
### load chat_session_id's chat historys ### load chat_session_id's chat historys
self._load_history(self.chat_session_id) self._load_history(self.chat_session_id)
@ -99,15 +96,15 @@ class BaseChat(ABC):
pass pass
@abstractmethod @abstractmethod
def do_with_prompt_response(self, prompt_response): def do_action(self, prompt_response):
pass return prompt_response
def __call_base(self): def __call_base(self):
input_values = self.generate_input_values() input_values = self.generate_input_values()
### Chat sequence advance ### Chat sequence advance
self.current_message.chat_order = len(self.history_message) + 1 self.current_message.chat_order = len(self.history_message) + 1
self.current_message.add_user_message(self.current_user_input) self.current_message.add_user_message(self.current_user_input)
self.current_message.start_date = datetime.datetime.now() self.current_message.start_date = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# TODO # TODO
self.current_message.tokens = 0 self.current_message.tokens = 0
current_prompt = None current_prompt = None
@ -203,13 +200,10 @@ class BaseChat(ABC):
# }""" # }"""
self.current_message.add_ai_message(ai_response_text) self.current_message.add_ai_message(ai_response_text)
prompt_define_response = ( prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text)
self.prompt_template.output_parser.parse_prompt_response(
ai_response_text
)
)
result = self.do_with_prompt_response(prompt_define_response)
result = self.do_action(prompt_define_response)
if hasattr(prompt_define_response, "thoughts"): if hasattr(prompt_define_response, "thoughts"):
if isinstance(prompt_define_response.thoughts, dict): if isinstance(prompt_define_response.thoughts, dict):

View File

View File

@ -0,0 +1,81 @@
import json
from typing import Dict, NamedTuple, List
from pilot.scene.base_message import (
HumanMessage,
ViewMessage,
)
from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database
from pilot.configs.config import Config
from pilot.common.markdown_text import (
generate_htm_table,
)
from pilot.scene.chat_db.auto_execute.prompt import prompt
from pilot.scene.chat_dashboard.data_preparation.report_schma import ChartData, ReportData
CFG = Config()
class ChatDashboard(BaseChat):
chat_scene: str = ChatScene.ChatDashboard.value
report_name: str
"""Number of results to return from the query"""
def __init__(
self, chat_session_id, db_name, user_input, report_name
):
""" """
super().__init__(
chat_mode=ChatScene.ChatWithDbExecute,
chat_session_id=chat_session_id,
current_user_input=user_input,
)
if not db_name:
raise ValueError(
f"{ChatScene.ChatWithDbExecute.value} mode should chose db!"
)
self.report_name = report_name
self.database = CFG.local_db
# 准备DB信息(拿到指定库的链接)
self.db_connect = self.database.get_session(self.db_name)
self.top_k: int = 5
def generate_input_values(self):
try:
from pilot.summary.db_summary_client import DBSummaryClient
except ImportError:
raise ValueError("Could not import DBSummaryClient. ")
client = DBSummaryClient()
input_values = {
"input": self.current_user_input,
"dialect": self.database.dialect,
"table_info": self.database.table_simple_info(self.db_connect),
"supported_chat_type": "" #TODO
# "table_info": client.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k)
}
return input_values
def do_action(self, prompt_response):
### TODO 记录整体信息,处理成功的,和未成功的分开记录处理
report_data: ReportData = ReportData()
chart_datas: List[ChartData] = []
for chart_item in prompt_response:
try:
datas = self.database.run(self.db_connect, chart_item.sql)
chart_data: ChartData = ChartData()
except Exception as e:
# TODO 修复流程
print(str(e))
chart_datas.append(chart_data)
report_data.conv_uid = self.chat_session_id
report_data.template_name = self.report_name
report_data.template_introduce = None
report_data.charts = chart_datas
return report_data

View File

@ -0,0 +1,22 @@
from pydantic import BaseModel, Field
from typing import TypeVar, Union, List, Generic, Any
class ChartData(BaseModel):
chart_uid: str
chart_type: str
chart_sql: str
column_name: List
values: List
style: Any
class ReportData(BaseModel):
conv_uid:str
template_name:str
template_introduce:str
charts: List[ChartData]

View File

@ -0,0 +1,41 @@
import json
import re
from abc import ABC, abstractmethod
from typing import Dict, NamedTuple, List
import pandas as pd
from pilot.utils import build_logger
from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR
class ChartItem(NamedTuple):
sql: str
title:str
thoughts: str
showcase:str
logger = build_logger("webserver", LOGDIR + "ChatDashboardOutputParser.log")
class ChatDashboardOutputParser(BaseOutputParser):
def __init__(self, sep: str, is_stream_out: bool):
super().__init__(sep=sep, is_stream_out=is_stream_out)
def parse_prompt_response(self, model_out_text):
clean_str = super().parse_prompt_response(model_out_text)
print("clean prompt response:", clean_str)
response = json.loads(clean_str)
chart_items = List[ChartItem]
for item in response:
chart_items.append(ChartItem(item["sql"], item["title"], item["thoughts"], item["showcase"]))
return chart_items
def parse_view_response(self, speak, data) -> str:
### TODO
return data
@property
def _type(self) -> str:
return "chat_dashboard"

View File

@ -0,0 +1,49 @@
import json
import importlib
from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config
from pilot.scene.base import ChatScene
from pilot.scene.chat_db.auto_execute.out_parser import DbChatOutputParser, SqlAction
from pilot.common.schema import SeparatorStyle
CFG = Config()
PROMPT_SCENE_DEFINE = """You are a {dialect} data analysis expert, please provide a professional data analysis solution according to the following situations"""
PROMPT_SCENE_DEFINE = None
_DEFAULT_TEMPLATE = """
According to the structure definition in the following tables:
{table_info}
Provide a professional data analysis with as few dimensions as possible, and the upper limit does not exceed 8 dimensions.
Used to support goal: {input}
Use the chart display method in the following range:
{supported_chat_type}
give {dialect} data analysis SQL, analysis title, display method and analytical thinking,respond in the following json format:
{response}
Ensure the response is correct json and can be parsed by Python json.loads
"""
RESPONSE_FORMAT = [{
"sql": "data analysis SQL",
"title": "Data Analysis Title",
"showcase": "What type of charts to show",
"thoughts": "Current thinking and value of data analysis"
}]
PROMPT_SEP = SeparatorStyle.SINGLE.value
PROMPT_NEED_NEED_STREAM_OUT = False
prompt = PromptTemplate(
template_scene=ChatScene.ChatWithDbExecute.value,
input_variables=["input", "table_info", "dialect", "supported_chat_type"],
response_format=json.dumps(RESPONSE_FORMAT, indent=4),
template_define=PROMPT_SCENE_DEFINE,
template=_DEFAULT_TEMPLATE,
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
output_parser=DbChatOutputParser(
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
),
)
CFG.prompt_templates.update({prompt.template_scene: prompt})

View File

@ -0,0 +1,9 @@
{
"title": "Sales Report",
"name": "sale_report",
"introduce": "",
"layout": "TODO",
"supported_chart_type":["HeatMap","sheet", "LineChart", "PieChart", "BarChart"],
"key_metrics":[],
"trends": []
}

View File

@ -22,12 +22,10 @@ class ChatWithDbAutoExecute(BaseChat):
"""Number of results to return from the query""" """Number of results to return from the query"""
def __init__( def __init__(
self, temperature, max_new_tokens, chat_session_id, db_name, user_input self, chat_session_id, db_name, user_input
): ):
""" """ """ """
super().__init__( super().__init__(
temperature=temperature,
max_new_tokens=max_new_tokens,
chat_mode=ChatScene.ChatWithDbExecute, chat_mode=ChatScene.ChatWithDbExecute,
chat_session_id=chat_session_id, chat_session_id=chat_session_id,
current_user_input=user_input, current_user_input=user_input,
@ -57,5 +55,5 @@ class ChatWithDbAutoExecute(BaseChat):
} }
return input_values return input_values
def do_with_prompt_response(self, prompt_response): def do_action(self, prompt_response):
return self.database.run(self.db_connect, prompt_response.sql) return self.database.run(self.db_connect, prompt_response.sql)

View File

@ -20,12 +20,10 @@ class ChatWithDbQA(BaseChat):
"""Number of results to return from the query""" """Number of results to return from the query"""
def __init__( def __init__(
self, temperature, max_new_tokens, chat_session_id, db_name, user_input self, chat_session_id, db_name, user_input
): ):
""" """ """ """
super().__init__( super().__init__(
temperature=temperature,
max_new_tokens=max_new_tokens,
chat_mode=ChatScene.ChatWithDbQA, chat_mode=ChatScene.ChatWithDbQA,
chat_session_id=chat_session_id, chat_session_id=chat_session_id,
current_user_input=user_input, current_user_input=user_input,
@ -66,5 +64,4 @@ class ChatWithDbQA(BaseChat):
} }
return input_values return input_values
def do_with_prompt_response(self, prompt_response):
return prompt_response

View File

@ -22,15 +22,11 @@ class ChatWithPlugin(BaseChat):
def __init__( def __init__(
self, self,
temperature,
max_new_tokens,
chat_session_id, chat_session_id,
user_input, user_input,
plugin_selector: str = None, plugin_selector: str = None,
): ):
super().__init__( super().__init__(
temperature=temperature,
max_new_tokens=max_new_tokens,
chat_mode=ChatScene.ChatExecution, chat_mode=ChatScene.ChatExecution,
chat_session_id=chat_session_id, chat_session_id=chat_session_id,
current_user_input=user_input, current_user_input=user_input,
@ -66,7 +62,7 @@ class ChatWithPlugin(BaseChat):
} }
return input_values return input_values
def do_with_prompt_response(self, prompt_response): def do_action(self, prompt_response):
## plugin command run ## plugin command run
return execute_command( return execute_command(
str(prompt_response.command.get("name")), str(prompt_response.command.get("name")),

View File

@ -30,12 +30,10 @@ class ChatNewKnowledge(BaseChat):
"""Number of results to return from the query""" """Number of results to return from the query"""
def __init__( def __init__(
self, temperature, max_new_tokens, chat_session_id, user_input, knowledge_name self, chat_session_id, user_input, knowledge_name
): ):
""" """ """ """
super().__init__( super().__init__(
temperature=temperature,
max_new_tokens=max_new_tokens,
chat_mode=ChatScene.ChatNewKnowledge, chat_mode=ChatScene.ChatNewKnowledge,
chat_session_id=chat_session_id, chat_session_id=chat_session_id,
current_user_input=user_input, current_user_input=user_input,
@ -67,8 +65,6 @@ class ChatNewKnowledge(BaseChat):
return input_values return input_values
def do_with_prompt_response(self, prompt_response):
return prompt_response
@property @property
def chat_type(self) -> str: def chat_type(self) -> str:

View File

@ -25,16 +25,14 @@ CFG = Config()
class ChatDefaultKnowledge(BaseChat): class ChatDefaultKnowledge(BaseChat):
chat_scene: str = ChatScene.ChatKnowledge.value chat_scene: str = ChatScene.ChatDefaultKnowledge.value
"""Number of results to return from the query""" """Number of results to return from the query"""
def __init__(self, temperature, max_new_tokens, chat_session_id, user_input): def __init__(self, chat_session_id, user_input):
""" """ """ """
super().__init__( super().__init__(
temperature=temperature, chat_mode=ChatScene.ChatDefaultKnowledge,
max_new_tokens=max_new_tokens,
chat_mode=ChatScene.ChatKnowledge,
chat_session_id=chat_session_id, chat_session_id=chat_session_id,
current_user_input=user_input, current_user_input=user_input,
) )
@ -61,9 +59,8 @@ class ChatDefaultKnowledge(BaseChat):
) )
return input_values return input_values
def do_with_prompt_response(self, prompt_response):
return prompt_response
@property @property
def chat_type(self) -> str: def chat_type(self) -> str:
return ChatScene.ChatKnowledge.value return ChatScene.ChatDefaultKnowledge.value

View File

@ -39,7 +39,7 @@ PROMPT_SEP = SeparatorStyle.SINGLE.value
PROMPT_NEED_NEED_STREAM_OUT = True PROMPT_NEED_NEED_STREAM_OUT = True
prompt = PromptTemplate( prompt = PromptTemplate(
template_scene=ChatScene.ChatKnowledge.value, template_scene=ChatScene.ChatDefaultKnowledge.value,
input_variables=["context", "question"], input_variables=["context", "question"],
response_format=None, response_format=None,
template_define=PROMPT_SCENE_DEFINE, template_define=PROMPT_SCENE_DEFINE,

View File

@ -14,8 +14,6 @@ class InnerChatDBSummary(BaseChat):
def __init__( def __init__(
self, self,
temperature,
max_new_tokens,
chat_session_id, chat_session_id,
user_input, user_input,
db_select, db_select,
@ -23,8 +21,6 @@ class InnerChatDBSummary(BaseChat):
): ):
""" """ """ """
super().__init__( super().__init__(
temperature=temperature,
max_new_tokens=max_new_tokens,
chat_mode=ChatScene.InnerChatDBSummary, chat_mode=ChatScene.InnerChatDBSummary,
chat_session_id=chat_session_id, chat_session_id=chat_session_id,
current_user_input=user_input, current_user_input=user_input,
@ -40,8 +36,6 @@ class InnerChatDBSummary(BaseChat):
} }
return input_values return input_values
def do_with_prompt_response(self, prompt_response):
return prompt_response
@property @property
def chat_type(self) -> str: def chat_type(self) -> str:

View File

@ -27,11 +27,9 @@ class ChatUrlKnowledge(BaseChat):
"""Number of results to return from the query""" """Number of results to return from the query"""
def __init__(self, temperature, max_new_tokens, chat_session_id, user_input, url): def __init__(self, chat_session_id, user_input, url):
""" """ """ """
super().__init__( super().__init__(
temperature=temperature,
max_new_tokens=max_new_tokens,
chat_mode=ChatScene.ChatUrlKnowledge, chat_mode=ChatScene.ChatUrlKnowledge,
chat_session_id=chat_session_id, chat_session_id=chat_session_id,
current_user_input=user_input, current_user_input=user_input,
@ -62,8 +60,6 @@ class ChatUrlKnowledge(BaseChat):
input_values = {"context": context, "question": self.current_user_input} input_values = {"context": context, "question": self.current_user_input}
return input_values return input_values
def do_with_prompt_response(self, prompt_response):
return prompt_response
@property @property
def chat_type(self) -> str: def chat_type(self) -> str:

View File

@ -18,11 +18,9 @@ class ChatNormal(BaseChat):
"""Number of results to return from the query""" """Number of results to return from the query"""
def __init__(self, temperature, max_new_tokens, chat_session_id, user_input): def __init__(self, chat_session_id, user_input):
""" """ """ """
super().__init__( super().__init__(
temperature=temperature,
max_new_tokens=max_new_tokens,
chat_mode=ChatScene.ChatNormal, chat_mode=ChatScene.ChatNormal,
chat_session_id=chat_session_id, chat_session_id=chat_session_id,
current_user_input=user_input, current_user_input=user_input,
@ -32,7 +30,7 @@ class ChatNormal(BaseChat):
input_values = {"input": self.current_user_input} input_values = {"input": self.current_user_input}
return input_values return input_values
def do_with_prompt_response(self, prompt_response): def do_action(self, prompt_response):
return prompt_response return prompt_response
@property @property

View File

@ -25,7 +25,8 @@ class OnceConversation:
All the information of a conversation, the current single service in memory, can expand cache and database support distributed services All the information of a conversation, the current single service in memory, can expand cache and database support distributed services
""" """
def __init__(self): def __init__(self, chat_mode):
self.chat_mode: str = chat_mode
self.messages: List[BaseMessage] = [] self.messages: List[BaseMessage] = []
self.start_date: str = "" self.start_date: str = ""
self.chat_order: int = 0 self.chat_order: int = 0
@ -43,12 +44,28 @@ class OnceConversation:
def add_ai_message(self, message: str) -> None: def add_ai_message(self, message: str) -> None:
"""Add an AI message to the store""" """Add an AI message to the store"""
has_message = any(isinstance(instance, AIMessage) for instance in self.messages) has_message = any(isinstance(instance, AIMessage) for instance in self.messages)
if has_message: if has_message:
raise ValueError("Already Have Ai message") self.__update_ai_message(message)
else:
self.messages.append(AIMessage(content=message)) self.messages.append(AIMessage(content=message))
""" """ """ """
def __update_ai_message(self, new_message: str) -> None:
"""
stream out message update
Args:
new_message:
Returns:
"""
for item in self.messages:
if item.type == "ai":
item.content = new_message
def add_view_message(self, message: str) -> None: def add_view_message(self, message: str) -> None:
"""Add an AI message to the store""" """Add an AI message to the store"""
@ -69,6 +86,13 @@ class OnceConversation:
self.session_id = None self.session_id = None
def get_user_message(self):
for once in self.messages:
if isinstance(once, HumanMessage):
return once.content
return ""
def _conversation_to_dic(once: OnceConversation) -> dict: def _conversation_to_dic(once: OnceConversation) -> dict:
start_str: str = "" start_str: str = ""
if once.start_date: if once.start_date:
@ -78,6 +102,7 @@ def _conversation_to_dic(once: OnceConversation) -> dict:
start_str = once.start_date start_str = once.start_date
return { return {
"chat_mode": once.chat_mode,
"chat_order": once.chat_order, "chat_order": once.chat_order,
"start_date": start_str, "start_date": start_str,
"cost": once.cost if once.cost else 0, "cost": once.cost if once.cost else 0,
@ -93,6 +118,7 @@ def conversations_to_dict(conversations: List[OnceConversation]) -> List[dict]:
def conversation_from_dict(once: dict) -> OnceConversation: def conversation_from_dict(once: dict) -> OnceConversation:
conversation = OnceConversation() conversation = OnceConversation()
conversation.cost = once.get("cost", 0) conversation.cost = once.get("cost", 0)
conversation.chat_mode = once.get("chat_mode", "chat_normal")
conversation.tokens = once.get("tokens", 0) conversation.tokens = once.get("tokens", 0)
conversation.start_date = once.get("start_date", "") conversation.start_date = once.get("start_date", "")
conversation.chat_order = int(once.get("chat_order")) conversation.chat_order = int(once.get("chat_order"))

View File

@ -1,15 +1,18 @@
import uuid import uuid
import json
from fastapi import APIRouter, Request, Body, status import asyncio
import time
from fastapi import APIRouter, Request, Body, status, HTTPException, Response
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from fastapi.encoders import jsonable_encoder from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from sse_starlette.sse import EventSourceResponse
from typing import List from typing import List
from pilot.server.api_v1.api_view_model import Result, ConversationVo, MessageVo from pilot.server.api_v1.api_view_model import Result, ConversationVo, MessageVo, ChatSceneVo
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.scene.base_chat import BaseChat from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene from pilot.scene.base import ChatScene
@ -17,6 +20,8 @@ from pilot.scene.chat_factory import ChatFactory
from pilot.configs.model_config import (LOGDIR) from pilot.configs.model_config import (LOGDIR)
from pilot.utils import build_logger from pilot.utils import build_logger
from pilot.scene.base_message import (BaseMessage) from pilot.scene.base_message import (BaseMessage)
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
from pilot.scene.message import OnceConversation
router = APIRouter() router = APIRouter()
CFG = Config() CFG = Config()
@ -28,32 +33,117 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE
message = "" message = ""
for error in exc.errors(): for error in exc.errors():
message += ".".join(error.get("loc")) + ":" + error.get("msg") + ";" message += ".".join(error.get("loc")) + ":" + error.get("msg") + ";"
return Result.faild(message) return Result.faild(msg=message)
@router.get('/v1/chat/dialogue/list', response_model=Result[List[ConversationVo]]) def __get_conv_user_message(conversations: dict):
async def dialogue_list(user_id: str): messages = conversations['messages']
#### TODO for item in messages:
if item['type'] == "human":
conversations = [ConversationVo(conv_uid="123", chat_mode="user", select_param="test1", user_input="message[0]"), return item['data']['content']
ConversationVo(conv_uid="123", chat_mode="user", select_param="test1", user_input="message[0]")] return ""
return Result[ConversationVo].succ(conversations)
@router.post('/v1/chat/dialogue/new', response_model=Result[str]) @router.get('/v1/chat/dialogue/list', response_model=Result[ConversationVo])
async def dialogue_new(user_id: str): async def dialogue_list(response: Response, user_id: str = None):
# 设置CORS头部信息
response.headers['Access-Control-Allow-Origin'] = '*'
response.headers['Access-Control-Allow-Methods'] = 'GET'
response.headers['Access-Control-Request-Headers'] = 'content-type'
dialogues: List = []
datas = DuckdbHistoryMemory.conv_list(user_id)
for item in datas:
conv_uid = item.get("conv_uid")
messages = item.get("messages")
conversations = json.loads(messages)
first_conv: OnceConversation = conversations[0]
conv_vo: ConversationVo = ConversationVo(conv_uid=conv_uid, user_input=__get_conv_user_message(first_conv),
chat_mode=first_conv['chat_mode'])
dialogues.append(conv_vo)
return Result[ConversationVo].succ(dialogues)
@router.post('/v1/chat/dialogue/scenes', response_model=Result[List[ChatSceneVo]])
async def dialogue_scenes():
scene_vos: List[ChatSceneVo] = []
new_modes:List[ChatScene] = [ChatScene.ChatDb, ChatScene.ChatData, ChatScene.ChatDashboard, ChatScene.ChatKnowledge, ChatScene.ChatExecution]
for scene in new_modes:
if not scene.value in [ChatScene.ChatNormal.value, ChatScene.InnerChatDBSummary.value]:
scene_vo = ChatSceneVo(chat_scene=scene.value, scene_name=scene.name, param_title="Selection Param")
scene_vos.append(scene_vo)
return Result.succ(scene_vos)
@router.post('/v1/chat/dialogue/new', response_model=Result[ConversationVo])
async def dialogue_new(chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None):
unique_id = uuid.uuid1() unique_id = uuid.uuid1()
return Result.succ(unique_id) return Result.succ(ConversationVo(conv_uid=str(unique_id), chat_mode=chat_mode))
@router.post('/v1/chat/dialogue/delete') def get_db_list():
async def dialogue_delete(con_uid: str, user_id: str): db = CFG.local_db
#### TODO dbs = db.get_database_list()
params:dict = {}
for name in dbs:
params.update({name: name})
return params
def plugins_select_info():
plugins_infos: dict = {}
for plugin in CFG.plugins:
plugins_infos.update({f"{plugin._name}】=>{plugin._description}": plugin._name})
return plugins_infos
def knowledge_list():
knowledge: dict = {}
### TODO
return knowledge
@router.post('/v1/chat/mode/params/list', response_model=Result[dict])
async def params_list(chat_mode: str = ChatScene.ChatNormal.value):
if ChatScene.ChatDb.value == chat_mode:
return Result.succ(get_db_list())
elif ChatScene.ChatData.value == chat_mode:
return Result.succ(get_db_list())
elif ChatScene.ChatDashboard.value == chat_mode:
return Result.succ(get_db_list())
elif ChatScene.ChatExecution.value == chat_mode:
return Result.succ(plugins_select_info())
elif ChatScene.ChatKnowledge.value == chat_mode:
return Result.succ(knowledge_list())
else:
return Result.succ(None) return Result.succ(None)
@router.post('/v1/chat/completions', response_model=Result[MessageVo]) @router.post('/v1/chat/dialogue/delete')
async def dialogue_delete(con_uid: str):
history_mem = DuckdbHistoryMemory(con_uid)
history_mem.delete()
return Result.succ(None)
@router.get('/v1/chat/dialogue/messages/history', response_model=Result[MessageVo])
async def dialogue_history_messages(con_uid: str):
print(f"dialogue_history_messages:{con_uid}")
message_vos: List[MessageVo] = []
history_mem = DuckdbHistoryMemory(con_uid)
history_messages: List[OnceConversation] = history_mem.get_messages()
if history_messages:
for once in history_messages:
once_message_vos = [message2Vo(element, once['chat_order']) for element in once['messages']]
message_vos.extend(once_message_vos)
return Result.succ(message_vos)
@router.post('/v1/chat/completions')
async def chat_completions(dialogue: ConversationVo = Body()): async def chat_completions(dialogue: ConversationVo = Body()):
print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param}") print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param}")
@ -65,22 +155,31 @@ async def chat_completions(dialogue: ConversationVo = Body()):
"user_input": dialogue.user_input, "user_input": dialogue.user_input,
} }
if ChatScene.ChatWithDbExecute == dialogue.chat_mode: if ChatScene.ChatDb == dialogue.chat_mode:
chat_param.update("db_name", dialogue.select_param) chat_param.update("db_name", dialogue.select_param)
elif ChatScene.ChatWithDbQA == dialogue.chat_mode: elif ChatScene.ChatData == dialogue.chat_mode:
chat_param.update("db_name", dialogue.select_param)
elif ChatScene.ChatDashboard == dialogue.chat_mode:
chat_param.update("db_name", dialogue.select_param) chat_param.update("db_name", dialogue.select_param)
elif ChatScene.ChatExecution == dialogue.chat_mode: elif ChatScene.ChatExecution == dialogue.chat_mode:
chat_param.update("plugin_selector", dialogue.select_param) chat_param.update("plugin_selector", dialogue.select_param)
elif ChatScene.ChatNewKnowledge == dialogue.chat_mode: elif ChatScene.ChatKnowledge == dialogue.chat_mode:
chat_param.update("knowledge_name", dialogue.select_param) chat_param.update("knowledge_name", dialogue.select_param)
elif ChatScene.ChatUrlKnowledge == dialogue.chat_mode:
chat_param.update("url", dialogue.select_param)
chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param) chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param)
if not chat.prompt_template.stream_out: if not chat.prompt_template.stream_out:
return non_stream_response(chat) return non_stream_response(chat)
else: else:
return stream_response(chat) # generator = stream_generator(chat)
# result = Result.succ(data=StreamingResponse(stream_test(), media_type='text/plain'))
# return result
return StreamingResponse(stream_generator(chat), media_type="text/plain")
def stream_test():
for message in ["Hello", "world", "how", "are", "you"]:
yield message
# yield json.dumps(Result.succ(message).__dict__).encode("utf-8")
def stream_generator(chat): def stream_generator(chat):
@ -89,24 +188,28 @@ def stream_generator(chat):
if chunk: if chunk:
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len) msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
chat.current_message.add_ai_message(msg) chat.current_message.add_ai_message(msg)
messageVos = [message2Vo(element) for element in chat.current_message.messages] yield msg
yield Result.succ(messageVos) # chat.current_message.add_ai_message(msg)
def stream_response(chat): # vo = MessageVo(role="view", context=msg, order=chat.current_message.chat_order)
logger.info("stream out start!") # json_text = json.dumps(vo.__dict__)
api_response = StreamingResponse(stream_generator(chat), media_type="application/json") # yield json_text.encode('utf-8')
return api_response chat.memory.append(chat.current_message)
# def stream_response(chat):
# logger.info("stream out start!")
# api_response = StreamingResponse(stream_generator(chat), media_type="application/json")
# return api_response
def message2Vo(message: dict, order) -> MessageVo:
# message.additional_kwargs['time_stamp'] if message.additional_kwargs["time_stamp"] else 0
return MessageVo(role=message['type'], context=message['data']['content'], order=order)
def message2Vo(message:BaseMessage)->MessageVo:
vo:MessageVo = MessageVo()
vo.role = message.type
vo.role = message.content
vo.time_stamp = message.additional_kwargs.time_stamp if message.additional_kwargs["time_stamp"] else 0
def non_stream_response(chat): def non_stream_response(chat):
logger.info("not stream out, wait model response!") logger.info("not stream out, wait model response!")
chat.nostream_call() return chat.nostream_call()
messageVos = [message2Vo(element) for element in chat.current_message.messages]
return Result.succ(messageVos)
@router.get('/v1/db/types', response_model=Result[str]) @router.get('/v1/db/types', response_model=Result[str])

View File

@ -1,28 +1,33 @@
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import TypeVar, Union, List, Generic from typing import TypeVar, Union, List, Generic, Any
T = TypeVar('T') T = TypeVar('T')
class Result(Generic[T], BaseModel): class Result(Generic[T], BaseModel):
success: bool success: bool
err_code: str err_code: str = None
err_msg: str err_msg: str = None
data: List[T] data: T = None
@classmethod @classmethod
def succ(cls, data: List[T]): def succ(cls, data: T):
return Result(True, None, None, data) return Result(success=True, err_code=None, err_msg=None, data=data)
@classmethod @classmethod
def faild(cls, msg): def faild(cls, msg):
return Result(True, "E000X", msg, None) return Result(success=False, err_code="E000X", err_msg=msg, data=None)
@classmethod @classmethod
def faild(cls, code, msg): def faild(cls, code, msg):
return Result(True, code, msg, None) return Result(success=False, err_code=code, err_msg=msg, data=None)
class ChatSceneVo(BaseModel):
chat_scene: str = Field(..., description="chat_scene")
scene_name: str = Field(..., description="chat_scene name show for user")
param_title: str = Field(..., description="chat_scene required parameter title")
class ConversationVo(BaseModel): class ConversationVo(BaseModel):
""" """
dialogue_uid dialogue_uid
@ -31,15 +36,21 @@ class ConversationVo(BaseModel):
""" """
user input user input
""" """
user_input: str user_input: str = ""
"""
user
"""
user_name: str = ""
""" """
the scene of chat the scene of chat
""" """
chat_mode: str = Field(..., description="the scene of chat ") chat_mode: str = Field(..., description="the scene of chat ")
""" """
chat scene select param chat scene select param
""" """
select_param: str select_param: str = None
class MessageVo(BaseModel): class MessageVo(BaseModel):
@ -51,7 +62,12 @@ class MessageVo(BaseModel):
current message current message
""" """
context: str context: str
""" message postion order """
order: int
""" """
time the current message was sent time the current message was sent
""" """
time_stamp: float time_stamp: Any = None

View File

@ -1,6 +1,5 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import signal
import threading import threading
import traceback import traceback
import argparse import argparse
@ -12,12 +11,10 @@ import uuid
import gradio as gr import gradio as gr
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH) sys.path.append(ROOT_PATH)
from pilot.summary.db_summary_client import DBSummaryClient from pilot.summary.db_summary_client import DBSummaryClient
from pilot.commands.command_mange import CommandRegistry
from pilot.scene.base_chat import BaseChat from pilot.scene.base_chat import BaseChat
@ -25,8 +22,8 @@ from pilot.configs.config import Config
from pilot.configs.model_config import ( from pilot.configs.model_config import (
DATASETS_DIR, DATASETS_DIR,
KNOWLEDGE_UPLOAD_ROOT_PATH, KNOWLEDGE_UPLOAD_ROOT_PATH,
LOGDIR,
LLM_MODEL_CONFIG, LLM_MODEL_CONFIG,
LOGDIR,
) )
from pilot.conversation import ( from pilot.conversation import (
@ -35,11 +32,10 @@ from pilot.conversation import (
chat_mode_title, chat_mode_title,
default_conversation, default_conversation,
) )
from pilot.common.plugins import scan_plugins, load_native_plugins
from pilot.server.gradio_css import code_highlight_css from pilot.server.gradio_css import code_highlight_css
from pilot.server.gradio_patch import Chatbot as grChatbot from pilot.server.gradio_patch import Chatbot as grChatbot
from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
from pilot.utils import build_logger from pilot.utils import build_logger
from pilot.vector_store.extract_tovec import ( from pilot.vector_store.extract_tovec import (
get_vector_storelist, get_vector_storelist,
@ -49,6 +45,20 @@ from pilot.vector_store.extract_tovec import (
from pilot.scene.base import ChatScene from pilot.scene.base import ChatScene
from pilot.scene.chat_factory import ChatFactory from pilot.scene.chat_factory import ChatFactory
from pilot.language.translation_handler import get_lang_text from pilot.language.translation_handler import get_lang_text
from pilot.server.webserver_base import server_init
import uvicorn
from fastapi import BackgroundTasks, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from fastapi import FastAPI, applications
from fastapi.openapi.docs import get_swagger_ui_html
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from pilot.server.api_v1.api_v1 import router as api_v1, validation_exception_handler
# 加载插件 # 加载插件
CFG = Config() CFG = Config()
@ -95,6 +105,30 @@ knowledge_qa_type_list = [
add_knowledge_base_dialogue, add_knowledge_base_dialogue,
] ]
def swagger_monkey_patch(*args, **kwargs):
return get_swagger_ui_html(
*args, **kwargs,
swagger_js_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js',
swagger_css_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css'
)
applications.get_swagger_ui_html = swagger_monkey_patch
app = FastAPI()
origins = ["*"]
# 添加跨域中间件
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
allow_headers=["*"],
)
# app.mount("static", StaticFiles(directory="static"), name="static")
app.include_router(api_v1)
app.add_exception_handler(RequestValidationError, validation_exception_handler)
def get_simlar(q): def get_simlar(q):
docsearch = knownledge_tovec_st(os.path.join(DATASETS_DIR, "plan.md")) docsearch = knownledge_tovec_st(os.path.join(DATASETS_DIR, "plan.md"))
@ -216,7 +250,7 @@ def get_chat_mode(selected, param=None) -> ChatScene:
else: else:
mode = param mode = param
if mode == conversation_types["default_knownledge"]: if mode == conversation_types["default_knownledge"]:
return ChatScene.ChatKnowledge return ChatScene.ChatDefaultKnowledge
elif mode == conversation_types["custome"]: elif mode == conversation_types["custome"]:
return ChatScene.ChatNewKnowledge return ChatScene.ChatNewKnowledge
elif mode == conversation_types["url"]: elif mode == conversation_types["url"]:
@ -286,7 +320,7 @@ def http_bot(
"chat_session_id": state.conv_id, "chat_session_id": state.conv_id,
"user_input": state.last_user_input, "user_input": state.last_user_input,
} }
elif ChatScene.ChatKnowledge == scene: elif ChatScene.ChatDefaultKnowledge == scene:
chat_param = { chat_param = {
"temperature": temperature, "temperature": temperature,
"max_new_tokens": max_new_tokens, "max_new_tokens": max_new_tokens,
@ -324,15 +358,14 @@ def http_bot(
response = chat.stream_call() response = chat.stream_call()
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk: if chunk:
state.messages[-1][ msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
-1 state.messages[-1][-1] =msg
] = chat.prompt_template.output_parser.parse_model_stream_resp_ex( chat.current_message.add_ai_message(msg)
chunk, chat.skip_echo_len
)
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
chat.memory.append(chat.current_message)
except Exception as e: except Exception as e:
print(traceback.format_exc()) print(traceback.format_exc())
state.messages[-1][-1] = "Error:" + str(e) state.messages[-1][-1] = f"""<span style=\"color:red\">ERROR!</span>{str(e)} """
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
@ -632,7 +665,7 @@ def knowledge_embedding_store(vs_id, files):
) )
knowledge_embedding_client = KnowledgeEmbedding( knowledge_embedding_client = KnowledgeEmbedding(
file_path=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename), file_path=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename),
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], model_name=LLM_MODEL_CONFIG["text2vec"],
vector_store_config={ vector_store_config={
"vector_store_name": vector_store_name["vs_name"], "vector_store_name": vector_store_name["vs_name"],
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
@ -657,42 +690,25 @@ def signal_handler(sig, frame):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model_list_mode", type=str, default="once", choices=["once", "reload"])
parser.add_argument('-new', '--new', action='store_true', help='enable new http mode')
# old version server config
parser.add_argument("--host", type=str, default="0.0.0.0") parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=CFG.WEB_SERVER_PORT) parser.add_argument("--port", type=int, default=CFG.WEB_SERVER_PORT)
parser.add_argument("--concurrency-count", type=int, default=10) parser.add_argument("--concurrency-count", type=int, default=10)
parser.add_argument(
"--model-list-mode", type=str, default="once", choices=["once", "reload"]
)
parser.add_argument("--share", default=False, action="store_true") parser.add_argument("--share", default=False, action="store_true")
# init server config
args = parser.parse_args() args = parser.parse_args()
logger.info(f"args: {args}") server_init(args)
# init config if args.new:
cfg = Config() import uvicorn
uvicorn.run(app, host="0.0.0.0", port=5000)
load_native_plugins(cfg) else:
dbs = cfg.local_db.get_database_list() ### Compatibility mode starts the old version server by default
signal.signal(signal.SIGINT, signal_handler)
async_db_summery()
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
# Loader plugins and commands
command_categories = [
"pilot.commands.built_in.audio_text",
"pilot.commands.built_in.image_gen",
]
# exclude commands
command_categories = [
x for x in command_categories if x not in cfg.disabled_command_categories
]
command_registry = CommandRegistry()
for command_category in command_categories:
command_registry.import_commands(command_category)
cfg.command_registry = command_registry
logger.info(args)
demo = build_webdemo() demo = build_webdemo()
demo.queue( demo.queue(
concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False
@ -702,3 +718,8 @@ if __name__ == "__main__":
share=args.share, share=args.share,
max_threads=200, max_threads=200,
) )