This commit is contained in:
csunny
2023-04-29 01:26:19 +08:00
parent 2ff4d71fdd
commit 75181d6f2f
2 changed files with 14 additions and 11 deletions

View File

@@ -1,4 +1,4 @@
name: db-pgt
name: db_pgt
channels:
- pytorch
- defaults

View File

@@ -12,18 +12,18 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
device = "cuda" if torch.cuda.is_available() else "cpu"
BASE_MODE = "/home/magic/workspace/github/DB-GPT/models/vicuna-13b"
tokenizer = AutoTokenizer.from_pretrained(BASE_MODE, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(
BASE_MODE,
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
device_map="auto",
)
def generate(prompt):
tokenizer = AutoTokenizer.from_pretrained(BASE_MODE, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(
BASE_MODE,
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
device_map="auto",
)
# compress_module(model, device)
# model.to(device)
print(model, tokenizer)
params = {
"model": "vicuna-13b",
"prompt": prompt,
@@ -32,9 +32,12 @@ def generate(prompt):
"stop": "###"
}
output = generate_stream(
model, tokenizer, params, device, context_len=2048, stream_interval=2)
model, tokenizer, params, device, context_len=2048, stream_interval=2):
yield output
for chunk in output.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
yield chunk
if __name__ == "__main__":
with gr.Blocks() as demo: