From 20ab1f55204ee18aca204f6faf60ca27b509b509 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Mon, 11 Apr 2022 22:00:27 +0800 Subject: [PATCH] [bug] fixed broken test_found_inf (#725) --- tests/test_zero_data_parallel/test_found_inf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_zero_data_parallel/test_found_inf.py b/tests/test_zero_data_parallel/test_found_inf.py index 7caa552e8..22c3a80fd 100644 --- a/tests/test_zero_data_parallel/test_found_inf.py +++ b/tests/test_zero_data_parallel/test_found_inf.py @@ -31,7 +31,7 @@ def _run_test_found_inf(cpu_offload, shard_strategy_class, gpu_margin_mem_ratio) model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func() with ZeroInitContext( - target_device=torch.device(f'cpu:0') if cpu_offload else torch.device(f'cuda:{get_current_device()}'), + target_device=torch.device(f'cpu:0') if cpu_offload else torch.device(get_current_device()), shard_strategy=shard_strategy, shard_param=True): zero_model = model_builder(checkpoint=True)