mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 03:20:52 +00:00
[shardformer] fix master param sync for hybrid plugin/rewrite unwrapping logic (#4758)
* fix master param sync for hybrid plugin * rewrite unwrap for ddp/fsdp * rewrite unwrap for zero/gemini * rewrite unwrap for hybrid plugin * fix geemini unwrap * fix bugs
This commit is contained in:
@@ -20,24 +20,33 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
||||
super().__init__()
|
||||
self.coordinator = DistCoordinator()
|
||||
|
||||
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
|
||||
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
|
||||
"""
|
||||
Load model from checkpoint with automatic unwrapping.
|
||||
Load model from checkpoint.
|
||||
"""
|
||||
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
|
||||
return super().load_unsharded_model(model, checkpoint, strict=strict)
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
|
||||
super().load_unsharded_model(model.unwrap(), checkpoint, strict=strict)
|
||||
|
||||
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
||||
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
||||
"""
|
||||
Save model to checkpoint but only on master process.
|
||||
"""
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
||||
if self.coordinator.is_master():
|
||||
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
|
||||
super().save_unsharded_model(model.unwrap(), checkpoint, gather_dtensor, use_safetensors)
|
||||
|
||||
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
|
||||
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
|
||||
"""
|
||||
Load optimizer from checkpoint.
|
||||
"""
|
||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
|
||||
super().load_unsharded_optimizer(optimizer, checkpoint)
|
||||
|
||||
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
|
||||
"""
|
||||
Save optimizer to checkpoint but only on master process.
|
||||
"""
|
||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
|
||||
if self.coordinator.is_master():
|
||||
super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
|
||||
|
||||
@@ -50,7 +59,7 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
||||
|
||||
def save_sharded_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
model: ModelWrapper,
|
||||
checkpoint_path: str,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
@@ -60,22 +69,52 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
||||
"""
|
||||
Save model to checkpoint but only on master process.
|
||||
"""
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
||||
if self.coordinator.is_master():
|
||||
super().save_sharded_model(model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors)
|
||||
super().save_sharded_model(
|
||||
model.unwrap(), checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors
|
||||
)
|
||||
|
||||
def load_sharded_model(
|
||||
self,
|
||||
model: ModelWrapper,
|
||||
checkpoint_index_file: str,
|
||||
strict: bool = False,
|
||||
use_safetensors: bool = False,
|
||||
load_sub_module: bool = True,
|
||||
):
|
||||
"""
|
||||
Load model from sharded checkpoint.
|
||||
"""
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
|
||||
super().load_sharded_model(model.unwrap(), checkpoint_index_file, strict, use_safetensors, load_sub_module)
|
||||
|
||||
def save_sharded_optimizer(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
optimizer: OptimizerWrapper,
|
||||
checkpoint: str,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
size_per_shard: int = 1024,
|
||||
):
|
||||
"""
|
||||
Save optimizer to checkpoint but only on master process.
|
||||
Save optimizer to sharded checkpoint but only on master process.
|
||||
"""
|
||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
|
||||
if self.coordinator.is_master():
|
||||
super().save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard)
|
||||
super().save_sharded_optimizer(optimizer.unwrap(), checkpoint, gather_dtensor, prefix, size_per_shard)
|
||||
|
||||
def load_sharded_optimizer(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
index_file_path: str,
|
||||
prefix: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Load optimizer from sharded checkpoint.
|
||||
"""
|
||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
|
||||
super().load_sharded_optimizer(optimizer.unwrap(), index_file_path, prefix)
|
||||
|
||||
|
||||
class TorchDDPModel(ModelWrapper):
|
||||
|
Reference in New Issue
Block a user