fix problem

This commit is contained in:
csunny 2023-04-30 23:58:32 +08:00
parent 3f13fe3a7f
commit 4c7ba5021a
5 changed files with 60 additions and 108 deletions

View File

@ -60,3 +60,4 @@ dependencies:
- gradio==3.24.1 - gradio==3.24.1
- gradio-client==0.0.8 - gradio-client==0.0.8
- wandb - wandb
- fschat=0.1.10

View File

@ -3,7 +3,7 @@
import dataclasses import dataclasses
from enum import auto, Enum from enum import auto, Enum
from typing import List, Tuple, Any from typing import List, Any
class SeparatorStyle(Enum): class SeparatorStyle(Enum):
@ -29,12 +29,12 @@ class Conversation:
def get_prompt(self): def get_prompt(self):
if self.sep_style == SeparatorStyle.SINGLE: if self.sep_style == SeparatorStyle.SINGLE:
ret = self.system ret = self.system + self.sep
for role, message in self.messages: for role, message in self.messages:
if message: if message:
ret += self.sep + " " + role + ": " + message ret += role + ": " + message + self.sep
else: else:
ret += self.sep + " " + role + ":" ret += role + ":"
return ret return ret
elif self.sep_style == SeparatorStyle.TWO: elif self.sep_style == SeparatorStyle.TWO:
@ -56,7 +56,7 @@ class Conversation:
def to_gradio_chatbot(self): def to_gradio_chatbot(self):
ret = [] ret = []
for i, (role, msg) in enumerate(self.messages[self.offset :]): for i, (role, msg) in enumerate(self.messages[self.offset:]):
if i % 2 == 0: if i % 2 == 0:
ret.append([msg, None]) ret.append([msg, None])
else: else:
@ -133,19 +133,9 @@ conv_vicuna_v1 = Conversation(
sep2="</s>", sep2="</s>",
) )
conv_template = { default_conversation = conv_one_shot
conv_templates = {
"conv_one_shot": conv_one_shot, "conv_one_shot": conv_one_shot,
"vicuna_v1": conv_vicuna_v1 "vicuna_v1": conv_vicuna_v1
} }
def get_default_conv_template(model_name: str = "vicuna-13b"):
model_name = model_name.lower()
if "vicuna" in model_name:
return conv_vicuna_v1
return conv_one_shot
def compute_skip_echo_len(prompt):
skip_echo_len = len(prompt) + 1 - prompt.count("</s>") * 3
return skip_echo_len

View File

@ -82,6 +82,7 @@ async def api_generate_stream(request: Request):
global_counter += 1 global_counter += 1
params = await request.json() params = await request.json()
print(model, tokenizer, params, DEVICE) print(model, tokenizer, params, DEVICE)
if model_semaphore is None: if model_semaphore is None:
model_semaphore = asyncio.Semaphore(LIMIT_MODEL_CONCURRENCY) model_semaphore = asyncio.Semaphore(LIMIT_MODEL_CONCURRENCY)
await model_semaphore.acquire() await model_semaphore.acquire()

View File

