mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 03:20:52 +00:00
Next commit [checkpointio] Unsharded Optimizer Checkpoint for Gemini Plugin (#4141)
* [checkpointio] unsharded optimizer checkpoint for Gemini plugin * [checkpointio] unsharded optimizer checkpoint for Gemini using all_gather
This commit is contained in:
@@ -33,44 +33,40 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
super().__init__()
|
||||
self.coordinator = DistCoordinator()
|
||||
|
||||
def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = True):
|
||||
"""
|
||||
Load model from checkpoint with automatic unwrapping.
|
||||
"""
|
||||
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
|
||||
return super().load_unsharded_model(model, checkpoint, strict=strict)
|
||||
|
||||
def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
||||
"""
|
||||
Save model to checkpoint but only on master process.
|
||||
Save sharded model to checkpoint but only on master process.
|
||||
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
|
||||
As there is communication when getting state dict, this must be called on all processes.
|
||||
"""
|
||||
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
|
||||
# as there is communication when get state dict, this must be called on all processes
|
||||
state_dict = model.state_dict(only_rank_0=True)
|
||||
if self.coordinator.is_master():
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors)
|
||||
|
||||
def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = True):
|
||||
"""
|
||||
Load model from checkpoint with automatic unwrapping.
|
||||
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
|
||||
"""
|
||||
super().load_unsharded_model(model, checkpoint, strict=strict)
|
||||
|
||||
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
|
||||
"""
|
||||
Save optimizer to checkpoint but only on master process.
|
||||
"""
|
||||
# TODO(ver217): optimizer state dict is sharded
|
||||
warnings.warn('GeminiPlugin does not support save full optimizer checkpoint now. Save it on every process.')
|
||||
checkpoint = f'{checkpoint}.rank{self.coordinator.rank}'
|
||||
super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
|
||||
|
||||
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
|
||||
warnings.warn(
|
||||
'GeminiPlugin can only load optimizer checkpoint saved by itself with the same number of processes.')
|
||||
checkpoint = f'{checkpoint}.rank{self.coordinator.rank}'
|
||||
super().load_optimizer(optimizer, checkpoint)
|
||||
|
||||
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
|
||||
"""
|
||||
Save model to checkpoint but only on master process.
|
||||
Save unsharded optimizer state dict to checkpoint.
|
||||
After calling optimizer.state_dict(), the complete optimizer states will be collected on master rank.
|
||||
As there is communication when getting state dict, this must be called on all processes.
|
||||
The saving process will only be executed by master rank.
|
||||
"""
|
||||
state_dict = optimizer.state_dict()
|
||||
if self.coordinator.is_master():
|
||||
super().save_lr_scheduler(lr_scheduler, checkpoint)
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||
|
||||
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str):
|
||||
"""
|
||||
Loading unsharded optimizer from checkpoint file.
|
||||
For each process, only loading optimizer states of parameters it controls.
|
||||
"""
|
||||
super().load_unsharded_optimizer(optimizer, checkpoint)
|
||||
|
||||
def save_sharded_model(self,
|
||||
model: GeminiDDP,
|
||||
@@ -82,6 +78,12 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
"""
|
||||
Save sharded model
|
||||
"""
|
||||
if os.path.isfile(checkpoint_path):
|
||||
logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file")
|
||||
return
|
||||
|
||||
Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True, dtype=torch.float32)
|
||||
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
|
||||
total_size = 0
|
||||
@@ -117,6 +119,23 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
"""
|
||||
return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False)
|
||||
|
||||
def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str,
|
||||
size_per_shard: int):
|
||||
"""
|
||||
Save sharded optimizer state dict to checkpoint folder.
|
||||
As there is communication when getting state dict, this must be called on all processes.
|
||||
"""
|
||||
Path(checkpoint).mkdir(parents=True, exist_ok=True)
|
||||
super().save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard)
|
||||
|
||||
def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint_index_file: Path, prefix: str):
|
||||
"""
|
||||
Loading sharded optimizer from checkpoint folder, with index file given.
|
||||
For each process, only loading optimizer states of parameters it controls.
|
||||
"""
|
||||
# TODO(Baizhou): To be implemented.
|
||||
pass
|
||||
|
||||
|
||||
class GeminiModel(ModelWrapper):
|
||||
|
||||
@@ -193,7 +212,7 @@ class GeminiPlugin(DPPluginBase):
|
||||
which will be used when using hybrid CPU optimizer.
|
||||
This argument is meaningless when `placement_policy` of `GeminiManager` is not "auto".
|
||||
Defaults to 0.0.
|
||||
initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32.
|
||||
initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**16.
|
||||
min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1.
|
||||
growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2.
|
||||
backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5.
|
||||
@@ -219,7 +238,7 @@ class GeminiPlugin(DPPluginBase):
|
||||
min_chunk_size_m: float = 32,
|
||||
memstats: Optional[MemStats] = None,
|
||||
gpu_margin_mem_ratio: float = 0.0,
|
||||
initial_scale: float = 2**32,
|
||||
initial_scale: float = 2**16,
|
||||
min_scale: float = 1,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
|
Reference in New Issue
Block a user