[zero] support all-gather overlap (#5898)

* [zero] support all-gather overlap

* [zero] add overlap all-gather flag

* [misc] fix typo

* [zero] update api
This commit is contained in:
Hongxin Liu
2024-07-11 18:59:59 +08:00
committed by GitHub
parent dd9e1cdafe
commit c068ef0fa0
7 changed files with 119 additions and 25 deletions

View File

@@ -64,8 +64,12 @@ def exam_zero_1_2_grad_acc():
zero1_optimizer.step()
zero2_optimizer.step()
zero1_optimizer._force_wait_all_gather()
zero2_optimizer._force_wait_all_gather()
# check updated param
for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()):
assert not hasattr(z1p, "_all_gather_handle")
assert torch.equal(z1p.data, z2p.data)