Next commit [checkpointio] Unsharded Optimizer Checkpoint for Gemini Plugin (#4141)

* [checkpointio] unsharded optimizer checkpoint for Gemini plugin

* [checkpointio] unsharded optimizer checkpoint for Gemini using all_gather
This commit is contained in:
Baizhou Zhang
2023-07-07 16:33:06 +08:00
committed by GitHub
parent fee32a3b78
commit 58913441a1
9 changed files with 684 additions and 83 deletions

View File

@@ -33,44 +33,40 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
super().__init__()
self.coordinator = DistCoordinator()
def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = True):
"""
Load model from checkpoint with automatic unwrapping.
"""
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
return super().load_unsharded_model(model, checkpoint, strict=strict)
def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
"""
Save model to checkpoint but only on master process.
Save sharded model to checkpoint but only on master process.
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
As there is communication when getting state dict, this must be called on all processes.
"""
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
# as there is communication when get state dict, this must be called on all processes
state_dict = model.state_dict(only_rank_0=True)
if self.coordinator.is_master():
save_state_dict(state_dict, checkpoint, use_safetensors)
def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = True):
"""
Load model from checkpoint with automatic unwrapping.
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
"""
super().load_unsharded_model(model, checkpoint, strict=strict)
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
"""
Save optimizer to checkpoint but only on master process.
"""
# TODO(ver217): optimizer state dict is sharded
warnings.warn('GeminiPlugin does not support save full optimizer checkpoint now. Save it on every process.')
checkpoint = f'{checkpoint}.rank{self.coordinator.rank}'
super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
warnings.warn(
'GeminiPlugin can only load optimizer checkpoint saved by itself with the same number of processes.')
checkpoint = f'{checkpoint}.rank{self.coordinator.rank}'
super().load_optimizer(optimizer, checkpoint)
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
"""
Save model to checkpoint but only on master process.
Save unsharded optimizer state dict to checkpoint.
After calling optimizer.state_dict(), the complete optimizer states will be collected on master rank.
As there is communication when getting state dict, this must be called on all processes.
The saving process will only be executed by master rank.
"""
state_dict = optimizer.state_dict()
if self.coordinator.is_master():
super().save_lr_scheduler(lr_scheduler, checkpoint)
save_state_dict(state_dict, checkpoint, use_safetensors=False)
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str):
"""
Loading unsharded optimizer from checkpoint file.
For each process, only loading optimizer states of parameters it controls.
"""
super().load_unsharded_optimizer(optimizer, checkpoint)
def save_sharded_model(self,
model: GeminiDDP,
@@ -82,6 +78,12 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
"""
Save sharded model
"""
if os.path.isfile(checkpoint_path):
logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file")
return
Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True, dtype=torch.float32)
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
total_size = 0
@@ -117,6 +119,23 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
"""
return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False)
def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str,
size_per_shard: int):
"""
Save sharded optimizer state dict to checkpoint folder.
As there is communication when getting state dict, this must be called on all processes.
"""
Path(checkpoint).mkdir(parents=True, exist_ok=True)
super().save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard)
def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint_index_file: Path, prefix: str):
"""
Loading sharded optimizer from checkpoint folder, with index file given.
For each process, only loading optimizer states of parameters it controls.
"""
# TODO(Baizhou): To be implemented.
pass
class GeminiModel(ModelWrapper):
@@ -193,7 +212,7 @@ class GeminiPlugin(DPPluginBase):
which will be used when using hybrid CPU optimizer.
This argument is meaningless when `placement_policy` of `GeminiManager` is not "auto".
Defaults to 0.0.
initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32.
initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**16.
min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1.
growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2.
backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5.
@@ -219,7 +238,7 @@ class GeminiPlugin(DPPluginBase):
min_chunk_size_m: float = 32,
memstats: Optional[MemStats] = None,
gpu_margin_mem_ratio: float = 0.0,
initial_scale: float = 2**32,
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,