mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-06 19:04:24 +00:00
同步服务接口
This commit is contained in:
parent
d6d84c70b0
commit
bde8503620
@ -28,7 +28,7 @@ VECTOR_SEARCH_TOP_K = 3
|
|||||||
LLM_MODEL = "vicuna-13b"
|
LLM_MODEL = "vicuna-13b"
|
||||||
LIMIT_MODEL_CONCURRENCY = 5
|
LIMIT_MODEL_CONCURRENCY = 5
|
||||||
MAX_POSITION_EMBEDDINGS = 4096
|
MAX_POSITION_EMBEDDINGS = 4096
|
||||||
VICUNA_MODEL_SERVER = "http://121.41.167.183:8000"
|
VICUNA_MODEL_SERVER = "http://120.27.148.250:8000"
|
||||||
|
|
||||||
# Load model config
|
# Load model config
|
||||||
ISLOAD_8BIT = True
|
ISLOAD_8BIT = True
|
||||||
@ -37,7 +37,7 @@ ISDEBUG = False
|
|||||||
|
|
||||||
DB_SETTINGS = {
|
DB_SETTINGS = {
|
||||||
"user": "root",
|
"user": "root",
|
||||||
"password": "aa12345678",
|
"password": "aa123456",
|
||||||
"host": "127.0.0.1",
|
"host": "127.0.0.1",
|
||||||
"port": 3306
|
"port": 3306
|
||||||
}
|
}
|
@ -159,6 +159,8 @@ auto_dbgpt_one_shot = Conversation(
|
|||||||
1. If you are unsure how you previously did something or want to recall past events, thinking about similar events will help you remember.
|
1. If you are unsure how you previously did something or want to recall past events, thinking about similar events will help you remember.
|
||||||
2. No user assistance
|
2. No user assistance
|
||||||
3. Exclusively use the commands listed in double quotes e.g. "command name"
|
3. Exclusively use the commands listed in double quotes e.g. "command name"
|
||||||
|
DBScheme:
|
||||||
|
|
||||||
|
|
||||||
Commands:
|
Commands:
|
||||||
1. analyze_code: Analyze Code, args: "code": "<full_code_string>"
|
1. analyze_code: Analyze Code, args: "code": "<full_code_string>"
|
||||||
@ -249,6 +251,11 @@ conv_qa_prompt_template = """ 基于以下已知的信息, 专业、详细的回
|
|||||||
|
|
||||||
default_conversation = conv_one_shot
|
default_conversation = conv_one_shot
|
||||||
|
|
||||||
|
conversation_sql_mode ={
|
||||||
|
"auto_execute_ai_response": "直接执行结果",
|
||||||
|
"dont_execute_ai_response": "不直接执行结果"
|
||||||
|
}
|
||||||
|
|
||||||
conversation_types = {
|
conversation_types = {
|
||||||
"native": "LLM原生对话",
|
"native": "LLM原生对话",
|
||||||
"default_knownledge": "默认知识库对话",
|
"default_knownledge": "默认知识库对话",
|
||||||
|
@ -69,11 +69,99 @@ def generate_stream(model, tokenizer, params, device,
|
|||||||
del past_key_values
|
del past_key_values
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def generate_output(model, tokenizer, params, device, context_len=2048, stream_interval=2):
|
def generate_output(model, tokenizer, params, device, context_len=4096, stream_interval=2):
|
||||||
|
"""Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py """
|
||||||
|
prompt = params["prompt"]
|
||||||
|
l_prompt = len(prompt)
|
||||||
|
temperature = float(params.get("temperature", 1.0))
|
||||||
|
max_new_tokens = int(params.get("max_new_tokens", 2048))
|
||||||
|
stop_parameter = params.get("stop", None)
|
||||||
|
|
||||||
|
if stop_parameter == tokenizer.eos_token:
|
||||||
|
stop_parameter = None
|
||||||
|
stop_strings = []
|
||||||
|
if isinstance(stop_parameter, str):
|
||||||
|
stop_strings.append(stop_parameter)
|
||||||
|
elif isinstance(stop_parameter, list):
|
||||||
|
stop_strings = stop_parameter
|
||||||
|
elif stop_parameter is None:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise TypeError("Stop parameter must be string or list of strings.")
|
||||||
|
|
||||||
|
|
||||||
|
input_ids = tokenizer(prompt).input_ids
|
||||||
|
output_ids = list(input_ids)
|
||||||
|
|
||||||
|
max_src_len = context_len - max_new_tokens - 8
|
||||||
|
input_ids = input_ids[-max_src_len:]
|
||||||
|
|
||||||
|
for i in range(max_new_tokens):
|
||||||
|
if i == 0:
|
||||||
|
out = model(
|
||||||
|
torch.as_tensor([input_ids], device=device), use_cache=True)
|
||||||
|
logits = out.logits
|
||||||
|
past_key_values = out.past_key_values
|
||||||
|
else:
|
||||||
|
attention_mask = torch.ones(
|
||||||
|
1, past_key_values[0][0].shape[-2] + 1, device=device)
|
||||||
|
out = model(input_ids=torch.as_tensor([[token]], device=device),
|
||||||
|
use_cache=True,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
past_key_values=past_key_values)
|
||||||
|
logits = out.logits
|
||||||
|
past_key_values = out.past_key_values
|
||||||
|
|
||||||
|
last_token_logits = logits[0][-1]
|
||||||
|
|
||||||
|
if device == "mps":
|
||||||
|
# Switch to CPU by avoiding some bugs in mps backend.
|
||||||
|
last_token_logits = last_token_logits.float().to("cpu")
|
||||||
|
|
||||||
|
if temperature < 1e-4:
|
||||||
|
token = int(torch.argmax(last_token_logits))
|
||||||
|
else:
|
||||||
|
probs = torch.softmax(last_token_logits / temperature, dim=-1)
|
||||||
|
token = int(torch.multinomial(probs, num_samples=1))
|
||||||
|
|
||||||
|
output_ids.append(token)
|
||||||
|
|
||||||
|
if token == tokenizer.eos_token_id:
|
||||||
|
stopped = True
|
||||||
|
else:
|
||||||
|
stopped = False
|
||||||
|
|
||||||
|
|
||||||
|
output = tokenizer.decode(output_ids, skip_special_tokens=True)
|
||||||
|
print("Partial output:", output)
|
||||||
|
for stop_str in stop_strings:
|
||||||
|
# print(f"Looking for '{stop_str}' in '{output[:l_prompt]}'#END")
|
||||||
|
pos = output.rfind(stop_str)
|
||||||
|
if pos != -1:
|
||||||
|
# print("Found stop str: ", output)
|
||||||
|
output = output[:pos]
|
||||||
|
# print("Trimmed output: ", output)
|
||||||
|
stopped = True
|
||||||
|
stop_word = stop_str
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
# print("Not found")
|
||||||
|
if stopped:
|
||||||
|
break
|
||||||
|
|
||||||
|
del past_key_values
|
||||||
|
if pos != -1:
|
||||||
|
return output[:pos]
|
||||||
|
return output
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def generate_output_ex(model, tokenizer, params, device, context_len=2048, stream_interval=2):
|
||||||
prompt = params["prompt"]
|
prompt = params["prompt"]
|
||||||
temperature = float(params.get("temperature", 1.0))
|
temperature = float(params.get("temperature", 1.0))
|
||||||
max_new_tokens = int(params.get("max_new_tokens", 2048))
|
max_new_tokens = int(params.get("max_new_tokens", 2048))
|
||||||
stop_parameter = params.get("stop", None)
|
stop_parameter = params.get("stop", None)
|
||||||
|
|
||||||
if stop_parameter == tokenizer.eos_token:
|
if stop_parameter == tokenizer.eos_token:
|
||||||
stop_parameter = None
|
stop_parameter = None
|
||||||
stop_strings = []
|
stop_strings = []
|
||||||
|
@ -39,6 +39,7 @@ class FirstPrompt:
|
|||||||
def construct_first_prompt(
|
def construct_first_prompt(
|
||||||
self,
|
self,
|
||||||
fisrt_message: [str]=[],
|
fisrt_message: [str]=[],
|
||||||
|
db_schemes: str=None,
|
||||||
prompt_generator: Optional[PromptGenerator] = None
|
prompt_generator: Optional[PromptGenerator] = None
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
@ -88,6 +89,10 @@ class FirstPrompt:
|
|||||||
self.ai_goals = fisrt_message
|
self.ai_goals = fisrt_message
|
||||||
for i, goal in enumerate(self.ai_goals):
|
for i, goal in enumerate(self.ai_goals):
|
||||||
full_prompt += f"{i+1}. {goal}\n"
|
full_prompt += f"{i+1}. {goal}\n"
|
||||||
|
if db_schemes:
|
||||||
|
full_prompt += f"DB SCHEME:\n\n"
|
||||||
|
full_prompt += f"{db_schemes}\n"
|
||||||
|
|
||||||
# if self.api_budget > 0.0:
|
# if self.api_budget > 0.0:
|
||||||
# full_prompt += f"\nIt takes money to let you run. Your API budget is ${self.api_budget:.3f}"
|
# full_prompt += f"\nIt takes money to let you run. Your API budget is ${self.api_budget:.3f}"
|
||||||
self.prompt_generator = prompt_generator
|
self.prompt_generator = prompt_generator
|
||||||
|
@ -93,14 +93,9 @@ async def api_generate_stream(request: Request):
|
|||||||
return StreamingResponse(generator, background=background_tasks)
|
return StreamingResponse(generator, background=background_tasks)
|
||||||
|
|
||||||
@app.post("/generate")
|
@app.post("/generate")
|
||||||
def generate(prompt_request: PromptRequest):
|
def generate(prompt_request: Request):
|
||||||
params = {
|
|
||||||
"prompt": prompt_request.prompt,
|
|
||||||
"temperature": prompt_request.temperature,
|
|
||||||
"max_new_tokens": prompt_request.max_new_tokens,
|
|
||||||
"stop": prompt_request.stop
|
|
||||||
}
|
|
||||||
|
|
||||||
|
params = request.json()
|
||||||
print("Receive prompt: ", params["prompt"])
|
print("Receive prompt: ", params["prompt"])
|
||||||
output = generate_output(model, tokenizer, params, DEVICE)
|
output = generate_output(model, tokenizer, params, DEVICE)
|
||||||
print("Output: ", output)
|
print("Output: ", output)
|
||||||
|
@ -21,11 +21,15 @@ from pilot.plugins import scan_plugins
|
|||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.commands.command_mange import CommandRegistry
|
from pilot.commands.command_mange import CommandRegistry
|
||||||
from pilot.prompts.first_conversation_prompt import FirstPrompt
|
from pilot.prompts.first_conversation_prompt import FirstPrompt
|
||||||
|
from pilot.prompts.generator import PromptGenerator
|
||||||
|
|
||||||
|
from pilot.commands.exception_not_commands import NotCommands
|
||||||
|
|
||||||
from pilot.conversation import (
|
from pilot.conversation import (
|
||||||
default_conversation,
|
default_conversation,
|
||||||
conv_templates,
|
conv_templates,
|
||||||
conversation_types,
|
conversation_types,
|
||||||
|
conversation_sql_mode,
|
||||||
SeparatorStyle
|
SeparatorStyle
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -152,11 +156,13 @@ def post_process_code(code):
|
|||||||
code = sep.join(blocks)
|
code = sep.join(blocks)
|
||||||
return code
|
return code
|
||||||
|
|
||||||
def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.Request):
|
def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request):
|
||||||
|
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
|
||||||
# MOCk
|
print("AUTO DB-GPT模式.")
|
||||||
autogpt = True
|
if sql_mode == conversation_sql_mode["dont_execute_ai_response"]:
|
||||||
|
print("标准DB-GPT模式.")
|
||||||
print("是否是AUTO-GPT模式.", autogpt)
|
print("是否是AUTO-GPT模式.", autogpt)
|
||||||
|
|
||||||
start_tstamp = time.time()
|
start_tstamp = time.time()
|
||||||
model_name = LLM_MODEL
|
model_name = LLM_MODEL
|
||||||
|
|
||||||
@ -167,17 +173,18 @@ def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.
|
|||||||
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
||||||
return
|
return
|
||||||
|
|
||||||
|
cfg = Config()
|
||||||
|
first_prompt = FirstPrompt()
|
||||||
|
|
||||||
|
|
||||||
# TODO when tab mode is AUTO_GPT, Prompt need to rebuild.
|
# TODO when tab mode is AUTO_GPT, Prompt need to rebuild.
|
||||||
if len(state.messages) == state.offset + 2:
|
if len(state.messages) == state.offset + 2:
|
||||||
query = state.messages[-2][1]
|
query = state.messages[-2][1]
|
||||||
# 第一轮对话需要加入提示Prompt
|
# 第一轮对话需要加入提示Prompt
|
||||||
cfg = Config()
|
|
||||||
first_prompt = FirstPrompt()
|
|
||||||
first_prompt.command_registry = cfg.command_registry
|
first_prompt.command_registry = cfg.command_registry
|
||||||
if(autogpt):
|
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
|
||||||
# autogpt模式的第一轮对话需要 构建专属prompt
|
# autogpt模式的第一轮对话需要 构建专属prompt
|
||||||
system_prompt = first_prompt.construct_first_prompt(fisrt_message=[query])
|
system_prompt = first_prompt.construct_first_prompt(fisrt_message=[query], db_schemes= gen_sqlgen_conversation(dbname))
|
||||||
logger.info("[TEST]:" + system_prompt)
|
logger.info("[TEST]:" + system_prompt)
|
||||||
template_name = "auto_dbgpt_one_shot"
|
template_name = "auto_dbgpt_one_shot"
|
||||||
new_state = conv_templates[template_name].copy()
|
new_state = conv_templates[template_name].copy()
|
||||||
@ -218,7 +225,13 @@ def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.
|
|||||||
"stop": state.sep if state.sep_style == SeparatorStyle.SINGLE else state.sep2,
|
"stop": state.sep if state.sep_style == SeparatorStyle.SINGLE else state.sep2,
|
||||||
}
|
}
|
||||||
logger.info(f"Requert: \n{payload}")
|
logger.info(f"Requert: \n{payload}")
|
||||||
|
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
|
||||||
|
auto_db_gpt_response(first_prompt.prompt_generator, payload)
|
||||||
|
else:
|
||||||
|
stream_ai_response(payload)
|
||||||
|
|
||||||
|
def stream_ai_response(payload):
|
||||||
|
# 流式输出
|
||||||
state.messages[-1][-1] = "▌"
|
state.messages[-1][-1] = "▌"
|
||||||
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
||||||
|
|
||||||
@ -264,6 +277,18 @@ def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.
|
|||||||
}
|
}
|
||||||
fout.write(json.dumps(data) + "\n")
|
fout.write(json.dumps(data) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
def auto_db_gpt_response( prompt: PromptGenerator, payload)->str:
|
||||||
|
response = requests.post(urljoin(VICUNA_MODEL_SERVER, "generate"),
|
||||||
|
headers=headers, json=payload, timeout=30)
|
||||||
|
print(response)
|
||||||
|
try:
|
||||||
|
plugin_resp = execute_ai_response_json(prompt, response)
|
||||||
|
print(plugin_resp)
|
||||||
|
except NotCommands as e:
|
||||||
|
print(str(e))
|
||||||
|
return "auto_db_gpt_response!"
|
||||||
|
|
||||||
block_css = (
|
block_css = (
|
||||||
code_highlight_css
|
code_highlight_css
|
||||||
+ """
|
+ """
|
||||||
@ -280,6 +305,12 @@ block_css = (
|
|||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def change_sql_mode(sql_mode):
|
||||||
|
if sql_mode in ["直接执行结果"]:
|
||||||
|
return gr.update(visible=True)
|
||||||
|
else:
|
||||||
|
return gr.update(visible=False)
|
||||||
|
|
||||||
def change_mode(mode):
|
def change_mode(mode):
|
||||||
if mode in ["默认知识库对话", "LLM原生对话"]:
|
if mode in ["默认知识库对话", "LLM原生对话"]:
|
||||||
return gr.update(visible=False)
|
return gr.update(visible=False)
|
||||||
@ -325,6 +356,9 @@ def build_single_model_ui():
|
|||||||
tabs= gr.Tabs()
|
tabs= gr.Tabs()
|
||||||
with tabs:
|
with tabs:
|
||||||
tab_sql = gr.TabItem("SQL生成与诊断", elem_id="SQL")
|
tab_sql = gr.TabItem("SQL生成与诊断", elem_id="SQL")
|
||||||
|
sql_mode = gr.Radio(["直接执行结果", "不执行结果"], show_label=False, value="不执行结果")
|
||||||
|
sql_vs_setting = gr.Markdown("自动执行模式下, DB-GPT可以具备执行SQL、从网络读取知识自动化存储学习的能力")
|
||||||
|
sql_mode.change(fn=change_sql_mode, inputs=sql_mode, outputs=sql_vs_setting)
|
||||||
with tab_sql:
|
with tab_sql:
|
||||||
# TODO A selector to choose database
|
# TODO A selector to choose database
|
||||||
with gr.Row(elem_id="db_selector"):
|
with gr.Row(elem_id="db_selector"):
|
||||||
@ -383,7 +417,7 @@ def build_single_model_ui():
|
|||||||
btn_list = [regenerate_btn, clear_btn]
|
btn_list = [regenerate_btn, clear_btn]
|
||||||
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
|
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
|
||||||
http_bot,
|
http_bot,
|
||||||
[state, mode, db_selector, temperature, max_output_tokens],
|
[state, mode, sql_mode, db_selector, temperature, max_output_tokens],
|
||||||
[state, chatbot] + btn_list,
|
[state, chatbot] + btn_list,
|
||||||
)
|
)
|
||||||
clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
|
clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
|
||||||
@ -392,7 +426,7 @@ def build_single_model_ui():
|
|||||||
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
|
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
|
||||||
).then(
|
).then(
|
||||||
http_bot,
|
http_bot,
|
||||||
[state, mode, db_selector, temperature, max_output_tokens],
|
[state, mode, sql_mode, db_selector, temperature, max_output_tokens],
|
||||||
[state, chatbot] + btn_list,
|
[state, chatbot] + btn_list,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -400,7 +434,7 @@ def build_single_model_ui():
|
|||||||
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
|
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
|
||||||
).then(
|
).then(
|
||||||
http_bot,
|
http_bot,
|
||||||
[state, mode, db_selector, temperature, max_output_tokens],
|
[state, mode, sql_mode, db_selector, temperature, max_output_tokens],
|
||||||
[state, chatbot] + btn_list
|
[state, chatbot] + btn_list
|
||||||
)
|
)
|
||||||
|
|
||||||
|
BIN
plugins/DB-GPT-SQL-Execution-Plugin.zip
Normal file
BIN
plugins/DB-GPT-SQL-Execution-Plugin.zip
Normal file
Binary file not shown.
Binary file not shown.
Loading…
Reference in New Issue
Block a user