diff --git a/tests/test_pipeline/test_schedule/test_dx_dw.py b/tests/test_pipeline/test_schedule/test_dx_dw.py index 6da1434d8..1ade7d45a 100644 --- a/tests/test_pipeline/test_schedule/test_dx_dw.py +++ b/tests/test_pipeline/test_schedule/test_dx_dw.py @@ -1176,12 +1176,16 @@ def model_chunk_dx_dw_comm_interleaved( print(f"After del: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};") +def run_fwd_bwd( + rank: int, + world_size: int, + port: int, +): + pass + + @rerun_if_address_is_in_use() def test_dx_dw_dist(): - # spawn( - # model_chunk_dx_dw_communication, - # nprocs=2, - # ) spawn( model_chunk_dx_dw_comm_interleaved,