mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
Next commit [checkpointio] Unsharded Optimizer Checkpoint for Gemini Plugin (#4141)
* [checkpointio] unsharded optimizer checkpoint for Gemini plugin * [checkpointio] unsharded optimizer checkpoint for Gemini using all_gather
This commit is contained in:
@@ -10,6 +10,8 @@ import torch
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
from colossalai.tensor.d_tensor import is_distributed_tensor
|
||||
|
||||
SAFE_WEIGHTS_NAME = "model.safetensors"
|
||||
@@ -88,6 +90,19 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool:
|
||||
# ======================================
|
||||
# Helper functions for saving shard file
|
||||
# ======================================
|
||||
def unwrap_optimizer(optimizer: OptimizerWrapper):
|
||||
'''
|
||||
Unwrap a wrapped optimizer.
|
||||
This method should be used before saving/loading it to/from sharded checkpoints.
|
||||
'''
|
||||
|
||||
# TODO(Baizhou): ColossalaiOptimizer will be replaced with OptimizerWrapper in the future
|
||||
unwrapped_optim = optimizer.optim
|
||||
if isinstance(unwrapped_optim, ColossalaiOptimizer):
|
||||
unwrapped_optim = unwrapped_optim.optim
|
||||
return unwrapped_optim
|
||||
|
||||
|
||||
def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
|
||||
"""
|
||||
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
|
||||
@@ -103,7 +118,7 @@ def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024)
|
||||
weight_size = calculate_tensor_size(weight)
|
||||
|
||||
# If this weight is going to tip up over the maximal size, we split.
|
||||
if current_block_size + weight_size > max_shard_size:
|
||||
if current_block_size + weight_size > max_shard_size and current_block_size > 0:
|
||||
ret_block = current_block
|
||||
ret_block_size = current_block_size
|
||||
current_block = {}
|
||||
@@ -140,9 +155,10 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->
|
||||
isDTensor = False
|
||||
for state_tensor in state.values():
|
||||
|
||||
# When state_tensor is None (e.g., a SGD optimizer with momentum set to 0),
|
||||
# When state_tensor is not of Tensor class,
|
||||
# e.g., a SGD optimizer with momentum set to 0 can have None as state
|
||||
# The calculation of tensor size should be skipped to avoid error.
|
||||
if state_tensor is None:
|
||||
if not isinstance(state_tensor, torch.Tensor):
|
||||
continue
|
||||
|
||||
# If the states are stored as DTensors, mark isDTensor as true.
|
||||
@@ -152,7 +168,7 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->
|
||||
|
||||
if not isDTensor:
|
||||
|
||||
if current_block_size + state_size > max_shard_size:
|
||||
if current_block_size + state_size > max_shard_size and current_block_size > 0:
|
||||
ret_block = current_block
|
||||
ret_block_size = current_block_size
|
||||
current_block = {}
|
||||
|
Reference in New Issue
Block a user