DB-GPT/pilot/model/loader.py
2023-04-29 23:02:13 +08:00

43 lines
994 B
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
)
from fastchat.serve.compression import compress_module
class ModerLoader:
kwargs = {}
def __init__(self,
model_path) -> None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model_path = model_path
self.kwargs = {
"torch_dtype": torch.float16,
"device_map": "auto",
}
def loader(self, load_8bit=False, debug=False):
tokenizer = AutoTokenizer.from_pretrained(self.model_path, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(self.model_path, low_cpu_mem_usage=True, **self.kwargs)
if debug:
print(model)
if load_8bit:
compress_module(model, self.device)
# if self.device == "cuda":
# model.to(self.device)
return model, tokenizer