[gemini] get the param visited order during runtime (#2108)

This commit is contained in:
Jiarui Fang
2022-12-09 16:13:03 +08:00
committed by GitHub
parent 61f31c3cf0
commit 70a8556946
6 changed files with 48 additions and 2 deletions

View File

@@ -38,6 +38,13 @@ def test_runtime_mem_tracer():
print("cuda_non_model_data_list", len(cuda_non_model_data_list))
print(non_model_data_list)
cnt1 = 0
for p in runtime_mem_tracer.parameters_in_runtime_order():
cnt1 += 1
cnt2 = 0
for p in model.parameters():
cnt2 += 1
assert cnt2 == cnt1, f'visited param number {cnt1} vs real param number {cnt2}'
del model