[shardformer] fix master param sync for hybrid plugin/rewrite unwrapping logic (#4758)

* fix master param sync for hybrid plugin

* rewrite unwrap for ddp/fsdp

* rewrite unwrap for zero/gemini

* rewrite unwrap for hybrid plugin

* fix geemini unwrap

* fix bugs
This commit is contained in:
Baizhou Zhang
2023-09-20 18:29:37 +08:00
committed by GitHub
parent 7b9b86441f
commit c0a033700c
14 changed files with 141 additions and 171 deletions

View File

@@ -8,8 +8,6 @@ from typing import Optional
import torch.nn as nn
from torch.optim import Optimizer
from colossalai.interface import OptimizerWrapper
from .checkpoint_io_base import CheckpointIO
from .index_file import CheckpointIndexFile
from .utils import (
@@ -28,7 +26,6 @@ from .utils import (
shard_model_checkpoint,
shard_optimizer_checkpoint,
sharded_optimizer_loading_epilogue,
unwrap_optimizer,
)
__all__ = ["GeneralCheckpointIO"]
@@ -58,10 +55,6 @@ class GeneralCheckpointIO(CheckpointIO):
Load sharded optimizer with the given path to index file.
"""
# If optimizer is wrapped, unwrap it.
if isinstance(optimizer, OptimizerWrapper):
optimizer = unwrap_optimizer(optimizer)
# Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)
@@ -98,10 +91,6 @@ class GeneralCheckpointIO(CheckpointIO):
- Multiple files (pytorch_optim-000XX.bin) that store state tensors of optimizer in a sharding way
"""
# If optimizer is wrapped, unwrap it.
if isinstance(optimizer, OptimizerWrapper):
optimizer = unwrap_optimizer(optimizer)
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return