mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 01:55:12 +00:00
[zero] support all-gather overlap (#5898)
* [zero] support all-gather overlap * [zero] add overlap all-gather flag * [misc] fix typo * [zero] update api
This commit is contained in:
@@ -64,8 +64,12 @@ def exam_zero_1_2_grad_acc():
|
||||
zero1_optimizer.step()
|
||||
zero2_optimizer.step()
|
||||
|
||||
zero1_optimizer._force_wait_all_gather()
|
||||
zero2_optimizer._force_wait_all_gather()
|
||||
|
||||
# check updated param
|
||||
for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()):
|
||||
assert not hasattr(z1p, "_all_gather_handle")
|
||||
assert torch.equal(z1p.data, z2p.data)
|
||||
|
||||
|
||||
|
@@ -177,6 +177,8 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool):
|
||||
# torch ddp step
|
||||
torch_optimizer.step()
|
||||
|
||||
zero_optimizer._force_wait_all_gather()
|
||||
|
||||
# check updated param
|
||||
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
|
||||
loose_close(p, z1p, dtype=dtype)
|
||||
|
Reference in New Issue
Block a user