mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 06:30:41 +00:00
[Tensor] add hybrid device demo and fix bugs (#1059)
This commit is contained in:
@@ -51,11 +51,17 @@ class ColoDDP(torch.nn.Module):
|
||||
free_storage(empty_grad)
|
||||
if self.dp_world_size > 1:
|
||||
grad = grad / self.dp_world_size
|
||||
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(self.comm_stream):
|
||||
dist.all_reduce(grad, group=gpc.get_group(ParallelMode.DATA))
|
||||
if grad.device.type != "cpu":
|
||||
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(self.comm_stream):
|
||||
group = gpc.get_group(ParallelMode.DATA)
|
||||
dist.all_reduce(grad, group=group)
|
||||
ColoDDP._save_grad(p, grad)
|
||||
grad.record_stream(self.comm_stream)
|
||||
else:
|
||||
group = gpc.get_cpu_group(ParallelMode.DATA)
|
||||
dist.all_reduce(grad, group=group)
|
||||
ColoDDP._save_grad(p, grad)
|
||||
grad.record_stream(self.comm_stream)
|
||||
else:
|
||||
ColoDDP._save_grad(p, grad)
|
||||
return empty_grad
|
||||
|
@@ -12,13 +12,17 @@ def register_colo_module(module_type: type, colo_module: ColoModule):
|
||||
|
||||
def is_colo_module(module: torch.nn.Module):
|
||||
global _COLOSSAL_MODULES
|
||||
return type(module) in _COLOSSAL_MODULES
|
||||
for module_type in _COLOSSAL_MODULES.keys():
|
||||
if isinstance(type(module), module_type):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_colo_module(module: torch.nn.Module):
|
||||
global _COLOSSAL_MODULES
|
||||
if is_colo_module(module):
|
||||
colo_module = _COLOSSAL_MODULES[type(module)]
|
||||
return colo_module
|
||||
for module_type, colo_module in _COLOSSAL_MODULES.items():
|
||||
if isinstance(type(module), module_type):
|
||||
return colo_module
|
||||
else:
|
||||
return None
|
||||
|
||||
|
@@ -92,4 +92,5 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
setattr(submodule, param_name, colo_param)
|
||||
colo_param.shared_param_modules.append(submodule)
|
||||
|
||||
module.to(self._device)
|
||||
ColoModulize(module)
|
||||
|
Reference in New Issue
Block a user