diff --git a/colossalai/zero/init_ctx/init_context.py b/colossalai/zero/init_ctx/init_context.py index 22418f33d..8c125db29 100644 --- a/colossalai/zero/init_ctx/init_context.py +++ b/colossalai/zero/init_ctx/init_context.py @@ -155,7 +155,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): torch.set_rng_state(self.cpu_rng_state) torch.cuda.set_rng_state(self.cuda_rng_state) - def _post_init_method(self, module: torch.nn.Module): + def _post_init_method(self, module: torch.nn.Module, *args, **kwargs): """ The function to call at the end of the constructor of each module. NOTE() The module may be passed to this function multiple times.