From 5efda69735bda0280ab219145f9be51cb74dacb1 Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Tue, 13 Dec 2022 14:14:55 +0800 Subject: [PATCH] [Gemini] hotfix the unittest bugs (#2125) --- .../memory_tracer/param_runtime_order.py | 2 +- .../test_gemini/update/test_gemini_use_rmt.py | 18 ++++++++++-------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/colossalai/gemini/memory_tracer/param_runtime_order.py b/colossalai/gemini/memory_tracer/param_runtime_order.py index dc9226a53..638c0533c 100644 --- a/colossalai/gemini/memory_tracer/param_runtime_order.py +++ b/colossalai/gemini/memory_tracer/param_runtime_order.py @@ -36,7 +36,7 @@ class OrderedParamGenerator(ParamGenerator): del visited_set def is_empty(self): - return len(self.param_visited_order) > 0 + return len(self.param_visited_order) == 0 def clear(self): self.param_visited_order = [] diff --git a/tests/test_gemini/update/test_gemini_use_rmt.py b/tests/test_gemini/update/test_gemini_use_rmt.py index 3e3247e39..926b61ef4 100644 --- a/tests/test_gemini/update/test_gemini_use_rmt.py +++ b/tests/test_gemini/update/test_gemini_use_rmt.py @@ -45,11 +45,15 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_ run_fwd_bwd(runtime_mem_tracer, input_ids, label, criterion, runtime_mem_tracer) memstats = runtime_mem_tracer.memstats() runtime_tracer_non_model_data = runtime_mem_tracer._memstats._non_model_data_cuda_list - print('runtime tracer: ', runtime_tracer_non_model_data) - print([memstats.param_used_timestep(p) for p in model.parameters()]) + print('runtime tracer non model data points: ', len(runtime_tracer_non_model_data)) - model = GeminiDDP(model, device='cuda', placement_policy=placement_policy, search_range_mb=1, memstats=memstats) - zero_optim = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=1) + world_size = torch.distributed.get_world_size() + config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) + config_dict[world_size]['chunk_size'] = 5000 + config_dict[world_size]['keep_gathered'] = keep_gather + chunk_manager = ChunkManager(config_dict) + gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats) + model = ZeroDDP(model, gemini_manager, pin_memory=True) pg = ProcessGroup() set_seed(pg.dp_local_rank()) @@ -61,12 +65,10 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_ break input_ids, label = input_ids.cuda(), label.cuda() - zero_optim.zero_grad() set_seed(42) - loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim) - zero_optim.step() + loss = run_fwd_bwd(model, input_ids, label, criterion, model) - gemini_non_model_data = model.gemini_manager._mem_stats_collector._memstats.non_model_data_list('cuda') + gemini_non_model_data = gemini_manager._mem_stats_collector._memstats.non_model_data_list('cuda') # print('gemini non model data:', gemini_non_model_data)