[Tensor] fix optimizer for CPU parallel (#1069)

This commit is contained in:
Ziyue Jiang
2022-06-06 17:36:11 +08:00
committed by GitHub
parent 49832b2344
commit 4fc748f69b
2 changed files with 16 additions and 12 deletions

View File

@@ -42,14 +42,15 @@ class ColoDDP(torch.nn.Module):
loss.backward()
torch.cuda.current_stream().wait_stream(self.comm_stream)
for p in self.module.parameters():
p.grad = p._saved_grad
if p.grad.device.type != "cpu":
p.grad = p._saved_grad
def grad_handle(self, p, grad):
empty_grad = torch.empty_like(grad)
free_storage(empty_grad)
if self.dp_world_size > 1:
grad = grad / self.dp_world_size
if grad.device.type != "cpu":
if grad.device.type != "cpu":
empty_grad = torch.empty_like(grad)
free_storage(empty_grad)
if self.dp_world_size > 1:
grad = grad / self.dp_world_size
self.comm_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.comm_stream):
group = gpc.get_group(ParallelMode.DATA)
@@ -57,12 +58,13 @@ class ColoDDP(torch.nn.Module):
ColoDDP._save_grad(p, grad)
grad.record_stream(self.comm_stream)
else:
group = gpc.get_cpu_group(ParallelMode.DATA)
dist.all_reduce(grad, group=group)
ColoDDP._save_grad(p, grad)
return empty_grad
else:
ColoDDP._save_grad(p, grad)
return empty_grad
group = gpc.get_cpu_group(ParallelMode.DATA)
dist.all_reduce(grad, group=group)
return grad
@staticmethod
def _save_grad(p, grad):