diff --git a/environment.yml b/environment.yml index 3ec4dfd98..3ba070e56 100644 --- a/environment.yml +++ b/environment.yml @@ -60,3 +60,4 @@ dependencies: - gradio==3.24.1 - gradio-client==0.0.8 - wandb + - fschat=0.1.10 diff --git a/pilot/conversation.py b/pilot/conversation.py index 1ee142762..3a702c615 100644 --- a/pilot/conversation.py +++ b/pilot/conversation.py @@ -3,7 +3,7 @@ import dataclasses from enum import auto, Enum -from typing import List, Tuple, Any +from typing import List, Any class SeparatorStyle(Enum): @@ -29,12 +29,12 @@ class Conversation: def get_prompt(self): if self.sep_style == SeparatorStyle.SINGLE: - ret = self.system + ret = self.system + self.sep for role, message in self.messages: if message: - ret += self.sep + " " + role + ": " + message + ret += role + ": " + message + self.sep else: - ret += self.sep + " " + role + ":" + ret += role + ":" return ret elif self.sep_style == SeparatorStyle.TWO: @@ -56,7 +56,7 @@ class Conversation: def to_gradio_chatbot(self): ret = [] - for i, (role, msg) in enumerate(self.messages[self.offset :]): + for i, (role, msg) in enumerate(self.messages[self.offset:]): if i % 2 == 0: ret.append([msg, None]) else: @@ -133,19 +133,9 @@ conv_vicuna_v1 = Conversation( sep2="", ) -conv_template = { +default_conversation = conv_one_shot + +conv_templates = { "conv_one_shot": conv_one_shot, "vicuna_v1": conv_vicuna_v1 } - - -def get_default_conv_template(model_name: str = "vicuna-13b"): - model_name = model_name.lower() - if "vicuna" in model_name: - return conv_vicuna_v1 - return conv_one_shot - - -def compute_skip_echo_len(prompt): - skip_echo_len = len(prompt) + 1 - prompt.count("") * 3 - return skip_echo_len \ No newline at end of file diff --git a/pilot/server/vicuna_server.py b/pilot/server/vicuna_server.py index 3012d7ef1..f664410e8 100644 --- a/pilot/server/vicuna_server.py +++ b/pilot/server/vicuna_server.py @@ -82,6 +82,7 @@ async def api_generate_stream(request: Request): global_counter += 1 params = await request.json() print(model, tokenizer, params, DEVICE) + if model_semaphore is None: model_semaphore = asyncio.Semaphore(LIMIT_MODEL_CONCURRENCY) await model_semaphore.acquire() diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index 8d4cb52ad..8bd977cbb 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -14,8 +14,8 @@ from urllib.parse import urljoin from pilot.configs.model_config import LOGDIR, vicuna_model_server, LLM_MODEL from pilot.conversation import ( - get_default_conv_template, - compute_skip_echo_len, + default_conversation, + conv_templates, SeparatorStyle ) @@ -43,29 +43,6 @@ priority = { "vicuna-13b": "aaa" } -def set_global_vars(enable_moderation_): - global enable_moderation, models - enable_moderation = enable_moderation_ - -def load_demo_single(url_params): - dropdown_update = gr.Dropdown.update(visible=True) - if "model" in url_params: - model = url_params["model"] - if model in models: - dropdown_update = gr.Dropdown.update(value=model, visible=True) - - state = None - return ( - state, - dropdown_update, - gr.Chatbot.update(visible=True), - gr.Textbox.update(visible=True), - gr.Button.update(visible=True), - gr.Row.update(visible=True), - gr.Accordion.update(visible=True), - ) - - get_window_url_params = """ function() { const params = new URLSearchParams(window.location.search); @@ -78,10 +55,24 @@ function() { return url_params; } """ - def load_demo(url_params, request: gr.Request): logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}") - return load_demo_single(url_params) + + dropdown_update = gr.Dropdown.update(visible=True) + if "model" in url_params: + model = url_params["model"] + if model in models: + dropdown_update = gr.Dropdown.update( + value=model, visible=True) + + state = default_conversation.copy() + return (state, + dropdown_update, + gr.Chatbot.update(visible=True), + gr.Textbox.update(visible=True), + gr.Button.update(visible=True), + gr.Row.update(visible=True), + gr.Accordion.update(visible=True)) def get_conv_log_filename(): t = datetime.datetime.now() @@ -100,29 +91,23 @@ def clear_history(request: gr.Request): state = None return (state, [], "") + (disable_btn,) * 5 -def add_text(state, text, request: gr.Request): - logger.info(f"add_text. ip: {request.client.host}. len:{len(text)}") - if state is None: - state = get_default_conv_template("vicuna").copy() - +def add_text(state, text, request: gr.Request): + logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}") if len(text) <= 0: state.skip_next = True return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5 - - # TODO 根据ChatGPT来进行模型评估 - if enable_moderation: + if args.moderate: flagged = violates_moderation(text) if flagged: - logger.info(f"violate moderation. ip: {request.client.host}. text: {text}") state.skip_next = True - return (state, state.to_gradio_chatbot(), moderation_msg) + (no_change_btn,) * 5 + return (state, state.to_gradio_chatbot(), moderation_msg) + ( + no_change_btn,) * 5 - text = text[:1536] # ? - state.append_message(state.roles[0], text) + text = text[:1536] # Hard cut-off + state.append_message(state.roles[0], text) state.append_message(state.roles[1], None) state.skip_next = False - return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5 def post_process_code(code): @@ -136,104 +121,80 @@ def post_process_code(code): return code def http_bot(state, temperature, max_new_tokens, request: gr.Request): - logger.info(f"http_bot. ip: {request.client.host}") start_tstamp = time.time() - - model_name = LLM_MODEL - temperature = float(temperature) - max_new_tokens = int(max_new_tokens) + model_name = LLM_MODEL if state.skip_next: + # This generate call is skipped due to invalid inputs yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 return if len(state.messages) == state.offset + 2: - # 第一轮对话 - new_state = get_default_conv_template(model_name).copy() + # First round of conversation + + template_name = "conv_one_shot" + new_state = conv_templates[template_name].copy() new_state.conv_id = uuid.uuid4().hex new_state.append_message(new_state.roles[0], state.messages[-2][1]) new_state.append_message(new_state.roles[1], None) state = new_state - prompt = state.get_prompt() - skip_echo_len = compute_skip_echo_len(prompt) + skip_echo_len = len(prompt.replace("", " ")) + 1 - logger.info(f"State: {state}") + # Make requests payload = { "model": model_name, "prompt": prompt, - "temperature": temperature, - "max_new_tokens": max_new_tokens, - "stop": state.sep if state.sep_style == SeparatorStyle.SINGLE else None, + "temperature": float(temperature), + "max_new_tokens": int(max_new_tokens), + "stop": state.sep if state.sep_style == SeparatorStyle.SINGLE else state.sep2, } + logger.info(f"Requert: \n{payload}") - logger.info(f"Request: \n {payload}") state.messages[-1][-1] = "▌" yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 try: - response = requests.post( - url=urljoin(vicuna_model_server, "worker_generate_stream"), - headers=headers, - json=payload, - stream=True, - timeout=20, - ) + # Stream output + response = requests.post(urljoin(vicuna_model_server, "worker_generate_stream"), + headers=headers, json=payload, stream=True, timeout=20) for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): if chunk: data = json.loads(chunk.decode()) - logger.info(f"Response: {data}") if data["error_code"] == 0: - output = data["text"][skip_echo_len].strip() + output = data["text"][skip_echo_len:].strip() output = post_process_code(output) state.messages[-1][-1] = output + "▌" yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 else: - output = data["text"] + f" (error_code): {data['error_code']}" + output = data["text"] + f" (error_code: {data['error_code']})" state.messages[-1][-1] = output - yield (state, state.to_gradio_chatbot()) + ( - disable_btn, - disable_btn, - disable_btn, - enable_btn, - enable_btn - ) - + yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) return time.sleep(0.02) except requests.exceptions.RequestException as e: state.messages[-1][-1] = server_error_msg + f" (error_code: 4)" - yield (state, state.to_gradio_chatbot()) + ( - disable_btn, - disable_btn, - disable_btn, - enable_btn, - enable_btn - ) + yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) return state.messages[-1][-1] = state.messages[-1][-1][:-1] yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 - + finish_tstamp = time.time() logger.info(f"{output}") - with open(get_conv_log_filename(), "a") as flog: + with open(get_conv_log_filename(), "a") as fout: data = { "tstamp": round(finish_tstamp, 4), "type": "chat", "model": model_name, - "gen_params": { - "temperature": temperature, - "max_new_tokens": max_new_tokens, - }, "start": round(start_tstamp, 4), - "finish": round(finish_tstamp, 4), + "finish": round(start_tstamp, 4), "state": state.dict(), "ip": request.client.host, } - flog.write(json.dumps(data), + "\n") + fout.write(json.dumps(data) + "\n") block_css = ( code_highlight_css @@ -382,8 +343,6 @@ if __name__ == "__main__": args = parser.parse_args() logger.info(f"args: {args}") - set_global_vars(args.moderate) - logger.info(args) demo = build_webdemo() demo.queue( diff --git a/requirements.txt b/requirements.txt index dd7bf5189..50354b3b4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -49,4 +49,5 @@ umap-learn notebook gradio==3.24.1 gradio-client==0.0.8 -wandb \ No newline at end of file +wandb +fschat=0.1.10 \ No newline at end of file