[chore] solve moe ckpt test failure and some other arg pass failure

This commit is contained in:
hxwang
2024-07-22 03:40:34 +00:00
committed by Hongxin Liu
parent 52d346f2a5
commit 70c9924d0d
12 changed files with 101 additions and 79 deletions

View File

@@ -1,7 +1,11 @@
import torch
def loose_close(a, b, dtype: torch.dtype = torch.float32, name=""):
def assert_loose_close(a, b, dtype: torch.dtype = torch.float32, name=""):
assert loose_close(a, b, dtype), f"{name} not close {a.mean()} {b.mean()}"
def loose_close(a, b, dtype: torch.dtype = torch.float32):
rtol = None
atol = None
if dtype is torch.float16:
@@ -12,10 +16,16 @@ def loose_close(a, b, dtype: torch.dtype = torch.float32, name=""):
atol = 4e-3
else:
assert dtype is torch.float32
rtol = 1e-5
atol = 1e-5
rtol = 1e-05
atol = 1e-08
a = a.detach().to(dtype)
b = b.detach().to(dtype).to(a.device)
assert torch.allclose(a, b, rtol=rtol, atol=atol), f"{name} not close {a.mean()} {b.mean()}"
return torch.allclose(a, b, rtol=rtol, atol=atol)
def check_model_equal(model1, model2):
assert set(model1.state_dict().keys()) == set(model2.state_dict().keys())
for i, ((name, p1), p2) in enumerate(zip(model1.named_parameters(), model2.parameters())):
assert_loose_close(p1, p2, p1.dtype)