model server fix message model

This commit is contained in:
yhjun1026 2023-06-01 18:03:04 +08:00
parent 661a7b5697
commit 084656245e
4 changed files with 55 additions and 48 deletions

View File

@ -29,6 +29,7 @@ def chatglm_generate_stream(
generate_kwargs["temperature"] = temperature generate_kwargs["temperature"] = temperature
# TODO, Fix this # TODO, Fix this
print(prompt)
messages = prompt.split(stop) messages = prompt.split(stop)
# #
# # Add history conversation # # Add history conversation

View File

@ -46,8 +46,26 @@ class BaseOutputParser(ABC):
code = sep.join(blocks) code = sep.join(blocks)
return code return code
def parse_model_stream_resp_ex(self, chunk, skip_echo_len):
data = json.loads(chunk.decode())
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
"""
if data["error_code"] == 0:
if CFG.LLM_MODEL in ["vicuna-13b", "guanaco"]:
output = data["text"][skip_echo_len:].strip()
else:
output = data["text"].strip()
output = self.__post_process_code(output)
return output
else:
output = data["text"] + f" (error_code: {data['error_code']})"
return output
# TODO 后续和模型绑定 # TODO 后续和模型绑定
def _parse_model_stream_resp(self, response, sep: str, skip_echo_len): def parse_model_stream_resp(self, response, skip_echo_len):
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk: if chunk:
@ -56,7 +74,7 @@ class BaseOutputParser(ABC):
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode. """ TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
""" """
if data["error_code"] == 0: if data["error_code"] == 0:
if CFG.LLM_MODEL in ["vicuna", "guanaco"]: if "vicuna" in CFG.LLM_MODEL or "guanaco" in CFG.LLM_MODEL:
output = data["text"][skip_echo_len:].strip() output = data["text"][skip_echo_len:].strip()
else: else:
output = data["text"].strip() output = data["text"].strip()
@ -69,7 +87,7 @@ class BaseOutputParser(ABC):
) )
yield output yield output
def _parse_model_nostream_resp(self, response, sep: str): def parse_model_nostream_resp(self, response, sep: str):
text = response.text.strip() text = response.text.strip()
text = text.rstrip() text = text.rstrip()
text = text.lower() text = text.lower()
@ -96,19 +114,6 @@ class BaseOutputParser(ABC):
else: else:
raise ValueError("Model server error!code=" + respObj_ex["error_code"]) raise ValueError("Model server error!code=" + respObj_ex["error_code"])
def parse_model_server_out(self, response, skip_echo_len: int = 0):
"""
parse the model server http response
Args:
response:
Returns:
"""
if not self.is_stream_out:
return self._parse_model_nostream_resp(response, self.sep)
else:
return self._parse_model_stream_resp(response, self.sep)
def parse_prompt_response(self, model_out_text) -> T: def parse_prompt_response(self, model_out_text) -> T:
""" """

View File

@ -129,7 +129,7 @@ class BaseChat(ABC):
def stream_call(self): def stream_call(self):
payload = self.__call_base() payload = self.__call_base()
skip_echo_len = len(payload.get('prompt').replace("</s>", " ")) + 1 self.skip_echo_len = len(payload.get('prompt').replace("</s>", " ")) + 11
logger.info(f"Requert: \n{payload}") logger.info(f"Requert: \n{payload}")
ai_response_text = "" ai_response_text = ""
try: try:
@ -138,14 +138,16 @@ class BaseChat(ABC):
urljoin(CFG.MODEL_SERVER, "generate_stream"), urljoin(CFG.MODEL_SERVER, "generate_stream"),
headers=headers, headers=headers,
json=payload, json=payload,
stream=True,
timeout=120, timeout=120,
) )
return response;
ai_response_text = self.prompt_template.output_parser.parse_model_server_out(response, skip_echo_len) # yield self.prompt_template.output_parser.parse_model_stream_resp(response, skip_echo_len)
for resp_text_trunck in ai_response_text: # for resp_text_trunck in ai_response_text:
show_info = resp_text_trunck # show_info = resp_text_trunck
yield resp_text_trunck + "" # yield resp_text_trunck + "▌"
self.current_message.add_ai_message(show_info) self.current_message.add_ai_message(show_info)
@ -173,7 +175,7 @@ class BaseChat(ABC):
### output parse ### output parse
ai_response_text = ( ai_response_text = (
self.prompt_template.output_parser.parse_model_server_out(response) self.prompt_template.output_parser.parse_model_nostream_resp(response, self.prompt_template.sep)
) )
self.current_message.add_ai_message(ai_response_text) self.current_message.add_ai_message(ai_response_text)
prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text) prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text)

