mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-03 18:17:45 +00:00
update:merge dev
This commit is contained in:
@@ -18,9 +18,10 @@ import requests
|
||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(ROOT_PATH)
|
||||
|
||||
from pilot.commands.command import execute_ai_response_json
|
||||
from pilot.commands.command_mange import CommandRegistry
|
||||
from pilot.commands.exception_not_commands import NotCommands
|
||||
|
||||
from pilot.scene.base_chat import BaseChat
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import (
|
||||
DATASETS_DIR,
|
||||
@@ -29,7 +30,6 @@ from pilot.configs.model_config import (
|
||||
LOGDIR,
|
||||
VECTOR_SEARCH_TOP_K,
|
||||
)
|
||||
from pilot.server.vectordb_qa import KnownLedgeBaseQA
|
||||
from pilot.connections.mysql import MySQLOperator
|
||||
from pilot.conversation import (
|
||||
SeparatorStyle,
|
||||
@@ -41,15 +41,22 @@ from pilot.conversation import (
|
||||
)
|
||||
from pilot.plugins import scan_plugins
|
||||
from pilot.prompts.auto_mode_prompt import AutoModePrompt
|
||||
from pilot.prompts.generator import PromptGenerator
|
||||
from pilot.server.gradio_css import code_highlight_css
|
||||
from pilot.server.gradio_patch import Chatbot as grChatbot
|
||||
from pilot.server.vectordb_qa import KnownLedgeBaseQA
|
||||
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
|
||||
from pilot.utils import build_logger, server_error_msg
|
||||
from pilot.vector_store.extract_tovec import (
|
||||
get_vector_storelist,
|
||||
knownledge_tovec_st,
|
||||
load_knownledge_from_doc,
|
||||
)
|
||||
|
||||
from pilot.commands.command import execute_ai_response_json
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.scene.chat_factory import ChatFactory
|
||||
|
||||
logger = build_logger("webserver", LOGDIR + "webserver.log")
|
||||
headers = {"User-Agent": "dbgpt Client"}
|
||||
|
||||
@@ -69,6 +76,7 @@ priority = {"vicuna-13b": "aaa"}
|
||||
|
||||
# 加载插件
|
||||
CFG = Config()
|
||||
CHAT_FACTORY = ChatFactory()
|
||||
|
||||
DB_SETTINGS = {
|
||||
"user": CFG.LOCAL_DB_USER,
|
||||
@@ -125,6 +133,10 @@ def load_demo(url_params, request: gr.Request):
|
||||
gr.Dropdown.update(choices=dbs)
|
||||
|
||||
state = default_conversation.copy()
|
||||
|
||||
unique_id = uuid.uuid1()
|
||||
state.conv_id = str(unique_id)
|
||||
|
||||
return (
|
||||
state,
|
||||
dropdown_update,
|
||||
@@ -166,6 +178,8 @@ def add_text(state, text, request: gr.Request):
|
||||
state.append_message(state.roles[0], text)
|
||||
state.append_message(state.roles[1], None)
|
||||
state.skip_next = False
|
||||
### TODO
|
||||
state.last_user_input = text
|
||||
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
|
||||
|
||||
|
||||
@@ -180,18 +194,42 @@ def post_process_code(code):
|
||||
return code
|
||||
|
||||
|
||||
def get_chat_mode(mode, sql_mode, db_selector) -> ChatScene:
|
||||
if mode == conversation_types["default_knownledge"] and not db_selector:
|
||||
return ChatScene.ChatKnowledge
|
||||
elif mode == conversation_types["custome"] and not db_selector:
|
||||
return ChatScene.ChatNewKnowledge
|
||||
elif sql_mode == conversation_sql_mode["auto_execute_ai_response"] and db_selector:
|
||||
return ChatScene.ChatWithDb
|
||||
|
||||
elif mode == conversation_types["auto_execute_plugin"] and not db_selector:
|
||||
return ChatScene.ChatExecution
|
||||
else:
|
||||
return ChatScene.ChatNormal
|
||||
|
||||
|
||||
def http_bot(
|
||||
state, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request
|
||||
):
|
||||
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
|
||||
print("AUTO DB-GPT模式.")
|
||||
if sql_mode == conversation_sql_mode["dont_execute_ai_response"]:
|
||||
print("标准DB-GPT模式.")
|
||||
print("是否是AUTO-GPT模式.", autogpt)
|
||||
|
||||
logger.info(f"User message send!{state.conv_id},{sql_mode},{db_selector}")
|
||||
start_tstamp = time.time()
|
||||
scene: ChatScene = get_chat_mode(mode, sql_mode, db_selector)
|
||||
print(f"当前对话模式:{scene.value}")
|
||||
model_name = CFG.LLM_MODEL
|
||||
|
||||
if ChatScene.ChatWithDb == scene:
|
||||
logger.info("基于DB对话走新的模式!")
|
||||
chat_param = {
|
||||
"chat_session_id": state.conv_id,
|
||||
"db_name": db_selector,
|
||||
"user_input": state.last_user_input,
|
||||
}
|
||||
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
|
||||
chat.call()
|
||||
state.messages[-1][-1] = f"{chat.current_ai_response()}"
|
||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||
|
||||
else:
|
||||
dbname = db_selector
|
||||
# TODO 这里的请求需要拼接现有知识库, 使得其根据现有知识库作答, 所以prompt需要继续优化
|
||||
if state.skip_next:
|
||||
@@ -303,7 +341,9 @@ def http_bot(
|
||||
"prompt": prompt,
|
||||
"temperature": float(temperature),
|
||||
"max_new_tokens": int(max_new_tokens),
|
||||
"stop": state.sep if state.sep_style == SeparatorStyle.SINGLE else state.sep2,
|
||||
"stop": state.sep
|
||||
if state.sep_style == SeparatorStyle.SINGLE
|
||||
else state.sep2,
|
||||
}
|
||||
logger.info(f"Requert: \n{payload}")
|
||||
|
||||
@@ -392,23 +432,13 @@ def http_bot(
|
||||
output = data["text"] + f" (error_code: {data['error_code']})"
|
||||
state.messages[-1][-1] = output
|
||||
yield (state, state.to_gradio_chatbot()) + (
|
||||
disable_btn,
|
||||
disable_btn,
|
||||
disable_btn,
|
||||
enable_btn,
|
||||
enable_btn,
|
||||
)
|
||||
disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
||||
return
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
state.messages[-1][-1] = server_error_msg + f" (error_code: 4)"
|
||||
yield (state, state.to_gradio_chatbot()) + (
|
||||
disable_btn,
|
||||
disable_btn,
|
||||
disable_btn,
|
||||
enable_btn,
|
||||
enable_btn,
|
||||
)
|
||||
disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
||||
return
|
||||
|
||||
state.messages[-1][-1] = state.messages[-1][-1][:-1]
|
||||
@@ -685,7 +715,8 @@ if __name__ == "__main__":
|
||||
# 配置初始化
|
||||
cfg = Config()
|
||||
|
||||
# dbs = get_database_list()
|
||||
dbs = cfg.local_db.get_database_list()
|
||||
|
||||
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
|
||||
|
||||
# 加载插件可执行命令
|
||||
|
Reference in New Issue
Block a user