mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-05 02:51:07 +00:00
update
This commit is contained in:
@@ -1,4 +1,4 @@
|
|||||||
name: db-pgt
|
name: db_pgt
|
||||||
channels:
|
channels:
|
||||||
- pytorch
|
- pytorch
|
||||||
- defaults
|
- defaults
|
||||||
|
@@ -12,18 +12,18 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
|
|||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
BASE_MODE = "/home/magic/workspace/github/DB-GPT/models/vicuna-13b"
|
BASE_MODE = "/home/magic/workspace/github/DB-GPT/models/vicuna-13b"
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(BASE_MODE, use_fast=False)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
BASE_MODE,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
device_map="auto",
|
||||||
|
)
|
||||||
|
|
||||||
def generate(prompt):
|
def generate(prompt):
|
||||||
tokenizer = AutoTokenizer.from_pretrained(BASE_MODE, use_fast=False)
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
BASE_MODE,
|
|
||||||
low_cpu_mem_usage=True,
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
device_map="auto",
|
|
||||||
)
|
|
||||||
# compress_module(model, device)
|
# compress_module(model, device)
|
||||||
# model.to(device)
|
# model.to(device)
|
||||||
print(model, tokenizer)
|
print(model, tokenizer)
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"model": "vicuna-13b",
|
"model": "vicuna-13b",
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
@@ -32,9 +32,12 @@ def generate(prompt):
|
|||||||
"stop": "###"
|
"stop": "###"
|
||||||
}
|
}
|
||||||
output = generate_stream(
|
output = generate_stream(
|
||||||
model, tokenizer, params, device, context_len=2048, stream_interval=2)
|
model, tokenizer, params, device, context_len=2048, stream_interval=2):
|
||||||
|
|
||||||
yield output
|
|
||||||
|
for chunk in output.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||||
|
if chunk:
|
||||||
|
yield chunk
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
with gr.Blocks() as demo:
|
with gr.Blocks() as demo:
|
||||||
|
Reference in New Issue
Block a user