From eca14bc0389438e9d60eb8b36a05ed929a7430c3 Mon Sep 17 00:00:00 2001 From: csunny Date: Sat, 29 Apr 2023 23:02:13 +0800 Subject: [PATCH] fix load model gpu oom --- pilot/model/loader.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pilot/model/loader.py b/pilot/model/loader.py index 7b78ebe8c..5f18a023c 100644 --- a/pilot/model/loader.py +++ b/pilot/model/loader.py @@ -7,6 +7,8 @@ from transformers import ( AutoModelForCausalLM, ) +from fastchat.serve.compression import compress_module + class ModerLoader: kwargs = {} @@ -29,6 +31,9 @@ class ModerLoader: if debug: print(model) + if load_8bit: + compress_module(model, self.device) + # if self.device == "cuda": # model.to(self.device)