mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 09:38:05 +00:00
[zero] allow passing process group to zero12 (#4153)
* allow passing process group to zero12 * union tp-zero and normal-zero * polish code
This commit is contained in:
@@ -33,10 +33,9 @@ def exam_zero_init():
|
||||
|
||||
assert optimizer1._local_rank == optimizer2._local_rank
|
||||
assert optimizer1._world_size == optimizer2._world_size
|
||||
assert optimizer1._dp_global_ranks == optimizer2._dp_global_ranks
|
||||
|
||||
mp_group1 = optimizer1._mp_torch_group
|
||||
mp_group2 = optimizer2._mp_torch_group
|
||||
mp_group1 = optimizer1.tp_pg
|
||||
mp_group2 = optimizer2.tp_pg
|
||||
assert dist.get_world_size(mp_group1) == dist.get_world_size(mp_group2)
|
||||
assert dist.get_rank(mp_group1) == dist.get_rank(mp_group2)
|
||||
|
||||
|
@@ -57,7 +57,9 @@ def exam_zero_with_tp(overlap_flag, partition_flag):
|
||||
initial_scale=2,
|
||||
clip_grad_norm=1.0,
|
||||
overlap_communication=overlap_flag,
|
||||
partition_grad=partition_flag)
|
||||
partition_grad=partition_flag,
|
||||
dp_process_group=tp_pg.dp_process_group(),
|
||||
tp_process_group=tp_pg.tp_process_group())
|
||||
|
||||
dp_local_rank = tp_pg.dp_local_rank()
|
||||
set_seed(255 + dp_local_rank)
|
||||
|
Reference in New Issue
Block a user