diff --git a/chunk_codegen_run.py b/chunk_codegen_run.py index 88c734903..99700e1af 100644 --- a/chunk_codegen_run.py +++ b/chunk_codegen_run.py @@ -45,8 +45,9 @@ def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair): with torch.no_grad(): non_fx_out = model(node, pair) fx_out = gm(node, pair) - assert torch.allclose(non_fx_out[0], fx_out[0], atol=1e-6), "fx_out doesn't comply with original output" - assert torch.allclose(non_fx_out[1], fx_out[1], atol=1e-6), "fx_out doesn't comply with original output" + + assert torch.allclose(non_fx_out[0], fx_out[0], atol=1e-4), "fx_out doesn't comply with original output" + assert torch.allclose(non_fx_out[1], fx_out[1], atol=1e-4), "fx_out doesn't comply with original output" # test barckward # loss0 = non_fx_out[0].sum() + non_fx_out[1].sum()