fix:merge tuyang branch and knowledge chat

1.fix knowledge chat
2.merge tuyang branch
This commit is contained in:
aries_ckt 2023-06-29 20:03:39 +08:00
commit 45d183d50b
22 changed files with 200 additions and 95 deletions

View File

@ -19,7 +19,7 @@ const AgentPage = (props) => {
});
const { history, handleChatSubmit } = useAgentChat({
queryAgentURL: `http://30.183.154.8:5000/v1/chat/completions`,
queryAgentURL: `http://localhost:5000/v1/chat/completions`,
queryBody: {
conv_uid: props.params?.agentId,
chat_mode: props.searchParams?.scene || 'chat_normal',

View File

@ -16,7 +16,7 @@ const Item = styled(Sheet)(({ theme }) => ({
const Agents = () => {
const { handleChatSubmit, history } = useAgentChat({
queryAgentURL: `http://30.183.154.8:5000/v1/chat/completions`,
queryAgentURL: `http://localhost:5000/v1/chat/completions`,
});
const data = [

View File

@ -2,7 +2,7 @@ import { message } from 'antd';
import axios from 'axios';
import { isPlainObject } from 'lodash';
axios.defaults.baseURL = 'http://30.183.154.8:5000';
axios.defaults.baseURL = 'http://localhost:5000';
axios.defaults.timeout = 10000;

View File

@ -0,0 +1,10 @@
from pilot.common.sql_database import Database
from pilot.configs.config import Config
CFG = Config()
if __name__ == "__main__":
connect = CFG.local_db.get_session("gpt-user")
datas = CFG.local_db.run(connect, "SELECT * FROM users; ")
print(datas)

View File

@ -6,9 +6,9 @@ default_db_path = os.path.join(os.getcwd(), "message")
duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/chat_history.db")
if __name__ == "__main__":
if os.path.isfile(duckdb_path):
cursor = duckdb.connect(duckdb_path).cursor()
if os.path.isfile("../../../message/chat_history.db"):
cursor = duckdb.connect("../../../message/chat_history.db").cursor()
# cursor.execute("SELECT * FROM chat_history limit 20")
cursor.execute("SELECT * FROM chat_history limit 20")
cursor.execute("SELECT * FROM chat_history where conv_uid ='b54ae5fe-1624-11ee-a271-b26789cc3e58'")
data = cursor.fetchall()
print(data)

View File

@ -25,6 +25,11 @@ class BaseChatHistoryMemory(ABC):
def messages(self) -> List[OnceConversation]: # type: ignore
"""Retrieve the messages from the local file"""
@abstractmethod
def create(self, user_name:str) -> None:
"""Append the message to the record in the local file"""
@abstractmethod
def append(self, message: OnceConversation) -> None:
"""Append the message to the record in the local file"""

View File

@ -36,8 +36,7 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
if not result:
# 如果表不存在,则创建新表
self.connect.execute(
"CREATE TABLE chat_history (conv_uid VARCHAR(100) PRIMARY KEY, user_name VARCHAR(100), messages TEXT)"
)
"CREATE TABLE chat_history (conv_uid VARCHAR(100) PRIMARY KEY, chat_mode VARCHAR(50), summary VARCHAR(255), user_name VARCHAR(100), messages TEXT)")
def __get_messages_by_conv_uid(self, conv_uid: str):
cursor = self.connect.cursor()
@ -55,6 +54,17 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
return conversations
return []
def create(self, chat_mode, summary: str, user_name: str) -> None:
try:
cursor = self.connect.cursor()
cursor.execute(
"INSERT INTO chat_history(conv_uid, chat_mode summary, user_name, messages)VALUES(?,?,?,?,?)",
[self.chat_seesion_id, chat_mode, summary, user_name, ""])
cursor.commit()
self.connect.commit()
except Exception as e:
print("init create conversation log error" + str(e))
def append(self, once_message: OnceConversation) -> None:
context = self.__get_messages_by_conv_uid(self.chat_seesion_id)
conversations: List[OnceConversation] = []
@ -69,13 +79,8 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
)
else:
cursor.execute(
"INSERT INTO chat_history(conv_uid, user_name, messages)VALUES(?,?,?)",
[
self.chat_seesion_id,
"",
json.dumps(conversations, ensure_ascii=False),
],
)
"INSERT INTO chat_history(conv_uid, chat_mode, summary, user_name, messages)VALUES(?,?,?,?,?)",
[self.chat_seesion_id, once_message.chat_mode, once_message.get_user_conv().content, "",json.dumps(conversations, ensure_ascii=False)])
cursor.commit()
self.connect.commit()
@ -125,5 +130,6 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
)
context = cursor.fetchone()
if context:
return json.loads(context[0])
if context[0]:
return json.loads(context[0])
return None

View File

@ -63,6 +63,39 @@ def __get_conv_user_message(conversations: dict):
return ""
def __new_conversation(chat_mode, user_id) -> ConversationVo:
unique_id = uuid.uuid1()
history_mem = DuckdbHistoryMemory(str(unique_id))
return ConversationVo(conv_uid=str(unique_id), chat_mode=chat_mode)
def get_db_list():
db = CFG.local_db
dbs = db.get_database_list()
params: dict = {}
for name in dbs:
params.update({name: name})
return params
def plugins_select_info():
plugins_infos: dict = {}
for plugin in CFG.plugins:
plugins_infos.update({f"{plugin._name}】=>{plugin._description}": plugin._name})
return plugins_infos
def knowledge_list():
"""return knowledge space list"""
params: dict = {}
request = KnowledgeSpaceRequest()
spaces = knowledge_service.get_knowledge_space(request)
for space in spaces:
params.update({space.name: space.name})
return params
@router.get("/v1/chat/dialogue/list", response_model=Result[ConversationVo])
async def dialogue_list(response: Response, user_id: str = None):
# 设置CORS头部信息
@ -75,14 +108,15 @@ async def dialogue_list(response: Response, user_id: str = None):
for item in datas:
conv_uid = item.get("conv_uid")
messages = item.get("messages")
conversations = json.loads(messages)
summary = item.get("summary")
chat_mode = item.get("chat_mode")
first_conv: OnceConversation = conversations[0]
conv_vo: ConversationVo = ConversationVo(
conv_uid=conv_uid,
user_input=__get_conv_user_message(first_conv),
chat_mode=first_conv["chat_mode"],
user_input=summary,
chat_mode=chat_mode,
)
dialogues.append(conv_vo)
@ -113,39 +147,13 @@ async def dialogue_scenes():
return Result.succ(scene_vos)
@router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo])
async def dialogue_new(
chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None
):
unique_id = uuid.uuid1()
return Result.succ(ConversationVo(conv_uid=str(unique_id), chat_mode=chat_mode))
def get_db_list():
db = CFG.local_db
dbs = db.get_database_list()
params: dict = {}
for name in dbs:
params.update({name: name})
return params
def plugins_select_info():
plugins_infos: dict = {}
for plugin in CFG.plugins:
plugins_infos.update({f"{plugin._name}】=>{plugin._description}": plugin._name})
return plugins_infos
def knowledge_list():
"""return knowledge space list"""
params: dict = {}
request = KnowledgeSpaceRequest()
spaces = knowledge_service.get_knowledge_space(request)
for space in spaces:
params.update({space.name: space.name})
return params
conv_vo = __new_conversation(chat_mode, user_id)
return Result.succ(conv_vo)
@router.post("/v1/chat/mode/params/list", response_model=Result[dict])
async def params_list(chat_mode: str = ChatScene.ChatNormal.value):
@ -192,7 +200,8 @@ async def chat_completions(dialogue: ConversationVo = Body()):
if not dialogue.chat_mode:
dialogue.chat_mode = ChatScene.ChatNormal.value
if not dialogue.conv_uid:
dialogue.conv_uid = str(uuid.uuid1())
conv_vo = __new_conversation(dialogue.chat_mode, dialogue.user_name)
dialogue.conv_uid = conv_vo.conv_uid
global model_semaphore, global_counter
global_counter += 1
@ -272,14 +281,15 @@ async def stream_generator(chat):
else:
for chunk in model_response:
if chunk:
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
chunk, chat.skip_echo_len
)
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
chat.current_message.add_ai_message(msg)
msg = msg.replace("\n", "\\n")
yield f"data:{msg}\n\n"
await asyncio.sleep(0.1)
chat.current_message.add_ai_message(msg)
chat.current_message.add_view_message(msg)
chat.memory.append(chat.current_message)

View File

@ -95,7 +95,7 @@ class BaseOutputParser(ABC):
yield output
def parse_model_nostream_resp(self, response, sep: str):
text = response.text.strip()
text = response.strip()
text = text.rstrip()
text = text.strip(b"\x00".decode())
respObj_ex = json.loads(text)

View File

@ -6,7 +6,7 @@ from pilot.common.schema import ExampleType
class ExampleSelector(BaseModel, ABC):
examples_record: List[List]
examples_record: List[dict]
use_example: bool = False
type: str = ExampleType.ONE_SHOT.value
@ -16,7 +16,7 @@ class ExampleSelector(BaseModel, ABC):
else:
return self.__few_shot_context(count)
def __few_shot_context(self, count: int = 2) -> List[List]:
def __few_shot_context(self, count: int = 2) -> List[dict]:
"""
Use 2 or more examples, default 2
Returns: example text
@ -26,7 +26,7 @@ class ExampleSelector(BaseModel, ABC):
return need_use
return None
def __one_show_context(self) -> List:
def __one_show_context(self) -> dict:
"""
Use one examples
Returns:

View File

@ -86,6 +86,11 @@ class BaseChat(ABC):
extra = Extra.forbid
arbitrary_types_allowed = True
def __init_history_message(self):
self.history_message == self.memory.messages()
if not self.history_message:
self.memory.create(self.current_user_input, "")
@property
def chat_type(self) -> str:
raise NotImplementedError("Not supported for this chat type.")
@ -102,12 +107,9 @@ class BaseChat(ABC):
### Chat sequence advance
self.current_message.chat_order = len(self.history_message) + 1
self.current_message.add_user_message(self.current_user_input)
self.current_message.start_date = datetime.datetime.now().strftime(
"%Y-%m-%d %H:%M:%S"
)
# TODO
self.current_message.tokens = 0
self.current_message.start_date = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
self.current_message.tokens = 0
if self.prompt_template.template:
current_prompt = self.prompt_template.format(**input_values)
self.current_message.add_system_message(current_prompt)
@ -146,8 +148,8 @@ class BaseChat(ABC):
self.current_message.add_view_message(
f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """
)
### store current conversation
self.memory.append(self.current_message)
### store current conversation
self.memory.append(self.current_message)
def nostream_call(self):
payload = self.__call_base()

View File

@ -65,6 +65,12 @@ class ChatDashboard(BaseChat):
try:
datas = self.database.run(self.db_connect, chart_item.sql)
chart_data: ChartData = ChartData()
chart_data.chart_sql = chart_item['sql']
chart_data.chart_type = chart_item['showcase']
chart_data.chart_name = chart_item['title']
chart_data.chart_desc = chart_item['thoughts']
chart_data.column_name = datas[0]
chart_data.values =datas
except Exception as e:
# TODO 修复流程
print(str(e))

View File

@ -4,7 +4,9 @@ from typing import TypeVar, Union, List, Generic, Any
class ChartData(BaseModel):
chart_uid: str
chart_name: str
chart_type: str
chart_desc: str
chart_sql: str
column_name: List
values: List

View File

@ -54,4 +54,5 @@ class ChatWithDbAutoExecute(BaseChat):
return input_values
def do_action(self, prompt_response):
print(f"do_action:{prompt_response}")
return self.database.run(self.db_connect, prompt_response.sql)

View File

@ -0,0 +1,37 @@
from pilot.prompts.example_base import ExampleSelector
from pilot.common.schema import ExampleType
## Two examples are defined by default
EXAMPLES = [
{
"messages": [
{"type": "human", "data": {"content": "查询用户test1所在的城市", "example": True}},
{
"type": "ai",
"data": {
"content": """{
\"thoughts\": \"thought text\",
\"sql\": \"SELECT city FROM users where user_name='test1'\",
}""",
"example": True,
},
},
]
},
{
"messages": [
{"type": "human", "data": {"content": "查询成都的用户的订单信息", "example": True}},
{
"type": "ai",
"data": {
"content": """{
\"thoughts\": \"thought text\",
\"sql\": \"SELECT b.* FROM users a LEFT JOIN tran_order b ON a.user_name=b.user_name where a.city='成都'\",
}""",
"example": True,
},
},
]
},
]
sql_data_example = ExampleSelector(examples_record=EXAMPLES, use_example=True, type=ExampleType.ONE_SHOT.value)

View File

@ -6,8 +6,9 @@ import pandas as pd
from pilot.utils import build_logger
from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR
from pilot.configs.config import Config
CFG = Config()
class SqlAction(NamedTuple):
sql: str
thoughts: Dict
@ -32,11 +33,16 @@ class DbChatOutputParser(BaseOutputParser):
if len(data) <= 1:
data.insert(0, ["result"])
df = pd.DataFrame(data[1:], columns=data[0])
table_style = """<style>
table{border-collapse:collapse;width:100%;height:80%;margin:0 auto;float:center;border: 1px solid #007bff; background-color:#333; color:#fff}th,td{border:1px solid #ddd;padding:3px;text-align:center}th{background-color:#C9C3C7;color: #fff;font-weight: bold;}tr:nth-child(even){background-color:#444}tr:hover{background-color:#444}
</style>"""
html_table = df.to_html(index=False, escape=False)
html = f"<html><head>{table_style}</head><body>{html_table}</body></html>"
if CFG.NEW_SERVER_MODE:
html = df.to_html(index=False, escape=False, sparsify=False)
html = ''.join(html.split())
else:
table_style = """<style>
table{border-collapse:collapse;width:100%;height:80%;margin:0 auto;float:center;border: 1px solid #007bff; background-color:#333; color:#fff}th,td{border:1px solid #ddd;padding:3px;text-align:center}th{background-color:#C9C3C7;color: #fff;font-weight: bold;}tr:nth-child(even){background-color:#444}tr:hover{background-color:#444}
</style>"""
html_table = df.to_html(index=False, escape=False)
html = f"<html><head>{table_style}</head><body>{html_table}</body></html>"
view_text = f"##### {str(speak)}" + "\n" + html.replace("\n", " ")
return view_text

View File

@ -1,10 +1,10 @@
import json
import importlib
from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config
from pilot.scene.base import ChatScene
from pilot.scene.chat_db.auto_execute.out_parser import DbChatOutputParser, SqlAction
from pilot.common.schema import SeparatorStyle
from pilot.scene.chat_db.auto_execute.example import sql_data_example
CFG = Config()
@ -12,35 +12,21 @@ PROMPT_SCENE_DEFINE = None
_DEFAULT_TEMPLATE = """
You are a SQL expert. Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
You are a SQL expert. Given an input question, create a syntactically correct {dialect} query.
Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results.
Use as few tables as possible when querying.
When generating insert, delete, update, or replace SQL, please make sure to use the data given by the human, and cannot use any unknown data. If you do not get enough information, speak to user: I dont have enough data complete your request.
Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
"""
PROMPT_SUFFIX = """Only use the following tables generate sql:
Only use the following tables schema to generate sql:
{table_info}
Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Question: {input}
"""
PROMPT_RESPONSE = """You must respond in JSON format as following format:
Rrespond in JSON format as following format:
{response}
Ensure the response is correct json and can be parsed by Python json.loads
"""
RESPONSE_FORMAT = {
"thoughts": {
"reasoning": "reasoning",
"speak": "thoughts summary to say to user",
},
"sql": "SQL Query to run",
}
RESPONSE_FORMAT_SIMPLE = {
"thoughts": "thoughts summary to say to user",
"sql": "SQL Query to run",
@ -55,10 +41,11 @@ prompt = PromptTemplate(
input_variables=["input", "table_info", "dialect", "top_k", "response"],
response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, indent=4),
template_define=PROMPT_SCENE_DEFINE,
template=_DEFAULT_TEMPLATE + PROMPT_SUFFIX + PROMPT_RESPONSE,
template=_DEFAULT_TEMPLATE,
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
output_parser=DbChatOutputParser(
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
),
example_selector=sql_data_example
)
CFG.prompt_templates.update({prompt.template_scene: prompt})

View File

@ -63,6 +63,7 @@ class ChatWithPlugin(BaseChat):
return input_values
def do_action(self, prompt_response):
print(f"do_action:{prompt_response}")
## plugin command run
return execute_command(
str(prompt_response.command.get("name")),

View File

@ -2,8 +2,38 @@ from pilot.prompts.example_base import ExampleSelector
## Two examples are defined by default
EXAMPLES = [
[{"system": "123"}, {"system": "xxx"}, {"human": "xxx"}, {"assistant": "xxx"}],
[{"system": "123"}, {"system": "xxx"}, {"human": "xxx"}, {"assistant": "xxx"}],
{
"messages": [
{"type": "human", "data": {"content": "查询xxx", "example": True}},
{
"type": "ai",
"data": {
"content": """{
\"thoughts\": \"thought text\",
\"speak\": \"thoughts summary to say to user\",
\"command\": {\"name\": \"command name\", \"args\": {\"arg name\": \"value\"}},
}""",
"example": True,
},
},
]
},
{
"messages": [
{"type": "human", "data": {"content": "查询xxx", "example": True}},
{
"type": "ai",
"data": {
"content": """{
\"thoughts\": \"thought text\",
\"speak\": \"thoughts summary to say to user\",
\"command\": {\"name\": \"command name\", \"args\": {\"arg name\": \"value\"}},
}""",
"example": True,
},
},
]
},
]
plugin_example = ExampleSelector(examples_record=EXAMPLES, use_example=True)

View File

@ -9,6 +9,7 @@ from pilot.scene.chat_db.auto_execute.chat import ChatWithDbAutoExecute
from pilot.scene.chat_knowledge.url.chat import ChatUrlKnowledge
from pilot.scene.chat_knowledge.custom.chat import ChatNewKnowledge
from pilot.scene.chat_knowledge.default.chat import ChatDefaultKnowledge
from pilot.scene.chat_knowledge.v1.chat import ChatKnowledge
from pilot.scene.chat_knowledge.inner_db_summary.chat import InnerChatDBSummary

View File

@ -18,7 +18,7 @@ from pilot.configs.model_config import (
LOGDIR,
)
from pilot.scene.chat_knowledge.default.prompt import prompt
from pilot.scene.chat_knowledge.v1.prompt import prompt
from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding
CFG = Config()

View File

@ -12,6 +12,7 @@ logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
class NormalChatOutputParser(BaseOutputParser):
def parse_prompt_response(self, model_out_text) -> T:
return model_out_text