[zero] improve adaptability for not-shard parameters (#708)

* adapt post grad hooks for not-shard parameters
* adapt optimizer for not-shard parameters
* offload gradients for not-replicated parameters
This commit is contained in:
HELSON
2022-04-11 13:38:51 +08:00
committed by GitHub
parent ab8c6b4a0e
commit a9b8300d54
9 changed files with 114 additions and 111 deletions

View File

@@ -93,7 +93,7 @@ def check_grads_padding(model, zero_model, loose=False):
rank = dist.get_rank()
for (name, p), (zero_name, zero_p) in zip(model.named_parameters(), zero_model.named_parameters()):
# zero_grad = zero_p.grad.clone().to(p.device)
if zero_p.colo_attr.param_is_sharded:
if zero_p.colo_attr.is_replicated:
zero_grad = zero_p.colo_attr.saved_grad.payload.clone().to(p.device)
chunks = torch.flatten(p.grad).chunk(dist.get_world_size())
if rank >= len(chunks):
@@ -102,8 +102,9 @@ def check_grads_padding(model, zero_model, loose=False):
if zero_grad.size(0) > grad.size(0):
zero_grad = zero_grad[:grad.size(0)]
else:
grad = p.grad
zero_grad = zero_p.colo_attr.saved_grad.payload
grad = p.grad.to(zero_grad.dtype)
assert grad.dtype == zero_grad.dtype
assert allclose(grad, zero_grad, loose=loose), f'diff: {grad - zero_grad}'
@@ -134,7 +135,7 @@ def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=
if zero_p.size(0) > p.size(0):
zero_p = zero_p[:p.size(0)]
else:
zero_p = zero_p.colo_attr.sharded_data_tensor.payload
zero_p = zero_p.colo_attr.sharded_data_tensor.payload.to(p.device)
assert p.dtype == zero_p.dtype
assert allclose(p, zero_p, loose=loose), f'{p} vs {zero_p}'