[hotfix] fix a running error in test_colo_checkpoint.py (#1387)

This commit is contained in:
HELSON
2022-07-29 15:58:06 +08:00
committed by GitHub
parent f792507ff3
commit 527758b2ae
4 changed files with 13 additions and 4 deletions

View File

@@ -146,6 +146,9 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
data = data.to(get_current_device())
label = label.to(get_current_device())
dist.broadcast(data, pg.tp_rank_list()[0], pg.tp_process_group())
dist.broadcast(label, pg.tp_rank_list()[0], pg.tp_process_group())
# Bcast rank0 data to all processes
if criterion:
output = model(data)
@@ -183,9 +186,9 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
pg = ProcessGroup(tp_degree=world_size)
# TODO(haichen) add BERT in the test
# the data loader of BERT is in DDP mode, causing the input data is not replicated in the TP context
for model_name in ['simple_net']:
for model_name in ['bert']:
_run_checkpoint(model_name,
init_1d_row_for_linear_weight_spec,
use_ddp,