[zero] support all-gather overlap (#5898)

* [zero] support all-gather overlap

* [zero] add overlap all-gather flag

* [misc] fix typo

* [zero] update api
This commit is contained in:
Hongxin Liu
2024-07-11 18:59:59 +08:00
committed by GitHub
parent dd9e1cdafe
commit c068ef0fa0
7 changed files with 119 additions and 25 deletions

View File

@@ -64,8 +64,12 @@ def exam_zero_1_2_grad_acc():
zero1_optimizer.step()
zero2_optimizer.step()
zero1_optimizer._force_wait_all_gather()
zero2_optimizer._force_wait_all_gather()
# check updated param
for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()):
assert not hasattr(z1p, "_all_gather_handle")
assert torch.equal(z1p.data, z2p.data)

View File

@@ -177,6 +177,8 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool):
# torch ddp step
torch_optimizer.step()
zero_optimizer._force_wait_all_gather()
# check updated param
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
loose_close(p, z1p, dtype=dtype)