[shardformer] fix master param sync for hybrid plugin/rewrite unwrapping logic (#4758)

* fix master param sync for hybrid plugin

* rewrite unwrap for ddp/fsdp

* rewrite unwrap for zero/gemini

* rewrite unwrap for hybrid plugin

* fix geemini unwrap

* fix bugs
This commit is contained in:
Baizhou Zhang
2023-09-20 18:29:37 +08:00
committed by GitHub
parent 7b9b86441f
commit c0a033700c
14 changed files with 141 additions and 171 deletions

View File

@@ -22,7 +22,6 @@ from colossalai.checkpoint_io.utils import (
save_param_groups,
save_state_dict,
sharded_optimizer_loading_epilogue,
unwrap_optimizer,
)
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.utils import get_current_device
@@ -65,10 +64,6 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
kwargs = tree_map(self.convert_fn, kwargs)
return super().forward(*args, **kwargs)
def unwrap(self):
# TODO(ver217): this is a workaround for loading model
return self
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
@@ -79,7 +74,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
checkpoint (str): Path to save checkpoint
gather_dtensor (bool): Whether to gather_dtensor, not used
"""
assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before saving!"
# the `state_dict` in LowLevelZeroOptimizer has communication
# if only the master rank collect state_dict and save,
# the communication on each rank would not match
@@ -109,6 +104,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
prefix (str): Perfix of file to save
size_per_shard (int): Max file size of each file that store state tensors
"""
assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before saving!"
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
@@ -160,9 +156,8 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
index_file_path (str): Path to the index file
prefix (str): Not used.
"""
# If optimizer is wrapped, unwrap it.
if isinstance(optimizer, OptimizerWrapper):
optimizer = unwrap_optimizer(optimizer)
assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before Loading!"
optimizer = optimizer.unwrap()
# Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)
@@ -194,44 +189,23 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
v_list = v.split(v.numel() // self.coordinator.world_size)
state_dict[param_idx][k] = v_list[self.coordinator.rank].detach().clone()
load_states_into_optimizer(optimizer, state_dict, id_map)
sharded_optimizer_loading_epilogue(optimizer)
def save_unsharded_model(
self, model: LowLevelZeroModel, checkpoint: str, gather_dtensor: bool, use_safetensors: bool
):
assert isinstance(model, LowLevelZeroModel)
super().save_unsharded_model(model.module, checkpoint, gather_dtensor, use_safetensors)
def save_sharded_model(
self,
model: nn.Module,
checkpoint_path: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False,
):
assert isinstance(model, LowLevelZeroModel)
super().save_sharded_model(
model.module, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors
)
def load_unsharded_model(self, model: LowLevelZeroModel, checkpoint: str, strict: bool = True):
assert isinstance(model, LowLevelZeroModel)
super().load_unsharded_model(model.module, checkpoint, strict)
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
super().load_unsharded_model(model, checkpoint, strict)
model.update_master_params()
def load_sharded_model(
self,
model: LowLevelZeroModel,
model: ModelWrapper,
checkpoint_index_file: Path,
strict: bool = False,
use_safetensors: bool = False,
load_sub_module: bool = True,
):
assert isinstance(model, LowLevelZeroModel)
super().load_sharded_model(model.module, checkpoint_index_file, strict, use_safetensors, load_sub_module)
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module)
model.update_master_params()