diff --git a/pilot/model/llm_out/chatglm_llm.py b/pilot/model/llm_out/chatglm_llm.py index b451e910c..59032f9e8 100644 --- a/pilot/model/llm_out/chatglm_llm.py +++ b/pilot/model/llm_out/chatglm_llm.py @@ -29,6 +29,7 @@ def chatglm_generate_stream( generate_kwargs["temperature"] = temperature # TODO, Fix this + print(prompt) messages = prompt.split(stop) # # # Add history conversation diff --git a/pilot/out_parser/base.py b/pilot/out_parser/base.py index 7a61c788f..663b87a7d 100644 --- a/pilot/out_parser/base.py +++ b/pilot/out_parser/base.py @@ -46,8 +46,26 @@ class BaseOutputParser(ABC): code = sep.join(blocks) return code + def parse_model_stream_resp_ex(self, chunk, skip_echo_len): + data = json.loads(chunk.decode()) + + """ TODO Multi mode output handler, rewrite this for multi model, use adapter mode. + """ + if data["error_code"] == 0: + if CFG.LLM_MODEL in ["vicuna-13b", "guanaco"]: + + output = data["text"][skip_echo_len:].strip() + else: + output = data["text"].strip() + + output = self.__post_process_code(output) + return output + else: + output = data["text"] + f" (error_code: {data['error_code']})" + return output + # TODO 后续和模型绑定 - def _parse_model_stream_resp(self, response, sep: str, skip_echo_len): + def parse_model_stream_resp(self, response, skip_echo_len): for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): if chunk: @@ -56,7 +74,7 @@ class BaseOutputParser(ABC): """ TODO Multi mode output handler, rewrite this for multi model, use adapter mode. """ if data["error_code"] == 0: - if CFG.LLM_MODEL in ["vicuna", "guanaco"]: + if "vicuna" in CFG.LLM_MODEL or "guanaco" in CFG.LLM_MODEL: output = data["text"][skip_echo_len:].strip() else: output = data["text"].strip() @@ -69,7 +87,7 @@ class BaseOutputParser(ABC): ) yield output - def _parse_model_nostream_resp(self, response, sep: str): + def parse_model_nostream_resp(self, response, sep: str): text = response.text.strip() text = text.rstrip() text = text.lower() @@ -96,19 +114,6 @@ class BaseOutputParser(ABC): else: raise ValueError("Model server error!code=" + respObj_ex["error_code"]) - def parse_model_server_out(self, response, skip_echo_len: int = 0): - """ - parse the model server http response - Args: - response: - - Returns: - - """ - if not self.is_stream_out: - return self._parse_model_nostream_resp(response, self.sep) - else: - return self._parse_model_stream_resp(response, self.sep) def parse_prompt_response(self, model_out_text) -> T: """ diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index e1d178a12..8c8dba501 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -129,7 +129,7 @@ class BaseChat(ABC): def stream_call(self): payload = self.__call_base() - skip_echo_len = len(payload.get('prompt').replace("", " ")) + 1 + self.skip_echo_len = len(payload.get('prompt').replace("", " ")) + 11 logger.info(f"Requert: \n{payload}") ai_response_text = "" try: @@ -138,14 +138,16 @@ class BaseChat(ABC): urljoin(CFG.MODEL_SERVER, "generate_stream"), headers=headers, json=payload, + stream=True, timeout=120, ) + return response; - ai_response_text = self.prompt_template.output_parser.parse_model_server_out(response, skip_echo_len) + # yield self.prompt_template.output_parser.parse_model_stream_resp(response, skip_echo_len) - for resp_text_trunck in ai_response_text: - show_info = resp_text_trunck - yield resp_text_trunck + "▌" + # for resp_text_trunck in ai_response_text: + # show_info = resp_text_trunck + # yield resp_text_trunck + "▌" self.current_message.add_ai_message(show_info) @@ -173,7 +175,7 @@ class BaseChat(ABC): ### output parse ai_response_text = ( - self.prompt_template.output_parser.parse_model_server_out(response) + self.prompt_template.output_parser.parse_model_nostream_resp(response, self.prompt_template.sep) ) self.current_message.add_ai_message(ai_response_text) prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text) diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index e51ef792c..6b385d9da 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -48,7 +48,6 @@ from pilot.scene.base import ChatScene from pilot.scene.chat_factory import ChatFactory from pilot.language.translation_handler import get_lang_text - # 加载插件 CFG = Config() logger = build_logger("webserver", LOGDIR + "webserver.log") @@ -78,7 +77,6 @@ DB_SETTINGS = { "port": CFG.LOCAL_DB_PORT, } - llm_native_dialogue = get_lang_text("knowledge_qa_type_llm_native_dialogue") default_knowledge_base_dialogue = get_lang_text( "knowledge_qa_type_default_knowledge_base_dialogue" @@ -111,7 +109,7 @@ def gen_sqlgen_conversation(dbname): db_connect = CFG.local_db.get_session(dbname) schemas = CFG.local_db.table_simple_info(db_connect) for s in schemas: - message += s+ ";" + message += s + ";" return get_lang_text("sql_schema_info").format(dbname, message) @@ -209,8 +207,8 @@ def post_process_code(code): def get_chat_mode(selected, param=None) -> ChatScene: if chat_mode_title['chat_use_plugin'] == selected: return ChatScene.ChatExecution - elif chat_mode_title['knowledge_qa'] == selected: - mode= param + elif chat_mode_title['knowledge_qa'] == selected: + mode = param if mode == conversation_types["default_knownledge"]: return ChatScene.ChatKnowledge elif mode == conversation_types["custome"]: @@ -220,7 +218,7 @@ def get_chat_mode(selected, param=None) -> ChatScene: else: return ChatScene.ChatNormal else: - sql_mode= param + sql_mode = param if sql_mode == conversation_sql_mode["auto_execute_ai_response"]: return ChatScene.ChatWithDbExecute else: @@ -234,10 +232,11 @@ def chatbot_callback(state, message): def http_bot( - state, selected, temperature, max_new_tokens, plugin_selector, mode, sql_mode, db_selector, url_input, knowledge_name + state, selected, temperature, max_new_tokens, plugin_selector, mode, sql_mode, db_selector, url_input, + knowledge_name ): - - logger.info(f"User message send!{state.conv_id},{selected},{plugin_selector},{mode},{sql_mode},{db_selector},{url_input}") + logger.info( + f"User message send!{state.conv_id},{selected},{plugin_selector},{mode},{sql_mode},{db_selector},{url_input}") if chat_mode_title['knowledge_qa'] == selected: scene: ChatScene = get_chat_mode(selected, mode) elif chat_mode_title['chat_use_plugin'] == selected: @@ -312,10 +311,11 @@ def http_bot( else: logger.info("stream out start!") try: - stream_gen = chat.stream_call() - for msg in stream_gen: - state.messages[-1][-1] = msg - yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 + response = chat.stream_call() + for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): + if chunk: + state.messages[-1][-1] = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk,chat.skip_echo_len) + yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 except Exception as e: print(traceback.format_exc()) state.messages[-1][-1] = "Error:" + str(e) @@ -323,8 +323,8 @@ def http_bot( block_css = ( - code_highlight_css - + """ + code_highlight_css + + """ pre { white-space: pre-wrap; /* Since CSS 2.1 */ white-space: -moz-pre-wrap; /* Mozilla, since 1999 */ @@ -353,7 +353,6 @@ def change_mode(mode): return gr.update(visible=False) - def build_single_model_ui(): notice_markdown = get_lang_text("db_gpt_introduction") learn_more_markdown = get_lang_text("learn_more_markdown") @@ -362,7 +361,7 @@ def build_single_model_ui(): gr.Markdown(notice_markdown, elem_id="notice_markdown") with gr.Accordion( - get_lang_text("model_control_param"), open=False, visible=False + get_lang_text("model_control_param"), open=False, visible=False ) as parameter_row: temperature = gr.Slider( minimum=0.0, @@ -458,13 +457,14 @@ def build_single_model_ui(): mode.change(fn=change_mode, inputs=mode, outputs=vs_setting) url_input = gr.Textbox(label=get_lang_text("url_input_label"), lines=1, interactive=True, visible=False) - def show_url_input(evt:gr.SelectData): + + def show_url_input(evt: gr.SelectData): if evt.value == url_knowledge_dialogue: return gr.update(visible=True) else: return gr.update(visible=False) - mode.select(fn=show_url_input, inputs=None, outputs=url_input) + mode.select(fn=show_url_input, inputs=None, outputs=url_input) with vs_setting: vs_name = gr.Textbox( @@ -516,7 +516,6 @@ def build_single_model_ui(): params = [plugin_selected, mode, sql_mode, db_selector, url_input, vs_name] - btn_list = [regenerate_btn, clear_btn] regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then( http_bot, @@ -529,7 +528,7 @@ def build_single_model_ui(): add_text, [state, textbox], [state, chatbot, textbox] + btn_list ).then( http_bot, - [state, selected, temperature, max_output_tokens]+ params, + [state, selected, temperature, max_output_tokens] + params, [state, chatbot] + btn_list, ) @@ -537,7 +536,7 @@ def build_single_model_ui(): add_text, [state, textbox], [state, chatbot, textbox] + btn_list ).then( http_bot, - [state, selected, temperature, max_output_tokens]+ params, + [state, selected, temperature, max_output_tokens] + params, [state, chatbot] + btn_list, ) vs_add.click( @@ -560,10 +559,10 @@ def build_single_model_ui(): def build_webdemo(): with gr.Blocks( - title=get_lang_text("database_smart_assistant"), - # theme=gr.themes.Base(), - theme=gr.themes.Default(), - css=block_css, + title=get_lang_text("database_smart_assistant"), + # theme=gr.themes.Base(), + theme=gr.themes.Default(), + css=block_css, ) as demo: url_params = gr.JSON(visible=False) (