diff --git a/colossalai/nn/module_utils.py b/colossalai/nn/module_utils.py index 64d3f8075..158bd1062 100644 --- a/colossalai/nn/module_utils.py +++ b/colossalai/nn/module_utils.py @@ -14,7 +14,7 @@ def register_colo_module(module_type: type, colo_module: ColoModule): def is_colo_module(module: torch.nn.Module): global _COLOSSAL_MODULES for module_type in _COLOSSAL_MODULES.keys(): - if isinstance(type(module), module_type): + if isinstance(module, module_type): return True return False @@ -23,7 +23,7 @@ def get_colo_module(module: torch.nn.Module): global _COLOSSAL_MODULES if is_colo_module(module): for module_type, colo_module in _COLOSSAL_MODULES.items(): - if isinstance(type(module), module_type): + if isinstance(module, module_type): return colo_module else: return None