mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-13 22:15:35 +00:00
fix problem
This commit is contained in:
parent
3f13fe3a7f
commit
4c7ba5021a
@ -60,3 +60,4 @@ dependencies:
|
|||||||
- gradio==3.24.1
|
- gradio==3.24.1
|
||||||
- gradio-client==0.0.8
|
- gradio-client==0.0.8
|
||||||
- wandb
|
- wandb
|
||||||
|
- fschat=0.1.10
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from enum import auto, Enum
|
from enum import auto, Enum
|
||||||
from typing import List, Tuple, Any
|
from typing import List, Any
|
||||||
|
|
||||||
|
|
||||||
class SeparatorStyle(Enum):
|
class SeparatorStyle(Enum):
|
||||||
@ -29,12 +29,12 @@ class Conversation:
|
|||||||
|
|
||||||
def get_prompt(self):
|
def get_prompt(self):
|
||||||
if self.sep_style == SeparatorStyle.SINGLE:
|
if self.sep_style == SeparatorStyle.SINGLE:
|
||||||
ret = self.system
|
ret = self.system + self.sep
|
||||||
for role, message in self.messages:
|
for role, message in self.messages:
|
||||||
if message:
|
if message:
|
||||||
ret += self.sep + " " + role + ": " + message
|
ret += role + ": " + message + self.sep
|
||||||
else:
|
else:
|
||||||
ret += self.sep + " " + role + ":"
|
ret += role + ":"
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
elif self.sep_style == SeparatorStyle.TWO:
|
elif self.sep_style == SeparatorStyle.TWO:
|
||||||
@ -56,7 +56,7 @@ class Conversation:
|
|||||||
|
|
||||||
def to_gradio_chatbot(self):
|
def to_gradio_chatbot(self):
|
||||||
ret = []
|
ret = []
|
||||||
for i, (role, msg) in enumerate(self.messages[self.offset :]):
|
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
||||||
if i % 2 == 0:
|
if i % 2 == 0:
|
||||||
ret.append([msg, None])
|
ret.append([msg, None])
|
||||||
else:
|
else:
|
||||||
@ -133,19 +133,9 @@ conv_vicuna_v1 = Conversation(
|
|||||||
sep2="</s>",
|
sep2="</s>",
|
||||||
)
|
)
|
||||||
|
|
||||||
conv_template = {
|
default_conversation = conv_one_shot
|
||||||
|
|
||||||
|
conv_templates = {
|
||||||
"conv_one_shot": conv_one_shot,
|
"conv_one_shot": conv_one_shot,
|
||||||
"vicuna_v1": conv_vicuna_v1
|
"vicuna_v1": conv_vicuna_v1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_default_conv_template(model_name: str = "vicuna-13b"):
|
|
||||||
model_name = model_name.lower()
|
|
||||||
if "vicuna" in model_name:
|
|
||||||
return conv_vicuna_v1
|
|
||||||
return conv_one_shot
|
|
||||||
|
|
||||||
|
|
||||||
def compute_skip_echo_len(prompt):
|
|
||||||
skip_echo_len = len(prompt) + 1 - prompt.count("</s>") * 3
|
|
||||||
return skip_echo_len
|
|
@ -82,6 +82,7 @@ async def api_generate_stream(request: Request):
|
|||||||
global_counter += 1
|
global_counter += 1
|
||||||
params = await request.json()
|
params = await request.json()
|
||||||
print(model, tokenizer, params, DEVICE)
|
print(model, tokenizer, params, DEVICE)
|
||||||
|
|
||||||
if model_semaphore is None:
|
if model_semaphore is None:
|
||||||
model_semaphore = asyncio.Semaphore(LIMIT_MODEL_CONCURRENCY)
|
model_semaphore = asyncio.Semaphore(LIMIT_MODEL_CONCURRENCY)
|
||||||
await model_semaphore.acquire()
|
await model_semaphore.acquire()
|
||||||
|
@ -14,8 +14,8 @@ from urllib.parse import urljoin
|
|||||||
from pilot.configs.model_config import LOGDIR, vicuna_model_server, LLM_MODEL
|
from pilot.configs.model_config import LOGDIR, vicuna_model_server, LLM_MODEL
|
||||||
|
|
||||||
from pilot.conversation import (
|
from pilot.conversation import (
|
||||||
get_default_conv_template,
|
default_conversation,
|
||||||
compute_skip_echo_len,
|
conv_templates,
|
||||||
SeparatorStyle
|
SeparatorStyle
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -43,29 +43,6 @@ priority = {
|
|||||||
"vicuna-13b": "aaa"
|
"vicuna-13b": "aaa"
|
||||||
}
|
}
|
||||||
|
|
||||||
def set_global_vars(enable_moderation_):
|
|
||||||
global enable_moderation, models
|
|
||||||
enable_moderation = enable_moderation_
|
|
||||||
|
|
||||||
def load_demo_single(url_params):
|
|
||||||
dropdown_update = gr.Dropdown.update(visible=True)
|
|
||||||
if "model" in url_params:
|
|
||||||
model = url_params["model"]
|
|
||||||
if model in models:
|
|
||||||
dropdown_update = gr.Dropdown.update(value=model, visible=True)
|
|
||||||
|
|
||||||
state = None
|
|
||||||
return (
|
|
||||||
state,
|
|
||||||
dropdown_update,
|
|
||||||
gr.Chatbot.update(visible=True),
|
|
||||||
gr.Textbox.update(visible=True),
|
|
||||||
gr.Button.update(visible=True),
|
|
||||||
gr.Row.update(visible=True),
|
|
||||||
gr.Accordion.update(visible=True),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
get_window_url_params = """
|
get_window_url_params = """
|
||||||
function() {
|
function() {
|
||||||
const params = new URLSearchParams(window.location.search);
|
const params = new URLSearchParams(window.location.search);
|
||||||
@ -78,10 +55,24 @@ function() {
|
|||||||
return url_params;
|
return url_params;
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def load_demo(url_params, request: gr.Request):
|
def load_demo(url_params, request: gr.Request):
|
||||||
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
|
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
|
||||||
return load_demo_single(url_params)
|
|
||||||
|
dropdown_update = gr.Dropdown.update(visible=True)
|
||||||
|
if "model" in url_params:
|
||||||
|
model = url_params["model"]
|
||||||
|
if model in models:
|
||||||
|
dropdown_update = gr.Dropdown.update(
|
||||||
|
value=model, visible=True)
|
||||||
|
|
||||||
|
state = default_conversation.copy()
|
||||||
|
return (state,
|
||||||
|
dropdown_update,
|
||||||
|
gr.Chatbot.update(visible=True),
|
||||||
|
gr.Textbox.update(visible=True),
|
||||||
|
gr.Button.update(visible=True),
|
||||||
|
gr.Row.update(visible=True),
|
||||||
|
gr.Accordion.update(visible=True))
|
||||||
|
|
||||||
def get_conv_log_filename():
|
def get_conv_log_filename():
|
||||||
t = datetime.datetime.now()
|
t = datetime.datetime.now()
|
||||||
@ -100,29 +91,23 @@ def clear_history(request: gr.Request):
|
|||||||
state = None
|
state = None
|
||||||
return (state, [], "") + (disable_btn,) * 5
|
return (state, [], "") + (disable_btn,) * 5
|
||||||
|
|
||||||
|
|
||||||
def add_text(state, text, request: gr.Request):
|
def add_text(state, text, request: gr.Request):
|
||||||
logger.info(f"add_text. ip: {request.client.host}. len:{len(text)}")
|
logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
|
||||||
|
|
||||||
if state is None:
|
|
||||||
state = get_default_conv_template("vicuna").copy()
|
|
||||||
|
|
||||||
if len(text) <= 0:
|
if len(text) <= 0:
|
||||||
state.skip_next = True
|
state.skip_next = True
|
||||||
return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5
|
return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5
|
||||||
|
if args.moderate:
|
||||||
# TODO 根据ChatGPT来进行模型评估
|
|
||||||
if enable_moderation:
|
|
||||||
flagged = violates_moderation(text)
|
flagged = violates_moderation(text)
|
||||||
if flagged:
|
if flagged:
|
||||||
logger.info(f"violate moderation. ip: {request.client.host}. text: {text}")
|
|
||||||
state.skip_next = True
|
state.skip_next = True
|
||||||
return (state, state.to_gradio_chatbot(), moderation_msg) + (no_change_btn,) * 5
|
return (state, state.to_gradio_chatbot(), moderation_msg) + (
|
||||||
|
no_change_btn,) * 5
|
||||||
|
|
||||||
text = text[:1536] # ?
|
text = text[:1536] # Hard cut-off
|
||||||
state.append_message(state.roles[0], text)
|
state.append_message(state.roles[0], text)
|
||||||
state.append_message(state.roles[1], None)
|
state.append_message(state.roles[1], None)
|
||||||
state.skip_next = False
|
state.skip_next = False
|
||||||
|
|
||||||
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
|
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
|
||||||
|
|
||||||
def post_process_code(code):
|
def post_process_code(code):
|
||||||
@ -136,81 +121,61 @@ def post_process_code(code):
|
|||||||
return code
|
return code
|
||||||
|
|
||||||
def http_bot(state, temperature, max_new_tokens, request: gr.Request):
|
def http_bot(state, temperature, max_new_tokens, request: gr.Request):
|
||||||
logger.info(f"http_bot. ip: {request.client.host}")
|
|
||||||
start_tstamp = time.time()
|
start_tstamp = time.time()
|
||||||
|
|
||||||
model_name = LLM_MODEL
|
model_name = LLM_MODEL
|
||||||
temperature = float(temperature)
|
|
||||||
max_new_tokens = int(max_new_tokens)
|
|
||||||
|
|
||||||
if state.skip_next:
|
if state.skip_next:
|
||||||
|
# This generate call is skipped due to invalid inputs
|
||||||
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
||||||
return
|
return
|
||||||
|
|
||||||
if len(state.messages) == state.offset + 2:
|
if len(state.messages) == state.offset + 2:
|
||||||
# 第一轮对话
|
# First round of conversation
|
||||||
new_state = get_default_conv_template(model_name).copy()
|
|
||||||
|
template_name = "conv_one_shot"
|
||||||
|
new_state = conv_templates[template_name].copy()
|
||||||
new_state.conv_id = uuid.uuid4().hex
|
new_state.conv_id = uuid.uuid4().hex
|
||||||
new_state.append_message(new_state.roles[0], state.messages[-2][1])
|
new_state.append_message(new_state.roles[0], state.messages[-2][1])
|
||||||
new_state.append_message(new_state.roles[1], None)
|
new_state.append_message(new_state.roles[1], None)
|
||||||
state = new_state
|
state = new_state
|
||||||
|
|
||||||
|
|
||||||
prompt = state.get_prompt()
|
prompt = state.get_prompt()
|
||||||
skip_echo_len = compute_skip_echo_len(prompt)
|
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
||||||
|
|
||||||
logger.info(f"State: {state}")
|
# Make requests
|
||||||
payload = {
|
payload = {
|
||||||
"model": model_name,
|
"model": model_name,
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"temperature": temperature,
|
"temperature": float(temperature),
|
||||||
"max_new_tokens": max_new_tokens,
|
"max_new_tokens": int(max_new_tokens),
|
||||||
"stop": state.sep if state.sep_style == SeparatorStyle.SINGLE else None,
|
"stop": state.sep if state.sep_style == SeparatorStyle.SINGLE else state.sep2,
|
||||||
}
|
}
|
||||||
|
logger.info(f"Requert: \n{payload}")
|
||||||
|
|
||||||
logger.info(f"Request: \n {payload}")
|
|
||||||
state.messages[-1][-1] = "▌"
|
state.messages[-1][-1] = "▌"
|
||||||
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = requests.post(
|
# Stream output
|
||||||
url=urljoin(vicuna_model_server, "worker_generate_stream"),
|
response = requests.post(urljoin(vicuna_model_server, "worker_generate_stream"),
|
||||||
headers=headers,
|
headers=headers, json=payload, stream=True, timeout=20)
|
||||||
json=payload,
|
|
||||||
stream=True,
|
|
||||||
timeout=20,
|
|
||||||
)
|
|
||||||
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||||
if chunk:
|
if chunk:
|
||||||
data = json.loads(chunk.decode())
|
data = json.loads(chunk.decode())
|
||||||
logger.info(f"Response: {data}")
|
|
||||||
if data["error_code"] == 0:
|
if data["error_code"] == 0:
|
||||||
output = data["text"][skip_echo_len].strip()
|
output = data["text"][skip_echo_len:].strip()
|
||||||
output = post_process_code(output)
|
output = post_process_code(output)
|
||||||
state.messages[-1][-1] = output + "▌"
|
state.messages[-1][-1] = output + "▌"
|
||||||
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
||||||
else:
|
else:
|
||||||
output = data["text"] + f" (error_code): {data['error_code']}"
|
output = data["text"] + f" (error_code: {data['error_code']})"
|
||||||
state.messages[-1][-1] = output
|
state.messages[-1][-1] = output
|
||||||
yield (state, state.to_gradio_chatbot()) + (
|
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
||||||
disable_btn,
|
|
||||||
disable_btn,
|
|
||||||
disable_btn,
|
|
||||||
enable_btn,
|
|
||||||
enable_btn
|
|
||||||
)
|
|
||||||
|
|
||||||
return
|
return
|
||||||
time.sleep(0.02)
|
time.sleep(0.02)
|
||||||
except requests.exceptions.RequestException as e:
|
except requests.exceptions.RequestException as e:
|
||||||
state.messages[-1][-1] = server_error_msg + f" (error_code: 4)"
|
state.messages[-1][-1] = server_error_msg + f" (error_code: 4)"
|
||||||
yield (state, state.to_gradio_chatbot()) + (
|
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
||||||
disable_btn,
|
|
||||||
disable_btn,
|
|
||||||
disable_btn,
|
|
||||||
enable_btn,
|
|
||||||
enable_btn
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
state.messages[-1][-1] = state.messages[-1][-1][:-1]
|
state.messages[-1][-1] = state.messages[-1][-1][:-1]
|
||||||
@ -219,21 +184,17 @@ def http_bot(state, temperature, max_new_tokens, request: gr.Request):
|
|||||||
finish_tstamp = time.time()
|
finish_tstamp = time.time()
|
||||||
logger.info(f"{output}")
|
logger.info(f"{output}")
|
||||||
|
|
||||||
with open(get_conv_log_filename(), "a") as flog:
|
with open(get_conv_log_filename(), "a") as fout:
|
||||||
data = {
|
data = {
|
||||||
"tstamp": round(finish_tstamp, 4),
|
"tstamp": round(finish_tstamp, 4),
|
||||||
"type": "chat",
|
"type": "chat",
|
||||||
"model": model_name,
|
"model": model_name,
|
||||||
"gen_params": {
|
|
||||||
"temperature": temperature,
|
|
||||||
"max_new_tokens": max_new_tokens,
|
|
||||||
},
|
|
||||||
"start": round(start_tstamp, 4),
|
"start": round(start_tstamp, 4),
|
||||||
"finish": round(finish_tstamp, 4),
|
"finish": round(start_tstamp, 4),
|
||||||
"state": state.dict(),
|
"state": state.dict(),
|
||||||
"ip": request.client.host,
|
"ip": request.client.host,
|
||||||
}
|
}
|
||||||
flog.write(json.dumps(data), + "\n")
|
fout.write(json.dumps(data) + "\n")
|
||||||
|
|
||||||
block_css = (
|
block_css = (
|
||||||
code_highlight_css
|
code_highlight_css
|
||||||
@ -382,8 +343,6 @@ if __name__ == "__main__":
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
logger.info(f"args: {args}")
|
logger.info(f"args: {args}")
|
||||||
|
|
||||||
set_global_vars(args.moderate)
|
|
||||||
|
|
||||||
logger.info(args)
|
logger.info(args)
|
||||||
demo = build_webdemo()
|
demo = build_webdemo()
|
||||||
demo.queue(
|
demo.queue(
|
||||||
|
@ -50,3 +50,4 @@ notebook
|
|||||||
gradio==3.24.1
|
gradio==3.24.1
|
||||||
gradio-client==0.0.8
|
gradio-client==0.0.8
|
||||||
wandb
|
wandb
|
||||||
|
fschat=0.1.10
|
Loading…
Reference in New Issue
Block a user