mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[shardformer] fix master param sync for hybrid plugin/rewrite unwrapping logic (#4758)
* fix master param sync for hybrid plugin * rewrite unwrap for ddp/fsdp * rewrite unwrap for zero/gemini * rewrite unwrap for hybrid plugin * fix geemini unwrap * fix bugs
This commit is contained in:
@@ -22,7 +22,6 @@ from colossalai.checkpoint_io.utils import (
|
||||
save_param_groups,
|
||||
save_state_dict,
|
||||
sharded_optimizer_loading_epilogue,
|
||||
unwrap_optimizer,
|
||||
)
|
||||
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
|
||||
from colossalai.utils import get_current_device
|
||||
@@ -65,10 +64,6 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
||||
kwargs = tree_map(self.convert_fn, kwargs)
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
def unwrap(self):
|
||||
# TODO(ver217): this is a workaround for loading model
|
||||
return self
|
||||
|
||||
|
||||
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
|
||||
@@ -79,7 +74,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
checkpoint (str): Path to save checkpoint
|
||||
gather_dtensor (bool): Whether to gather_dtensor, not used
|
||||
"""
|
||||
|
||||
assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before saving!"
|
||||
# the `state_dict` in LowLevelZeroOptimizer has communication
|
||||
# if only the master rank collect state_dict and save,
|
||||
# the communication on each rank would not match
|
||||
@@ -109,6 +104,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
prefix (str): Perfix of file to save
|
||||
size_per_shard (int): Max file size of each file that store state tensors
|
||||
"""
|
||||
assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before saving!"
|
||||
if os.path.isfile(checkpoint):
|
||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||
return
|
||||
@@ -160,9 +156,8 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
index_file_path (str): Path to the index file
|
||||
prefix (str): Not used.
|
||||
"""
|
||||
# If optimizer is wrapped, unwrap it.
|
||||
if isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer = unwrap_optimizer(optimizer)
|
||||
assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before Loading!"
|
||||
optimizer = optimizer.unwrap()
|
||||
|
||||
# Read checkpoint index file.
|
||||
ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)
|
||||
@@ -194,44 +189,23 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
v_list = v.split(v.numel() // self.coordinator.world_size)
|
||||
state_dict[param_idx][k] = v_list[self.coordinator.rank].detach().clone()
|
||||
load_states_into_optimizer(optimizer, state_dict, id_map)
|
||||
|
||||
sharded_optimizer_loading_epilogue(optimizer)
|
||||
|
||||
def save_unsharded_model(
|
||||
self, model: LowLevelZeroModel, checkpoint: str, gather_dtensor: bool, use_safetensors: bool
|
||||
):
|
||||
assert isinstance(model, LowLevelZeroModel)
|
||||
super().save_unsharded_model(model.module, checkpoint, gather_dtensor, use_safetensors)
|
||||
|
||||
def save_sharded_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
checkpoint_path: str,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
max_shard_size: int = 1024,
|
||||
use_safetensors: bool = False,
|
||||
):
|
||||
assert isinstance(model, LowLevelZeroModel)
|
||||
super().save_sharded_model(
|
||||
model.module, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors
|
||||
)
|
||||
|
||||
def load_unsharded_model(self, model: LowLevelZeroModel, checkpoint: str, strict: bool = True):
|
||||
assert isinstance(model, LowLevelZeroModel)
|
||||
super().load_unsharded_model(model.module, checkpoint, strict)
|
||||
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
|
||||
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
|
||||
super().load_unsharded_model(model, checkpoint, strict)
|
||||
model.update_master_params()
|
||||
|
||||
def load_sharded_model(
|
||||
self,
|
||||
model: LowLevelZeroModel,
|
||||
model: ModelWrapper,
|
||||
checkpoint_index_file: Path,
|
||||
strict: bool = False,
|
||||
use_safetensors: bool = False,
|
||||
load_sub_module: bool = True,
|
||||
):
|
||||
assert isinstance(model, LowLevelZeroModel)
|
||||
super().load_sharded_model(model.module, checkpoint_index_file, strict, use_safetensors, load_sub_module)
|
||||
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
|
||||
super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module)
|
||||
model.update_master_params()
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user