mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[zero] solve hang
This commit is contained in:
@@ -137,7 +137,7 @@ def sync_local_from_ep(local_model, ep_model, assert_grad_flag: bool = False) ->
|
||||
local_param.data.copy_(all_param.data)
|
||||
|
||||
|
||||
def loose_close(a, b, dtype: torch.dtype = torch.float32):
|
||||
def loose_close(a, b, dtype: torch.dtype = torch.float32, name=""):
|
||||
rtol = None
|
||||
atol = None
|
||||
if dtype is torch.float16:
|
||||
@@ -150,4 +150,4 @@ def loose_close(a, b, dtype: torch.dtype = torch.float32):
|
||||
a = a.detach().to(dtype)
|
||||
b = b.detach().to(dtype).to(a.device)
|
||||
|
||||
assert_close(a, b, rtol=rtol, atol=atol)
|
||||
assert torch.allclose(a, b, rtol=rtol, atol=atol), f"{name} not close {a.mean()} {b.mean()}"
|
||||
|
Reference in New Issue
Block a user