diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/utils/neftune_patch.py b/applications/Colossal-LLaMA-2/colossal_llama2/utils/neftune_patch.py index 079faaace..9f6c9c1cc 100644 --- a/applications/Colossal-LLaMA-2/colossal_llama2/utils/neftune_patch.py +++ b/applications/Colossal-LLaMA-2/colossal_llama2/utils/neftune_patch.py @@ -16,7 +16,10 @@ import torch def unwrap(model): - return model.unwrap().module + if hasattr(model, "module"): + return unwrap_model(model.module) + else: + return model def neftune_post_forward_hook(module, input, output):