diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index fff4ad60d..81b390734 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -28,7 +28,7 @@ VECTOR_SEARCH_TOP_K = 3 LLM_MODEL = "vicuna-13b" LIMIT_MODEL_CONCURRENCY = 5 MAX_POSITION_EMBEDDINGS = 4096 -VICUNA_MODEL_SERVER = "http://121.41.167.183:8000" +VICUNA_MODEL_SERVER = "http://120.27.148.250:8000" # Load model config ISLOAD_8BIT = True @@ -37,7 +37,7 @@ ISDEBUG = False DB_SETTINGS = { "user": "root", - "password": "aa12345678", + "password": "aa123456", "host": "127.0.0.1", "port": 3306 } \ No newline at end of file diff --git a/pilot/conversation.py b/pilot/conversation.py index 5f76c814f..16e33534b 100644 --- a/pilot/conversation.py +++ b/pilot/conversation.py @@ -159,7 +159,9 @@ auto_dbgpt_one_shot = Conversation( 1. If you are unsure how you previously did something or want to recall past events, thinking about similar events will help you remember. 2. No user assistance 3. Exclusively use the commands listed in double quotes e.g. "command name" - + DBScheme: + + Commands: 1. analyze_code: Analyze Code, args: "code": "" 2. execute_python_file: Execute Python File, args: "filename": "" @@ -249,6 +251,11 @@ conv_qa_prompt_template = """ 基于以下已知的信息, 专业、详细的回 default_conversation = conv_one_shot +conversation_sql_mode ={ + "auto_execute_ai_response": "直接执行结果", + "dont_execute_ai_response": "不直接执行结果" +} + conversation_types = { "native": "LLM原生对话", "default_knownledge": "默认知识库对话", diff --git a/pilot/model/inference.py b/pilot/model/inference.py index c62b0e255..109fe7e98 100644 --- a/pilot/model/inference.py +++ b/pilot/model/inference.py @@ -69,11 +69,99 @@ def generate_stream(model, tokenizer, params, device, del past_key_values @torch.inference_mode() -def generate_output(model, tokenizer, params, device, context_len=2048, stream_interval=2): +def generate_output(model, tokenizer, params, device, context_len=4096, stream_interval=2): + """Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py """ + prompt = params["prompt"] + l_prompt = len(prompt) + temperature = float(params.get("temperature", 1.0)) + max_new_tokens = int(params.get("max_new_tokens", 2048)) + stop_parameter = params.get("stop", None) + + if stop_parameter == tokenizer.eos_token: + stop_parameter = None + stop_strings = [] + if isinstance(stop_parameter, str): + stop_strings.append(stop_parameter) + elif isinstance(stop_parameter, list): + stop_strings = stop_parameter + elif stop_parameter is None: + pass + else: + raise TypeError("Stop parameter must be string or list of strings.") + + + input_ids = tokenizer(prompt).input_ids + output_ids = list(input_ids) + + max_src_len = context_len - max_new_tokens - 8 + input_ids = input_ids[-max_src_len:] + + for i in range(max_new_tokens): + if i == 0: + out = model( + torch.as_tensor([input_ids], device=device), use_cache=True) + logits = out.logits + past_key_values = out.past_key_values + else: + attention_mask = torch.ones( + 1, past_key_values[0][0].shape[-2] + 1, device=device) + out = model(input_ids=torch.as_tensor([[token]], device=device), + use_cache=True, + attention_mask=attention_mask, + past_key_values=past_key_values) + logits = out.logits + past_key_values = out.past_key_values + + last_token_logits = logits[0][-1] + + if device == "mps": + # Switch to CPU by avoiding some bugs in mps backend. + last_token_logits = last_token_logits.float().to("cpu") + + if temperature < 1e-4: + token = int(torch.argmax(last_token_logits)) + else: + probs = torch.softmax(last_token_logits / temperature, dim=-1) + token = int(torch.multinomial(probs, num_samples=1)) + + output_ids.append(token) + + if token == tokenizer.eos_token_id: + stopped = True + else: + stopped = False + + + output = tokenizer.decode(output_ids, skip_special_tokens=True) + print("Partial output:", output) + for stop_str in stop_strings: + # print(f"Looking for '{stop_str}' in '{output[:l_prompt]}'#END") + pos = output.rfind(stop_str) + if pos != -1: + # print("Found stop str: ", output) + output = output[:pos] + # print("Trimmed output: ", output) + stopped = True + stop_word = stop_str + break + else: + pass + # print("Not found") + if stopped: + break + + del past_key_values + if pos != -1: + return output[:pos] + return output + +@torch.inference_mode() +def generate_output_ex(model, tokenizer, params, device, context_len=2048, stream_interval=2): prompt = params["prompt"] temperature = float(params.get("temperature", 1.0)) max_new_tokens = int(params.get("max_new_tokens", 2048)) stop_parameter = params.get("stop", None) + if stop_parameter == tokenizer.eos_token: stop_parameter = None stop_strings = [] diff --git a/pilot/prompts/first_conversation_prompt.py b/pilot/prompts/first_conversation_prompt.py index 9b9afc025..36c7d5c4d 100644 --- a/pilot/prompts/first_conversation_prompt.py +++ b/pilot/prompts/first_conversation_prompt.py @@ -39,6 +39,7 @@ class FirstPrompt: def construct_first_prompt( self, fisrt_message: [str]=[], + db_schemes: str=None, prompt_generator: Optional[PromptGenerator] = None ) -> str: """ @@ -88,6 +89,10 @@ class FirstPrompt: self.ai_goals = fisrt_message for i, goal in enumerate(self.ai_goals): full_prompt += f"{i+1}. {goal}\n" + if db_schemes: + full_prompt += f"DB SCHEME:\n\n" + full_prompt += f"{db_schemes}\n" + # if self.api_budget > 0.0: # full_prompt += f"\nIt takes money to let you run. Your API budget is ${self.api_budget:.3f}" self.prompt_generator = prompt_generator diff --git a/pilot/server/llmserver.py b/pilot/server/llmserver.py index 95781b69b..190571afe 100644 --- a/pilot/server/llmserver.py +++ b/pilot/server/llmserver.py @@ -93,14 +93,9 @@ async def api_generate_stream(request: Request): return StreamingResponse(generator, background=background_tasks) @app.post("/generate") -def generate(prompt_request: PromptRequest): - params = { - "prompt": prompt_request.prompt, - "temperature": prompt_request.temperature, - "max_new_tokens": prompt_request.max_new_tokens, - "stop": prompt_request.stop - } +def generate(prompt_request: Request): + params = request.json() print("Receive prompt: ", params["prompt"]) output = generate_output(model, tokenizer, params, DEVICE) print("Output: ", output) diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index f394bd9e7..40d96b1ef 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -21,11 +21,15 @@ from pilot.plugins import scan_plugins from pilot.configs.config import Config from pilot.commands.command_mange import CommandRegistry from pilot.prompts.first_conversation_prompt import FirstPrompt +from pilot.prompts.generator import PromptGenerator + +from pilot.commands.exception_not_commands import NotCommands from pilot.conversation import ( default_conversation, conv_templates, conversation_types, + conversation_sql_mode, SeparatorStyle ) @@ -152,11 +156,13 @@ def post_process_code(code): code = sep.join(blocks) return code -def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.Request): - - # MOCk - autogpt = True +def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request): + if sql_mode == conversation_sql_mode["auto_execute_ai_response"]: + print("AUTO DB-GPT模式.") + if sql_mode == conversation_sql_mode["dont_execute_ai_response"]: + print("标准DB-GPT模式.") print("是否是AUTO-GPT模式.", autogpt) + start_tstamp = time.time() model_name = LLM_MODEL @@ -167,17 +173,18 @@ def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr. yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 return + cfg = Config() + first_prompt = FirstPrompt() + # TODO when tab mode is AUTO_GPT, Prompt need to rebuild. if len(state.messages) == state.offset + 2: query = state.messages[-2][1] # 第一轮对话需要加入提示Prompt - cfg = Config() - first_prompt = FirstPrompt() first_prompt.command_registry = cfg.command_registry - if(autogpt): + if sql_mode == conversation_sql_mode["auto_execute_ai_response"]: # autogpt模式的第一轮对话需要 构建专属prompt - system_prompt = first_prompt.construct_first_prompt(fisrt_message=[query]) + system_prompt = first_prompt.construct_first_prompt(fisrt_message=[query], db_schemes= gen_sqlgen_conversation(dbname)) logger.info("[TEST]:" + system_prompt) template_name = "auto_dbgpt_one_shot" new_state = conv_templates[template_name].copy() @@ -218,7 +225,13 @@ def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr. "stop": state.sep if state.sep_style == SeparatorStyle.SINGLE else state.sep2, } logger.info(f"Requert: \n{payload}") + if sql_mode == conversation_sql_mode["auto_execute_ai_response"]: + auto_db_gpt_response(first_prompt.prompt_generator, payload) + else: + stream_ai_response(payload) +def stream_ai_response(payload): + # 流式输出 state.messages[-1][-1] = "▌" yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 @@ -264,6 +277,18 @@ def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr. } fout.write(json.dumps(data) + "\n") + +def auto_db_gpt_response( prompt: PromptGenerator, payload)->str: + response = requests.post(urljoin(VICUNA_MODEL_SERVER, "generate"), + headers=headers, json=payload, timeout=30) + print(response) + try: + plugin_resp = execute_ai_response_json(prompt, response) + print(plugin_resp) + except NotCommands as e: + print(str(e)) + return "auto_db_gpt_response!" + block_css = ( code_highlight_css + """ @@ -280,6 +305,12 @@ block_css = ( """ ) +def change_sql_mode(sql_mode): + if sql_mode in ["直接执行结果"]: + return gr.update(visible=True) + else: + return gr.update(visible=False) + def change_mode(mode): if mode in ["默认知识库对话", "LLM原生对话"]: return gr.update(visible=False) @@ -325,6 +356,9 @@ def build_single_model_ui(): tabs= gr.Tabs() with tabs: tab_sql = gr.TabItem("SQL生成与诊断", elem_id="SQL") + sql_mode = gr.Radio(["直接执行结果", "不执行结果"], show_label=False, value="不执行结果") + sql_vs_setting = gr.Markdown("自动执行模式下, DB-GPT可以具备执行SQL、从网络读取知识自动化存储学习的能力") + sql_mode.change(fn=change_sql_mode, inputs=sql_mode, outputs=sql_vs_setting) with tab_sql: # TODO A selector to choose database with gr.Row(elem_id="db_selector"): @@ -383,7 +417,7 @@ def build_single_model_ui(): btn_list = [regenerate_btn, clear_btn] regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then( http_bot, - [state, mode, db_selector, temperature, max_output_tokens], + [state, mode, sql_mode, db_selector, temperature, max_output_tokens], [state, chatbot] + btn_list, ) clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list) @@ -392,7 +426,7 @@ def build_single_model_ui(): add_text, [state, textbox], [state, chatbot, textbox] + btn_list ).then( http_bot, - [state, mode, db_selector, temperature, max_output_tokens], + [state, mode, sql_mode, db_selector, temperature, max_output_tokens], [state, chatbot] + btn_list, ) @@ -400,7 +434,7 @@ def build_single_model_ui(): add_text, [state, textbox], [state, chatbot, textbox] + btn_list ).then( http_bot, - [state, mode, db_selector, temperature, max_output_tokens], + [state, mode, sql_mode, db_selector, temperature, max_output_tokens], [state, chatbot] + btn_list ) diff --git a/plugins/DB-GPT-SQL-Execution-Plugin.zip b/plugins/DB-GPT-SQL-Execution-Plugin.zip new file mode 100644 index 000000000..5327a3514 Binary files /dev/null and b/plugins/DB-GPT-SQL-Execution-Plugin.zip differ diff --git a/plugins/Db-GPT-SimpleChart-Plugin.zip b/plugins/Db-GPT-SimpleChart-Plugin.zip index 03d995339..5cdb06a57 100644 Binary files a/plugins/Db-GPT-SimpleChart-Plugin.zip and b/plugins/Db-GPT-SimpleChart-Plugin.zip differ