[Gemini] mapping of preop timestep and param (#2124)

This commit is contained in:
Jiarui Fang
2022-12-13 12:50:24 +08:00
committed by GitHub
parent 764bc16f3e
commit 05bb28aacf
3 changed files with 49 additions and 6 deletions

View File

@@ -45,7 +45,8 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_
run_fwd_bwd(runtime_mem_tracer, input_ids, label, criterion, runtime_mem_tracer)
memstats = runtime_mem_tracer.memstats()
runtime_tracer_non_model_data = runtime_mem_tracer._memstats._non_model_data_cuda_list
print('runtime tracer non model data points: ', len(runtime_tracer_non_model_data))
print('runtime tracer: ', runtime_tracer_non_model_data)
print([memstats.param_used_timestep(p) for p in model.parameters()])
model = GeminiDDP(model, device='cuda', placement_policy=placement_policy, search_range_mb=1, memstats=memstats)
zero_optim = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=1)