mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[shardformer] fix master param sync for hybrid plugin/rewrite unwrapping logic (#4758)
* fix master param sync for hybrid plugin * rewrite unwrap for ddp/fsdp * rewrite unwrap for zero/gemini * rewrite unwrap for hybrid plugin * fix geemini unwrap * fix bugs
This commit is contained in:
@@ -8,8 +8,6 @@ from typing import Optional
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
|
||||
from .checkpoint_io_base import CheckpointIO
|
||||
from .index_file import CheckpointIndexFile
|
||||
from .utils import (
|
||||
@@ -28,7 +26,6 @@ from .utils import (
|
||||
shard_model_checkpoint,
|
||||
shard_optimizer_checkpoint,
|
||||
sharded_optimizer_loading_epilogue,
|
||||
unwrap_optimizer,
|
||||
)
|
||||
|
||||
__all__ = ["GeneralCheckpointIO"]
|
||||
@@ -58,10 +55,6 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
Load sharded optimizer with the given path to index file.
|
||||
"""
|
||||
|
||||
# If optimizer is wrapped, unwrap it.
|
||||
if isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer = unwrap_optimizer(optimizer)
|
||||
|
||||
# Read checkpoint index file.
|
||||
ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)
|
||||
|
||||
@@ -98,10 +91,6 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
- Multiple files (pytorch_optim-000XX.bin) that store state tensors of optimizer in a sharding way
|
||||
"""
|
||||
|
||||
# If optimizer is wrapped, unwrap it.
|
||||
if isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer = unwrap_optimizer(optimizer)
|
||||
|
||||
if os.path.isfile(checkpoint):
|
||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||
return
|
||||
|
Reference in New Issue
Block a user