mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-01 16:18:27 +00:00
同步服务接口
This commit is contained in:
parent
d6d84c70b0
commit
bde8503620
@ -28,7 +28,7 @@ VECTOR_SEARCH_TOP_K = 3
|
||||
LLM_MODEL = "vicuna-13b"
|
||||
LIMIT_MODEL_CONCURRENCY = 5
|
||||
MAX_POSITION_EMBEDDINGS = 4096
|
||||
VICUNA_MODEL_SERVER = "http://121.41.167.183:8000"
|
||||
VICUNA_MODEL_SERVER = "http://120.27.148.250:8000"
|
||||
|
||||
# Load model config
|
||||
ISLOAD_8BIT = True
|
||||
@ -37,7 +37,7 @@ ISDEBUG = False
|
||||
|
||||
DB_SETTINGS = {
|
||||
"user": "root",
|
||||
"password": "aa12345678",
|
||||
"password": "aa123456",
|
||||
"host": "127.0.0.1",
|
||||
"port": 3306
|
||||
}
|
@ -159,7 +159,9 @@ auto_dbgpt_one_shot = Conversation(
|
||||
1. If you are unsure how you previously did something or want to recall past events, thinking about similar events will help you remember.
|
||||
2. No user assistance
|
||||
3. Exclusively use the commands listed in double quotes e.g. "command name"
|
||||
|
||||
DBScheme:
|
||||
|
||||
|
||||
Commands:
|
||||
1. analyze_code: Analyze Code, args: "code": "<full_code_string>"
|
||||
2. execute_python_file: Execute Python File, args: "filename": "<filename>"
|
||||
@ -249,6 +251,11 @@ conv_qa_prompt_template = """ 基于以下已知的信息, 专业、详细的回
|
||||
|
||||
default_conversation = conv_one_shot
|
||||
|
||||
conversation_sql_mode ={
|
||||
"auto_execute_ai_response": "直接执行结果",
|
||||
"dont_execute_ai_response": "不直接执行结果"
|
||||
}
|
||||
|
||||
conversation_types = {
|
||||
"native": "LLM原生对话",
|
||||
"default_knownledge": "默认知识库对话",
|
||||
|
@ -69,11 +69,99 @@ def generate_stream(model, tokenizer, params, device,
|
||||
del past_key_values
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate_output(model, tokenizer, params, device, context_len=2048, stream_interval=2):
|
||||
def generate_output(model, tokenizer, params, device, context_len=4096, stream_interval=2):
|
||||
"""Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py """
|
||||
prompt = params["prompt"]
|
||||
l_prompt = len(prompt)
|
||||
temperature = float(params.get("temperature", 1.0))
|
||||
max_new_tokens = int(params.get("max_new_tokens", 2048))
|
||||
stop_parameter = params.get("stop", None)
|
||||
|
||||
if stop_parameter == tokenizer.eos_token:
|
||||
stop_parameter = None
|
||||
stop_strings = []
|
||||
if isinstance(stop_parameter, str):
|
||||
stop_strings.append(stop_parameter)
|
||||
elif isinstance(stop_parameter, list):
|
||||
stop_strings = stop_parameter
|
||||
elif stop_parameter is None:
|
||||
pass
|
||||
else:
|
||||
raise TypeError("Stop parameter must be string or list of strings.")
|
||||
|
||||
|
||||
input_ids = tokenizer(prompt).input_ids
|
||||
output_ids = list(input_ids)
|
||||
|
||||
max_src_len = context_len - max_new_tokens - 8
|
||||
input_ids = input_ids[-max_src_len:]
|
||||
|
||||
for i in range(max_new_tokens):
|
||||
if i == 0:
|
||||
out = model(
|
||||
torch.as_tensor([input_ids], device=device), use_cache=True)
|
||||
logits = out.logits
|
||||
past_key_values = out.past_key_values
|
||||
else:
|
||||
attention_mask = torch.ones(
|
||||
1, past_key_values[0][0].shape[-2] + 1, device=device)
|
||||
out = model(input_ids=torch.as_tensor([[token]], device=device),
|
||||
use_cache=True,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values)
|
||||
logits = out.logits
|
||||
past_key_values = out.past_key_values
|
||||
|
||||
last_token_logits = logits[0][-1]
|
||||
|
||||
if device == "mps":
|
||||
# Switch to CPU by avoiding some bugs in mps backend.
|
||||
last_token_logits = last_token_logits.float().to("cpu")
|
||||
|
||||
if temperature < 1e-4:
|
||||
token = int(torch.argmax(last_token_logits))
|
||||
else:
|
||||
probs = torch.softmax(last_token_logits / temperature, dim=-1)
|
||||
token = int(torch.multinomial(probs, num_samples=1))
|
||||
|
||||
output_ids.append(token)
|
||||
|
||||
if token == tokenizer.eos_token_id:
|
||||
stopped = True
|
||||
else:
|
||||
stopped = False
|
||||
|
||||
|
||||
output = tokenizer.decode(output_ids, skip_special_tokens=True)
|
||||
print("Partial output:", output)
|
||||
for stop_str in stop_strings:
|
||||
# print(f"Looking for '{stop_str}' in '{output[:l_prompt]}'#END")
|
||||
pos = output.rfind(stop_str)
|
||||
if pos != -1:
|
||||
# print("Found stop str: ", output)
|
||||
output = output[:pos]
|
||||
# print("Trimmed output: ", output)
|
||||
stopped = True
|
||||
stop_word = stop_str
|
||||
break
|
||||
else:
|
||||
pass
|
||||
# print("Not found")
|
||||
if stopped:
|
||||
break
|
||||
|
||||
del past_key_values
|
||||
if pos != -1:
|
||||
return output[:pos]
|
||||
return output
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate_output_ex(model, tokenizer, params, device, context_len=2048, stream_interval=2):
|
||||
prompt = params["prompt"]
|
||||
temperature = float(params.get("temperature", 1.0))
|
||||
max_new_tokens = int(params.get("max_new_tokens", 2048))
|
||||
stop_parameter = params.get("stop", None)
|
||||
|
||||
if stop_parameter == tokenizer.eos_token:
|
||||
stop_parameter = None
|
||||
stop_strings = []
|
||||
|
@ -39,6 +39,7 @@ class FirstPrompt:
|
||||
def construct_first_prompt(
|
||||
self,
|
||||
fisrt_message: [str]=[],
|
||||
db_schemes: str=None,
|
||||
prompt_generator: Optional[PromptGenerator] = None
|
||||
) -> str:
|
||||
"""
|
||||
@ -88,6 +89,10 @@ class FirstPrompt:
|
||||
self.ai_goals = fisrt_message
|
||||
for i, goal in enumerate(self.ai_goals):
|
||||
full_prompt += f"{i+1}. {goal}\n"
|
||||
if db_schemes:
|
||||
full_prompt += f"DB SCHEME:\n\n"
|
||||
full_prompt += f"{db_schemes}\n"
|
||||
|
||||
# if self.api_budget > 0.0:
|
||||
# full_prompt += f"\nIt takes money to let you run. Your API budget is ${self.api_budget:.3f}"
|
||||
self.prompt_generator = prompt_generator
|
||||
|
@ -93,14 +93,9 @@ async def api_generate_stream(request: Request):
|
||||
return StreamingResponse(generator, background=background_tasks)
|
||||
|
||||
@app.post("/generate")
|
||||
def generate(prompt_request: PromptRequest):
|
||||
params = {
|
||||
"prompt": prompt_request.prompt,
|
||||
"temperature": prompt_request.temperature,
|
||||
"max_new_tokens": prompt_request.max_new_tokens,
|
||||
"stop": prompt_request.stop
|
||||
}
|
||||
def generate(prompt_request: Request):
|
||||
|
||||
params = request.json()
|
||||
print("Receive prompt: ", params["prompt"])
|
||||
output = generate_output(model, tokenizer, params, DEVICE)
|
||||
print("Output: ", output)
|
||||
|
@ -21,11 +21,15 @@ from pilot.plugins import scan_plugins
|
||||
from pilot.configs.config import Config
|
||||
from pilot.commands.command_mange import CommandRegistry
|
||||
from pilot.prompts.first_conversation_prompt import FirstPrompt
|
||||
from pilot.prompts.generator import PromptGenerator
|
||||
|
||||
from pilot.commands.exception_not_commands import NotCommands
|
||||
|
||||
from pilot.conversation import (
|
||||
default_conversation,
|
||||
conv_templates,
|
||||
conversation_types,
|
||||
conversation_sql_mode,
|
||||
SeparatorStyle
|
||||
)
|
||||
|
||||
@ -152,11 +156,13 @@ def post_process_code(code):
|
||||
code = sep.join(blocks)
|
||||
return code
|
||||
|
||||
def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.Request):
|
||||
|
||||
# MOCk
|
||||
autogpt = True
|
||||
def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request):
|
||||
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
|
||||
print("AUTO DB-GPT模式.")
|
||||
if sql_mode == conversation_sql_mode["dont_execute_ai_response"]:
|
||||
print("标准DB-GPT模式.")
|
||||
print("是否是AUTO-GPT模式.", autogpt)
|
||||
|
||||
start_tstamp = time.time()
|
||||
model_name = LLM_MODEL
|
||||
|
||||
@ -167,17 +173,18 @@ def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.
|
||||
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
||||
return
|
||||
|
||||
cfg = Config()
|
||||
first_prompt = FirstPrompt()
|
||||
|
||||
|
||||
# TODO when tab mode is AUTO_GPT, Prompt need to rebuild.
|
||||
if len(state.messages) == state.offset + 2:
|
||||
query = state.messages[-2][1]
|
||||
# 第一轮对话需要加入提示Prompt
|
||||
cfg = Config()
|
||||
first_prompt = FirstPrompt()
|
||||
first_prompt.command_registry = cfg.command_registry
|
||||
if(autogpt):
|
||||
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
|
||||
# autogpt模式的第一轮对话需要 构建专属prompt
|
||||
system_prompt = first_prompt.construct_first_prompt(fisrt_message=[query])
|
||||
system_prompt = first_prompt.construct_first_prompt(fisrt_message=[query], db_schemes= gen_sqlgen_conversation(dbname))
|
||||
logger.info("[TEST]:" + system_prompt)
|
||||
template_name = "auto_dbgpt_one_shot"
|
||||
new_state = conv_templates[template_name].copy()
|
||||
@ -218,7 +225,13 @@ def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.
|
||||
"stop": state.sep if state.sep_style == SeparatorStyle.SINGLE else state.sep2,
|
||||
}
|
||||
logger.info(f"Requert: \n{payload}")
|
||||
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
|
||||
auto_db_gpt_response(first_prompt.prompt_generator, payload)
|
||||
else:
|
||||
stream_ai_response(payload)
|
||||
|
||||
def stream_ai_response(payload):
|
||||
# 流式输出
|
||||
state.messages[-1][-1] = "▌"
|
||||
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
||||
|
||||
@ -264,6 +277,18 @@ def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.
|
||||
}
|
||||
fout.write(json.dumps(data) + "\n")
|
||||
|
||||
|
||||
def auto_db_gpt_response( prompt: PromptGenerator, payload)->str:
|
||||
response = requests.post(urljoin(VICUNA_MODEL_SERVER, "generate"),
|
||||
headers=headers, json=payload, timeout=30)
|
||||
print(response)
|
||||
try:
|
||||
plugin_resp = execute_ai_response_json(prompt, response)
|
||||
print(plugin_resp)
|
||||
except NotCommands as e:
|
||||
print(str(e))
|
||||
return "auto_db_gpt_response!"
|
||||
|
||||
block_css = (
|
||||
code_highlight_css
|
||||
+ """
|
||||
@ -280,6 +305,12 @@ block_css = (
|
||||
"""
|
||||
)
|
||||
|
||||
def change_sql_mode(sql_mode):
|
||||
if sql_mode in ["直接执行结果"]:
|
||||
return gr.update(visible=True)
|
||||
else:
|
||||
return gr.update(visible=False)
|
||||
|
||||
def change_mode(mode):
|
||||
if mode in ["默认知识库对话", "LLM原生对话"]:
|
||||
return gr.update(visible=False)
|
||||
@ -325,6 +356,9 @@ def build_single_model_ui():
|
||||
tabs= gr.Tabs()
|
||||
with tabs:
|
||||
tab_sql = gr.TabItem("SQL生成与诊断", elem_id="SQL")
|
||||
sql_mode = gr.Radio(["直接执行结果", "不执行结果"], show_label=False, value="不执行结果")
|
||||
sql_vs_setting = gr.Markdown("自动执行模式下, DB-GPT可以具备执行SQL、从网络读取知识自动化存储学习的能力")
|
||||
sql_mode.change(fn=change_sql_mode, inputs=sql_mode, outputs=sql_vs_setting)
|
||||
with tab_sql:
|
||||
# TODO A selector to choose database
|
||||
with gr.Row(elem_id="db_selector"):
|
||||
@ -383,7 +417,7 @@ def build_single_model_ui():
|
||||
btn_list = [regenerate_btn, clear_btn]
|
||||
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
|
||||
http_bot,
|
||||
[state, mode, db_selector, temperature, max_output_tokens],
|
||||
[state, mode, sql_mode, db_selector, temperature, max_output_tokens],
|
||||
[state, chatbot] + btn_list,
|
||||
)
|
||||
clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
|
||||
@ -392,7 +426,7 @@ def build_single_model_ui():
|
||||
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
|
||||
).then(
|
||||
http_bot,
|
||||
[state, mode, db_selector, temperature, max_output_tokens],
|
||||
[state, mode, sql_mode, db_selector, temperature, max_output_tokens],
|
||||
[state, chatbot] + btn_list,
|
||||
)
|
||||
|
||||
@ -400,7 +434,7 @@ def build_single_model_ui():
|
||||
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
|
||||
).then(
|
||||
http_bot,
|
||||
[state, mode, db_selector, temperature, max_output_tokens],
|
||||
[state, mode, sql_mode, db_selector, temperature, max_output_tokens],
|
||||
[state, chatbot] + btn_list
|
||||
)
|
||||
|
||||
|
BIN
plugins/DB-GPT-SQL-Execution-Plugin.zip
Normal file
BIN
plugins/DB-GPT-SQL-Execution-Plugin.zip
Normal file
Binary file not shown.
Binary file not shown.
Loading…
Reference in New Issue
Block a user