mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
[polish] polish code for get_static_torch_model (#2405)
* [gemini] polish code * [testing] remove code * [gemini] make more robust
This commit is contained in:
@@ -27,8 +27,6 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module, pg: ProcessGroup):
|
||||
for key, value in torch_dict.items():
|
||||
# key is 'module.model.PARAMETER', so we truncate it
|
||||
key = key[7:]
|
||||
if key == 'model.lm_head.weight':
|
||||
continue
|
||||
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
|
||||
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
|
||||
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
|
||||
|
Reference in New Issue
Block a user