[zero] update sharded optim v2 (#334)

This commit is contained in:
ver217
2022-03-09 16:09:36 +08:00
committed by Frank Lee
parent 2b8cddd40e
commit d0ae0f2215
5 changed files with 115 additions and 68 deletions

View File

@@ -102,6 +102,11 @@ class ShardedModelV2(nn.Module):
# Wait for the non-blocking GPU -> CPU grad transfers to finish.
torch.cuda.current_stream().synchronize()
self.reducer.free()
# In case some post bwd hook is not fired
if self.shard_param:
for p in self.module.parameters():
if not p.col_attr.param_is_sharded:
self.shard_strategy.shard([p.col_attr.data])
for p in self.module.parameters():
p.col_attr.bwd_count = 0
if not p.requires_grad:
@@ -113,13 +118,12 @@ class ShardedModelV2(nn.Module):
if not self._require_backward_grad_sync:
continue
# Write grad back to p.grad and set p.col_attr.grad to None
p.grad.data = p.col_attr.grad
# We have to make sure grad and param have the same shape
# If world size > 1, and sharded param, `.view()` may be not needed
# If world size == 1, and sharded param, `data` is a flatten tensor
# But the shape `grad` is the same as unsharded param
p.grad.data = p.col_attr.grad.view(p.col_attr.data.shape)
p.col_attr.grad = None
# In case some post bwd hook is not fired
if self.shard_param:
for p in self.module.parameters():
if not p.col_attr.param_is_sharded:
self.shard_strategy.shard([p.col_attr.data])
@torch.no_grad()
def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]:
@@ -180,7 +184,11 @@ class ShardedModelV2(nn.Module):
if param.col_attr.grad is None:
param.col_attr.grad = reduced_grad.data
else:
param.col_attr.grad.add_(reduced_grad.data)
# When dp size = 1
# param.col_attr.grad is local accumulated grad shard (full but flatten)
# But reduced_grad here is full grad
# We should call `view_as`
param.col_attr.grad.add_(reduced_grad.data.view_as(param.col_attr.grad))
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
self.shard_strategy.gather([p.col_attr.data for p in self.module.parameters()])