[checkpointio]support asyncio for 3d (#6152)

* fix

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update utils.py

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
flybird11111
2024-12-23 10:24:22 +08:00
committed by GitHub
parent aaafb38851
commit 130229fdcb
17 changed files with 776 additions and 188 deletions

View File

@@ -17,6 +17,8 @@ from torch.utils.data.distributed import DistributedSampler
from colossalai.accelerator import get_accelerator
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO
from colossalai.checkpoint_io.utils import (
async_save_state_dict_shards,
create_pinned_state_dict,
get_model_base_filenames,
get_optimizer_base_filenames,
load_shard_state_dict,
@@ -28,6 +30,7 @@ from colossalai.cluster import DistCoordinator, ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.logging import get_dist_logger
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.utils.safetensors import load_flat
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.memory_tracer import MemStats
@@ -82,7 +85,15 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
state_dict = model.state_dict(only_rank_0=True)
if self.coordinator.is_master():
if use_async:
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async)
from colossalai.utils.safetensors import save
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
for k, v in state_dict.items():
self.pinned_state_dicts[id(model)][k].copy_(v)
state_dict[k] = self.pinned_state_dicts[id(model)][k]
writer = save(checkpoint, state_dict)
self.async_writers.append(writer)
else:
save_state_dict(state_dict, checkpoint, use_safetensors)
@@ -106,7 +117,19 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!"
state_dict = optimizer.state_dict()
if self.coordinator.is_master():
save_state_dict(state_dict, checkpoint, use_safetensors=False)
if use_async:
from colossalai.utils.safetensors import _flatten_optim_state_dict, save
flatten_state_dict, metadata = _flatten_optim_state_dict(state_dict)
if id(optimizer) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(optimizer)] = create_pinned_state_dict(flatten_state_dict)
for k, v in flatten_state_dict.items():
self.pinned_state_dicts[id(optimizer)][k].copy_(v)
flatten_state_dict[k] = self.pinned_state_dicts[id(optimizer)][k]
writer = save(checkpoint, flatten_state_dict, metadata)
self.async_writers.append(writer)
else:
save_state_dict(state_dict, checkpoint, use_safetensors=False)
def load_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str):
"""
@@ -137,17 +160,29 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True)
if use_async and self.coordinator.is_master():
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = {}
pinned_state_dicts = self.pinned_state_dicts[id(model)]
else:
pinned_state_dicts = None
state_dict_shard = model.state_dict_shard(
max_shard_size=max_shard_size, only_rank_0=True, pinned_state_dicts=pinned_state_dicts
)
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
index_file = CheckpointIndexFile(checkpoint_path)
# Save shards of optimizer states.
is_master = self.coordinator.is_master()
if use_async:
super().save_sharded_model(
model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors, use_async
total_size, writers = async_save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint_path,
index_file=index_file,
base_filename=weights_name,
is_master=is_master,
)
self.async_writers.extend(writers)
else:
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
@@ -158,17 +193,17 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
use_safetensors=use_safetensors,
)
# only save the index file on the master rank
if self.coordinator.is_master():
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model.unwrap(), checkpoint_path)
self.logger.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.",
ranks=[0],
)
# only save the index file on the master rank
if self.coordinator.is_master():
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model.unwrap(), checkpoint_path)
self.logger.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.",
ranks=[0],
)
def load_sharded_model(
self, model: GeminiDDP, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False
@@ -201,7 +236,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
Path(checkpoint).mkdir(parents=True, exist_ok=True)
# Preparing file paths and index file.
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix, use_safetensors=use_async)
index_file = CheckpointIndexFile(checkpoint)
index_file.append_meta_data("param_groups", param_group_file)
@@ -212,17 +247,36 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
torch.save(param_groups, group_file_path)
# States are broken into shards within max_shard_size.
state_dict_shard = optimizer.state_shard(prefix=prefix, max_shard_size=size_per_shard, only_rank_0=True)
if use_async and self.coordinator.is_master():
if id(optimizer) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(optimizer)] = {}
pinned_state_dicts = self.pinned_state_dicts[id(optimizer)]
else:
pinned_state_dicts = None
state_dict_shard = optimizer.state_shard(
prefix=prefix, max_shard_size=size_per_shard, only_rank_0=True, pinned_state_dicts=pinned_state_dicts
)
# Save shards of optimizer states.
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=self.coordinator.is_master(),
use_safetensors=False,
)
if use_async:
total_size, writers = async_save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=self.coordinator.is_master(),
state_preprocess=True,
)
self.async_writers.extend(writers)
else:
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=self.coordinator.is_master(),
use_safetensors=False,
)
# Wrap up index file. Only save it on master rank.
if self.coordinator.is_master():
@@ -264,7 +318,10 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
# Load optimizer states from shard files under checkpoint path.
# For each file, only load the states managed by current process.
for shard_file in checkpoint_files:
state_dict_shard = load_shard_state_dict(Path(shard_file), use_safetensors=False)
if shard_file.endswith(".safetensors"):
state_dict_shard = load_flat(shard_file)
else:
state_dict_shard = load_shard_state_dict(Path(shard_file), use_safetensors=False)
optimizer.load_param_states(state_dict_shard)
del state_dict_shard
gc.collect()