mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-15 06:53:12 +00:00
update
This commit is contained in:
parent
75181d6f2f
commit
c7d3dd2ef2
@ -7,7 +7,6 @@ import torch
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
from fastchat.serve.inference import generate_stream, compress_module
|
from fastchat.serve.inference import generate_stream, compress_module
|
||||||
|
|
||||||
|
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
BASE_MODE = "/home/magic/workspace/github/DB-GPT/models/vicuna-13b"
|
BASE_MODE = "/home/magic/workspace/github/DB-GPT/models/vicuna-13b"
|
||||||
@ -26,18 +25,19 @@ def generate(prompt):
|
|||||||
print(model, tokenizer)
|
print(model, tokenizer)
|
||||||
params = {
|
params = {
|
||||||
"model": "vicuna-13b",
|
"model": "vicuna-13b",
|
||||||
"prompt": prompt,
|
"prompt": "这是一个用户与助手之间的对话, 助手精通数据库领域的知识, 并能够对数据库领域知识做出非常专业的回答。以下是用户的问题:" + prompt,
|
||||||
"temperature": 0.7,
|
"temperature": 0.7,
|
||||||
"max_new_tokens": 512,
|
"max_new_tokens": 512,
|
||||||
"stop": "###"
|
"stop": "###"
|
||||||
}
|
}
|
||||||
output = generate_stream(
|
for output in generate_stream(
|
||||||
model, tokenizer, params, device, context_len=2048, stream_interval=2):
|
model, tokenizer, params, device, context_len=2048, stream_interval=2):
|
||||||
|
ret = {
|
||||||
|
"text": output,
|
||||||
|
"error_code": 0
|
||||||
|
}
|
||||||
|
|
||||||
|
yield json.dumps(ret).decode() + b"\0"
|
||||||
for chunk in output.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
|
||||||
if chunk:
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
with gr.Blocks() as demo:
|
with gr.Blocks() as demo:
|
||||||
@ -50,7 +50,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
text_button.click(generate, inputs=text_input, outputs=text_output)
|
text_button.click(generate, inputs=text_input, outputs=text_output)
|
||||||
|
|
||||||
demo.queue(concurrency_count=3).launch()
|
demo.queue(concurrency_count=3).launch(host="0.0.0.0")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user