diff --git a/pilot/model/loader.py b/pilot/model/loader.py index 7b78ebe8c..5f18a023c 100644 --- a/pilot/model/loader.py +++ b/pilot/model/loader.py @@ -7,6 +7,8 @@ from transformers import ( AutoModelForCausalLM, ) +from fastchat.serve.compression import compress_module + class ModerLoader: kwargs = {} @@ -29,6 +31,9 @@ class ModerLoader: if debug: print(model) + if load_8bit: + compress_module(model, self.device) + # if self.device == "cuda": # model.to(self.device)