diff --git a/environment.yml b/environment.yml index 81872a557..3ec4dfd98 100644 --- a/environment.yml +++ b/environment.yml @@ -1,4 +1,4 @@ -name: db-pgt +name: db_pgt channels: - pytorch - defaults diff --git a/pilot/server/sqlgpt.py b/pilot/server/sqlgpt.py index edd2baf84..81f9b22dd 100644 --- a/pilot/server/sqlgpt.py +++ b/pilot/server/sqlgpt.py @@ -12,18 +12,18 @@ from transformers import AutoTokenizer, AutoModelForCausalLM device = "cuda" if torch.cuda.is_available() else "cpu" BASE_MODE = "/home/magic/workspace/github/DB-GPT/models/vicuna-13b" +tokenizer = AutoTokenizer.from_pretrained(BASE_MODE, use_fast=False) +model = AutoModelForCausalLM.from_pretrained( + BASE_MODE, + low_cpu_mem_usage=True, + torch_dtype=torch.float16, + device_map="auto", +) + def generate(prompt): - tokenizer = AutoTokenizer.from_pretrained(BASE_MODE, use_fast=False) - model = AutoModelForCausalLM.from_pretrained( - BASE_MODE, - low_cpu_mem_usage=True, - torch_dtype=torch.float16, - device_map="auto", - ) # compress_module(model, device) # model.to(device) print(model, tokenizer) - params = { "model": "vicuna-13b", "prompt": prompt, @@ -32,9 +32,12 @@ def generate(prompt): "stop": "###" } output = generate_stream( - model, tokenizer, params, device, context_len=2048, stream_interval=2) + model, tokenizer, params, device, context_len=2048, stream_interval=2): + - yield output + for chunk in output.iter_lines(decode_unicode=False, delimiter=b"\0"): + if chunk: + yield chunk if __name__ == "__main__": with gr.Blocks() as demo: