[Tensor] add hybrid device demo and fix bugs (#1059)

This commit is contained in:
Ziyue Jiang
2022-06-03 12:09:49 +08:00
committed by GitHub
parent b167258b6a
commit df9dcbbff6
5 changed files with 94 additions and 8 deletions

View File

@@ -12,13 +12,17 @@ def register_colo_module(module_type: type, colo_module: ColoModule):
def is_colo_module(module: torch.nn.Module):
global _COLOSSAL_MODULES
return type(module) in _COLOSSAL_MODULES
for module_type in _COLOSSAL_MODULES.keys():
if isinstance(type(module), module_type):
return True
return False
def get_colo_module(module: torch.nn.Module):
global _COLOSSAL_MODULES
if is_colo_module(module):
colo_module = _COLOSSAL_MODULES[type(module)]
return colo_module
for module_type, colo_module in _COLOSSAL_MODULES.items():
if isinstance(type(module), module_type):
return colo_module
else:
return None