mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 17:46:42 +00:00
[checkpointio]support asyncio for 3d (#6152)
* fix * fix * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update utils.py * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple
|
||||
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -26,9 +26,11 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO, utils
|
||||
from colossalai.checkpoint_io.utils import async_save_state_dict_shards, create_pinned_state_dict
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils.safetensors import load_flat
|
||||
|
||||
from .dp_plugin_base import DPPluginBase
|
||||
|
||||
@@ -49,8 +51,36 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
||||
|
||||
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: Path):
|
||||
assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before loading!"
|
||||
checkpoint = utils.load_state_dict(checkpoint)
|
||||
if checkpoint.endswith(".safetensors"):
|
||||
checkpoint = load_flat(checkpoint, seperator=".")
|
||||
else:
|
||||
checkpoint = utils.load_state_dict(checkpoint)
|
||||
|
||||
fsdp_model = optimizer.unwrap_model()
|
||||
full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=False)
|
||||
start_index = 0
|
||||
id2name = {}
|
||||
|
||||
def get_index_mapping(group: Dict[str, Any]) -> Dict[str, Any]:
|
||||
nonlocal start_index
|
||||
start_num = len(id2name)
|
||||
id2name.update({i: p for i, p in enumerate(group["params"], start_index) if i not in id2name})
|
||||
end_num = len(id2name)
|
||||
start_index += end_num - start_num
|
||||
|
||||
for g in full_optimizer_state["param_groups"]:
|
||||
get_index_mapping(g)
|
||||
|
||||
new_state = {}
|
||||
for key, value in checkpoint["state"].items():
|
||||
new_state[id2name[int(key)]] = value
|
||||
checkpoint["state"] = new_state
|
||||
for g in checkpoint["param_groups"]:
|
||||
new_group = []
|
||||
for param_id in g["params"]:
|
||||
new_group.append(id2name[param_id])
|
||||
g["params"] = new_group
|
||||
|
||||
sharded_osd = FSDP.scatter_full_optim_state_dict(checkpoint, fsdp_model)
|
||||
optimizer.load_state_dict(sharded_osd)
|
||||
|
||||
@@ -65,7 +95,21 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
||||
cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
|
||||
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg):
|
||||
full_model_state = model.state_dict()
|
||||
utils.save_state_dict(full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors)
|
||||
if self.coordinator.is_master():
|
||||
if use_async:
|
||||
from colossalai.utils.safetensors import save
|
||||
|
||||
if id(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(full_model_state)
|
||||
for k, v in full_model_state.items():
|
||||
self.pinned_state_dicts[id(model)][k].copy_(v)
|
||||
full_model_state[k] = self.pinned_state_dicts[id(model)][k]
|
||||
writer = save(checkpoint, full_model_state)
|
||||
self.async_writers.append(writer)
|
||||
else:
|
||||
utils.save_state_dict(
|
||||
full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors
|
||||
)
|
||||
|
||||
def save_unsharded_optimizer(
|
||||
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False
|
||||
@@ -75,8 +119,43 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
||||
"""
|
||||
assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!"
|
||||
fsdp_model = optimizer.unwrap_model()
|
||||
|
||||
full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=True)
|
||||
utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False)
|
||||
|
||||
if self.coordinator.is_master():
|
||||
|
||||
# Save order indices instead of Tensors
|
||||
name2id: Dict[str, int] = {}
|
||||
start_index = 0
|
||||
|
||||
def pack_group(group: Dict[str, Any]) -> Dict[str, Any]:
|
||||
nonlocal start_index
|
||||
packed = {k: v for k, v in group.items() if k != "params"}
|
||||
name2id.update({p: i for i, p in enumerate(group["params"], start_index) if p not in name2id})
|
||||
packed["params"] = [name2id[p] for p in group["params"]]
|
||||
start_index += len(packed["params"])
|
||||
return packed
|
||||
|
||||
param_groups = [pack_group(g) for g in full_optimizer_state["param_groups"]]
|
||||
full_optimizer_state["param_groups"] = param_groups
|
||||
new_state = {}
|
||||
for key, value in full_optimizer_state["state"].items():
|
||||
new_state[name2id[key]] = value
|
||||
full_optimizer_state["state"] = new_state
|
||||
|
||||
if use_async:
|
||||
from colossalai.utils.safetensors import _flatten_optim_state_dict, save
|
||||
|
||||
flatten_state_dict, metadata = _flatten_optim_state_dict(full_optimizer_state, seperator=".")
|
||||
if id(optimizer) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(optimizer)] = create_pinned_state_dict(flatten_state_dict)
|
||||
for k, v in flatten_state_dict.items():
|
||||
self.pinned_state_dicts[id(optimizer)][k].copy_(v)
|
||||
flatten_state_dict[k] = self.pinned_state_dicts[id(optimizer)][k]
|
||||
writer = save(checkpoint, state_dict=flatten_state_dict, metadata=metadata)
|
||||
self.async_writers.append(writer)
|
||||
else:
|
||||
utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False)
|
||||
|
||||
def save_sharded_model(
|
||||
self,
|
||||
@@ -102,20 +181,38 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
||||
):
|
||||
state_dict = model.unwrap().state_dict()
|
||||
|
||||
state_dict_shard = utils.shard_model_checkpoint(state_dict, max_shard_size=size_per_shard)
|
||||
if use_async and self.coordinator.is_master():
|
||||
if id(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(model)] = {}
|
||||
pinned_state_dicts = self.pinned_state_dicts[id(model)]
|
||||
else:
|
||||
pinned_state_dicts = None
|
||||
state_dict_shard = utils.shard_model_checkpoint(
|
||||
state_dict, max_shard_size=size_per_shard, pinned_state_dicts=pinned_state_dicts
|
||||
)
|
||||
|
||||
weights_name, save_index_file = utils.get_model_base_filenames(prefix, use_safetensors)
|
||||
index_file = CheckpointIndexFile(checkpoint_path)
|
||||
|
||||
# In general cases, is_master is set to True to get the right behavior.
|
||||
total_size = utils.save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint_path,
|
||||
index_file=index_file,
|
||||
base_filename=weights_name,
|
||||
is_master=self.coordinator.is_master(),
|
||||
use_safetensors=use_safetensors,
|
||||
)
|
||||
if use_async:
|
||||
total_size, writers = async_save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint_path,
|
||||
index_file=index_file,
|
||||
base_filename=weights_name,
|
||||
is_master=self.coordinator.is_master(),
|
||||
)
|
||||
self.async_writers.extend(writers)
|
||||
else:
|
||||
total_size = utils.save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint_path,
|
||||
index_file=index_file,
|
||||
base_filename=weights_name,
|
||||
is_master=self.coordinator.is_master(),
|
||||
use_safetensors=use_safetensors,
|
||||
)
|
||||
|
||||
# only save the index file on the master rank
|
||||
if self.coordinator.is_master():
|
||||
@@ -188,26 +285,66 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
||||
)
|
||||
|
||||
if self.coordinator.is_master():
|
||||
|
||||
# Save order indices instead of Tensors
|
||||
name2id: Dict[str, int] = {}
|
||||
start_index = 0
|
||||
|
||||
def pack_group(group: Dict[str, Any]) -> Dict[str, Any]:
|
||||
nonlocal start_index
|
||||
packed = {k: v for k, v in group.items() if k != "params"}
|
||||
name2id.update({p: i for i, p in enumerate(group["params"], start_index) if p not in name2id})
|
||||
packed["params"] = [name2id[p] for p in group["params"]]
|
||||
start_index += len(packed["params"])
|
||||
return packed
|
||||
|
||||
param_groups = [pack_group(g) for g in fsdp_optim_state["param_groups"]]
|
||||
fsdp_optim_state["param_groups"] = param_groups
|
||||
new_state = {}
|
||||
for key, value in fsdp_optim_state["state"].items():
|
||||
new_state[name2id[key]] = value
|
||||
fsdp_optim_state["state"] = new_state
|
||||
|
||||
# Preparing file paths and index file.
|
||||
states_name, save_index_file, param_group_file = utils.get_optimizer_base_filenames(prefix)
|
||||
states_name, save_index_file, param_group_file = utils.get_optimizer_base_filenames(
|
||||
prefix, use_safetensors=use_async
|
||||
)
|
||||
index_file = CheckpointIndexFile(checkpoint)
|
||||
|
||||
index_file.append_meta_data("param_groups", param_group_file)
|
||||
group_file_path = os.path.join(checkpoint, param_group_file)
|
||||
utils.save_param_groups(fsdp_optim_state, group_file_path)
|
||||
|
||||
sharded_state = utils.shard_optimizer_checkpoint(fsdp_optim_state, max_shard_size=size_per_shard)
|
||||
|
||||
if use_async:
|
||||
if id(optimizer) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(optimizer)] = {}
|
||||
pinned_state_dicts = self.pinned_state_dicts[id(optimizer)]
|
||||
else:
|
||||
pinned_state_dicts = None
|
||||
sharded_state = utils.shard_optimizer_checkpoint(
|
||||
fsdp_optim_state, max_shard_size=size_per_shard, pinned_state_dicts=pinned_state_dicts
|
||||
)
|
||||
# Save shards of optimizer states.
|
||||
# In general cases, is_master is set to True to get the right behavior.
|
||||
total_size = utils.save_state_dict_shards(
|
||||
sharded_state_dict=sharded_state,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=states_name,
|
||||
is_master=self.coordinator.is_master(),
|
||||
use_safetensors=False,
|
||||
)
|
||||
if use_async:
|
||||
total_size, writers = async_save_state_dict_shards(
|
||||
sharded_state_dict=sharded_state,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=states_name,
|
||||
is_master=self.coordinator.is_master(),
|
||||
state_preprocess=True,
|
||||
)
|
||||
self.async_writers.extend(writers)
|
||||
else:
|
||||
total_size = utils.save_state_dict_shards(
|
||||
sharded_state_dict=sharded_state,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=states_name,
|
||||
is_master=self.coordinator.is_master(),
|
||||
use_safetensors=False,
|
||||
)
|
||||
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
@@ -239,11 +376,39 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
||||
fsdp_optim_state = {}
|
||||
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
|
||||
for shard_file in checkpoint_files:
|
||||
state_dict_shard = utils.load_shard_state_dict(Path(shard_file), use_safetensors=False)
|
||||
if shard_file.endswith(".safetensors"):
|
||||
state_dict_shard = load_flat(shard_file, seperator=".")
|
||||
else:
|
||||
state_dict_shard = utils.load_shard_state_dict(Path(shard_file), use_safetensors=False)
|
||||
fsdp_optim_state.update(state_dict_shard)
|
||||
|
||||
fsdp_optim_dict = dict(state=fsdp_optim_state, param_groups=saved_param_groups)
|
||||
|
||||
fsdp_model = optimizer.unwrap_model()
|
||||
full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model.unwrap(), optim=optimizer, rank0_only=False)
|
||||
start_index = 0
|
||||
id2name = {}
|
||||
|
||||
def get_index_mapping(group: Dict[str, Any]) -> Dict[str, Any]:
|
||||
nonlocal start_index
|
||||
start_num = len(id2name)
|
||||
id2name.update({i: p for i, p in enumerate(group["params"], start_index) if i not in id2name})
|
||||
end_num = len(id2name)
|
||||
start_index += end_num - start_num
|
||||
|
||||
for g in full_optimizer_state["param_groups"]:
|
||||
get_index_mapping(g)
|
||||
|
||||
new_state = {}
|
||||
for key, value in fsdp_optim_dict["state"].items():
|
||||
new_state[id2name[int(key)]] = value
|
||||
fsdp_optim_dict["state"] = new_state
|
||||
for g in fsdp_optim_dict["param_groups"]:
|
||||
new_group = []
|
||||
for param_id in g["params"]:
|
||||
new_group.append(id2name[param_id])
|
||||
g["params"] = new_group
|
||||
|
||||
with FSDP.state_dict_type(optimizer.unwrap_model().unwrap(), StateDictType.FULL_STATE_DICT):
|
||||
fsdp_state = FSDP.optim_state_dict_to_load(
|
||||
model=optimizer.unwrap_model().unwrap(), optim=optimizer, optim_state_dict=fsdp_optim_dict
|
||||
|
Reference in New Issue
Block a user