mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-09 12:18:12 +00:00
model server fix message model
This commit is contained in:
parent
661a7b5697
commit
084656245e
@ -29,6 +29,7 @@ def chatglm_generate_stream(
|
|||||||
generate_kwargs["temperature"] = temperature
|
generate_kwargs["temperature"] = temperature
|
||||||
|
|
||||||
# TODO, Fix this
|
# TODO, Fix this
|
||||||
|
print(prompt)
|
||||||
messages = prompt.split(stop)
|
messages = prompt.split(stop)
|
||||||
#
|
#
|
||||||
# # Add history conversation
|
# # Add history conversation
|
||||||
|
@ -46,8 +46,26 @@ class BaseOutputParser(ABC):
|
|||||||
code = sep.join(blocks)
|
code = sep.join(blocks)
|
||||||
return code
|
return code
|
||||||
|
|
||||||
|
def parse_model_stream_resp_ex(self, chunk, skip_echo_len):
|
||||||
|
data = json.loads(chunk.decode())
|
||||||
|
|
||||||
|
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
|
||||||
|
"""
|
||||||
|
if data["error_code"] == 0:
|
||||||
|
if CFG.LLM_MODEL in ["vicuna-13b", "guanaco"]:
|
||||||
|
|
||||||
|
output = data["text"][skip_echo_len:].strip()
|
||||||
|
else:
|
||||||
|
output = data["text"].strip()
|
||||||
|
|
||||||
|
output = self.__post_process_code(output)
|
||||||
|
return output
|
||||||
|
else:
|
||||||
|
output = data["text"] + f" (error_code: {data['error_code']})"
|
||||||
|
return output
|
||||||
|
|
||||||
# TODO 后续和模型绑定
|
# TODO 后续和模型绑定
|
||||||
def _parse_model_stream_resp(self, response, sep: str, skip_echo_len):
|
def parse_model_stream_resp(self, response, skip_echo_len):
|
||||||
|
|
||||||
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||||
if chunk:
|
if chunk:
|
||||||
@ -56,7 +74,7 @@ class BaseOutputParser(ABC):
|
|||||||
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
|
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
|
||||||
"""
|
"""
|
||||||
if data["error_code"] == 0:
|
if data["error_code"] == 0:
|
||||||
if CFG.LLM_MODEL in ["vicuna", "guanaco"]:
|
if "vicuna" in CFG.LLM_MODEL or "guanaco" in CFG.LLM_MODEL:
|
||||||
output = data["text"][skip_echo_len:].strip()
|
output = data["text"][skip_echo_len:].strip()
|
||||||
else:
|
else:
|
||||||
output = data["text"].strip()
|
output = data["text"].strip()
|
||||||
@ -69,7 +87,7 @@ class BaseOutputParser(ABC):
|
|||||||
)
|
)
|
||||||
yield output
|
yield output
|
||||||
|
|
||||||
def _parse_model_nostream_resp(self, response, sep: str):
|
def parse_model_nostream_resp(self, response, sep: str):
|
||||||
text = response.text.strip()
|
text = response.text.strip()
|
||||||
text = text.rstrip()
|
text = text.rstrip()
|
||||||
text = text.lower()
|
text = text.lower()
|
||||||
@ -96,19 +114,6 @@ class BaseOutputParser(ABC):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Model server error!code=" + respObj_ex["error_code"])
|
raise ValueError("Model server error!code=" + respObj_ex["error_code"])
|
||||||
|
|
||||||
def parse_model_server_out(self, response, skip_echo_len: int = 0):
|
|
||||||
"""
|
|
||||||
parse the model server http response
|
|
||||||
Args:
|
|
||||||
response:
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
"""
|
|
||||||
if not self.is_stream_out:
|
|
||||||
return self._parse_model_nostream_resp(response, self.sep)
|
|
||||||
else:
|
|
||||||
return self._parse_model_stream_resp(response, self.sep)
|
|
||||||
|
|
||||||
def parse_prompt_response(self, model_out_text) -> T:
|
def parse_prompt_response(self, model_out_text) -> T:
|
||||||
"""
|
"""
|
||||||
|
@ -129,7 +129,7 @@ class BaseChat(ABC):
|
|||||||
def stream_call(self):
|
def stream_call(self):
|
||||||
payload = self.__call_base()
|
payload = self.__call_base()
|
||||||
|
|
||||||
skip_echo_len = len(payload.get('prompt').replace("</s>", " ")) + 1
|
self.skip_echo_len = len(payload.get('prompt').replace("</s>", " ")) + 11
|
||||||
logger.info(f"Requert: \n{payload}")
|
logger.info(f"Requert: \n{payload}")
|
||||||
ai_response_text = ""
|
ai_response_text = ""
|
||||||
try:
|
try:
|
||||||
@ -138,14 +138,16 @@ class BaseChat(ABC):
|
|||||||
urljoin(CFG.MODEL_SERVER, "generate_stream"),
|
urljoin(CFG.MODEL_SERVER, "generate_stream"),
|
||||||
headers=headers,
|
headers=headers,
|
||||||
json=payload,
|
json=payload,
|
||||||
|
stream=True,
|
||||||
timeout=120,
|
timeout=120,
|
||||||
)
|
)
|
||||||
|
return response;
|
||||||
|
|
||||||
ai_response_text = self.prompt_template.output_parser.parse_model_server_out(response, skip_echo_len)
|
# yield self.prompt_template.output_parser.parse_model_stream_resp(response, skip_echo_len)
|
||||||
|
|
||||||
for resp_text_trunck in ai_response_text:
|
# for resp_text_trunck in ai_response_text:
|
||||||
show_info = resp_text_trunck
|
# show_info = resp_text_trunck
|
||||||
yield resp_text_trunck + "▌"
|
# yield resp_text_trunck + "▌"
|
||||||
|
|
||||||
self.current_message.add_ai_message(show_info)
|
self.current_message.add_ai_message(show_info)
|
||||||
|
|
||||||
@ -173,7 +175,7 @@ class BaseChat(ABC):
|
|||||||
|
|
||||||
### output parse
|
### output parse
|
||||||
ai_response_text = (
|
ai_response_text = (
|
||||||
self.prompt_template.output_parser.parse_model_server_out(response)
|
self.prompt_template.output_parser.parse_model_nostream_resp(response, self.prompt_template.sep)
|
||||||
)
|
)
|
||||||
self.current_message.add_ai_message(ai_response_text)
|
self.current_message.add_ai_message(ai_response_text)
|
||||||
prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text)
|
prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text)
|
||||||
|
@ -48,7 +48,6 @@ from pilot.scene.base import ChatScene
|
|||||||
from pilot.scene.chat_factory import ChatFactory
|
from pilot.scene.chat_factory import ChatFactory
|
||||||
from pilot.language.translation_handler import get_lang_text
|
from pilot.language.translation_handler import get_lang_text
|
||||||
|
|
||||||
|
|
||||||
# 加载插件
|
# 加载插件
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
logger = build_logger("webserver", LOGDIR + "webserver.log")
|
logger = build_logger("webserver", LOGDIR + "webserver.log")
|
||||||
@ -78,7 +77,6 @@ DB_SETTINGS = {
|
|||||||
"port": CFG.LOCAL_DB_PORT,
|
"port": CFG.LOCAL_DB_PORT,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
llm_native_dialogue = get_lang_text("knowledge_qa_type_llm_native_dialogue")
|
llm_native_dialogue = get_lang_text("knowledge_qa_type_llm_native_dialogue")
|
||||||
default_knowledge_base_dialogue = get_lang_text(
|
default_knowledge_base_dialogue = get_lang_text(
|
||||||
"knowledge_qa_type_default_knowledge_base_dialogue"
|
"knowledge_qa_type_default_knowledge_base_dialogue"
|
||||||
@ -111,7 +109,7 @@ def gen_sqlgen_conversation(dbname):
|
|||||||
db_connect = CFG.local_db.get_session(dbname)
|
db_connect = CFG.local_db.get_session(dbname)
|
||||||
schemas = CFG.local_db.table_simple_info(db_connect)
|
schemas = CFG.local_db.table_simple_info(db_connect)
|
||||||
for s in schemas:
|
for s in schemas:
|
||||||
message += s+ ";"
|
message += s + ";"
|
||||||
return get_lang_text("sql_schema_info").format(dbname, message)
|
return get_lang_text("sql_schema_info").format(dbname, message)
|
||||||
|
|
||||||
|
|
||||||
@ -209,8 +207,8 @@ def post_process_code(code):
|
|||||||
def get_chat_mode(selected, param=None) -> ChatScene:
|
def get_chat_mode(selected, param=None) -> ChatScene:
|
||||||
if chat_mode_title['chat_use_plugin'] == selected:
|
if chat_mode_title['chat_use_plugin'] == selected:
|
||||||
return ChatScene.ChatExecution
|
return ChatScene.ChatExecution
|
||||||
elif chat_mode_title['knowledge_qa'] == selected:
|
elif chat_mode_title['knowledge_qa'] == selected:
|
||||||
mode= param
|
mode = param
|
||||||
if mode == conversation_types["default_knownledge"]:
|
if mode == conversation_types["default_knownledge"]:
|
||||||
return ChatScene.ChatKnowledge
|
return ChatScene.ChatKnowledge
|
||||||
elif mode == conversation_types["custome"]:
|
elif mode == conversation_types["custome"]:
|
||||||
@ -220,7 +218,7 @@ def get_chat_mode(selected, param=None) -> ChatScene:
|
|||||||
else:
|
else:
|
||||||
return ChatScene.ChatNormal
|
return ChatScene.ChatNormal
|
||||||
else:
|
else:
|
||||||
sql_mode= param
|
sql_mode = param
|
||||||
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
|
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
|
||||||
return ChatScene.ChatWithDbExecute
|
return ChatScene.ChatWithDbExecute
|
||||||
else:
|
else:
|
||||||
@ -234,10 +232,11 @@ def chatbot_callback(state, message):
|
|||||||
|
|
||||||
|
|
||||||
def http_bot(
|
def http_bot(
|
||||||
state, selected, temperature, max_new_tokens, plugin_selector, mode, sql_mode, db_selector, url_input, knowledge_name
|
state, selected, temperature, max_new_tokens, plugin_selector, mode, sql_mode, db_selector, url_input,
|
||||||
|
knowledge_name
|
||||||
):
|
):
|
||||||
|
logger.info(
|
||||||
logger.info(f"User message send!{state.conv_id},{selected},{plugin_selector},{mode},{sql_mode},{db_selector},{url_input}")
|
f"User message send!{state.conv_id},{selected},{plugin_selector},{mode},{sql_mode},{db_selector},{url_input}")
|
||||||
if chat_mode_title['knowledge_qa'] == selected:
|
if chat_mode_title['knowledge_qa'] == selected:
|
||||||
scene: ChatScene = get_chat_mode(selected, mode)
|
scene: ChatScene = get_chat_mode(selected, mode)
|
||||||
elif chat_mode_title['chat_use_plugin'] == selected:
|
elif chat_mode_title['chat_use_plugin'] == selected:
|
||||||
@ -312,10 +311,11 @@ def http_bot(
|
|||||||
else:
|
else:
|
||||||
logger.info("stream out start!")
|
logger.info("stream out start!")
|
||||||
try:
|
try:
|
||||||
stream_gen = chat.stream_call()
|
response = chat.stream_call()
|
||||||
for msg in stream_gen:
|
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||||
state.messages[-1][-1] = msg
|
if chunk:
|
||||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
state.messages[-1][-1] = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk,chat.skip_echo_len)
|
||||||
|
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
state.messages[-1][-1] = "Error:" + str(e)
|
state.messages[-1][-1] = "Error:" + str(e)
|
||||||
@ -323,8 +323,8 @@ def http_bot(
|
|||||||
|
|
||||||
|
|
||||||
block_css = (
|
block_css = (
|
||||||
code_highlight_css
|
code_highlight_css
|
||||||
+ """
|
+ """
|
||||||
pre {
|
pre {
|
||||||
white-space: pre-wrap; /* Since CSS 2.1 */
|
white-space: pre-wrap; /* Since CSS 2.1 */
|
||||||
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
|
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
|
||||||
@ -353,7 +353,6 @@ def change_mode(mode):
|
|||||||
return gr.update(visible=False)
|
return gr.update(visible=False)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def build_single_model_ui():
|
def build_single_model_ui():
|
||||||
notice_markdown = get_lang_text("db_gpt_introduction")
|
notice_markdown = get_lang_text("db_gpt_introduction")
|
||||||
learn_more_markdown = get_lang_text("learn_more_markdown")
|
learn_more_markdown = get_lang_text("learn_more_markdown")
|
||||||
@ -362,7 +361,7 @@ def build_single_model_ui():
|
|||||||
gr.Markdown(notice_markdown, elem_id="notice_markdown")
|
gr.Markdown(notice_markdown, elem_id="notice_markdown")
|
||||||
|
|
||||||
with gr.Accordion(
|
with gr.Accordion(
|
||||||
get_lang_text("model_control_param"), open=False, visible=False
|
get_lang_text("model_control_param"), open=False, visible=False
|
||||||
) as parameter_row:
|
) as parameter_row:
|
||||||
temperature = gr.Slider(
|
temperature = gr.Slider(
|
||||||
minimum=0.0,
|
minimum=0.0,
|
||||||
@ -458,13 +457,14 @@ def build_single_model_ui():
|
|||||||
mode.change(fn=change_mode, inputs=mode, outputs=vs_setting)
|
mode.change(fn=change_mode, inputs=mode, outputs=vs_setting)
|
||||||
|
|
||||||
url_input = gr.Textbox(label=get_lang_text("url_input_label"), lines=1, interactive=True, visible=False)
|
url_input = gr.Textbox(label=get_lang_text("url_input_label"), lines=1, interactive=True, visible=False)
|
||||||
def show_url_input(evt:gr.SelectData):
|
|
||||||
|
def show_url_input(evt: gr.SelectData):
|
||||||
if evt.value == url_knowledge_dialogue:
|
if evt.value == url_knowledge_dialogue:
|
||||||
return gr.update(visible=True)
|
return gr.update(visible=True)
|
||||||
else:
|
else:
|
||||||
return gr.update(visible=False)
|
return gr.update(visible=False)
|
||||||
mode.select(fn=show_url_input, inputs=None, outputs=url_input)
|
|
||||||
|
|
||||||
|
mode.select(fn=show_url_input, inputs=None, outputs=url_input)
|
||||||
|
|
||||||
with vs_setting:
|
with vs_setting:
|
||||||
vs_name = gr.Textbox(
|
vs_name = gr.Textbox(
|
||||||
@ -516,7 +516,6 @@ def build_single_model_ui():
|
|||||||
|
|
||||||
params = [plugin_selected, mode, sql_mode, db_selector, url_input, vs_name]
|
params = [plugin_selected, mode, sql_mode, db_selector, url_input, vs_name]
|
||||||
|
|
||||||
|
|
||||||
btn_list = [regenerate_btn, clear_btn]
|
btn_list = [regenerate_btn, clear_btn]
|
||||||
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
|
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
|
||||||
http_bot,
|
http_bot,
|
||||||
@ -529,7 +528,7 @@ def build_single_model_ui():
|
|||||||
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
|
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
|
||||||
).then(
|
).then(
|
||||||
http_bot,
|
http_bot,
|
||||||
[state, selected, temperature, max_output_tokens]+ params,
|
[state, selected, temperature, max_output_tokens] + params,
|
||||||
[state, chatbot] + btn_list,
|
[state, chatbot] + btn_list,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -537,7 +536,7 @@ def build_single_model_ui():
|
|||||||
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
|
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
|
||||||
).then(
|
).then(
|
||||||
http_bot,
|
http_bot,
|
||||||
[state, selected, temperature, max_output_tokens]+ params,
|
[state, selected, temperature, max_output_tokens] + params,
|
||||||
[state, chatbot] + btn_list,
|
[state, chatbot] + btn_list,
|
||||||
)
|
)
|
||||||
vs_add.click(
|
vs_add.click(
|
||||||
@ -560,10 +559,10 @@ def build_single_model_ui():
|
|||||||
|
|
||||||
def build_webdemo():
|
def build_webdemo():
|
||||||
with gr.Blocks(
|
with gr.Blocks(
|
||||||
title=get_lang_text("database_smart_assistant"),
|
title=get_lang_text("database_smart_assistant"),
|
||||||
# theme=gr.themes.Base(),
|
# theme=gr.themes.Base(),
|
||||||
theme=gr.themes.Default(),
|
theme=gr.themes.Default(),
|
||||||
css=block_css,
|
css=block_css,
|
||||||
) as demo:
|
) as demo:
|
||||||
url_params = gr.JSON(visible=False)
|
url_params = gr.JSON(visible=False)
|
||||||
(
|
(
|
||||||
|
Loading…
Reference in New Issue
Block a user