[polish] polish code for get_static_torch_model (#2405)

* [gemini] polish code

* [testing] remove code

* [gemini] make more robust
This commit is contained in:
HELSON
2023-01-09 17:41:38 +08:00
committed by GitHub
parent 551cafec14
commit ea13a201bb
6 changed files with 15 additions and 28 deletions

View File

@@ -27,8 +27,6 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module, pg: ProcessGroup):
for key, value in torch_dict.items():
# key is 'module.model.PARAMETER', so we truncate it
key = key[7:]
if key == 'model.lm_head.weight':
continue
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))