[amp] add gradient clipping for unit tests (#2283)

* [amp] add gradient clipping in unit tests

* fix bugs
This commit is contained in:
HELSON
2023-01-04 11:59:56 +08:00
committed by GitHub
parent e00cedd181
commit 5d3a2be3af
5 changed files with 64 additions and 44 deletions

View File

@@ -2,6 +2,7 @@ import torch
import torch.distributed as dist
from torch import Tensor
from torch.distributed import ProcessGroup
from torch.testing import assert_close
def assert_equal(a: Tensor, b: Tensor):
@@ -12,12 +13,8 @@ def assert_not_equal(a: Tensor, b: Tensor):
assert not torch.all(a == b), f'expected a and b to be not equal but they are, {a} vs {b}'
def assert_close(a: Tensor, b: Tensor, rtol: float = 1e-5, atol: float = 1e-8):
assert torch.allclose(a, b, rtol=rtol, atol=atol), f'expected a and b to be close but they are not, {a} vs {b}'
def assert_close_loose(a: Tensor, b: Tensor, rtol: float = 1e-3, atol: float = 1e-3):
assert_close(a, b, rtol, atol)
assert_close(a, b, rtol=rtol, atol=atol)
def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None):
@@ -30,4 +27,4 @@ def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None):
for i in range(world_size - 1):
a = tensor_list[i]
b = tensor_list[i + 1]
assert torch.all(a == b), f'expected tensors on rank {i} and {i+1} to be equal but they are not, {a} vs {b}'
assert torch.all(a == b), f'expected tensors on rank {i} and {i + 1} to be equal but they are not, {a} vs {b}'