mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-10 12:42:34 +00:00
fix:merge tuyang branch and knowledge chat
1.fix knowledge chat 2.merge tuyang branch
This commit is contained in:
commit
45d183d50b
@ -19,7 +19,7 @@ const AgentPage = (props) => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
const { history, handleChatSubmit } = useAgentChat({
|
const { history, handleChatSubmit } = useAgentChat({
|
||||||
queryAgentURL: `http://30.183.154.8:5000/v1/chat/completions`,
|
queryAgentURL: `http://localhost:5000/v1/chat/completions`,
|
||||||
queryBody: {
|
queryBody: {
|
||||||
conv_uid: props.params?.agentId,
|
conv_uid: props.params?.agentId,
|
||||||
chat_mode: props.searchParams?.scene || 'chat_normal',
|
chat_mode: props.searchParams?.scene || 'chat_normal',
|
||||||
|
@ -16,7 +16,7 @@ const Item = styled(Sheet)(({ theme }) => ({
|
|||||||
|
|
||||||
const Agents = () => {
|
const Agents = () => {
|
||||||
const { handleChatSubmit, history } = useAgentChat({
|
const { handleChatSubmit, history } = useAgentChat({
|
||||||
queryAgentURL: `http://30.183.154.8:5000/v1/chat/completions`,
|
queryAgentURL: `http://localhost:5000/v1/chat/completions`,
|
||||||
});
|
});
|
||||||
|
|
||||||
const data = [
|
const data = [
|
||||||
|
@ -2,7 +2,7 @@ import { message } from 'antd';
|
|||||||
import axios from 'axios';
|
import axios from 'axios';
|
||||||
import { isPlainObject } from 'lodash';
|
import { isPlainObject } from 'lodash';
|
||||||
|
|
||||||
axios.defaults.baseURL = 'http://30.183.154.8:5000';
|
axios.defaults.baseURL = 'http://localhost:5000';
|
||||||
|
|
||||||
axios.defaults.timeout = 10000;
|
axios.defaults.timeout = 10000;
|
||||||
|
|
||||||
|
10
pilot/connections/rdbms/py_study/study_data.py
Normal file
10
pilot/connections/rdbms/py_study/study_data.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
from pilot.common.sql_database import Database
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
connect = CFG.local_db.get_session("gpt-user")
|
||||||
|
datas = CFG.local_db.run(connect, "SELECT * FROM users; ")
|
||||||
|
|
||||||
|
print(datas)
|
@ -6,9 +6,9 @@ default_db_path = os.path.join(os.getcwd(), "message")
|
|||||||
duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/chat_history.db")
|
duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/chat_history.db")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
if os.path.isfile(duckdb_path):
|
if os.path.isfile("../../../message/chat_history.db"):
|
||||||
cursor = duckdb.connect(duckdb_path).cursor()
|
cursor = duckdb.connect("../../../message/chat_history.db").cursor()
|
||||||
# cursor.execute("SELECT * FROM chat_history limit 20")
|
# cursor.execute("SELECT * FROM chat_history limit 20")
|
||||||
cursor.execute("SELECT * FROM chat_history limit 20")
|
cursor.execute("SELECT * FROM chat_history where conv_uid ='b54ae5fe-1624-11ee-a271-b26789cc3e58'")
|
||||||
data = cursor.fetchall()
|
data = cursor.fetchall()
|
||||||
print(data)
|
print(data)
|
@ -25,6 +25,11 @@ class BaseChatHistoryMemory(ABC):
|
|||||||
def messages(self) -> List[OnceConversation]: # type: ignore
|
def messages(self) -> List[OnceConversation]: # type: ignore
|
||||||
"""Retrieve the messages from the local file"""
|
"""Retrieve the messages from the local file"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create(self, user_name:str) -> None:
|
||||||
|
"""Append the message to the record in the local file"""
|
||||||
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def append(self, message: OnceConversation) -> None:
|
def append(self, message: OnceConversation) -> None:
|
||||||
"""Append the message to the record in the local file"""
|
"""Append the message to the record in the local file"""
|
||||||
|
@ -36,8 +36,7 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
|||||||
if not result:
|
if not result:
|
||||||
# 如果表不存在,则创建新表
|
# 如果表不存在,则创建新表
|
||||||
self.connect.execute(
|
self.connect.execute(
|
||||||
"CREATE TABLE chat_history (conv_uid VARCHAR(100) PRIMARY KEY, user_name VARCHAR(100), messages TEXT)"
|
"CREATE TABLE chat_history (conv_uid VARCHAR(100) PRIMARY KEY, chat_mode VARCHAR(50), summary VARCHAR(255), user_name VARCHAR(100), messages TEXT)")
|
||||||
)
|
|
||||||
|
|
||||||
def __get_messages_by_conv_uid(self, conv_uid: str):
|
def __get_messages_by_conv_uid(self, conv_uid: str):
|
||||||
cursor = self.connect.cursor()
|
cursor = self.connect.cursor()
|
||||||
@ -55,6 +54,17 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
|||||||
return conversations
|
return conversations
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
def create(self, chat_mode, summary: str, user_name: str) -> None:
|
||||||
|
try:
|
||||||
|
cursor = self.connect.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
"INSERT INTO chat_history(conv_uid, chat_mode summary, user_name, messages)VALUES(?,?,?,?,?)",
|
||||||
|
[self.chat_seesion_id, chat_mode, summary, user_name, ""])
|
||||||
|
cursor.commit()
|
||||||
|
self.connect.commit()
|
||||||
|
except Exception as e:
|
||||||
|
print("init create conversation log error!" + str(e))
|
||||||
|
|
||||||
def append(self, once_message: OnceConversation) -> None:
|
def append(self, once_message: OnceConversation) -> None:
|
||||||
context = self.__get_messages_by_conv_uid(self.chat_seesion_id)
|
context = self.__get_messages_by_conv_uid(self.chat_seesion_id)
|
||||||
conversations: List[OnceConversation] = []
|
conversations: List[OnceConversation] = []
|
||||||
@ -69,13 +79,8 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"INSERT INTO chat_history(conv_uid, user_name, messages)VALUES(?,?,?)",
|
"INSERT INTO chat_history(conv_uid, chat_mode, summary, user_name, messages)VALUES(?,?,?,?,?)",
|
||||||
[
|
[self.chat_seesion_id, once_message.chat_mode, once_message.get_user_conv().content, "",json.dumps(conversations, ensure_ascii=False)])
|
||||||
self.chat_seesion_id,
|
|
||||||
"",
|
|
||||||
json.dumps(conversations, ensure_ascii=False),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
cursor.commit()
|
cursor.commit()
|
||||||
self.connect.commit()
|
self.connect.commit()
|
||||||
|
|
||||||
@ -125,5 +130,6 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
|||||||
)
|
)
|
||||||
context = cursor.fetchone()
|
context = cursor.fetchone()
|
||||||
if context:
|
if context:
|
||||||
return json.loads(context[0])
|
if context[0]:
|
||||||
|
return json.loads(context[0])
|
||||||
return None
|
return None
|
||||||
|
@ -63,6 +63,39 @@ def __get_conv_user_message(conversations: dict):
|
|||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def __new_conversation(chat_mode, user_id) -> ConversationVo:
|
||||||
|
unique_id = uuid.uuid1()
|
||||||
|
history_mem = DuckdbHistoryMemory(str(unique_id))
|
||||||
|
return ConversationVo(conv_uid=str(unique_id), chat_mode=chat_mode)
|
||||||
|
|
||||||
|
|
||||||
|
def get_db_list():
|
||||||
|
db = CFG.local_db
|
||||||
|
dbs = db.get_database_list()
|
||||||
|
params: dict = {}
|
||||||
|
for name in dbs:
|
||||||
|
params.update({name: name})
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def plugins_select_info():
|
||||||
|
plugins_infos: dict = {}
|
||||||
|
for plugin in CFG.plugins:
|
||||||
|
plugins_infos.update({f"【{plugin._name}】=>{plugin._description}": plugin._name})
|
||||||
|
return plugins_infos
|
||||||
|
|
||||||
|
|
||||||
|
def knowledge_list():
|
||||||
|
"""return knowledge space list"""
|
||||||
|
params: dict = {}
|
||||||
|
request = KnowledgeSpaceRequest()
|
||||||
|
spaces = knowledge_service.get_knowledge_space(request)
|
||||||
|
for space in spaces:
|
||||||
|
params.update({space.name: space.name})
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/v1/chat/dialogue/list", response_model=Result[ConversationVo])
|
@router.get("/v1/chat/dialogue/list", response_model=Result[ConversationVo])
|
||||||
async def dialogue_list(response: Response, user_id: str = None):
|
async def dialogue_list(response: Response, user_id: str = None):
|
||||||
# 设置CORS头部信息
|
# 设置CORS头部信息
|
||||||
@ -75,14 +108,15 @@ async def dialogue_list(response: Response, user_id: str = None):
|
|||||||
|
|
||||||
for item in datas:
|
for item in datas:
|
||||||
conv_uid = item.get("conv_uid")
|
conv_uid = item.get("conv_uid")
|
||||||
messages = item.get("messages")
|
summary = item.get("summary")
|
||||||
conversations = json.loads(messages)
|
chat_mode = item.get("chat_mode")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
first_conv: OnceConversation = conversations[0]
|
|
||||||
conv_vo: ConversationVo = ConversationVo(
|
conv_vo: ConversationVo = ConversationVo(
|
||||||
conv_uid=conv_uid,
|
conv_uid=conv_uid,
|
||||||
user_input=__get_conv_user_message(first_conv),
|
user_input=summary,
|
||||||
chat_mode=first_conv["chat_mode"],
|
chat_mode=chat_mode,
|
||||||
)
|
)
|
||||||
dialogues.append(conv_vo)
|
dialogues.append(conv_vo)
|
||||||
|
|
||||||
@ -113,39 +147,13 @@ async def dialogue_scenes():
|
|||||||
return Result.succ(scene_vos)
|
return Result.succ(scene_vos)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo])
|
@router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo])
|
||||||
async def dialogue_new(
|
async def dialogue_new(
|
||||||
chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None
|
chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None
|
||||||
):
|
):
|
||||||
unique_id = uuid.uuid1()
|
conv_vo = __new_conversation(chat_mode, user_id)
|
||||||
return Result.succ(ConversationVo(conv_uid=str(unique_id), chat_mode=chat_mode))
|
return Result.succ(conv_vo)
|
||||||
|
|
||||||
|
|
||||||
def get_db_list():
|
|
||||||
db = CFG.local_db
|
|
||||||
dbs = db.get_database_list()
|
|
||||||
params: dict = {}
|
|
||||||
for name in dbs:
|
|
||||||
params.update({name: name})
|
|
||||||
return params
|
|
||||||
|
|
||||||
|
|
||||||
def plugins_select_info():
|
|
||||||
plugins_infos: dict = {}
|
|
||||||
for plugin in CFG.plugins:
|
|
||||||
plugins_infos.update({f"【{plugin._name}】=>{plugin._description}": plugin._name})
|
|
||||||
return plugins_infos
|
|
||||||
|
|
||||||
|
|
||||||
def knowledge_list():
|
|
||||||
"""return knowledge space list"""
|
|
||||||
params: dict = {}
|
|
||||||
request = KnowledgeSpaceRequest()
|
|
||||||
spaces = knowledge_service.get_knowledge_space(request)
|
|
||||||
for space in spaces:
|
|
||||||
params.update({space.name: space.name})
|
|
||||||
return params
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/v1/chat/mode/params/list", response_model=Result[dict])
|
@router.post("/v1/chat/mode/params/list", response_model=Result[dict])
|
||||||
async def params_list(chat_mode: str = ChatScene.ChatNormal.value):
|
async def params_list(chat_mode: str = ChatScene.ChatNormal.value):
|
||||||
@ -192,7 +200,8 @@ async def chat_completions(dialogue: ConversationVo = Body()):
|
|||||||
if not dialogue.chat_mode:
|
if not dialogue.chat_mode:
|
||||||
dialogue.chat_mode = ChatScene.ChatNormal.value
|
dialogue.chat_mode = ChatScene.ChatNormal.value
|
||||||
if not dialogue.conv_uid:
|
if not dialogue.conv_uid:
|
||||||
dialogue.conv_uid = str(uuid.uuid1())
|
conv_vo = __new_conversation(dialogue.chat_mode, dialogue.user_name)
|
||||||
|
dialogue.conv_uid = conv_vo.conv_uid
|
||||||
|
|
||||||
global model_semaphore, global_counter
|
global model_semaphore, global_counter
|
||||||
global_counter += 1
|
global_counter += 1
|
||||||
@ -272,14 +281,15 @@ async def stream_generator(chat):
|
|||||||
else:
|
else:
|
||||||
for chunk in model_response:
|
for chunk in model_response:
|
||||||
if chunk:
|
if chunk:
|
||||||
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
|
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
|
||||||
chunk, chat.skip_echo_len
|
|
||||||
)
|
|
||||||
chat.current_message.add_ai_message(msg)
|
chat.current_message.add_ai_message(msg)
|
||||||
|
|
||||||
msg = msg.replace("\n", "\\n")
|
msg = msg.replace("\n", "\\n")
|
||||||
yield f"data:{msg}\n\n"
|
yield f"data:{msg}\n\n"
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
chat.current_message.add_ai_message(msg)
|
||||||
|
chat.current_message.add_view_message(msg)
|
||||||
chat.memory.append(chat.current_message)
|
chat.memory.append(chat.current_message)
|
||||||
|
|
||||||
|
|
||||||
|
@ -95,7 +95,7 @@ class BaseOutputParser(ABC):
|
|||||||
yield output
|
yield output
|
||||||
|
|
||||||
def parse_model_nostream_resp(self, response, sep: str):
|
def parse_model_nostream_resp(self, response, sep: str):
|
||||||
text = response.text.strip()
|
text = response.strip()
|
||||||
text = text.rstrip()
|
text = text.rstrip()
|
||||||
text = text.strip(b"\x00".decode())
|
text = text.strip(b"\x00".decode())
|
||||||
respObj_ex = json.loads(text)
|
respObj_ex = json.loads(text)
|
||||||
|
@ -6,7 +6,7 @@ from pilot.common.schema import ExampleType
|
|||||||
|
|
||||||
|
|
||||||
class ExampleSelector(BaseModel, ABC):
|
class ExampleSelector(BaseModel, ABC):
|
||||||
examples_record: List[List]
|
examples_record: List[dict]
|
||||||
use_example: bool = False
|
use_example: bool = False
|
||||||
type: str = ExampleType.ONE_SHOT.value
|
type: str = ExampleType.ONE_SHOT.value
|
||||||
|
|
||||||
@ -16,7 +16,7 @@ class ExampleSelector(BaseModel, ABC):
|
|||||||
else:
|
else:
|
||||||
return self.__few_shot_context(count)
|
return self.__few_shot_context(count)
|
||||||
|
|
||||||
def __few_shot_context(self, count: int = 2) -> List[List]:
|
def __few_shot_context(self, count: int = 2) -> List[dict]:
|
||||||
"""
|
"""
|
||||||
Use 2 or more examples, default 2
|
Use 2 or more examples, default 2
|
||||||
Returns: example text
|
Returns: example text
|
||||||
@ -26,7 +26,7 @@ class ExampleSelector(BaseModel, ABC):
|
|||||||
return need_use
|
return need_use
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def __one_show_context(self) -> List:
|
def __one_show_context(self) -> dict:
|
||||||
"""
|
"""
|
||||||
Use one examples
|
Use one examples
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -86,6 +86,11 @@ class BaseChat(ABC):
|
|||||||
extra = Extra.forbid
|
extra = Extra.forbid
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
def __init_history_message(self):
|
||||||
|
self.history_message == self.memory.messages()
|
||||||
|
if not self.history_message:
|
||||||
|
self.memory.create(self.current_user_input, "")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def chat_type(self) -> str:
|
def chat_type(self) -> str:
|
||||||
raise NotImplementedError("Not supported for this chat type.")
|
raise NotImplementedError("Not supported for this chat type.")
|
||||||
@ -102,12 +107,9 @@ class BaseChat(ABC):
|
|||||||
### Chat sequence advance
|
### Chat sequence advance
|
||||||
self.current_message.chat_order = len(self.history_message) + 1
|
self.current_message.chat_order = len(self.history_message) + 1
|
||||||
self.current_message.add_user_message(self.current_user_input)
|
self.current_message.add_user_message(self.current_user_input)
|
||||||
self.current_message.start_date = datetime.datetime.now().strftime(
|
self.current_message.start_date = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
"%Y-%m-%d %H:%M:%S"
|
|
||||||
)
|
|
||||||
# TODO
|
|
||||||
self.current_message.tokens = 0
|
|
||||||
|
|
||||||
|
self.current_message.tokens = 0
|
||||||
if self.prompt_template.template:
|
if self.prompt_template.template:
|
||||||
current_prompt = self.prompt_template.format(**input_values)
|
current_prompt = self.prompt_template.format(**input_values)
|
||||||
self.current_message.add_system_message(current_prompt)
|
self.current_message.add_system_message(current_prompt)
|
||||||
@ -146,8 +148,8 @@ class BaseChat(ABC):
|
|||||||
self.current_message.add_view_message(
|
self.current_message.add_view_message(
|
||||||
f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """
|
f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """
|
||||||
)
|
)
|
||||||
### store current conversation
|
### store current conversation
|
||||||
self.memory.append(self.current_message)
|
self.memory.append(self.current_message)
|
||||||
|
|
||||||
def nostream_call(self):
|
def nostream_call(self):
|
||||||
payload = self.__call_base()
|
payload = self.__call_base()
|
||||||
|
@ -65,6 +65,12 @@ class ChatDashboard(BaseChat):
|
|||||||
try:
|
try:
|
||||||
datas = self.database.run(self.db_connect, chart_item.sql)
|
datas = self.database.run(self.db_connect, chart_item.sql)
|
||||||
chart_data: ChartData = ChartData()
|
chart_data: ChartData = ChartData()
|
||||||
|
chart_data.chart_sql = chart_item['sql']
|
||||||
|
chart_data.chart_type = chart_item['showcase']
|
||||||
|
chart_data.chart_name = chart_item['title']
|
||||||
|
chart_data.chart_desc = chart_item['thoughts']
|
||||||
|
chart_data.column_name = datas[0]
|
||||||
|
chart_data.values =datas
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# TODO 修复流程
|
# TODO 修复流程
|
||||||
print(str(e))
|
print(str(e))
|
||||||
|
@ -4,7 +4,9 @@ from typing import TypeVar, Union, List, Generic, Any
|
|||||||
|
|
||||||
class ChartData(BaseModel):
|
class ChartData(BaseModel):
|
||||||
chart_uid: str
|
chart_uid: str
|
||||||
|
chart_name: str
|
||||||
chart_type: str
|
chart_type: str
|
||||||
|
chart_desc: str
|
||||||
chart_sql: str
|
chart_sql: str
|
||||||
column_name: List
|
column_name: List
|
||||||
values: List
|
values: List
|
||||||
|
@ -54,4 +54,5 @@ class ChatWithDbAutoExecute(BaseChat):
|
|||||||
return input_values
|
return input_values
|
||||||
|
|
||||||
def do_action(self, prompt_response):
|
def do_action(self, prompt_response):
|
||||||
|
print(f"do_action:{prompt_response}")
|
||||||
return self.database.run(self.db_connect, prompt_response.sql)
|
return self.database.run(self.db_connect, prompt_response.sql)
|
||||||
|
37
pilot/scene/chat_db/auto_execute/example.py
Normal file
37
pilot/scene/chat_db/auto_execute/example.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
from pilot.prompts.example_base import ExampleSelector
|
||||||
|
from pilot.common.schema import ExampleType
|
||||||
|
## Two examples are defined by default
|
||||||
|
EXAMPLES = [
|
||||||
|
{
|
||||||
|
"messages": [
|
||||||
|
{"type": "human", "data": {"content": "查询用户test1所在的城市", "example": True}},
|
||||||
|
{
|
||||||
|
"type": "ai",
|
||||||
|
"data": {
|
||||||
|
"content": """{
|
||||||
|
\"thoughts\": \"thought text\",
|
||||||
|
\"sql\": \"SELECT city FROM users where user_name='test1'\",
|
||||||
|
}""",
|
||||||
|
"example": True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"messages": [
|
||||||
|
{"type": "human", "data": {"content": "查询成都的用户的订单信息", "example": True}},
|
||||||
|
{
|
||||||
|
"type": "ai",
|
||||||
|
"data": {
|
||||||
|
"content": """{
|
||||||
|
\"thoughts\": \"thought text\",
|
||||||
|
\"sql\": \"SELECT b.* FROM users a LEFT JOIN tran_order b ON a.user_name=b.user_name where a.city='成都'\",
|
||||||
|
}""",
|
||||||
|
"example": True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
sql_data_example = ExampleSelector(examples_record=EXAMPLES, use_example=True, type=ExampleType.ONE_SHOT.value)
|
@ -6,8 +6,9 @@ import pandas as pd
|
|||||||
from pilot.utils import build_logger
|
from pilot.utils import build_logger
|
||||||
from pilot.out_parser.base import BaseOutputParser, T
|
from pilot.out_parser.base import BaseOutputParser, T
|
||||||
from pilot.configs.model_config import LOGDIR
|
from pilot.configs.model_config import LOGDIR
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
class SqlAction(NamedTuple):
|
class SqlAction(NamedTuple):
|
||||||
sql: str
|
sql: str
|
||||||
thoughts: Dict
|
thoughts: Dict
|
||||||
@ -32,11 +33,16 @@ class DbChatOutputParser(BaseOutputParser):
|
|||||||
if len(data) <= 1:
|
if len(data) <= 1:
|
||||||
data.insert(0, ["result"])
|
data.insert(0, ["result"])
|
||||||
df = pd.DataFrame(data[1:], columns=data[0])
|
df = pd.DataFrame(data[1:], columns=data[0])
|
||||||
table_style = """<style>
|
if CFG.NEW_SERVER_MODE:
|
||||||
table{border-collapse:collapse;width:100%;height:80%;margin:0 auto;float:center;border: 1px solid #007bff; background-color:#333; color:#fff}th,td{border:1px solid #ddd;padding:3px;text-align:center}th{background-color:#C9C3C7;color: #fff;font-weight: bold;}tr:nth-child(even){background-color:#444}tr:hover{background-color:#444}
|
html = df.to_html(index=False, escape=False, sparsify=False)
|
||||||
</style>"""
|
html = ''.join(html.split())
|
||||||
html_table = df.to_html(index=False, escape=False)
|
else:
|
||||||
html = f"<html><head>{table_style}</head><body>{html_table}</body></html>"
|
table_style = """<style>
|
||||||
|
table{border-collapse:collapse;width:100%;height:80%;margin:0 auto;float:center;border: 1px solid #007bff; background-color:#333; color:#fff}th,td{border:1px solid #ddd;padding:3px;text-align:center}th{background-color:#C9C3C7;color: #fff;font-weight: bold;}tr:nth-child(even){background-color:#444}tr:hover{background-color:#444}
|
||||||
|
</style>"""
|
||||||
|
html_table = df.to_html(index=False, escape=False)
|
||||||
|
html = f"<html><head>{table_style}</head><body>{html_table}</body></html>"
|
||||||
|
|
||||||
view_text = f"##### {str(speak)}" + "\n" + html.replace("\n", " ")
|
view_text = f"##### {str(speak)}" + "\n" + html.replace("\n", " ")
|
||||||
return view_text
|
return view_text
|
||||||
|
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
import json
|
import json
|
||||||
import importlib
|
|
||||||
from pilot.prompts.prompt_new import PromptTemplate
|
from pilot.prompts.prompt_new import PromptTemplate
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.scene.base import ChatScene
|
from pilot.scene.base import ChatScene
|
||||||
from pilot.scene.chat_db.auto_execute.out_parser import DbChatOutputParser, SqlAction
|
from pilot.scene.chat_db.auto_execute.out_parser import DbChatOutputParser, SqlAction
|
||||||
from pilot.common.schema import SeparatorStyle
|
from pilot.common.schema import SeparatorStyle
|
||||||
|
from pilot.scene.chat_db.auto_execute.example import sql_data_example
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
@ -12,35 +12,21 @@ PROMPT_SCENE_DEFINE = None
|
|||||||
|
|
||||||
|
|
||||||
_DEFAULT_TEMPLATE = """
|
_DEFAULT_TEMPLATE = """
|
||||||
You are a SQL expert. Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
|
You are a SQL expert. Given an input question, create a syntactically correct {dialect} query.
|
||||||
|
|
||||||
Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results.
|
Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results.
|
||||||
Use as few tables as possible when querying.
|
Use as few tables as possible when querying.
|
||||||
When generating insert, delete, update, or replace SQL, please make sure to use the data given by the human, and cannot use any unknown data. If you do not get enough information, speak to user: I don’t have enough data complete your request.
|
Only use the following tables schema to generate sql:
|
||||||
Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
PROMPT_SUFFIX = """Only use the following tables generate sql:
|
|
||||||
{table_info}
|
{table_info}
|
||||||
|
Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
|
||||||
|
|
||||||
Question: {input}
|
Question: {input}
|
||||||
|
|
||||||
"""
|
Rrespond in JSON format as following format:
|
||||||
|
|
||||||
PROMPT_RESPONSE = """You must respond in JSON format as following format:
|
|
||||||
{response}
|
{response}
|
||||||
|
|
||||||
Ensure the response is correct json and can be parsed by Python json.loads
|
Ensure the response is correct json and can be parsed by Python json.loads
|
||||||
"""
|
"""
|
||||||
|
|
||||||
RESPONSE_FORMAT = {
|
|
||||||
"thoughts": {
|
|
||||||
"reasoning": "reasoning",
|
|
||||||
"speak": "thoughts summary to say to user",
|
|
||||||
},
|
|
||||||
"sql": "SQL Query to run",
|
|
||||||
}
|
|
||||||
|
|
||||||
RESPONSE_FORMAT_SIMPLE = {
|
RESPONSE_FORMAT_SIMPLE = {
|
||||||
"thoughts": "thoughts summary to say to user",
|
"thoughts": "thoughts summary to say to user",
|
||||||
"sql": "SQL Query to run",
|
"sql": "SQL Query to run",
|
||||||
@ -55,10 +41,11 @@ prompt = PromptTemplate(
|
|||||||
input_variables=["input", "table_info", "dialect", "top_k", "response"],
|
input_variables=["input", "table_info", "dialect", "top_k", "response"],
|
||||||
response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, indent=4),
|
response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, indent=4),
|
||||||
template_define=PROMPT_SCENE_DEFINE,
|
template_define=PROMPT_SCENE_DEFINE,
|
||||||
template=_DEFAULT_TEMPLATE + PROMPT_SUFFIX + PROMPT_RESPONSE,
|
template=_DEFAULT_TEMPLATE,
|
||||||
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
|
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
|
||||||
output_parser=DbChatOutputParser(
|
output_parser=DbChatOutputParser(
|
||||||
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
|
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
|
||||||
),
|
),
|
||||||
|
example_selector=sql_data_example
|
||||||
)
|
)
|
||||||
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
||||||
|
@ -63,6 +63,7 @@ class ChatWithPlugin(BaseChat):
|
|||||||
return input_values
|
return input_values
|
||||||
|
|
||||||
def do_action(self, prompt_response):
|
def do_action(self, prompt_response):
|
||||||
|
print(f"do_action:{prompt_response}")
|
||||||
## plugin command run
|
## plugin command run
|
||||||
return execute_command(
|
return execute_command(
|
||||||
str(prompt_response.command.get("name")),
|
str(prompt_response.command.get("name")),
|
||||||
|
@ -2,8 +2,38 @@ from pilot.prompts.example_base import ExampleSelector
|
|||||||
|
|
||||||
## Two examples are defined by default
|
## Two examples are defined by default
|
||||||
EXAMPLES = [
|
EXAMPLES = [
|
||||||
[{"system": "123"}, {"system": "xxx"}, {"human": "xxx"}, {"assistant": "xxx"}],
|
{
|
||||||
[{"system": "123"}, {"system": "xxx"}, {"human": "xxx"}, {"assistant": "xxx"}],
|
"messages": [
|
||||||
|
{"type": "human", "data": {"content": "查询xxx", "example": True}},
|
||||||
|
{
|
||||||
|
"type": "ai",
|
||||||
|
"data": {
|
||||||
|
"content": """{
|
||||||
|
\"thoughts\": \"thought text\",
|
||||||
|
\"speak\": \"thoughts summary to say to user\",
|
||||||
|
\"command\": {\"name\": \"command name\", \"args\": {\"arg name\": \"value\"}},
|
||||||
|
}""",
|
||||||
|
"example": True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"messages": [
|
||||||
|
{"type": "human", "data": {"content": "查询xxx", "example": True}},
|
||||||
|
{
|
||||||
|
"type": "ai",
|
||||||
|
"data": {
|
||||||
|
"content": """{
|
||||||
|
\"thoughts\": \"thought text\",
|
||||||
|
\"speak\": \"thoughts summary to say to user\",
|
||||||
|
\"command\": {\"name\": \"command name\", \"args\": {\"arg name\": \"value\"}},
|
||||||
|
}""",
|
||||||
|
"example": True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
plugin_example = ExampleSelector(examples_record=EXAMPLES, use_example=True)
|
plugin_example = ExampleSelector(examples_record=EXAMPLES, use_example=True)
|
||||||
|
@ -9,6 +9,7 @@ from pilot.scene.chat_db.auto_execute.chat import ChatWithDbAutoExecute
|
|||||||
from pilot.scene.chat_knowledge.url.chat import ChatUrlKnowledge
|
from pilot.scene.chat_knowledge.url.chat import ChatUrlKnowledge
|
||||||
from pilot.scene.chat_knowledge.custom.chat import ChatNewKnowledge
|
from pilot.scene.chat_knowledge.custom.chat import ChatNewKnowledge
|
||||||
from pilot.scene.chat_knowledge.default.chat import ChatDefaultKnowledge
|
from pilot.scene.chat_knowledge.default.chat import ChatDefaultKnowledge
|
||||||
|
from pilot.scene.chat_knowledge.v1.chat import ChatKnowledge
|
||||||
from pilot.scene.chat_knowledge.inner_db_summary.chat import InnerChatDBSummary
|
from pilot.scene.chat_knowledge.inner_db_summary.chat import InnerChatDBSummary
|
||||||
|
|
||||||
|
|
||||||
|
@ -18,7 +18,7 @@ from pilot.configs.model_config import (
|
|||||||
LOGDIR,
|
LOGDIR,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pilot.scene.chat_knowledge.default.prompt import prompt
|
from pilot.scene.chat_knowledge.v1.prompt import prompt
|
||||||
from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding
|
from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
@ -12,6 +12,7 @@ logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
|||||||
|
|
||||||
|
|
||||||
class NormalChatOutputParser(BaseOutputParser):
|
class NormalChatOutputParser(BaseOutputParser):
|
||||||
|
|
||||||
def parse_prompt_response(self, model_out_text) -> T:
|
def parse_prompt_response(self, model_out_text) -> T:
|
||||||
return model_out_text
|
return model_out_text
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user