[optim] hotfix adam load (#6146)

* [optim] hotfix adam load

* [checkpointio] fix optimizer async io

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [checkpointio] update test

* [checkpointio] update test

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Hongxin Liu
2024-11-20 16:36:37 +08:00
committed by GitHub
parent 5caad13055
commit cf519dac6a
5 changed files with 139 additions and 76 deletions

View File

@@ -1,4 +1,4 @@
from typing import Any, List, OrderedDict, Tuple
from typing import Any, List, OrderedDict
import torch
import torch.distributed as dist
@@ -78,9 +78,7 @@ def check_state_dict_equal(
v1 = v1.to(v2.dtype)
assert_close_loose(v1, v2)
else:
if isinstance(v1, Tuple) and not isinstance(v2, Tuple):
v2 = tuple(v2)
assert v1 == v2, f"{v1} not equals to {v2}. {type(v1)}, {type(v2)}"
assert v1 == v2, f"{v1} not equals to {v2}"
def check_state_dict_equal_pytree(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True):