From 73ad05fc8ccf4551f475420a44656ba91268abc5 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Mon, 20 Jun 2022 11:24:27 +0800 Subject: [PATCH] [zero] added error message to handle on-the-fly import of torch Module class (#1135) * [zero] added error message to handle on-the-fly import of torch Module class * polish code --- colossalai/utils/model/utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/colossalai/utils/model/utils.py b/colossalai/utils/model/utils.py index 0b0b73820..ecc0cdb5a 100644 --- a/colossalai/utils/model/utils.py +++ b/colossalai/utils/model/utils.py @@ -80,6 +80,10 @@ class InsertPostInitMethodToModuleSubClasses(object): torch.set_default_dtype(self._old_default_dtype) def _disable_class(cls): + if not hasattr(cls, '_old_init'): + raise AttributeError( + f"_old_init is not found in the {cls.__name__}, please make sure that you have imported {cls.__name__} before entering the context." + ) cls.__init__ = cls._old_init # Replace .__init__() for all existing subclasses of torch.nn.Module