mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[chore] solve moe ckpt test failure and some other arg pass failure
This commit is contained in:
@@ -1,7 +1,11 @@
|
||||
import torch
|
||||
|
||||
|
||||
def loose_close(a, b, dtype: torch.dtype = torch.float32, name=""):
|
||||
def assert_loose_close(a, b, dtype: torch.dtype = torch.float32, name=""):
|
||||
assert loose_close(a, b, dtype), f"{name} not close {a.mean()} {b.mean()}"
|
||||
|
||||
|
||||
def loose_close(a, b, dtype: torch.dtype = torch.float32):
|
||||
rtol = None
|
||||
atol = None
|
||||
if dtype is torch.float16:
|
||||
@@ -12,10 +16,16 @@ def loose_close(a, b, dtype: torch.dtype = torch.float32, name=""):
|
||||
atol = 4e-3
|
||||
else:
|
||||
assert dtype is torch.float32
|
||||
rtol = 1e-5
|
||||
atol = 1e-5
|
||||
rtol = 1e-05
|
||||
atol = 1e-08
|
||||
|
||||
a = a.detach().to(dtype)
|
||||
b = b.detach().to(dtype).to(a.device)
|
||||
|
||||
assert torch.allclose(a, b, rtol=rtol, atol=atol), f"{name} not close {a.mean()} {b.mean()}"
|
||||
return torch.allclose(a, b, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
def check_model_equal(model1, model2):
|
||||
assert set(model1.state_dict().keys()) == set(model2.state_dict().keys())
|
||||
for i, ((name, p1), p2) in enumerate(zip(model1.named_parameters(), model2.parameters())):
|
||||
assert_loose_close(p1, p2, p1.dtype)
|
||||
|
Reference in New Issue
Block a user