同步服务接口

This commit is contained in:
tuyang.yhj 2023-05-13 21:24:05 +08:00
parent d6d84c70b0
commit bde8503620
8 changed files with 151 additions and 22 deletions

View File

@ -28,7 +28,7 @@ VECTOR_SEARCH_TOP_K = 3
LLM_MODEL = "vicuna-13b" LLM_MODEL = "vicuna-13b"
LIMIT_MODEL_CONCURRENCY = 5 LIMIT_MODEL_CONCURRENCY = 5
MAX_POSITION_EMBEDDINGS = 4096 MAX_POSITION_EMBEDDINGS = 4096
VICUNA_MODEL_SERVER = "http://121.41.167.183:8000" VICUNA_MODEL_SERVER = "http://120.27.148.250:8000"
# Load model config # Load model config
ISLOAD_8BIT = True ISLOAD_8BIT = True
@ -37,7 +37,7 @@ ISDEBUG = False
DB_SETTINGS = { DB_SETTINGS = {
"user": "root", "user": "root",
"password": "aa12345678", "password": "aa123456",
"host": "127.0.0.1", "host": "127.0.0.1",
"port": 3306 "port": 3306
} }

View File

@ -159,6 +159,8 @@ auto_dbgpt_one_shot = Conversation(
1. If you are unsure how you previously did something or want to recall past events, thinking about similar events will help you remember. 1. If you are unsure how you previously did something or want to recall past events, thinking about similar events will help you remember.
2. No user assistance 2. No user assistance
3. Exclusively use the commands listed in double quotes e.g. "command name" 3. Exclusively use the commands listed in double quotes e.g. "command name"
DBScheme:
Commands: Commands:
1. analyze_code: Analyze Code, args: "code": "<full_code_string>" 1. analyze_code: Analyze Code, args: "code": "<full_code_string>"
@ -249,6 +251,11 @@ conv_qa_prompt_template = """ 基于以下已知的信息, 专业、详细的回
default_conversation = conv_one_shot default_conversation = conv_one_shot
conversation_sql_mode ={
"auto_execute_ai_response": "直接执行结果",
"dont_execute_ai_response": "不直接执行结果"
}
conversation_types = { conversation_types = {
"native": "LLM原生对话", "native": "LLM原生对话",
"default_knownledge": "默认知识库对话", "default_knownledge": "默认知识库对话",

View File

@ -69,11 +69,99 @@ def generate_stream(model, tokenizer, params, device,
del past_key_values del past_key_values
@torch.inference_mode() @torch.inference_mode()
def generate_output(model, tokenizer, params, device, context_len=2048, stream_interval=2): def generate_output(model, tokenizer, params, device, context_len=4096, stream_interval=2):
"""Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py """
prompt = params["prompt"]
l_prompt = len(prompt)
temperature = float(params.get("temperature", 1.0))
max_new_tokens = int(params.get("max_new_tokens", 2048))
stop_parameter = params.get("stop", None)
if stop_parameter == tokenizer.eos_token:
stop_parameter = None
stop_strings = []
if isinstance(stop_parameter, str):
stop_strings.append(stop_parameter)
elif isinstance(stop_parameter, list):
stop_strings = stop_parameter
elif stop_parameter is None:
pass
else:
raise TypeError("Stop parameter must be string or list of strings.")
input_ids = tokenizer(prompt).input_ids
output_ids = list(input_ids)
max_src_len = context_len - max_new_tokens - 8
input_ids = input_ids[-max_src_len:]
for i in range(max_new_tokens):
if i == 0:
out = model(
torch.as_tensor([input_ids], device=device), use_cache=True)
logits = out.logits
past_key_values = out.past_key_values
else:
attention_mask = torch.ones(
1, past_key_values[0][0].shape[-2] + 1, device=device)
out = model(input_ids=torch.as_tensor([[token]], device=device),
use_cache=True,
attention_mask=attention_mask,
past_key_values=past_key_values)
logits = out.logits
past_key_values = out.past_key_values
last_token_logits = logits[0][-1]
if device == "mps":
# Switch to CPU by avoiding some bugs in mps backend.
last_token_logits = last_token_logits.float().to("cpu")
if temperature < 1e-4:
token = int(torch.argmax(last_token_logits))
else:
probs = torch.softmax(last_token_logits / temperature, dim=-1)
token = int(torch.multinomial(probs, num_samples=1))
output_ids.append(token)
if token == tokenizer.eos_token_id:
stopped = True
else:
stopped = False
output = tokenizer.decode(output_ids, skip_special_tokens=True)
print("Partial output:", output)
for stop_str in stop_strings:
# print(f"Looking for '{stop_str}' in '{output[:l_prompt]}'#END")
pos = output.rfind(stop_str)
if pos != -1:
# print("Found stop str: ", output)
output = output[:pos]
# print("Trimmed output: ", output)
stopped = True
stop_word = stop_str
break
else:
pass
# print("Not found")
if stopped:
break
del past_key_values
if pos != -1:
return output[:pos]
return output
@torch.inference_mode()
def generate_output_ex(model, tokenizer, params, device, context_len=2048, stream_interval=2):
prompt = params["prompt"] prompt = params["prompt"]
temperature = float(params.get("temperature", 1.0)) temperature = float(params.get("temperature", 1.0))
max_new_tokens = int(params.get("max_new_tokens", 2048)) max_new_tokens = int(params.get("max_new_tokens", 2048))
stop_parameter = params.get("stop", None) stop_parameter = params.get("stop", None)
if stop_parameter == tokenizer.eos_token: if stop_parameter == tokenizer.eos_token:
stop_parameter = None stop_parameter = None
stop_strings = [] stop_strings = []

View File

@ -39,6 +39,7 @@ class FirstPrompt:
def construct_first_prompt( def construct_first_prompt(
self, self,
fisrt_message: [str]=[], fisrt_message: [str]=[],
db_schemes: str=None,
prompt_generator: Optional[PromptGenerator] = None prompt_generator: Optional[PromptGenerator] = None
) -> str: ) -> str:
""" """
@ -88,6 +89,10 @@ class FirstPrompt:
self.ai_goals = fisrt_message self.ai_goals = fisrt_message
for i, goal in enumerate(self.ai_goals): for i, goal in enumerate(self.ai_goals):
full_prompt += f"{i+1}. {goal}\n" full_prompt += f"{i+1}. {goal}\n"
if db_schemes:
full_prompt += f"DB SCHEME:\n\n"
full_prompt += f"{db_schemes}\n"
# if self.api_budget > 0.0: # if self.api_budget > 0.0:
# full_prompt += f"\nIt takes money to let you run. Your API budget is ${self.api_budget:.3f}" # full_prompt += f"\nIt takes money to let you run. Your API budget is ${self.api_budget:.3f}"
self.prompt_generator = prompt_generator self.prompt_generator = prompt_generator

View File

@ -93,14 +93,9 @@ async def api_generate_stream(request: Request):
return StreamingResponse(generator, background=background_tasks) return StreamingResponse(generator, background=background_tasks)
@app.post("/generate") @app.post("/generate")
def generate(prompt_request: PromptRequest): def generate(prompt_request: Request):
params = {
"prompt": prompt_request.prompt,
"temperature": prompt_request.temperature,
"max_new_tokens": prompt_request.max_new_tokens,
"stop": prompt_request.stop
}
params = request.json()
print("Receive prompt: ", params["prompt"]) print("Receive prompt: ", params["prompt"])
output = generate_output(model, tokenizer, params, DEVICE) output = generate_output(model, tokenizer, params, DEVICE)
print("Output: ", output) print("Output: ", output)

View File

@ -21,11 +21,15 @@ from pilot.plugins import scan_plugins
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.commands.command_mange import CommandRegistry from pilot.commands.command_mange import CommandRegistry
from pilot.prompts.first_conversation_prompt import FirstPrompt from pilot.prompts.first_conversation_prompt import FirstPrompt
from pilot.prompts.generator import PromptGenerator
from pilot.commands.exception_not_commands import NotCommands
from pilot.conversation import ( from pilot.conversation import (
default_conversation, default_conversation,
conv_templates, conv_templates,
conversation_types, conversation_types,
conversation_sql_mode,
SeparatorStyle SeparatorStyle
) )
@ -152,11 +156,13 @@ def post_process_code(code):
code = sep.join(blocks) code = sep.join(blocks)
return code return code
def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.Request): def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request):
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
# MOCk print("AUTO DB-GPT模式.")
autogpt = True if sql_mode == conversation_sql_mode["dont_execute_ai_response"]:
print("标准DB-GPT模式.")
print("是否是AUTO-GPT模式.", autogpt) print("是否是AUTO-GPT模式.", autogpt)
start_tstamp = time.time() start_tstamp = time.time()
model_name = LLM_MODEL model_name = LLM_MODEL
@ -167,17 +173,18 @@ def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
return return
cfg = Config()
first_prompt = FirstPrompt()
# TODO when tab mode is AUTO_GPT, Prompt need to rebuild. # TODO when tab mode is AUTO_GPT, Prompt need to rebuild.
if len(state.messages) == state.offset + 2: if len(state.messages) == state.offset + 2:
query = state.messages[-2][1] query = state.messages[-2][1]
# 第一轮对话需要加入提示Prompt # 第一轮对话需要加入提示Prompt
cfg = Config()
first_prompt = FirstPrompt()
first_prompt.command_registry = cfg.command_registry first_prompt.command_registry = cfg.command_registry
if(autogpt): if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
# autogpt模式的第一轮对话需要 构建专属prompt # autogpt模式的第一轮对话需要 构建专属prompt
system_prompt = first_prompt.construct_first_prompt(fisrt_message=[query]) system_prompt = first_prompt.construct_first_prompt(fisrt_message=[query], db_schemes= gen_sqlgen_conversation(dbname))
logger.info("[TEST]:" + system_prompt) logger.info("[TEST]:" + system_prompt)
template_name = "auto_dbgpt_one_shot" template_name = "auto_dbgpt_one_shot"
new_state = conv_templates[template_name].copy() new_state = conv_templates[template_name].copy()
@ -218,7 +225,13 @@ def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.
"stop": state.sep if state.sep_style == SeparatorStyle.SINGLE else state.sep2, "stop": state.sep if state.sep_style == SeparatorStyle.SINGLE else state.sep2,
} }
logger.info(f"Requert: \n{payload}") logger.info(f"Requert: \n{payload}")
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
auto_db_gpt_response(first_prompt.prompt_generator, payload)
else:
stream_ai_response(payload)
def stream_ai_response(payload):
# 流式输出
state.messages[-1][-1] = "" state.messages[-1][-1] = ""
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
@ -264,6 +277,18 @@ def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.
} }
fout.write(json.dumps(data) + "\n") fout.write(json.dumps(data) + "\n")
def auto_db_gpt_response( prompt: PromptGenerator, payload)->str:
response = requests.post(urljoin(VICUNA_MODEL_SERVER, "generate"),
headers=headers, json=payload, timeout=30)
print(response)
try:
plugin_resp = execute_ai_response_json(prompt, response)
print(plugin_resp)
except NotCommands as e:
print(str(e))
return "auto_db_gpt_response!"
block_css = ( block_css = (
code_highlight_css code_highlight_css
+ """ + """
@ -280,6 +305,12 @@ block_css = (
""" """
) )
def change_sql_mode(sql_mode):
if sql_mode in ["直接执行结果"]:
return gr.update(visible=True)
else:
return gr.update(visible=False)
def change_mode(mode): def change_mode(mode):
if mode in ["默认知识库对话", "LLM原生对话"]: if mode in ["默认知识库对话", "LLM原生对话"]:
return gr.update(visible=False) return gr.update(visible=False)
@ -325,6 +356,9 @@ def build_single_model_ui():
tabs= gr.Tabs() tabs= gr.Tabs()
with tabs: with tabs:
tab_sql = gr.TabItem("SQL生成与诊断", elem_id="SQL") tab_sql = gr.TabItem("SQL生成与诊断", elem_id="SQL")
sql_mode = gr.Radio(["直接执行结果", "不执行结果"], show_label=False, value="不执行结果")
sql_vs_setting = gr.Markdown("自动执行模式下, DB-GPT可以具备执行SQL、从网络读取知识自动化存储学习的能力")
sql_mode.change(fn=change_sql_mode, inputs=sql_mode, outputs=sql_vs_setting)
with tab_sql: with tab_sql:
# TODO A selector to choose database # TODO A selector to choose database
with gr.Row(elem_id="db_selector"): with gr.Row(elem_id="db_selector"):
@ -383,7 +417,7 @@ def build_single_model_ui():
btn_list = [regenerate_btn, clear_btn] btn_list = [regenerate_btn, clear_btn]
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then( regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
http_bot, http_bot,
[state, mode, db_selector, temperature, max_output_tokens], [state, mode, sql_mode, db_selector, temperature, max_output_tokens],
[state, chatbot] + btn_list, [state, chatbot] + btn_list,
) )
clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list) clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
@ -392,7 +426,7 @@ def build_single_model_ui():
add_text, [state, textbox], [state, chatbot, textbox] + btn_list add_text, [state, textbox], [state, chatbot, textbox] + btn_list
).then( ).then(
http_bot, http_bot,
[state, mode, db_selector, temperature, max_output_tokens], [state, mode, sql_mode, db_selector, temperature, max_output_tokens],
[state, chatbot] + btn_list, [state, chatbot] + btn_list,
) )
@ -400,7 +434,7 @@ def build_single_model_ui():
add_text, [state, textbox], [state, chatbot, textbox] + btn_list add_text, [state, textbox], [state, chatbot, textbox] + btn_list
).then( ).then(
http_bot, http_bot,
[state, mode, db_selector, temperature, max_output_tokens], [state, mode, sql_mode, db_selector, temperature, max_output_tokens],
[state, chatbot] + btn_list [state, chatbot] + btn_list
) )

Binary file not shown.