From f0aa191f51704806e65ac849da137069bb35a6d5 Mon Sep 17 00:00:00 2001 From: ver217 Date: Mon, 13 Feb 2023 17:53:15 +0800 Subject: [PATCH] [gemini] fix colo_init_context (#2683) --- colossalai/utils/model/colo_init_context.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/colossalai/utils/model/colo_init_context.py b/colossalai/utils/model/colo_init_context.py index ab354ea70..87ae413a2 100644 --- a/colossalai/utils/model/colo_init_context.py +++ b/colossalai/utils/model/colo_init_context.py @@ -32,7 +32,7 @@ def _convert_to_coloparam(param: torch.nn.Parameter, default_pg: Optional[ProcessGroup] = None, default_dist_spec: Optional[Any] = None) -> ColoParameter: - if isinstance(param, ColoParameter): + if type(param) is ColoParameter: return param # detaching tensor is necessary for optimizers. requires_grad = param.requires_grad @@ -102,7 +102,7 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses): """ name_list = [] for name, param in _named_params_with_replica(module): - if isinstance(param, ColoTensor): + if type(param) is ColoParameter: continue split = name.rfind('.')