[optim] hotfix adam load (#6146)

* [optim] hotfix adam load

* [checkpointio] fix optimizer async io

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [checkpointio] update test

* [checkpointio] update test

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Hongxin Liu
2024-11-20 16:36:37 +08:00
committed by GitHub
parent 5caad13055
commit cf519dac6a
5 changed files with 139 additions and 76 deletions

View File

@@ -81,6 +81,14 @@ class CPUAdam(NVMeOptimizer):
# if you find yourself stuck here, make sure that you install colossalai with BUILD_EXT=1 specification
self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode)
def load_state_dict(self, state_dict):
super().load_state_dict(state_dict)
for group in self.param_groups:
for p in group["params"]:
state = self.state[p]
if "step" in state and isinstance(state["step"], torch.Tensor):
state["step"] = int(state["step"].item())
def torch_adam_update(
self,
data,