mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 18:19:58 +00:00
[devops] remove post commit ci (#5566)
* [devops] remove post commit ci * [misc] run pre-commit on all files * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -27,7 +27,7 @@ from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO, utils, CheckpointIndexFile
|
||||
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO, utils
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
|
||||
@@ -93,9 +93,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
||||
|
||||
Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
|
||||
with FSDP.state_dict_type(
|
||||
model.unwrap(),
|
||||
StateDictType.FULL_STATE_DICT,
|
||||
FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
|
||||
model.unwrap(), StateDictType.FULL_STATE_DICT, FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
|
||||
):
|
||||
state_dict = model.unwrap().state_dict()
|
||||
|
||||
@@ -172,7 +170,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
||||
with FSDP.state_dict_type(
|
||||
optimizer.unwrap_model().unwrap(),
|
||||
StateDictType.FULL_STATE_DICT,
|
||||
FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
|
||||
FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
|
||||
):
|
||||
fsdp_optim_state = FSDP.full_optim_state_dict(
|
||||
optimizer.unwrap_model().unwrap(), optim=optimizer, rank0_only=True
|
||||
@@ -241,7 +239,6 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
||||
)
|
||||
optimizer.load_state_dict(fsdp_state)
|
||||
|
||||
|
||||
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
|
||||
"""
|
||||
Save model to checkpoint but only on master process.
|
||||
|
Reference in New Issue
Block a user