[zero]remove registered gradients hooks (#5687)

* remove registered hooks

fix

fix

fix zero

fix

fix

fix

fix

fix zero

fix zero

fix

fix

fix

* fix

fix

fix
This commit is contained in:
flybird11111
2024-05-07 12:01:38 +08:00
committed by GitHub
parent c25f83c85f
commit 77ec773388
7 changed files with 256 additions and 167 deletions

View File

@@ -64,7 +64,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
for p1, p2 in zip(llama_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]):
working_p = sharded_optimizer._param_store.master_to_working_param[id(p2)]
grads = sharded_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(working_p))
grad_index = 0 if sharded_optimizer._partition_grads else sharded_optimizer._local_rank
grad_index = (
0 if sharded_optimizer._grad_store._partition_grads else sharded_optimizer._bucket_store.zero_local_rank
)
grad = grads[grad_index]
sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()]
assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-3, rtol=5e-3, check_dtype=False)