mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-21 01:24:04 +00:00
[checkpointio]support asyncio for 3d (#6152)
* fix * fix * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update utils.py * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -17,6 +17,8 @@ from torch.utils.data.distributed import DistributedSampler
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO
|
||||
from colossalai.checkpoint_io.utils import (
|
||||
async_save_state_dict_shards,
|
||||
create_pinned_state_dict,
|
||||
get_model_base_filenames,
|
||||
get_optimizer_base_filenames,
|
||||
load_shard_state_dict,
|
||||
@@ -28,6 +30,7 @@ from colossalai.cluster import DistCoordinator, ProcessGroupMesh
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.utils.safetensors import load_flat
|
||||
from colossalai.zero import GeminiDDP, GeminiOptimizer
|
||||
from colossalai.zero.gemini.memory_tracer import MemStats
|
||||
|
||||
@@ -82,7 +85,15 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
state_dict = model.state_dict(only_rank_0=True)
|
||||
if self.coordinator.is_master():
|
||||
if use_async:
|
||||
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async)
|
||||
from colossalai.utils.safetensors import save
|
||||
|
||||
if id(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
|
||||
for k, v in state_dict.items():
|
||||
self.pinned_state_dicts[id(model)][k].copy_(v)
|
||||
state_dict[k] = self.pinned_state_dicts[id(model)][k]
|
||||
writer = save(checkpoint, state_dict)
|
||||
self.async_writers.append(writer)
|
||||
else:
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors)
|
||||
|
||||
@@ -106,7 +117,19 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!"
|
||||
state_dict = optimizer.state_dict()
|
||||
if self.coordinator.is_master():
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||
if use_async:
|
||||
from colossalai.utils.safetensors import _flatten_optim_state_dict, save
|
||||
|
||||
flatten_state_dict, metadata = _flatten_optim_state_dict(state_dict)
|
||||
if id(optimizer) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(optimizer)] = create_pinned_state_dict(flatten_state_dict)
|
||||
for k, v in flatten_state_dict.items():
|
||||
self.pinned_state_dicts[id(optimizer)][k].copy_(v)
|
||||
flatten_state_dict[k] = self.pinned_state_dicts[id(optimizer)][k]
|
||||
writer = save(checkpoint, flatten_state_dict, metadata)
|
||||
self.async_writers.append(writer)
|
||||
else:
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||
|
||||
def load_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str):
|
||||
"""
|
||||
@@ -137,17 +160,29 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
|
||||
Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True)
|
||||
if use_async and self.coordinator.is_master():
|
||||
if id(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(model)] = {}
|
||||
pinned_state_dicts = self.pinned_state_dicts[id(model)]
|
||||
else:
|
||||
pinned_state_dicts = None
|
||||
state_dict_shard = model.state_dict_shard(
|
||||
max_shard_size=max_shard_size, only_rank_0=True, pinned_state_dicts=pinned_state_dicts
|
||||
)
|
||||
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
|
||||
index_file = CheckpointIndexFile(checkpoint_path)
|
||||
|
||||
# Save shards of optimizer states.
|
||||
is_master = self.coordinator.is_master()
|
||||
if use_async:
|
||||
super().save_sharded_model(
|
||||
model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors, use_async
|
||||
total_size, writers = async_save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint_path,
|
||||
index_file=index_file,
|
||||
base_filename=weights_name,
|
||||
is_master=is_master,
|
||||
)
|
||||
|
||||
self.async_writers.extend(writers)
|
||||
else:
|
||||
total_size = save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
@@ -158,17 +193,17 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
use_safetensors=use_safetensors,
|
||||
)
|
||||
|
||||
# only save the index file on the master rank
|
||||
if self.coordinator.is_master():
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
save_config_file(model.unwrap(), checkpoint_path)
|
||||
self.logger.info(
|
||||
f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}.",
|
||||
ranks=[0],
|
||||
)
|
||||
# only save the index file on the master rank
|
||||
if self.coordinator.is_master():
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
save_config_file(model.unwrap(), checkpoint_path)
|
||||
self.logger.info(
|
||||
f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}.",
|
||||
ranks=[0],
|
||||
)
|
||||
|
||||
def load_sharded_model(
|
||||
self, model: GeminiDDP, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False
|
||||
@@ -201,7 +236,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
Path(checkpoint).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Preparing file paths and index file.
|
||||
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
|
||||
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix, use_safetensors=use_async)
|
||||
index_file = CheckpointIndexFile(checkpoint)
|
||||
index_file.append_meta_data("param_groups", param_group_file)
|
||||
|
||||
@@ -212,17 +247,36 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
torch.save(param_groups, group_file_path)
|
||||
|
||||
# States are broken into shards within max_shard_size.
|
||||
state_dict_shard = optimizer.state_shard(prefix=prefix, max_shard_size=size_per_shard, only_rank_0=True)
|
||||
if use_async and self.coordinator.is_master():
|
||||
if id(optimizer) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(optimizer)] = {}
|
||||
pinned_state_dicts = self.pinned_state_dicts[id(optimizer)]
|
||||
else:
|
||||
pinned_state_dicts = None
|
||||
state_dict_shard = optimizer.state_shard(
|
||||
prefix=prefix, max_shard_size=size_per_shard, only_rank_0=True, pinned_state_dicts=pinned_state_dicts
|
||||
)
|
||||
|
||||
# Save shards of optimizer states.
|
||||
total_size = save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=states_name,
|
||||
is_master=self.coordinator.is_master(),
|
||||
use_safetensors=False,
|
||||
)
|
||||
if use_async:
|
||||
total_size, writers = async_save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=states_name,
|
||||
is_master=self.coordinator.is_master(),
|
||||
state_preprocess=True,
|
||||
)
|
||||
self.async_writers.extend(writers)
|
||||
else:
|
||||
total_size = save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=states_name,
|
||||
is_master=self.coordinator.is_master(),
|
||||
use_safetensors=False,
|
||||
)
|
||||
|
||||
# Wrap up index file. Only save it on master rank.
|
||||
if self.coordinator.is_master():
|
||||
@@ -264,7 +318,10 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
# Load optimizer states from shard files under checkpoint path.
|
||||
# For each file, only load the states managed by current process.
|
||||
for shard_file in checkpoint_files:
|
||||
state_dict_shard = load_shard_state_dict(Path(shard_file), use_safetensors=False)
|
||||
if shard_file.endswith(".safetensors"):
|
||||
state_dict_shard = load_flat(shard_file)
|
||||
else:
|
||||
state_dict_shard = load_shard_state_dict(Path(shard_file), use_safetensors=False)
|
||||
optimizer.load_param_states(state_dict_shard)
|
||||
del state_dict_shard
|
||||
gc.collect()
|
||||
|
Reference in New Issue
Block a user