mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-08 04:23:35 +00:00
Add: multi model support
This commit is contained in:
@@ -2,21 +2,19 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
import warnings
|
||||
from pilot.singleton import Singleton
|
||||
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoModelForCausalLM,
|
||||
AutoModel
|
||||
)
|
||||
|
||||
from pilot.model.compression import compress_module
|
||||
from pilot.model.adapter import get_llm_model_adapter
|
||||
|
||||
|
||||
class ModelLoader(metaclass=Singleton):
|
||||
"""Model loader is a class for model load
|
||||
|
||||
Args: model_path
|
||||
|
||||
|
||||
TODO: multi model support.
|
||||
"""
|
||||
|
||||
kwargs = {}
|
||||
@@ -31,9 +29,11 @@ class ModelLoader(metaclass=Singleton):
|
||||
"device_map": "auto",
|
||||
}
|
||||
|
||||
# TODO multi gpu support
|
||||
def loader(self, num_gpus, load_8bit=False, debug=False):
|
||||
if self.device == "cpu":
|
||||
kwargs = {}
|
||||
|
||||
elif self.device == "cuda":
|
||||
kwargs = {"torch_dtype": torch.float16}
|
||||
if num_gpus == "auto":
|
||||
@@ -46,18 +46,20 @@ class ModelLoader(metaclass=Singleton):
|
||||
"max_memory": {i: "13GiB" for i in range(num_gpus)},
|
||||
})
|
||||
else:
|
||||
# Todo Support mps for practise
|
||||
raise ValueError(f"Invalid device: {self.device}")
|
||||
|
||||
if "chatglm" in self.model_path:
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
|
||||
model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True).half().cuda()
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_path, use_fast=False)
|
||||
model = AutoModelForCausalLM.from_pretrained(self.model_path,
|
||||
low_cpu_mem_usage=True, **kwargs)
|
||||
|
||||
llm_adapter = get_llm_model_adapter(self.model_path)
|
||||
model, tokenizer = llm_adapter.loader(self.model_path, kwargs)
|
||||
|
||||
if load_8bit:
|
||||
compress_module(model, self.device)
|
||||
if num_gpus != 1:
|
||||
warnings.warn(
|
||||
"8-bit quantization is not supported for multi-gpu inference"
|
||||
)
|
||||
else:
|
||||
compress_module(model, self.device)
|
||||
|
||||
if (self.device == "cuda" and num_gpus == 1):
|
||||
model.to(self.device)
|
||||
|
Reference in New Issue
Block a user