[Tensor] add hybrid device demo and fix bugs (#1059)

This commit is contained in:
Ziyue Jiang
2022-06-03 12:09:49 +08:00
committed by GitHub
parent b167258b6a
commit df9dcbbff6
5 changed files with 94 additions and 8 deletions

View File

@@ -51,11 +51,17 @@ class ColoDDP(torch.nn.Module):
free_storage(empty_grad)
if self.dp_world_size > 1:
grad = grad / self.dp_world_size
self.comm_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.comm_stream):
dist.all_reduce(grad, group=gpc.get_group(ParallelMode.DATA))
if grad.device.type != "cpu":
self.comm_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.comm_stream):
group = gpc.get_group(ParallelMode.DATA)
dist.all_reduce(grad, group=group)
ColoDDP._save_grad(p, grad)
grad.record_stream(self.comm_stream)
else:
group = gpc.get_cpu_group(ParallelMode.DATA)
dist.all_reduce(grad, group=group)
ColoDDP._save_grad(p, grad)
grad.record_stream(self.comm_stream)
else:
ColoDDP._save_grad(p, grad)
return empty_grad