llms: fix

This commit is contained in:
csunny 2023-05-21 14:54:16 +08:00
parent ce72820085
commit f52c7523b5

View File

@ -59,8 +59,6 @@ class ModelLoader(metaclass=Singleton):
# TODO multi gpu support # TODO multi gpu support
def loader(self, num_gpus, load_8bit=False, debug=False, cpu_offloading=False, max_gpu_memory: Optional[str]=None): def loader(self, num_gpus, load_8bit=False, debug=False, cpu_offloading=False, max_gpu_memory: Optional[str]=None):
cpu_offloading(self.device, load_8bit, cpu_offloading)
if self.device == "cpu": if self.device == "cpu":
kwargs = {"torch_dtype": torch.float32} kwargs = {"torch_dtype": torch.float32}