fix load model gpu oom

This commit is contained in:
csunny 2023-04-29 23:02:13 +08:00
parent ca29dacc37
commit eca14bc038

View File

@ -7,6 +7,8 @@ from transformers import (
AutoModelForCausalLM,
)
from fastchat.serve.compression import compress_module
class ModerLoader:
kwargs = {}
@ -29,6 +31,9 @@ class ModerLoader:
if debug:
print(model)
if load_8bit:
compress_module(model, self.device)
# if self.device == "cuda":
# model.to(self.device)