[zero] allow passing process group to zero12 (#4153)

* allow passing process group to zero12

* union tp-zero and normal-zero

* polish code
This commit is contained in:
LuGY
2023-07-04 17:41:28 +08:00
committed by Hongxin Liu
parent 79cf1b5f33
commit c668801d36
4 changed files with 41 additions and 90 deletions

View File

@@ -33,10 +33,9 @@ def exam_zero_init():
assert optimizer1._local_rank == optimizer2._local_rank
assert optimizer1._world_size == optimizer2._world_size
assert optimizer1._dp_global_ranks == optimizer2._dp_global_ranks
mp_group1 = optimizer1._mp_torch_group
mp_group2 = optimizer2._mp_torch_group
mp_group1 = optimizer1.tp_pg
mp_group2 = optimizer2.tp_pg
assert dist.get_world_size(mp_group1) == dist.get_world_size(mp_group2)
assert dist.get_rank(mp_group1) == dist.get_rank(mp_group2)

View File

@@ -57,7 +57,9 @@ def exam_zero_with_tp(overlap_flag, partition_flag):
initial_scale=2,
clip_grad_norm=1.0,
overlap_communication=overlap_flag,
partition_grad=partition_flag)
partition_grad=partition_flag,
dp_process_group=tp_pg.dp_process_group(),
tp_process_group=tp_pg.tp_process_group())
dp_local_rank = tp_pg.dp_local_rank()
set_seed(255 + dp_local_rank)