Merge remote-tracking branch 'origin/plugin_init' into source_embedding

This commit is contained in:
chenketing 2023-05-14 22:07:47 +08:00
commit 88d23c2147
13 changed files with 336 additions and 121 deletions

View File

@ -42,9 +42,10 @@ def execute_ai_response_json(
cfg = Config()
try:
assistant_reply_json = fix_json_using_multiple_techniques(ai_response)
except (json.JSONDecodeError, ValueError) as e:
except (json.JSONDecodeError, ValueError, AttributeError) as e:
raise NotCommands("非可执行命令结构")
command_name, arguments = get_command(assistant_reply_json)
if cfg.speak_mode:
say_text(f"I want to execute {command_name}")
@ -105,11 +106,15 @@ def execute_command(
or command_name == command["name"].lower()
):
try:
# 删除非定义参数
diff_ags = list(set(arguments.keys()).difference(set(command['args'].keys())))
for arg_name in diff_ags:
del arguments[arg_name]
print(str(arguments))
return command["function"](**arguments)
except Exception as e:
return f"Error: {str(e)}"
raise NotCommands("非可用命令" + command)
raise NotCommands("非可用命令" + command_name)
def get_command(response_json: Dict):

View File

@ -1,4 +1,5 @@
class NotCommands(Exception):
def __init__(self, message, error_code):
def __init__(self, message):
super().__init__(message)
self.error_code = error_code
self.message = message

View File

@ -105,3 +105,7 @@ class Config(metaclass=Singleton):
def set_speak_mode(self, value: bool) -> None:
"""Set the speak mode value."""
self.speak_mode = value
def set_last_plugin_return(self, value: bool) -> None:
"""Set the speak mode value."""
self.last_plugin_return = value

View File

@ -37,7 +37,7 @@ ISDEBUG = False
DB_SETTINGS = {
"user": "root",
"password": "aa12345678",
"password": "aa123456",
"host": "127.0.0.1",
"port": 3306
}

View File

@ -159,7 +159,12 @@ auto_dbgpt_one_shot = Conversation(
1. If you are unsure how you previously did something or want to recall past events, thinking about similar events will help you remember.
2. No user assistance
3. Exclusively use the commands listed in double quotes e.g. "command name"
Schema:
数据库gpt-user的Schema信息如下: users(city,create_time,email,last_login_time,phone,user_name);
Commands:
1. analyze_code: Analyze Code, args: "code": "<full_code_string>"
2. execute_python_file: Execute Python File, args: "filename": "<filename>"
@ -168,21 +173,9 @@ auto_dbgpt_one_shot = Conversation(
5. list_files: List Files in Directory, args: "directory": "<directory>"
6. read_file: Read file, args: "filename": "<filename>"
7. write_to_file: Write to file, args: "filename": "<filename>", "text": "<text>"
8. ob_sql_executor: "Execute SQL in OB Database.", args: "sql": "<sql>"
Resources:
1. Internet access for searches and information gathering.
2. Long Term memory management.
3. vicuna powered Agents for delegation of simple tasks.
Performance Evaluation:
1. Continuously review and analyze your actions to ensure you are performing to the best of your abilities.
2. Constructively self-criticize your big-picture behavior constantly.
3. Reflect on past decisions and strategies to refine your approach.
4. Every command has a cost, so be smart and efficient. Aim to complete tasks in the least number of steps.
5. Write all code to a file.
You should only respond in JSON format as described below
8. db_sql_executor: "Execute SQL in Database.", args: "sql": "<sql>"
You should only respond in JSON format as described below and ensure the response can be parsed by Python json.loads
Response Format:
{
"thoughts": {
@ -199,7 +192,6 @@ auto_dbgpt_one_shot = Conversation(
}
}
}
Ensure the response can be parsed by Python json.loads
"""
),
(
@ -207,16 +199,16 @@ auto_dbgpt_one_shot = Conversation(
"""
{
"thoughts": {
"text": "To answer how many users by query database we need to write SQL query to get the count of the distinct users from the database. We can use ob_sql_executor command to execute the SQL query in database.",
"text": "To answer how many users by query database we need to write SQL query to get the count of the distinct users from the database. We can use db_sql_executor command to execute the SQL query in database.",
"reasoning": "We can use the sql_executor command to execute the SQL query for getting count of distinct users from the users database. We can select the count of the distinct users from the users table.",
"plan": "- Write SQL query to get count of distinct users from users database\n- Use ob_sql_executor to execute the SQL query in OB database\n- Parse the SQL result to get the count\n- Respond with the count as the answer",
"plan": "- Write SQL query to get count of distinct users from users database\n- Use db_sql_executor to execute the SQL query in OB database\n- Parse the SQL result to get the count\n- Respond with the count as the answer",
"criticism": "None",
"speak": "To get the number of users in users, I will execute an SQL query in OB database using the ob_sql_executor command and respond with the count."
"speak": "To get the number of users in users, I will execute an SQL query in OB database using the db_sql_executor command and respond with the count."
},
"command": {
"name": "ob_sql_executor",
"name": "db_sql_executor",
"args": {
"sql": "SELECT COUNT(DISTINCT(*)) FROM users ;"
"sql": "SELECT COUNT(DISTINCT(user_name)) FROM users ;"
}
}
}
@ -249,6 +241,11 @@ conv_qa_prompt_template = """ 基于以下已知的信息, 专业、详细的回
default_conversation = conv_one_shot
conversation_sql_mode ={
"auto_execute_ai_response": "直接执行结果",
"dont_execute_ai_response": "不直接执行结果"
}
conversation_types = {
"native": "LLM原生对话",
"default_knownledge": "默认知识库对话",

View File

@ -69,11 +69,77 @@ def generate_stream(model, tokenizer, params, device,
del past_key_values
@torch.inference_mode()
def generate_output(model, tokenizer, params, device, context_len=2048, stream_interval=2):
def generate_output(model, tokenizer, params, device, context_len=4096, stream_interval=2):
"""Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py """
prompt = params["prompt"]
l_prompt = len(prompt)
temperature = float(params.get("temperature", 1.0))
max_new_tokens = int(params.get("max_new_tokens", 2048))
stop_str = params.get("stop", None)
input_ids = tokenizer(prompt).input_ids
output_ids = list(input_ids)
max_src_len = context_len - max_new_tokens - 8
input_ids = input_ids[-max_src_len:]
for i in range(max_new_tokens):
if i == 0:
out = model(
torch.as_tensor([input_ids], device=device), use_cache=True)
logits = out.logits
past_key_values = out.past_key_values
else:
attention_mask = torch.ones(
1, past_key_values[0][0].shape[-2] + 1, device=device)
out = model(input_ids=torch.as_tensor([[token]], device=device),
use_cache=True,
attention_mask=attention_mask,
past_key_values=past_key_values)
logits = out.logits
past_key_values = out.past_key_values
last_token_logits = logits[0][-1]
if device == "mps":
# Switch to CPU by avoiding some bugs in mps backend.
last_token_logits = last_token_logits.float().to("cpu")
if temperature < 1e-4:
token = int(torch.argmax(last_token_logits))
else:
probs = torch.softmax(last_token_logits / temperature, dim=-1)
token = int(torch.multinomial(probs, num_samples=1))
output_ids.append(token)
if token == tokenizer.eos_token_id:
stopped = True
else:
stopped = False
if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped:
output = tokenizer.decode(output_ids, skip_special_tokens=True)
pos = output.rfind(stop_str, l_prompt)
if pos != -1:
output = output[:pos]
stopped = True
return output
if stopped:
break
del past_key_values
@torch.inference_mode()
def generate_output_ex(model, tokenizer, params, device, context_len=2048, stream_interval=2):
prompt = params["prompt"]
temperature = float(params.get("temperature", 1.0))
max_new_tokens = int(params.get("max_new_tokens", 2048))
stop_parameter = params.get("stop", None)
if stop_parameter == tokenizer.eos_token:
stop_parameter = None
stop_strings = []

View File

@ -7,10 +7,10 @@ from pathlib import Path
import distro
import yaml
from pilot.configs.config import Config
from pilot.prompts.prompt import build_default_prompt_generator
from pilot.prompts.prompt import build_default_prompt_generator, DEFAULT_PROMPT_OHTER
class FirstPrompt:
class AutoModePrompt:
"""
"""
@ -35,10 +35,56 @@ class FirstPrompt:
self.prompt_generator = None
self.command_registry = None
def construct_follow_up_prompt(
self,
user_input:[str],
last_auto_return: str = None,
prompt_generator: Optional[PromptGenerator] = None
)-> str:
"""
基于用户输入的后续对话信息构建完整的prompt信息
Args:
self:
prompt_generator:
Returns:
"""
prompt_start = (
DEFAULT_PROMPT_OHTER
)
if prompt_generator is None:
prompt_generator = build_default_prompt_generator()
prompt_generator.goals = user_input
prompt_generator.command_registry = self.command_registry
# 加载插件中可用命令
cfg = Config()
for plugin in cfg.plugins:
if not plugin.can_handle_post_prompt():
continue
prompt_generator = plugin.post_prompt(prompt_generator)
full_prompt = f"{prompt_start}\n\nGOALS:\n\n"
if not self.ai_goals :
self.ai_goals = user_input
for i, goal in enumerate(self.ai_goals):
full_prompt += f"{i+1}.根据提供的Schema信息, {goal}\n"
# if last_auto_return == None:
# full_prompt += f"{cfg.last_plugin_return}\n\n"
# else:
# full_prompt += f"{last_auto_return}\n\n"
full_prompt += f"Constraints:\n\n{DEFAULT_TRIGGERING_PROMPT}\n"
full_prompt += """Based on the above definition, answer the current goal and ensure that the response meets both the current constraints and the above definition and constraints"""
self.prompt_generator = prompt_generator
return full_prompt
def construct_first_prompt(
self,
fisrt_message: [str]=[],
db_schemes: str=None,
prompt_generator: Optional[PromptGenerator] = None
) -> str:
"""
@ -56,10 +102,6 @@ class FirstPrompt:
" simple strategies with no legal complications."
""
)
if prompt_generator is None:
prompt_generator = build_default_prompt_generator()
prompt_generator.goals = fisrt_message
@ -83,11 +125,14 @@ class FirstPrompt:
# Construct full prompt
full_prompt = f"{prompt_start}\n\nGOALS:\n\n"
if not self.ai_goals :
self.ai_goals = fisrt_message
for i, goal in enumerate(self.ai_goals):
full_prompt += f"{i+1}. {goal}\n"
full_prompt += f"{i+1}.根据提供的Schema信息,{goal}\n"
if db_schemes:
full_prompt += f"\nSchema:\n\n"
full_prompt += f"{db_schemes}"
# if self.api_budget > 0.0:
# full_prompt += f"\nIt takes money to let you run. Your API budget is ${self.api_budget:.3f}"
self.prompt_generator = prompt_generator

View File

@ -149,7 +149,7 @@ class PromptGenerator:
f"Resources:\n{self._generate_numbered_list(self.resources)}\n\n"
"Performance Evaluation:\n"
f"{self._generate_numbered_list(self.performance_evaluation)}\n\n"
"You should only respond in JSON format as described below \nResponse"
f" Format: \n{formatted_response_format} \nEnsure the response can be"
" parsed by Python json.loads"
"You should only respond in JSON format as described below and ensure the"
"response can be parsed by Python json.loads \nResponse"
f" Format: \n{formatted_response_format}"
)

View File

@ -6,9 +6,12 @@ from pilot.prompts.generator import PromptGenerator
CFG = Config()
DEFAULT_TRIGGERING_PROMPT = (
"Determine which next command to use, and respond using the format specified above:"
"Determine which next command to use, and respond using the format specified above"
)
DEFAULT_PROMPT_OHTER = (
"Previous response was excellent. Please response according to the requirements based on the new goal"
)
def build_default_prompt_generator() -> PromptGenerator:
"""
@ -23,27 +26,37 @@ def build_default_prompt_generator() -> PromptGenerator:
prompt_generator = PromptGenerator()
# Add constraints to the PromptGenerator object
prompt_generator.add_constraint(
"~4000 word limit for short term memory. Your short term memory is short, so"
" immediately save important information to files."
)
# prompt_generator.add_constraint(
# "~4000 word limit for short term memory. Your short term memory is short, so"
# " immediately save important information to files."
# )
prompt_generator.add_constraint(
"If you are unsure how you previously did something or want to recall past"
" events, thinking about similar events will help you remember."
)
prompt_generator.add_constraint("No user assistance")
# prompt_generator.add_constraint("No user assistance")
prompt_generator.add_constraint(
'Only output one correct JSON response at a time'
)
prompt_generator.add_constraint(
'Exclusively use the commands listed in double quotes e.g. "command name"'
)
prompt_generator.add_constraint(
'If there is SQL in the args parameter, ensure to use the database and table definitions in Schema, and ensure that the fields and table names are in the definition'
)
prompt_generator.add_constraint(
'The generated command args need to comply with the definition of the command'
)
# Add resources to the PromptGenerator object
prompt_generator.add_resource(
"Internet access for searches and information gathering."
)
prompt_generator.add_resource("Long Term memory management.")
prompt_generator.add_resource(
"GPT-3.5 powered Agents for delegation of simple tasks."
)
# prompt_generator.add_resource(
# "Internet access for searches and information gathering."
# )
# prompt_generator.add_resource("Long Term memory management.")
# prompt_generator.add_resource(
# "DB-GPT powered Agents for delegation of simple tasks."
# )
# prompt_generator.add_resource("File output.")
# Add performance evaluations to the PromptGenerator object
@ -57,9 +70,9 @@ def build_default_prompt_generator() -> PromptGenerator:
prompt_generator.add_performance_evaluation(
"Reflect on past decisions and strategies to refine your approach."
)
prompt_generator.add_performance_evaluation(
"Every command has a cost, so be smart and efficient. Aim to complete tasks in"
" the least number of steps."
)
# prompt_generator.add_performance_evaluation(
# "Every command has a cost, so be smart and efficient. Aim to complete tasks in"
# " the least number of steps."
# )
# prompt_generator.add_performance_evaluation("Write all code to a file.")
return prompt_generator

View File

@ -36,8 +36,8 @@ class PromptRequest(BaseModel):
prompt: str
temperature: float
max_new_tokens: int
stop: Optional[List[str]] = None
model: str
stop: str = None
class StreamRequest(BaseModel):
model: str
@ -101,11 +101,17 @@ def generate(prompt_request: PromptRequest):
"stop": prompt_request.stop
}
print("Receive prompt: ", params["prompt"])
output = generate_output(model, tokenizer, params, DEVICE)
print("Output: ", output)
return {"response": output}
response = []
rsp_str = ""
output = generate_stream_gate(params)
for rsp in output:
# rsp = rsp.decode("utf-8")
rsp_str = str(rsp, "utf-8")
print("[TEST: output]:", rsp_str)
response.append(rsp_str)
return {"response": rsp_str}
@app.post("/embedding")
def embeddings(prompt_request: EmbeddingRequest):

View File

@ -20,12 +20,16 @@ from pilot.configs.model_config import LOGDIR, VICUNA_MODEL_SERVER, LLM_MODEL, D
from pilot.plugins import scan_plugins
from pilot.configs.config import Config
from pilot.commands.command_mange import CommandRegistry
from pilot.prompts.first_conversation_prompt import FirstPrompt
from pilot.prompts.auto_mode_prompt import AutoModePrompt
from pilot.prompts.generator import PromptGenerator
from pilot.commands.exception_not_commands import NotCommands
from pilot.conversation import (
default_conversation,
conv_templates,
conversation_types,
conversation_sql_mode,
SeparatorStyle
)
@ -152,11 +156,13 @@ def post_process_code(code):
code = sep.join(blocks)
return code
def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.Request):
# MOCk
autogpt = True
def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request):
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
print("AUTO DB-GPT模式.")
if sql_mode == conversation_sql_mode["dont_execute_ai_response"]:
print("标准DB-GPT模式.")
print("是否是AUTO-GPT模式.", autogpt)
start_tstamp = time.time()
model_name = LLM_MODEL
@ -167,28 +173,26 @@ def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
return
cfg = Config()
auto_prompt = AutoModePrompt()
auto_prompt.command_registry = cfg.command_registry
# TODO when tab mode is AUTO_GPT, Prompt need to rebuild.
if len(state.messages) == state.offset + 2:
query = state.messages[-2][1]
# 第一轮对话需要加入提示Prompt
cfg = Config()
first_prompt = FirstPrompt()
first_prompt.command_registry = cfg.command_registry
if(autogpt):
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
# autogpt模式的第一轮对话需要 构建专属prompt
system_prompt = first_prompt.construct_first_prompt(fisrt_message=[query])
system_prompt = auto_prompt.construct_first_prompt(fisrt_message=[query], db_schemes= gen_sqlgen_conversation(dbname))
logger.info("[TEST]:" + system_prompt)
template_name = "auto_dbgpt_one_shot"
new_state = conv_templates[template_name].copy()
new_state.append_message(role='USER', message=system_prompt)
# new_state.append_message(new_state.roles[0], query)
new_state.append_message(new_state.roles[1], None)
else:
template_name = "conv_one_shot"
new_state = conv_templates[template_name].copy()
new_state.conv_id = uuid.uuid4().hex
if not autogpt:
# prompt 中添加上下文提示, 根据已有知识对话, 上下文提示是否也应该放在第一轮, 还是每一轮都添加上下文?
# 如果用户侧的问题跨度很大, 应该每一轮都加提示。
if db_selector:
@ -198,7 +202,20 @@ def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.
new_state.append_message(new_state.roles[0], query)
new_state.append_message(new_state.roles[1], None)
new_state.conv_id = uuid.uuid4().hex
state = new_state
else:
### 后续对话
query = state.messages[-2][1]
# 第一轮对话需要加入提示Prompt
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
## 获取最后一次插件的返回
follow_up_prompt = auto_prompt.construct_follow_up_prompt([query])
state.messages[0][0] = ""
state.messages[0][1] = ""
state.messages[-2][1] = follow_up_prompt
if mode == conversation_types["default_knownledge"] and not db_selector:
query = state.messages[-2][1]
knqa = KnownLedgeBaseQA()
@ -219,50 +236,101 @@ def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.
}
logger.info(f"Requert: \n{payload}")
state.messages[-1][-1] = ""
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
response = requests.post(urljoin(VICUNA_MODEL_SERVER, "generate"),
headers=headers, json=payload, timeout=120)
try:
# Stream output
response = requests.post(urljoin(VICUNA_MODEL_SERVER, "generate_stream"),
headers=headers, json=payload, stream=True, timeout=20)
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
if data["error_code"] == 0:
output = data["text"][skip_echo_len:].strip()
output = post_process_code(output)
state.messages[-1][-1] = output + ""
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
print(response.json())
print(str(response))
try:
# response = """{"thoughts":{"text":"thought","reasoning":"reasoning","plan":"- short bulleted\n- list that conveys\n- long-term plan","criticism":"constructive self-criticism","speak":"thoughts summary to say to user"},"command":{"name":"db_sql_executor","args":{"sql":"select count(*) as user_count from users u where create_time >= DATE_SUB(NOW(), INTERVAL 1 MONTH);"}}}"""
# response = response.replace("\n", "\\n")
# response = """{"thoughts":{"text":"In order to get the number of users who have grown in the last three days, I need to analyze the create\_time of each user and see if it is within the last three days. I will use the SQL query to filter the users who have created their account in the last three days.","reasoning":"I can use the SQL query to filter the users who have created their account in the last three days. I will get the current date and then subtract three days from it, and then use this as the filter for the query. This will give me the number of users who have created their account in the last three days.","plan":"- Get the current date and subtract three days from it\n- Use the SQL query to filter the users who have created their account in the last three days\n- Count the number of users who match the filter to get the number of users who have grown in the last three days","criticism":"None"},"command":{"name":"db_sql_executor","args":{"sql":"SELECT COUNT(DISTINCT(ID)) FROM users WHERE create_time >= DATE_SUB(NOW(), INTERVAL 3 DAY);"}}}"""
# response = response.replace("\n", "\\)
text = response.text.strip()
text = text.rstrip()
respObj = json.loads(text)
xx = respObj['response']
xx = xx.strip(b'\x00'.decode())
respObj_ex = json.loads(xx)
if respObj_ex['error_code'] == 0:
ai_response = None
all_text = respObj_ex['text']
### 解析返回文本获取AI回复部分
tmpResp = all_text.split(state.sep)
last_index = -1
for i in range(len(tmpResp)):
if tmpResp[i].find('ASSISTANT:') != -1:
last_index = i
ai_response = tmpResp[last_index]
ai_response = ai_response.replace("ASSISTANT:", "")
ai_response = ai_response.replace("\n", "")
ai_response = ai_response.replace("\_", "_")
print(ai_response)
if ai_response == None:
state.messages[-1][-1] = "ASSISTANT未能正确回复回复结果为:\n" + all_text
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
else:
output = data["text"] + f" (error_code: {data['error_code']})"
state.messages[-1][-1] = output
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
return
plugin_resp = execute_ai_response_json(auto_prompt.prompt_generator, ai_response)
cfg.set_last_plugin_return(plugin_resp)
print(plugin_resp)
state.messages[-1][-1] = "Model推理信息:\n"+ ai_response +"\n\nDB-GPT执行结果:\n" + plugin_resp
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
except NotCommands as e:
print("命令执行:" + e.message)
state.messages[-1][-1] = "命令执行:" + e.message +"\n模型输出:\n" + str(ai_response)
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
else:
# 流式输出
state.messages[-1][-1] = ""
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
except requests.exceptions.RequestException as e:
state.messages[-1][-1] = server_error_msg + f" (error_code: 4)"
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
return
try:
# Stream output
response = requests.post(urljoin(VICUNA_MODEL_SERVER, "generate_stream"),
headers=headers, json=payload, stream=True, timeout=20)
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
if data["error_code"] == 0:
output = data["text"][skip_echo_len:].strip()
output = post_process_code(output)
state.messages[-1][-1] = output + ""
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
else:
output = data["text"] + f" (error_code: {data['error_code']})"
state.messages[-1][-1] = output
yield (state, state.to_gradio_chatbot()) + (
disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
return
state.messages[-1][-1] = state.messages[-1][-1][:-1]
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
except requests.exceptions.RequestException as e:
state.messages[-1][-1] = server_error_msg + f" (error_code: 4)"
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
return
# 记录运行日志
finish_tstamp = time.time()
logger.info(f"{output}")
state.messages[-1][-1] = state.messages[-1][-1][:-1]
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
# 记录运行日志
finish_tstamp = time.time()
logger.info(f"{output}")
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name,
"start": round(start_tstamp, 4),
"finish": round(start_tstamp, 4),
"state": state.dict(),
"ip": request.client.host,
}
fout.write(json.dumps(data) + "\n")
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name,
"start": round(start_tstamp, 4),
"finish": round(start_tstamp, 4),
"state": state.dict(),
"ip": request.client.host,
}
fout.write(json.dumps(data) + "\n")
block_css = (
code_highlight_css
@ -280,6 +348,12 @@ block_css = (
"""
)
def change_sql_mode(sql_mode):
if sql_mode in ["直接执行结果"]:
return gr.update(visible=True)
else:
return gr.update(visible=False)
def change_mode(mode):
if mode in ["默认知识库对话", "LLM原生对话"]:
return gr.update(visible=False)
@ -316,8 +390,8 @@ def build_single_model_ui():
max_output_tokens = gr.Slider(
minimum=0,
maximum=4096,
value=2048,
maximum=1024,
value=1024,
step=64,
interactive=True,
label="最大输出Token数",
@ -333,7 +407,11 @@ def build_single_model_ui():
choices=dbs,
value=dbs[0] if len(models) > 0 else "",
interactive=True,
show_label=True).style(container=False)
show_label=True).style(container=False)
sql_mode = gr.Radio(["直接执行结果", "不执行结果"], show_label=False, value="不执行结果")
sql_vs_setting = gr.Markdown("自动执行模式下, DB-GPT可以具备执行SQL、从网络读取知识自动化存储学习的能力")
sql_mode.change(fn=change_sql_mode, inputs=sql_mode, outputs=sql_vs_setting)
tab_auto = gr.TabItem("AUTO-GPT", elem_id="auto")
with tab_auto:
gr.Markdown("自动执行模式下, DB-GPT可以具备执行SQL、从网络读取知识自动化存储学习的能力")
@ -383,7 +461,7 @@ def build_single_model_ui():
btn_list = [regenerate_btn, clear_btn]
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
http_bot,
[state, mode, db_selector, temperature, max_output_tokens],
[state, mode, sql_mode, db_selector, temperature, max_output_tokens],
[state, chatbot] + btn_list,
)
clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
@ -392,7 +470,7 @@ def build_single_model_ui():
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
).then(
http_bot,
[state, mode, db_selector, temperature, max_output_tokens],
[state, mode, sql_mode, db_selector, temperature, max_output_tokens],
[state, chatbot] + btn_list,
)
@ -400,7 +478,7 @@ def build_single_model_ui():
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
).then(
http_bot,
[state, mode, db_selector, temperature, max_output_tokens],
[state, mode, sql_mode, db_selector, temperature, max_output_tokens],
[state, chatbot] + btn_list
)

Binary file not shown.