mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[checkpointio] Sharded Optimizer Checkpoint for Gemini Plugin (#4302)
* sharded optimizer checkpoint for gemini plugin * modify test to reduce testing time * update doc * fix bug when keep_gatherd is true under GeminiPlugin
This commit is contained in:
@@ -5,6 +5,7 @@ from functools import reduce
|
||||
from pathlib import Path
|
||||
from typing import Iterator, Optional, OrderedDict, Tuple
|
||||
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
|
||||
@@ -16,7 +17,6 @@ from .utils import (
|
||||
get_model_base_filenames,
|
||||
get_optimizer_base_filenames,
|
||||
get_shard_filename,
|
||||
has_index_file,
|
||||
is_safetensors_available,
|
||||
load_param_groups_into_optimizer,
|
||||
load_shard_state_dict,
|
||||
@@ -25,6 +25,7 @@ from .utils import (
|
||||
load_states_into_optimizer,
|
||||
save_param_groups,
|
||||
save_state_dict,
|
||||
save_state_dict_shards,
|
||||
shard_model_checkpoint,
|
||||
shard_optimizer_checkpoint,
|
||||
sharded_optimizer_loading_epilogue,
|
||||
@@ -122,15 +123,13 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
save_param_groups(state_dict, group_file_path)
|
||||
|
||||
# Save shards of optimizer states.
|
||||
total_size = 0
|
||||
for idx, shard_pair in enumerate(sharded_state):
|
||||
shard, current_size = shard_pair
|
||||
shard_file = get_shard_filename(states_name, idx)
|
||||
total_size = total_size + current_size
|
||||
for key in shard.keys():
|
||||
index_file.append_weight_map(key, shard_file)
|
||||
checkpoint_file_path = os.path.join(checkpoint, shard_file)
|
||||
save_state_dict(shard, checkpoint_file_path, use_safetensors=False)
|
||||
# In general cases, is_master is set to True to get the right behavior.
|
||||
total_size = save_state_dict_shards(sharded_state_dict=sharded_state,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=states_name,
|
||||
is_master=True,
|
||||
use_safetensors=False)
|
||||
|
||||
# Wrap up index file.
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
@@ -172,18 +171,17 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
# shard checkpoint
|
||||
state_dict = model.state_dict()
|
||||
state_dict_shard = shard_model_checkpoint(state_dict, max_shard_size=max_shard_size)
|
||||
|
||||
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
|
||||
total_size = 0
|
||||
index_file = CheckpointIndexFile(checkpoint_path)
|
||||
for idx, shard_pair in enumerate(state_dict_shard):
|
||||
shard = shard_pair[0]
|
||||
shard_file = get_shard_filename(weights_name, idx)
|
||||
total_size = total_size + shard_pair[1]
|
||||
for key in shard.keys():
|
||||
index_file.append_weight_map(key, shard_file)
|
||||
checkpoint_file_path = os.path.join(checkpoint_path, shard_file)
|
||||
save_state_dict(shard, checkpoint_file_path, use_safetensors)
|
||||
|
||||
# Save shards of optimizer states.
|
||||
# In general cases, is_master is set to True to get the right behavior.
|
||||
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint_path,
|
||||
index_file=index_file,
|
||||
base_filename=weights_name,
|
||||
is_master=True,
|
||||
use_safetensors=use_safetensors)
|
||||
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
|
Reference in New Issue
Block a user