From 0f7ed8c1925d51511c19943a51c0eedc6e1ff73a Mon Sep 17 00:00:00 2001 From: ver217 Date: Sun, 24 Apr 2022 14:16:50 +0800 Subject: [PATCH] fix _post_init_method of zero init ctx (#847) --- colossalai/zero/init_ctx/init_context.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/zero/init_ctx/init_context.py b/colossalai/zero/init_ctx/init_context.py index 22418f33d..8c125db29 100644 --- a/colossalai/zero/init_ctx/init_context.py +++ b/colossalai/zero/init_ctx/init_context.py @@ -155,7 +155,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): torch.set_rng_state(self.cpu_rng_state) torch.cuda.set_rng_state(self.cuda_rng_state) - def _post_init_method(self, module: torch.nn.Module): + def _post_init_method(self, module: torch.nn.Module, *args, **kwargs): """ The function to call at the end of the constructor of each module. NOTE() The module may be passed to this function multiple times.