DB-GPT/pilot/server/sqlgpt.py

51 lines
1.5 KiB
Python

#!/usr/bin/env python3
#-*- coding: utf-8 -*-
import json
import torch
import gradio as gr
from fastchat.serve.inference import generate_stream
from transformers import AutoTokenizer, AutoModelForCausalLM
device = "cuda" if torch.cuda.is_available() else "cpu"
BASE_MODE = "/home/magic/workspace/github/DB-GPT/models/vicuna-13b"
tokenizer = AutoTokenizer.from_pretrained(BASE_MODE, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(
BASE_MODE,
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
device_map="auto",
)
def generate(prompt):
model.to(device)
print(model, tokenizer)
params = {
"model": "vicuna-13b",
"prompt": "这是一个用户与助手之间的对话, 助手精通数据库领域的知识, 并能够对数据库领域知识做出非常专业的回答。以下是用户的问题:" + prompt,
"temperature": 0.7,
"max_new_tokens": 512,
"stop": "###"
}
output = generate_stream(
model, tokenizer, params, device, context_len=2048, stream_interval=2)
for chunk in output:
yield chunk
if __name__ == "__main__":
with gr.Blocks() as demo:
gr.Markdown("数据库SQL生成助手")
with gr.Tab("SQL生成"):
text_input = gr.TextArea()
text_output = gr.TextArea()
text_button = gr.Button("提交")
text_button.click(generate, inputs=text_input, outputs=text_output)
demo.queue(concurrency_count=3).launch(server_name="0.0.0.0")