diff --git a/colossalai/testing/comparison.py b/colossalai/testing/comparison.py index e98f3c18b..de4f460c0 100644 --- a/colossalai/testing/comparison.py +++ b/colossalai/testing/comparison.py @@ -7,15 +7,19 @@ from torch.distributed import ProcessGroup def assert_equal(a: Tensor, b: Tensor): assert torch.all(a == b), f'expected a and b to be equal but they are not, {a} vs {b}' + def assert_not_equal(a: Tensor, b: Tensor): assert not torch.all(a == b), f'expected a and b to be not equal but they are, {a} vs {b}' + def assert_close(a: Tensor, b: Tensor, rtol: float = 1e-5, atol: float = 1e-8): assert torch.allclose(a, b, rtol=rtol, atol=atol), f'expected a and b to be close but they are not, {a} vs {b}' + def assert_close_loose(a: Tensor, b: Tensor, rtol: float = 1e-3, atol: float = 1e-3): assert_close(a, b, rtol, atol) + def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None): # all gather tensors from different ranks world_size = dist.get_world_size(process_group) @@ -25,5 +29,5 @@ def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None): # check if they are equal one by one for i in range(world_size - 1): a = tensor_list[i] - b = tensor_list[i+1] + b = tensor_list[i + 1] assert torch.all(a == b), f'expected tensors on rank {i} and {i+1} to be equal but they are not, {a} vs {b}'