[zero] solve hang

This commit is contained in:
botbw
2024-07-09 08:14:00 +00:00
committed by Hongxin Liu
parent b5bfeb2efd
commit 13b48ac0aa
8 changed files with 218 additions and 335 deletions

View File

@@ -137,7 +137,7 @@ def sync_local_from_ep(local_model, ep_model, assert_grad_flag: bool = False) ->
local_param.data.copy_(all_param.data)
def loose_close(a, b, dtype: torch.dtype = torch.float32):
def loose_close(a, b, dtype: torch.dtype = torch.float32, name=""):
rtol = None
atol = None
if dtype is torch.float16:
@@ -150,4 +150,4 @@ def loose_close(a, b, dtype: torch.dtype = torch.float32):
a = a.detach().to(dtype)
b = b.detach().to(dtype).to(a.device)
assert_close(a, b, rtol=rtol, atol=atol)
assert torch.allclose(a, b, rtol=rtol, atol=atol), f"{name} not close {a.mean()} {b.mean()}"