mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 12:01:39 +00:00
[hotfix] fix init context (#1543)
* fix init context * fix lazy init ctx
This commit is contained in:
@@ -3,10 +3,12 @@ import functools
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def substitute_init_recursively(cls, func):
|
||||
def substitute_init_recursively(cls, func, visited: set):
|
||||
for subcls in cls.__subclasses__():
|
||||
substitute_init_recursively(subcls, func)
|
||||
func(subcls)
|
||||
substitute_init_recursively(subcls, func, visited)
|
||||
if subcls not in visited:
|
||||
func(subcls)
|
||||
visited.add(subcls)
|
||||
|
||||
|
||||
def call_to_str(base, *args, **kwargs):
|
||||
@@ -64,7 +66,7 @@ class InsertPostInitMethodToModuleSubClasses(object):
|
||||
|
||||
# Replace .__init__() for all existing subclasses of torch.nn.Module
|
||||
# Excution self._post_init_method after the default init function.
|
||||
substitute_init_recursively(torch.nn.modules.module.Module, _enable_class)
|
||||
substitute_init_recursively(torch.nn.modules.module.Module, _enable_class, set())
|
||||
|
||||
# holding on to the current __init__subclass__ for exit
|
||||
torch.nn.modules.module.Module._old_init_subclass = (torch.nn.modules.module.Module.__init_subclass__)
|
||||
@@ -87,7 +89,7 @@ class InsertPostInitMethodToModuleSubClasses(object):
|
||||
cls.__init__ = cls._old_init
|
||||
|
||||
# Replace .__init__() for all existing subclasses of torch.nn.Module
|
||||
substitute_init_recursively(torch.nn.modules.module.Module, _disable_class)
|
||||
substitute_init_recursively(torch.nn.modules.module.Module, _disable_class, set())
|
||||
|
||||
# Replace .__init__() for future subclasses of torch.nn.Module
|
||||
torch.nn.modules.module.Module.__init_subclass__ = (torch.nn.modules.module.Module._old_init_subclass)
|
||||
|
Reference in New Issue
Block a user