mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[shardformer] fix master param sync for hybrid plugin/rewrite unwrapping logic (#4758)
* fix master param sync for hybrid plugin * rewrite unwrap for ddp/fsdp * rewrite unwrap for zero/gemini * rewrite unwrap for hybrid plugin * fix geemini unwrap * fix bugs
This commit is contained in:
@@ -44,6 +44,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
|
||||
As there is communication when getting state dict, model.state_dict() must be called on all processes.
|
||||
"""
|
||||
assert isinstance(model, GeminiDDP), "Please boost the model before saving!"
|
||||
state_dict = model.state_dict(only_rank_0=True)
|
||||
if self.coordinator.is_master():
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors)
|
||||
@@ -53,24 +54,27 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
Load model from checkpoint with automatic unwrapping.
|
||||
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
|
||||
"""
|
||||
assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
|
||||
super().load_unsharded_model(model, checkpoint, strict=strict)
|
||||
|
||||
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
|
||||
def save_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str, gather_dtensor: bool):
|
||||
"""
|
||||
Save unsharded optimizer state dict to checkpoint.
|
||||
After calling optimizer.state_dict(), the complete optimizer states will be collected on master rank.
|
||||
As there is communication when getting state dict, optimizer.state_dict() must be called on all processes.
|
||||
The saving process will only be executed by master rank.
|
||||
"""
|
||||
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!"
|
||||
state_dict = optimizer.state_dict()
|
||||
if self.coordinator.is_master():
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||
|
||||
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str):
|
||||
def load_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str):
|
||||
"""
|
||||
Loading unsharded optimizer from checkpoint file.
|
||||
For each process, only loading optimizer states of parameters it controls.
|
||||
"""
|
||||
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!"
|
||||
super().load_unsharded_optimizer(optimizer, checkpoint)
|
||||
|
||||
def save_sharded_model(
|
||||
@@ -86,6 +90,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
Save sharded model.
|
||||
As there is communication when getting state dict, model.state_dict() must be called on all processes.
|
||||
"""
|
||||
assert isinstance(model, GeminiDDP), "Please boost the model before saving!"
|
||||
if os.path.isfile(checkpoint_path):
|
||||
logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file")
|
||||
return
|
||||
@@ -111,7 +116,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
if self.coordinator.is_master():
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
save_config_file(model.module, checkpoint_path)
|
||||
save_config_file(model.unwrap(), checkpoint_path)
|
||||
logging.info(
|
||||
f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
@@ -124,17 +129,17 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
"""
|
||||
Load shard model, load model from multiple files.
|
||||
"""
|
||||
assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
|
||||
return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False)
|
||||
|
||||
def save_sharded_optimizer(
|
||||
self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int
|
||||
self, optimizer: GeminiOptimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int
|
||||
):
|
||||
"""
|
||||
Save sharded optimizer state dict to checkpoint folder.
|
||||
As there is communication when getting state dict, this must be called on all processes.
|
||||
"""
|
||||
|
||||
assert isinstance(optimizer, GeminiOptimizer)
|
||||
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!"
|
||||
|
||||
if os.path.isfile(checkpoint):
|
||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||
@@ -176,12 +181,12 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
f"index located at {save_index_file}."
|
||||
)
|
||||
|
||||
def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint_index_file: Path, prefix: str):
|
||||
def load_sharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint_index_file: Path, prefix: str):
|
||||
"""
|
||||
Loading sharded optimizer from checkpoint folder, with index file given.
|
||||
For each process, only loading optimizer states of parameters it controls.
|
||||
"""
|
||||
|
||||
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!"
|
||||
if not os.path.isfile(checkpoint_index_file):
|
||||
logging.error(f"Provided path ({checkpoint_index_file}) should be a file")
|
||||
|
||||
@@ -383,7 +388,7 @@ class GeminiPlugin(DPPluginBase):
|
||||
|
||||
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer = GeminiOptimizer(
|
||||
optimizer, model.unwrap(), **self.zero_optim_config, **self.optim_kwargs, verbose=self.verbose
|
||||
optimizer, model, **self.zero_optim_config, **self.optim_kwargs, verbose=self.verbose
|
||||
)
|
||||
|
||||
return model, optimizer, criterion, dataloader, lr_scheduler
|
||||
|
Reference in New Issue
Block a user