mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-21 14:49:24 +00:00
[zero] update sharded optim v2 (#334)
This commit is contained in:
@@ -102,6 +102,11 @@ class ShardedModelV2(nn.Module):
|
||||
# Wait for the non-blocking GPU -> CPU grad transfers to finish.
|
||||
torch.cuda.current_stream().synchronize()
|
||||
self.reducer.free()
|
||||
# In case some post bwd hook is not fired
|
||||
if self.shard_param:
|
||||
for p in self.module.parameters():
|
||||
if not p.col_attr.param_is_sharded:
|
||||
self.shard_strategy.shard([p.col_attr.data])
|
||||
for p in self.module.parameters():
|
||||
p.col_attr.bwd_count = 0
|
||||
if not p.requires_grad:
|
||||
@@ -113,13 +118,12 @@ class ShardedModelV2(nn.Module):
|
||||
if not self._require_backward_grad_sync:
|
||||
continue
|
||||
# Write grad back to p.grad and set p.col_attr.grad to None
|
||||
p.grad.data = p.col_attr.grad
|
||||
# We have to make sure grad and param have the same shape
|
||||
# If world size > 1, and sharded param, `.view()` may be not needed
|
||||
# If world size == 1, and sharded param, `data` is a flatten tensor
|
||||
# But the shape `grad` is the same as unsharded param
|
||||
p.grad.data = p.col_attr.grad.view(p.col_attr.data.shape)
|
||||
p.col_attr.grad = None
|
||||
# In case some post bwd hook is not fired
|
||||
if self.shard_param:
|
||||
for p in self.module.parameters():
|
||||
if not p.col_attr.param_is_sharded:
|
||||
self.shard_strategy.shard([p.col_attr.data])
|
||||
|
||||
@torch.no_grad()
|
||||
def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
@@ -180,7 +184,11 @@ class ShardedModelV2(nn.Module):
|
||||
if param.col_attr.grad is None:
|
||||
param.col_attr.grad = reduced_grad.data
|
||||
else:
|
||||
param.col_attr.grad.add_(reduced_grad.data)
|
||||
# When dp size = 1
|
||||
# param.col_attr.grad is local accumulated grad shard (full but flatten)
|
||||
# But reduced_grad here is full grad
|
||||
# We should call `view_as`
|
||||
param.col_attr.grad.add_(reduced_grad.data.view_as(param.col_attr.grad))
|
||||
|
||||
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
|
||||
self.shard_strategy.gather([p.col_attr.data for p in self.module.parameters()])
|
||||
|
Reference in New Issue
Block a user