[shardformer] fix master param sync for hybrid plugin/rewrite unwrapping logic (#4758)

* fix master param sync for hybrid plugin

* rewrite unwrap for ddp/fsdp

* rewrite unwrap for zero/gemini

* rewrite unwrap for hybrid plugin

* fix geemini unwrap

* fix bugs
This commit is contained in:
Baizhou Zhang
2023-09-20 18:29:37 +08:00
committed by GitHub
parent 7b9b86441f
commit c0a033700c
14 changed files with 141 additions and 171 deletions

View File

@@ -20,24 +20,33 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
super().__init__()
self.coordinator = DistCoordinator()
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
"""
Load model from checkpoint with automatic unwrapping.
Load model from checkpoint.
"""
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
return super().load_unsharded_model(model, checkpoint, strict=strict)
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
super().load_unsharded_model(model.unwrap(), checkpoint, strict=strict)
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
"""
Save model to checkpoint but only on master process.
"""
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
if self.coordinator.is_master():
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
super().save_unsharded_model(model.unwrap(), checkpoint, gather_dtensor, use_safetensors)
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
"""
Load optimizer from checkpoint.
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
super().load_unsharded_optimizer(optimizer, checkpoint)
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
"""
Save optimizer to checkpoint but only on master process.
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
if self.coordinator.is_master():
super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
@@ -50,7 +59,7 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
def save_sharded_model(
self,
model: nn.Module,
model: ModelWrapper,
checkpoint_path: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
@@ -60,22 +69,52 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
"""
Save model to checkpoint but only on master process.
"""
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
if self.coordinator.is_master():
super().save_sharded_model(model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors)
super().save_sharded_model(
model.unwrap(), checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors
)
def load_sharded_model(
self,
model: ModelWrapper,
checkpoint_index_file: str,
strict: bool = False,
use_safetensors: bool = False,
load_sub_module: bool = True,
):
"""
Load model from sharded checkpoint.
"""
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
super().load_sharded_model(model.unwrap(), checkpoint_index_file, strict, use_safetensors, load_sub_module)
def save_sharded_optimizer(
self,
optimizer: Optimizer,
optimizer: OptimizerWrapper,
checkpoint: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
):
"""
Save optimizer to checkpoint but only on master process.
Save optimizer to sharded checkpoint but only on master process.
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
if self.coordinator.is_master():
super().save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard)
super().save_sharded_optimizer(optimizer.unwrap(), checkpoint, gather_dtensor, prefix, size_per_shard)
def load_sharded_optimizer(
self,
optimizer: Optimizer,
index_file_path: str,
prefix: Optional[str] = None,
):
"""
Load optimizer from sharded checkpoint.
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
super().load_sharded_optimizer(optimizer.unwrap(), index_file_path, prefix)
class TorchDDPModel(ModelWrapper):