支持插件的prompt

This commit is contained in:
tuyang.yhj
2023-05-12 17:22:13 +08:00
parent 5f1d327901
commit 7e1f95d19d
8 changed files with 59 additions and 53 deletions

View File

@@ -161,6 +161,8 @@ def post_process_code(code):
def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.Request):
# MOCk
autogpt = True
print("是否是AUTO-GPT模式.", autogpt)
start_tstamp = time.time()
model_name = LLM_MODEL
@@ -175,25 +177,37 @@ def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.
# TODO when tab mode is AUTO_GPT, Prompt need to rebuild.
if len(state.messages) == state.offset + 2:
# 第一轮对话需要加入提示Prompt
template_name = "conv_one_shot"
new_state = conv_templates[template_name].copy()
new_state.conv_id = uuid.uuid4().hex
query = state.messages[-2][1]
# 第一轮对话需要加入提示Prompt
if(autogpt):
# autogpt模式的第一轮对话需要 构建专属prompt
cfg = Config()
first_prompt = FirstPrompt()
first_prompt.command_registry = cfg.command_registry
# prompt 中添加上下文提示, 根据已有知识对话, 上下文提示是否也应该放在第一轮, 还是每一轮都添加上下文?
# 如果用户侧的问题跨度很大, 应该每一轮都加提示。
if db_selector:
new_state.append_message(new_state.roles[0], gen_sqlgen_conversation(dbname) + query)
new_state.append_message(new_state.roles[1], None)
state = new_state
system_prompt = first_prompt.construct_first_prompt(fisrt_message=[query])
logger.info("[TEST]:" + system_prompt)
template_name = "auto_dbgpt_one_shot"
new_state = conv_templates[template_name].copy()
new_state.append_message(role='USER', message=system_prompt)
else:
new_state.append_message(new_state.roles[0], query)
new_state.append_message(new_state.roles[1], None)
state = new_state
template_name = "conv_one_shot"
new_state = conv_templates[template_name].copy()
new_state.conv_id = uuid.uuid4().hex
if not autogpt:
# prompt 中添加上下文提示, 根据已有知识对话, 上下文提示是否也应该放在第一轮, 还是每一轮都添加上下文?
# 如果用户侧的问题跨度很大, 应该每一轮都加提示。
if db_selector:
new_state.append_message(new_state.roles[0], gen_sqlgen_conversation(dbname) + query)
new_state.append_message(new_state.roles[1], None)
state = new_state
else:
new_state.append_message(new_state.roles[0], query)
new_state.append_message(new_state.roles[1], None)
state = new_state
if mode == conversation_types["default_knownledge"] and not db_selector:
query = state.messages[-2][1]
knqa = KnownLedgeBaseQA()
@@ -457,8 +471,6 @@ if __name__ == "__main__":
# 加载插件
cfg = Config()
cfg.plugins_dir = "123"
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
# 加载插件可执行命令
@@ -474,15 +486,9 @@ if __name__ == "__main__":
for command_category in command_categories:
command_registry.import_commands(command_category)
cfg.command_registry =command_category
first_prompt =FirstPrompt(cfg= cfg)
first_prompt.command_registry = command_registry
system_prompt = first_prompt.construct_first_prompt( fisrt_message=["this is a test goal"])
logger.info("[TEST]:" + system_prompt)
logger.info(args)
demo = build_webdemo()
demo.queue(