mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-17 23:18:36 +00:00
[hotfix] fix a running error in test_colo_checkpoint.py (#1387)
This commit is contained in:
@@ -146,6 +146,9 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
|
||||
data = data.to(get_current_device())
|
||||
label = label.to(get_current_device())
|
||||
|
||||
dist.broadcast(data, pg.tp_rank_list()[0], pg.tp_process_group())
|
||||
dist.broadcast(label, pg.tp_rank_list()[0], pg.tp_process_group())
|
||||
|
||||
# Bcast rank0 data to all processes
|
||||
if criterion:
|
||||
output = model(data)
|
||||
@@ -183,9 +186,9 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
|
||||
def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
# TODO(haichen) add BERT in the test
|
||||
|
||||
# the data loader of BERT is in DDP mode, causing the input data is not replicated in the TP context
|
||||
for model_name in ['simple_net']:
|
||||
for model_name in ['bert']:
|
||||
_run_checkpoint(model_name,
|
||||
init_1d_row_for_linear_weight_spec,
|
||||
use_ddp,
|
||||
|
Reference in New Issue
Block a user