mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 21:09:18 +00:00
[optim] hotfix adam load (#6146)
* [optim] hotfix adam load * [checkpointio] fix optimizer async io * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [checkpointio] update test * [checkpointio] update test --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -81,6 +81,14 @@ class CPUAdam(NVMeOptimizer):
|
||||
# if you find yourself stuck here, make sure that you install colossalai with BUILD_EXT=1 specification
|
||||
self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
super().load_state_dict(state_dict)
|
||||
for group in self.param_groups:
|
||||
for p in group["params"]:
|
||||
state = self.state[p]
|
||||
if "step" in state and isinstance(state["step"], torch.Tensor):
|
||||
state["step"] = int(state["step"].item())
|
||||
|
||||
def torch_adam_update(
|
||||
self,
|
||||
data,
|
||||
|
Reference in New Issue
Block a user