[Tensor] add hybrid device demo and fix bugs (#1059)

This commit is contained in:
Ziyue Jiang
2022-06-03 12:09:49 +08:00
committed by GitHub
parent b167258b6a
commit df9dcbbff6
5 changed files with 94 additions and 8 deletions

View File

@@ -51,11 +51,17 @@ class ColoDDP(torch.nn.Module):
free_storage(empty_grad)
if self.dp_world_size > 1:
grad = grad / self.dp_world_size
self.comm_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.comm_stream):
dist.all_reduce(grad, group=gpc.get_group(ParallelMode.DATA))
if grad.device.type != "cpu":
self.comm_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.comm_stream):
group = gpc.get_group(ParallelMode.DATA)
dist.all_reduce(grad, group=group)
ColoDDP._save_grad(p, grad)
grad.record_stream(self.comm_stream)
else:
group = gpc.get_cpu_group(ParallelMode.DATA)
dist.all_reduce(grad, group=group)
ColoDDP._save_grad(p, grad)
grad.record_stream(self.comm_stream)
else:
ColoDDP._save_grad(p, grad)
return empty_grad

View File

@@ -12,13 +12,17 @@ def register_colo_module(module_type: type, colo_module: ColoModule):
def is_colo_module(module: torch.nn.Module):
global _COLOSSAL_MODULES
return type(module) in _COLOSSAL_MODULES
for module_type in _COLOSSAL_MODULES.keys():
if isinstance(type(module), module_type):
return True
return False
def get_colo_module(module: torch.nn.Module):
global _COLOSSAL_MODULES
if is_colo_module(module):
colo_module = _COLOSSAL_MODULES[type(module)]
return colo_module
for module_type, colo_module in _COLOSSAL_MODULES.items():
if isinstance(type(module), module_type):
return colo_module
else:
return None

View File

@@ -92,4 +92,5 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
setattr(submodule, param_name, colo_param)
colo_param.shared_param_modules.append(submodule)
module.to(self._device)
ColoModulize(module)