View File

@ -48,7 +48,6 @@ from pilot.scene.base import ChatScene
from pilot.scene.chat_factory import ChatFactory from pilot.scene.chat_factory import ChatFactory
from pilot.language.translation_handler import get_lang_text from pilot.language.translation_handler import get_lang_text
# 加载插件 # 加载插件
CFG = Config() CFG = Config()
logger = build_logger("webserver", LOGDIR + "webserver.log") logger = build_logger("webserver", LOGDIR + "webserver.log")
@ -78,7 +77,6 @@ DB_SETTINGS = {
"port": CFG.LOCAL_DB_PORT, "port": CFG.LOCAL_DB_PORT,
} }
llm_native_dialogue = get_lang_text("knowledge_qa_type_llm_native_dialogue") llm_native_dialogue = get_lang_text("knowledge_qa_type_llm_native_dialogue")
default_knowledge_base_dialogue = get_lang_text( default_knowledge_base_dialogue = get_lang_text(
"knowledge_qa_type_default_knowledge_base_dialogue" "knowledge_qa_type_default_knowledge_base_dialogue"
@ -111,7 +109,7 @@ def gen_sqlgen_conversation(dbname):
db_connect = CFG.local_db.get_session(dbname) db_connect = CFG.local_db.get_session(dbname)
schemas = CFG.local_db.table_simple_info(db_connect) schemas = CFG.local_db.table_simple_info(db_connect)
for s in schemas: for s in schemas:
message += s+ ";" message += s + ";"
return get_lang_text("sql_schema_info").format(dbname, message) return get_lang_text("sql_schema_info").format(dbname, message)
@ -209,8 +207,8 @@ def post_process_code(code):
def get_chat_mode(selected, param=None) -> ChatScene: def get_chat_mode(selected, param=None) -> ChatScene:
if chat_mode_title['chat_use_plugin'] == selected: if chat_mode_title['chat_use_plugin'] == selected:
return ChatScene.ChatExecution return ChatScene.ChatExecution
elif chat_mode_title['knowledge_qa'] == selected: elif chat_mode_title['knowledge_qa'] == selected:
mode= param mode = param
if mode == conversation_types["default_knownledge"]: if mode == conversation_types["default_knownledge"]:
return ChatScene.ChatKnowledge return ChatScene.ChatKnowledge
elif mode == conversation_types["custome"]: elif mode == conversation_types["custome"]:
@ -220,7 +218,7 @@ def get_chat_mode(selected, param=None) -> ChatScene:
else: else:
return ChatScene.ChatNormal return ChatScene.ChatNormal
else: else:
sql_mode= param sql_mode = param
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]: if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
return ChatScene.ChatWithDbExecute return ChatScene.ChatWithDbExecute
else: else:
@ -234,10 +232,11 @@ def chatbot_callback(state, message):
def http_bot( def http_bot(
state, selected, temperature, max_new_tokens, plugin_selector, mode, sql_mode, db_selector, url_input, knowledge_name state, selected, temperature, max_new_tokens, plugin_selector, mode, sql_mode, db_selector, url_input,
knowledge_name
): ):
logger.info(
logger.info(f"User message send!{state.conv_id},{selected},{plugin_selector},{mode},{sql_mode},{db_selector},{url_input}") f"User message send!{state.conv_id},{selected},{plugin_selector},{mode},{sql_mode},{db_selector},{url_input}")
if chat_mode_title['knowledge_qa'] == selected: if chat_mode_title['knowledge_qa'] == selected:
scene: ChatScene = get_chat_mode(selected, mode) scene: ChatScene = get_chat_mode(selected, mode)
elif chat_mode_title['chat_use_plugin'] == selected: elif chat_mode_title['chat_use_plugin'] == selected:
@ -312,10 +311,11 @@ def http_bot(
else: else:
logger.info("stream out start!") logger.info("stream out start!")
try: try:
stream_gen = chat.stream_call() response = chat.stream_call()
for msg in stream_gen: for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
state.messages[-1][-1] = msg if chunk:
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 state.messages[-1][-1] = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk,chat.skip_echo_len)
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
except Exception as e: except Exception as e:
print(traceback.format_exc()) print(traceback.format_exc())
state.messages[-1][-1] = "Error:" + str(e) state.messages[-1][-1] = "Error:" + str(e)
@ -323,8 +323,8 @@ def http_bot(
block_css = ( block_css = (
code_highlight_css code_highlight_css
+ """ + """
pre { pre {
white-space: pre-wrap; /* Since CSS 2.1 */ white-space: pre-wrap; /* Since CSS 2.1 */
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */ white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
@ -353,7 +353,6 @@ def change_mode(mode):
return gr.update(visible=False) return gr.update(visible=False)
def build_single_model_ui(): def build_single_model_ui():
notice_markdown = get_lang_text("db_gpt_introduction") notice_markdown = get_lang_text("db_gpt_introduction")
learn_more_markdown = get_lang_text("learn_more_markdown") learn_more_markdown = get_lang_text("learn_more_markdown")
@ -362,7 +361,7 @@ def build_single_model_ui():
gr.Markdown(notice_markdown, elem_id="notice_markdown") gr.Markdown(notice_markdown, elem_id="notice_markdown")
with gr.Accordion( with gr.Accordion(
get_lang_text("model_control_param"), open=False, visible=False get_lang_text("model_control_param"), open=False, visible=False
) as parameter_row: ) as parameter_row:
temperature = gr.Slider( temperature = gr.Slider(
minimum=0.0, minimum=0.0,
@ -458,13 +457,14 @@ def build_single_model_ui():
mode.change(fn=change_mode, inputs=mode, outputs=vs_setting) mode.change(fn=change_mode, inputs=mode, outputs=vs_setting)
url_input = gr.Textbox(label=get_lang_text("url_input_label"), lines=1, interactive=True, visible=False) url_input = gr.Textbox(label=get_lang_text("url_input_label"), lines=1, interactive=True, visible=False)
def show_url_input(evt:gr.SelectData):
def show_url_input(evt: gr.SelectData):
if evt.value == url_knowledge_dialogue: if evt.value == url_knowledge_dialogue:
return gr.update(visible=True) return gr.update(visible=True)
else: else:
return gr.update(visible=False) return gr.update(visible=False)
mode.select(fn=show_url_input, inputs=None, outputs=url_input)
mode.select(fn=show_url_input, inputs=None, outputs=url_input)
with vs_setting: with vs_setting:
vs_name = gr.Textbox( vs_name = gr.Textbox(
@ -516,7 +516,6 @@ def build_single_model_ui():
params = [plugin_selected, mode, sql_mode, db_selector, url_input, vs_name] params = [plugin_selected, mode, sql_mode, db_selector, url_input, vs_name]
btn_list = [regenerate_btn, clear_btn] btn_list = [regenerate_btn, clear_btn]
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then( regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
http_bot, http_bot,
@ -529,7 +528,7 @@ def build_single_model_ui():
add_text, [state, textbox], [state, chatbot, textbox] + btn_list add_text, [state, textbox], [state, chatbot, textbox] + btn_list
).then( ).then(
http_bot, http_bot,
[state, selected, temperature, max_output_tokens]+ params, [state, selected, temperature, max_output_tokens] + params,
[state, chatbot] + btn_list, [state, chatbot] + btn_list,
) )
@ -537,7 +536,7 @@ def build_single_model_ui():
add_text, [state, textbox], [state, chatbot, textbox] + btn_list add_text, [state, textbox], [state, chatbot, textbox] + btn_list
).then( ).then(
http_bot, http_bot,
[state, selected, temperature, max_output_tokens]+ params, [state, selected, temperature, max_output_tokens] + params,
[state, chatbot] + btn_list, [state, chatbot] + btn_list,
) )
vs_add.click( vs_add.click(
@ -560,10 +559,10 @@ def build_single_model_ui():
def build_webdemo(): def build_webdemo():
with gr.Blocks( with gr.Blocks(
title=get_lang_text("database_smart_assistant"), title=get_lang_text("database_smart_assistant"),
# theme=gr.themes.Base(), # theme=gr.themes.Base(),
theme=gr.themes.Default(), theme=gr.themes.Default(),
css=block_css, css=block_css,
) as demo: ) as demo:
url_params = gr.JSON(visible=False) url_params = gr.JSON(visible=False)
( (