@ -14,8 +14,8 @@ from urllib.parse import urljoin
from pilot.configs.model_config import LOGDIR, vicuna_model_server, LLM_MODEL from pilot.configs.model_config import LOGDIR, vicuna_model_server, LLM_MODEL
from pilot.conversation import ( from pilot.conversation import (
get_default_conv_template, default_conversation,
compute_skip_echo_len, conv_templates,
SeparatorStyle SeparatorStyle
) )
@ -43,29 +43,6 @@ priority = {
"vicuna-13b": "aaa" "vicuna-13b": "aaa"
} }
def set_global_vars(enable_moderation_):
global enable_moderation, models
enable_moderation = enable_moderation_
def load_demo_single(url_params):
dropdown_update = gr.Dropdown.update(visible=True)
if "model" in url_params:
model = url_params["model"]
if model in models:
dropdown_update = gr.Dropdown.update(value=model, visible=True)
state = None
return (
state,
dropdown_update,
gr.Chatbot.update(visible=True),
gr.Textbox.update(visible=True),
gr.Button.update(visible=True),
gr.Row.update(visible=True),
gr.Accordion.update(visible=True),
)
get_window_url_params = """ get_window_url_params = """
function() { function() {
const params = new URLSearchParams(window.location.search); const params = new URLSearchParams(window.location.search);
@ -78,10 +55,24 @@ function() {
return url_params; return url_params;
} }
""" """
def load_demo(url_params, request: gr.Request): def load_demo(url_params, request: gr.Request):
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}") logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
return load_demo_single(url_params)
dropdown_update = gr.Dropdown.update(visible=True)
if "model" in url_params:
model = url_params["model"]
if model in models:
dropdown_update = gr.Dropdown.update(
value=model, visible=True)
state = default_conversation.copy()
return (state,
dropdown_update,
gr.Chatbot.update(visible=True),
gr.Textbox.update(visible=True),
gr.Button.update(visible=True),
gr.Row.update(visible=True),
gr.Accordion.update(visible=True))
def get_conv_log_filename(): def get_conv_log_filename():
t = datetime.datetime.now() t = datetime.datetime.now()
@ -100,29 +91,23 @@ def clear_history(request: gr.Request):
state = None state = None
return (state, [], "") + (disable_btn,) * 5 return (state, [], "") + (disable_btn,) * 5
def add_text(state, text, request: gr.Request): def add_text(state, text, request: gr.Request):
logger.info(f"add_text. ip: {request.client.host}. len:{len(text)}") logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
if state is None:
state = get_default_conv_template("vicuna").copy()
if len(text) <= 0: if len(text) <= 0:
state.skip_next = True state.skip_next = True
return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5 return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5
if args.moderate:
# TODO 根据ChatGPT来进行模型评估
if enable_moderation:
flagged = violates_moderation(text) flagged = violates_moderation(text)
if flagged: if flagged:
logger.info(f"violate moderation. ip: {request.client.host}. text: {text}")
state.skip_next = True state.skip_next = True
return (state, state.to_gradio_chatbot(), moderation_msg) + (no_change_btn,) * 5 return (state, state.to_gradio_chatbot(), moderation_msg) + (
no_change_btn,) * 5
text = text[:1536] # ? text = text[:1536] # Hard cut-off
state.append_message(state.roles[0], text) state.append_message(state.roles[0], text)
state.append_message(state.roles[1], None) state.append_message(state.roles[1], None)
state.skip_next = False state.skip_next = False
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5 return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
def post_process_code(code): def post_process_code(code):
@ -136,81 +121,61 @@ def post_process_code(code):
return code return code
def http_bot(state, temperature, max_new_tokens, request: gr.Request): def http_bot(state, temperature, max_new_tokens, request: gr.Request):
logger.info(f"http_bot. ip: {request.client.host}")
start_tstamp = time.time() start_tstamp = time.time()
model_name = LLM_MODEL model_name = LLM_MODEL
temperature = float(temperature)
max_new_tokens = int(max_new_tokens)
if state.skip_next: if state.skip_next:
# This generate call is skipped due to invalid inputs
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
return return
if len(state.messages) == state.offset + 2: if len(state.messages) == state.offset + 2:
# 第一轮对话 # First round of conversation
new_state = get_default_conv_template(model_name).copy()
template_name = "conv_one_shot"
new_state = conv_templates[template_name].copy()
new_state.conv_id = uuid.uuid4().hex new_state.conv_id = uuid.uuid4().hex
new_state.append_message(new_state.roles[0], state.messages[-2][1]) new_state.append_message(new_state.roles[0], state.messages[-2][1])
new_state.append_message(new_state.roles[1], None) new_state.append_message(new_state.roles[1], None)
state = new_state state = new_state
prompt = state.get_prompt() prompt = state.get_prompt()
skip_echo_len = compute_skip_echo_len(prompt) skip_echo_len = len(prompt.replace("</s>", " ")) + 1
logger.info(f"State: {state}") # Make requests
payload = { payload = {
"model": model_name, "model": model_name,
"prompt": prompt, "prompt": prompt,
"temperature": temperature, "temperature": float(temperature),
"max_new_tokens": max_new_tokens, "max_new_tokens": int(max_new_tokens),
"stop": state.sep if state.sep_style == SeparatorStyle.SINGLE else None, "stop": state.sep if state.sep_style == SeparatorStyle.SINGLE else state.sep2,
} }
logger.info(f"Requert: \n{payload}")
logger.info(f"Request: \n {payload}")
state.messages[-1][-1] = "" state.messages[-1][-1] = ""
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
try: try:
response = requests.post( # Stream output
url=urljoin(vicuna_model_server, "worker_generate_stream"), response = requests.post(urljoin(vicuna_model_server, "worker_generate_stream"),
headers=headers, headers=headers, json=payload, stream=True, timeout=20)
json=payload,
stream=True,
timeout=20,
)
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk: if chunk:
data = json.loads(chunk.decode()) data = json.loads(chunk.decode())
logger.info(f"Response: {data}")
if data["error_code"] == 0: if data["error_code"] == 0:
output = data["text"][skip_echo_len].strip() output = data["text"][skip_echo_len:].strip()
output = post_process_code(output) output = post_process_code(output)
state.messages[-1][-1] = output + "" state.messages[-1][-1] = output + ""
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
else: else:
output = data["text"] + f" (error_code): {data['error_code']}" output = data["text"] + f" (error_code: {data['error_code']})"
state.messages[-1][-1] = output state.messages[-1][-1] = output
yield (state, state.to_gradio_chatbot()) + ( yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
disable_btn,
disable_btn,
disable_btn,
enable_btn,
enable_btn
)
return return
time.sleep(0.02) time.sleep(0.02)
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
state.messages[-1][-1] = server_error_msg + f" (error_code: 4)" state.messages[-1][-1] = server_error_msg + f" (error_code: 4)"
yield (state, state.to_gradio_chatbot()) + ( yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
disable_btn,
disable_btn,
disable_btn,
enable_btn,
enable_btn
)
return return
state.messages[-1][-1] = state.messages[-1][-1][:-1] state.messages[-1][-1] = state.messages[-1][-1][:-1]
@ -219,21 +184,17 @@ def http_bot(state, temperature, max_new_tokens, request: gr.Request):
finish_tstamp = time.time() finish_tstamp = time.time()
logger.info(f"{output}") logger.info(f"{output}")
with open(get_conv_log_filename(), "a") as flog: with open(get_conv_log_filename(), "a") as fout:
data = { data = {
"tstamp": round(finish_tstamp, 4), "tstamp": round(finish_tstamp, 4),
"type": "chat", "type": "chat",
"model": model_name, "model": model_name,
"gen_params": {
"temperature": temperature,
"max_new_tokens": max_new_tokens,
},
"start": round(start_tstamp, 4), "start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4), "finish": round(start_tstamp, 4),
"state": state.dict(), "state": state.dict(),
"ip": request.client.host, "ip": request.client.host,
} }
flog.write(json.dumps(data), + "\n") fout.write(json.dumps(data) + "\n")
block_css = ( block_css = (
code_highlight_css code_highlight_css
@ -382,8 +343,6 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
logger.info(f"args: {args}") logger.info(f"args: {args}")
set_global_vars(args.moderate)
logger.info(args) logger.info(args)
demo = build_webdemo() demo = build_webdemo()
demo.queue( demo.queue(

View File

@ -50,3 +50,4 @@ notebook
gradio==3.24.1 gradio==3.24.1
gradio-client==0.0.8 gradio-client==0.0.8
wandb wandb
fschat=0.1.10