[hotfix] add deconstructor for stateful tensor (#848)

* add deconstructor for stateful tensor

* fix colo init context
This commit is contained in:
ver217
2022-04-24 15:03:04 +08:00
committed by GitHub
parent 0f7ed8c192
commit 0dea140760
3 changed files with 20 additions and 12 deletions

View File

@@ -12,7 +12,7 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
super().__init__()
self._lazy_memory_allocate = lazy_memory_allocate
def _post_init_method(self, module: torch.nn.Module):
def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
"""
The function to call at the end of the constructor of each module.
FIXME(fjr) The module may be passed to this function multiple times?