mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-06 08:04:16 +00:00
[Tensor] fix optimizer for CPU parallel (#1069)
This commit is contained in:
@@ -42,14 +42,15 @@ class ColoDDP(torch.nn.Module):
|
||||
loss.backward()
|
||||
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
||||
for p in self.module.parameters():
|
||||
p.grad = p._saved_grad
|
||||
if p.grad.device.type != "cpu":
|
||||
p.grad = p._saved_grad
|
||||
|
||||
def grad_handle(self, p, grad):
|
||||
empty_grad = torch.empty_like(grad)
|
||||
free_storage(empty_grad)
|
||||
if self.dp_world_size > 1:
|
||||
grad = grad / self.dp_world_size
|
||||
if grad.device.type != "cpu":
|
||||
if grad.device.type != "cpu":
|
||||
empty_grad = torch.empty_like(grad)
|
||||
free_storage(empty_grad)
|
||||
if self.dp_world_size > 1:
|
||||
grad = grad / self.dp_world_size
|
||||
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(self.comm_stream):
|
||||
group = gpc.get_group(ParallelMode.DATA)
|
||||
@@ -57,12 +58,13 @@ class ColoDDP(torch.nn.Module):
|
||||
ColoDDP._save_grad(p, grad)
|
||||
grad.record_stream(self.comm_stream)
|
||||
else:
|
||||
group = gpc.get_cpu_group(ParallelMode.DATA)
|
||||
dist.all_reduce(grad, group=group)
|
||||
ColoDDP._save_grad(p, grad)
|
||||
return empty_grad
|
||||
|
||||
else:
|
||||
ColoDDP._save_grad(p, grad)
|
||||
return empty_grad
|
||||
group = gpc.get_cpu_group(ParallelMode.DATA)
|
||||
dist.all_reduce(grad, group=group)
|
||||
return grad
|
||||
|
||||
@staticmethod
|
||||
def _save_grad(p, grad):
|
||||
|
||||
Reference in New Issue
Block a user