[optim] hotfix adam load (#6146)

* [optim] hotfix adam load

* [checkpointio] fix optimizer async io

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [checkpointio] update test

* [checkpointio] update test

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Hongxin Liu
2024-11-20 16:36:37 +08:00
committed by GitHub
parent 5caad13055
commit cf519dac6a
5 changed files with 139 additions and 76 deletions

View File

@@ -142,7 +142,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
from colossalai.utils.safetensors import save_nested
f_writer = AsyncFileWriter(fp=open(checkpoint, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread")
save_nested(f_writer, state_dict["state"], {"param_groups": state_dict["param_groups"]})
save_nested(f_writer, state_dict)
self.async_writers.append(f_writer)
else:
save_state_dict(state_dict, checkpoint, use_safetensors=False)