[zero] add L2 gradient clipping for ZeRO (#2112)

* [zero] add L2 gradient clipping

* [testing] add MlpModel

* [zero] add unit test for grad clipping

* fix atol
This commit is contained in:
HELSON
2022-12-09 18:09:17 +08:00
committed by GitHub
parent 70a8556946
commit 63fbba3c19
5 changed files with 194 additions and 11 deletions

View File

@@ -42,7 +42,7 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-2)
assert_close(value, temp_zero_value, rtol=1e-3, atol=4e-3)
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])