[zero] adapt zero for unsharded parameters (#561)

* support existing sharded and unsharded parameters in zero

* add unitest for moe-zero model init

* polish moe gradient handler
This commit is contained in:
HELSON
2022-03-31 18:34:11 +08:00
committed by GitHub
parent 13ed4b6441
commit e6d50ec107
11 changed files with 211 additions and 70 deletions

View File

@@ -91,15 +91,19 @@ def check_params(model, zero_model, loose=False):
def check_grads_padding(model, zero_model, loose=False):
rank = dist.get_rank()
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
for (name, p), (zero_name, zero_p) in zip(model.named_parameters(), zero_model.named_parameters()):
# zero_grad = zero_p.grad.clone().to(p.device)
zero_grad = zero_p.colo_attr.saved_grad.payload.clone().to(p.device)
chunks = torch.flatten(p.grad).chunk(dist.get_world_size())
if rank >= len(chunks):
continue
grad = chunks[rank].float()
if zero_grad.size(0) > grad.size(0):
zero_grad = zero_grad[:grad.size(0)]
if zero_p.colo_attr.param_is_sharded:
zero_grad = zero_p.colo_attr.saved_grad.payload.clone().to(p.device)
chunks = torch.flatten(p.grad).chunk(dist.get_world_size())
if rank >= len(chunks):
continue
grad = chunks[rank].float()
if zero_grad.size(0) > grad.size(0):
zero_grad = zero_grad[:grad.size(0)]
else:
grad = p.grad
zero_grad = zero_p.colo_attr.saved_grad.payload
assert grad.dtype == zero_grad.dtype
assert allclose(grad, zero_grad, loose=loose), f'diff: {grad - zero_grad}'