同步服务接口

This commit is contained in:
tuyang.yhj 2023-05-13 21:24:05 +08:00
parent d6d84c70b0
commit bde8503620
8 changed files with 151 additions and 22 deletions

View File

@ -28,7 +28,7 @@ VECTOR_SEARCH_TOP_K = 3
LLM_MODEL = "vicuna-13b"
LIMIT_MODEL_CONCURRENCY = 5
MAX_POSITION_EMBEDDINGS = 4096
VICUNA_MODEL_SERVER = "http://121.41.167.183:8000"
VICUNA_MODEL_SERVER = "http://120.27.148.250:8000"
# Load model config
ISLOAD_8BIT = True
@ -37,7 +37,7 @@ ISDEBUG = False
DB_SETTINGS = {
"user": "root",
"password": "aa12345678",
"password": "aa123456",
"host": "127.0.0.1",
"port": 3306
}

View File

@ -159,7 +159,9 @@ auto_dbgpt_one_shot = Conversation(
1. If you are unsure how you previously did something or want to recall past events, thinking about similar events will help you remember.
2. No user assistance
3. Exclusively use the commands listed in double quotes e.g. "command name"
DBScheme:
Commands:
1. analyze_code: Analyze Code, args: "code": "<full_code_string>"
2. execute_python_file: Execute Python File, args: "filename": "<filename>"
@ -249,6 +251,11 @@ conv_qa_prompt_template = """ 基于以下已知的信息, 专业、详细的回
default_conversation = conv_one_shot
conversation_sql_mode ={
"auto_execute_ai_response": "直接执行结果",
"dont_execute_ai_response": "不直接执行结果"
}
conversation_types = {
"native": "LLM原生对话",
"default_knownledge": "默认知识库对话",

View File

@ -69,11 +69,99 @@ def generate_stream(model, tokenizer, params, device,
del past_key_values
@torch.inference_mode()
def generate_output(model, tokenizer, params, device, context_len=2048, stream_interval=2):
def generate_output(model, tokenizer, params, device, context_len=4096, stream_interval=2):
"""Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py """
prompt = params["prompt"]
l_prompt = len(prompt)
temperature = float(params.get("temperature", 1.0))
max_new_tokens = int(params.get("max_new_tokens", 2048))
stop_parameter = params.get("stop", None)
if stop_parameter == tokenizer.eos_token:
stop_parameter = None
stop_strings = []
if isinstance(stop_parameter, str):
stop_strings.append(stop_parameter)
elif isinstance(stop_parameter, list):
stop_strings = stop_parameter
elif stop_parameter is None:
pass
else:
raise TypeError("Stop parameter must be string or list of strings.")
input_ids = tokenizer(prompt).input_ids
output_ids = list(input_ids)
max_src_len = context_len - max_new_tokens - 8
input_ids = input_ids[-max_src_len:]
for i in range(max_new_tokens):
if i == 0:
out = model(
torch.as_tensor([input_ids], device=device), use_cache=True)
logits = out.logits
past_key_values = out.past_key_values
else:
attention_mask = torch.ones(
1, past_key_values[0][0].shape[-2] + 1, device=device)
out = model(input_ids=torch.as_tensor([[token]], device=device),
use_cache=True,
attention_mask=attention_mask,
past_key_values=past_key_values)
logits = out.logits
past_key_values = out.past_key_values
last_token_logits = logits[0][-1]
if device == "mps":
# Switch to CPU by avoiding some bugs in mps backend.
last_token_logits = last_token_logits.float().to("cpu")
if temperature < 1e-4:
token = int(torch.argmax(last_token_logits))
else:
probs = torch.softmax(last_token_logits / temperature, dim=-1)
token = int(torch.multinomial(probs, num_samples=1))
output_ids.append(token)
if token == tokenizer.eos_token_id:
stopped = True
else:
stopped = False
output = tokenizer.decode(output_ids, skip_special_tokens=True)
print("Partial output:", output)
for stop_str in stop_strings:
# print(f"Looking for '{stop_str}' in '{output[:l_prompt]}'#END")
pos = output.rfind(stop_str)
if pos != -1:
# print("Found stop str: ", output)
output = output[:pos]
# print("Trimmed output: ", output)
stopped = True
stop_word = stop_str
break
else:
pass
# print("Not found")
if stopped:
break
del past_key_values
if pos != -1:
return output[:pos]
return output
@torch.inference_mode()
def generate_output_ex(model, tokenizer, params, device, context_len=2048, stream_interval=2):
prompt = params["prompt"]
temperature = float(params.get("temperature", 1.0))
max_new_tokens = int(params.get("max_new_tokens", 2048))
stop_parameter = params.get("stop", None)
if stop_parameter == tokenizer.eos_token:
stop_parameter = None
stop_strings = []

View File

@ -39,6 +39,7 @@ class FirstPrompt:
def construct_first_prompt(
self,
fisrt_message: [str]=[],
db_schemes: str=None,
prompt_generator: Optional[PromptGenerator] = None
) -> str:
"""
@ -88,6 +89,10 @@ class FirstPrompt:
self.ai_goals = fisrt_message
for i, goal in enumerate(self.ai_goals):
full_prompt += f"{i+1}. {goal}\n"
if db_schemes:
full_prompt += f"DB SCHEME:\n\n"
full_prompt += f"{db_schemes}\n"
# if self.api_budget > 0.0:
# full_prompt += f"\nIt takes money to let you run. Your API budget is ${self.api_budget:.3f}"
self.prompt_generator = prompt_generator

View File

@ -93,14 +93,9 @@ async def api_generate_stream(request: Request):
return StreamingResponse(generator, background=background_tasks)
@app.post("/generate")
def generate(prompt_request: PromptRequest):
params = {
"prompt": prompt_request.prompt,
"temperature": prompt_request.temperature,
"max_new_tokens": prompt_request.max_new_tokens,
"stop": prompt_request.stop
}
def generate(prompt_request: Request):
params = request.json()
print("Receive prompt: ", params["prompt"])
output = generate_output(model, tokenizer, params, DEVICE)
print("Output: ", output)

View File

@ -21,11 +21,15 @@ from pilot.plugins import scan_plugins
from pilot.configs.config import Config
from pilot.commands.command_mange import CommandRegistry
from pilot.prompts.first_conversation_prompt import FirstPrompt
from pilot.prompts.generator import PromptGenerator
from pilot.commands.exception_not_commands import NotCommands
from pilot.conversation import (
default_conversation,
conv_templates,
conversation_types,
conversation_sql_mode,
SeparatorStyle
)
@ -152,11 +156,13 @@ def post_process_code(code):
code = sep.join(blocks)
return code
def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.Request):
# MOCk
autogpt = True
def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request):
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
print("AUTO DB-GPT模式.")
if sql_mode == conversation_sql_mode["dont_execute_ai_response"]:
print("标准DB-GPT模式.")
print("是否是AUTO-GPT模式.", autogpt)
start_tstamp = time.time()
model_name = LLM_MODEL
@ -167,17 +173,18 @@ def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
return
cfg = Config()
first_prompt = FirstPrompt()
# TODO when tab mode is AUTO_GPT, Prompt need to rebuild.
if len(state.messages) == state.offset + 2:
query = state.messages[-2][1]
# 第一轮对话需要加入提示Prompt
cfg = Config()
first_prompt = FirstPrompt()
first_prompt.command_registry = cfg.command_registry
if(autogpt):
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
# autogpt模式的第一轮对话需要 构建专属prompt
system_prompt = first_prompt.construct_first_prompt(fisrt_message=[query])
system_prompt = first_prompt.construct_first_prompt(fisrt_message=[query], db_schemes= gen_sqlgen_conversation(dbname))
logger.info("[TEST]:" + system_prompt)
template_name = "auto_dbgpt_one_shot"
new_state = conv_templates[template_name].copy()
@ -218,7 +225,13 @@ def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.
"stop": state.sep if state.sep_style == SeparatorStyle.SINGLE else state.sep2,
}
logger.info(f"Requert: \n{payload}")
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
auto_db_gpt_response(first_prompt.prompt_generator, payload)
else:
stream_ai_response(payload)
def stream_ai_response(payload):
# 流式输出
state.messages[-1][-1] = ""
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
@ -264,6 +277,18 @@ def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.
}
fout.write(json.dumps(data) + "\n")
def auto_db_gpt_response( prompt: PromptGenerator, payload)->str:
response = requests.post(urljoin(VICUNA_MODEL_SERVER, "generate"),
headers=headers, json=payload, timeout=30)
print(response)
try:
plugin_resp = execute_ai_response_json(prompt, response)
print(plugin_resp)
except NotCommands as e:
print(str(e))
return "auto_db_gpt_response!"
block_css = (
code_highlight_css
+ """
@ -280,6 +305,12 @@ block_css = (
"""
)
def change_sql_mode(sql_mode):
if sql_mode in ["直接执行结果"]:
return gr.update(visible=True)
else:
return gr.update(visible=False)
def change_mode(mode):
if mode in ["默认知识库对话", "LLM原生对话"]:
return gr.update(visible=False)
@ -325,6 +356,9 @@ def build_single_model_ui():
tabs= gr.Tabs()
with tabs:
tab_sql = gr.TabItem("SQL生成与诊断", elem_id="SQL")
sql_mode = gr.Radio(["直接执行结果", "不执行结果"], show_label=False, value="不执行结果")
sql_vs_setting = gr.Markdown("自动执行模式下, DB-GPT可以具备执行SQL、从网络读取知识自动化存储学习的能力")
sql_mode.change(fn=change_sql_mode, inputs=sql_mode, outputs=sql_vs_setting)
with tab_sql:
# TODO A selector to choose database
with gr.Row(elem_id="db_selector"):
@ -383,7 +417,7 @@ def build_single_model_ui():
btn_list = [regenerate_btn, clear_btn]
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
http_bot,
[state, mode, db_selector, temperature, max_output_tokens],
[state, mode, sql_mode, db_selector, temperature, max_output_tokens],
[state, chatbot] + btn_list,
)
clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
@ -392,7 +426,7 @@ def build_single_model_ui():
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
).then(
http_bot,
[state, mode, db_selector, temperature, max_output_tokens],
[state, mode, sql_mode, db_selector, temperature, max_output_tokens],
[state, chatbot] + btn_list,
)
@ -400,7 +434,7 @@ def build_single_model_ui():
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
).then(
http_bot,
[state, mode, db_selector, temperature, max_output_tokens],
[state, mode, sql_mode, db_selector, temperature, max_output_tokens],
[state, chatbot] + btn_list
)

Binary file not shown.