This commit is contained in:
csunny 2023-04-29 01:44:33 +08:00
parent 75181d6f2f
commit c7d3dd2ef2

View File

@ -7,7 +7,6 @@ import torch
import gradio as gr import gradio as gr
from fastchat.serve.inference import generate_stream, compress_module from fastchat.serve.inference import generate_stream, compress_module
from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import AutoTokenizer, AutoModelForCausalLM
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
BASE_MODE = "/home/magic/workspace/github/DB-GPT/models/vicuna-13b" BASE_MODE = "/home/magic/workspace/github/DB-GPT/models/vicuna-13b"
@ -26,18 +25,19 @@ def generate(prompt):
print(model, tokenizer) print(model, tokenizer)
params = { params = {
"model": "vicuna-13b", "model": "vicuna-13b",
"prompt": prompt, "prompt": "这是一个用户与助手之间的对话, 助手精通数据库领域的知识, 并能够对数据库领域知识做出非常专业的回答。以下是用户的问题:" + prompt,
"temperature": 0.7, "temperature": 0.7,
"max_new_tokens": 512, "max_new_tokens": 512,
"stop": "###" "stop": "###"
} }
output = generate_stream( for output in generate_stream(
model, tokenizer, params, device, context_len=2048, stream_interval=2): model, tokenizer, params, device, context_len=2048, stream_interval=2):
ret = {
"text": output,
"error_code": 0
}
yield json.dumps(ret).decode() + b"\0"
for chunk in output.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
yield chunk
if __name__ == "__main__": if __name__ == "__main__":
with gr.Blocks() as demo: with gr.Blocks() as demo:
@ -50,7 +50,7 @@ if __name__ == "__main__":
text_button.click(generate, inputs=text_input, outputs=text_output) text_button.click(generate, inputs=text_input, outputs=text_output)
demo.queue(concurrency_count=3).launch() demo.queue(concurrency_count=3).launch(host="0.0.0.0")