diff --git a/tests/test_ddp/test_ddp_state_dict.py b/tests/test_ddp/test_ddp_state_dict.py index 3f3c316a8..359dcafac 100644 --- a/tests/test_ddp/test_ddp_state_dict.py +++ b/tests/test_ddp/test_ddp_state_dict.py @@ -25,7 +25,7 @@ def check_state_dict_equal(state_dict: OrderedDict, other_state_dict: OrderedDic else: temp_t2 = t2 - assert torch.allclose(t1, temp_t2, atol=1e-3, rtol=1e-3) + assert torch.equal(t1, temp_t2) def init_ddp(module: torch.nn.Module) -> ColoDDP: