diff --git a/pilot/server/sqlgpt.py b/pilot/server/sqlgpt.py index 81f9b22dd..52522e6bd 100644 --- a/pilot/server/sqlgpt.py +++ b/pilot/server/sqlgpt.py @@ -7,7 +7,6 @@ import torch import gradio as gr from fastchat.serve.inference import generate_stream, compress_module - from transformers import AutoTokenizer, AutoModelForCausalLM device = "cuda" if torch.cuda.is_available() else "cpu" BASE_MODE = "/home/magic/workspace/github/DB-GPT/models/vicuna-13b" @@ -26,18 +25,19 @@ def generate(prompt): print(model, tokenizer) params = { "model": "vicuna-13b", - "prompt": prompt, + "prompt": "这是一个用户与助手之间的对话, 助手精通数据库领域的知识, 并能够对数据库领域知识做出非常专业的回答。以下是用户的问题:" + prompt, "temperature": 0.7, "max_new_tokens": 512, "stop": "###" } - output = generate_stream( - model, tokenizer, params, device, context_len=2048, stream_interval=2): + for output in generate_stream( + model, tokenizer, params, device, context_len=2048, stream_interval=2): + ret = { + "text": output, + "error_code": 0 + } - - for chunk in output.iter_lines(decode_unicode=False, delimiter=b"\0"): - if chunk: - yield chunk + yield json.dumps(ret).decode() + b"\0" if __name__ == "__main__": with gr.Blocks() as demo: @@ -50,7 +50,7 @@ if __name__ == "__main__": text_button.click(generate, inputs=text_input, outputs=text_output) - demo.queue(concurrency_count=3).launch() + demo.queue(concurrency_count=3).launch(host="0.0.0.0")