mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-22 13:41:43 +00:00
fix
This commit is contained in:
parent
603d06ad56
commit
0009a4dbbe
@ -9,6 +9,7 @@ from typing import Dict, Iterator, Optional, OrderedDict, Tuple
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from tensornvme.async_file_io import AsyncFileWriter
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
from torch.utils._pytree import tree_map
|
||||
@ -22,6 +23,7 @@ from colossalai.tensor.padded_tensor import (
|
||||
to_unpadded_tensor,
|
||||
)
|
||||
from colossalai.utils import get_current_device, get_non_persistent_buffers_set
|
||||
from colossalai.utils.safetensors import move_and_save
|
||||
|
||||
from .general_checkpoint_io import GeneralCheckpointIO
|
||||
from .index_file import CheckpointIndexFile
|
||||
@ -199,7 +201,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
|
||||
use_async (bool, optional): Whether to save the state_dicts of model asynchronously. Defaults to False.
|
||||
"""
|
||||
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
||||
model._force_wait_all_gather()
|
||||
model = model.unwrap()
|
||||
@ -224,7 +225,18 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
if self.pp_size == 1:
|
||||
# When pipeline is not used, save the model shards as in general checkpointIO
|
||||
if use_async:
|
||||
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async=use_async)
|
||||
pinned_state_dict = self.pinned_state_dicts.get(id(model), None)
|
||||
total_size, pinned_state_dict, writers = async_save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=weights_name,
|
||||
is_master=control_saving,
|
||||
use_safetensors=use_safetensors,
|
||||
n_write_entries=self.N_WRITE_ENTRIES,
|
||||
)
|
||||
self.pinned_state_dicts[id(model)] = pinned_state_dict
|
||||
self.async_writers.extend(writers)
|
||||
else:
|
||||
total_size = save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
@ -234,16 +246,16 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
is_master=control_saving,
|
||||
use_safetensors=use_safetensors,
|
||||
)
|
||||
if control_saving:
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
save_config_file(model, checkpoint)
|
||||
if self.verbose and self.coordinator.is_master():
|
||||
logging.info(
|
||||
f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}."
|
||||
)
|
||||
if control_saving:
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
save_config_file(model, checkpoint)
|
||||
if self.verbose and self.coordinator.is_master():
|
||||
logging.info(
|
||||
f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}."
|
||||
)
|
||||
|
||||
else:
|
||||
# When pipeline is used, each stage produces its own shard files and index files.
|
||||
@ -259,24 +271,28 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
|
||||
save_index_file = os.path.join("tmp_index_files", save_index_file)
|
||||
if use_async:
|
||||
total_size, returned_state_dict, writers = async_save_state_dict_shards(
|
||||
pinned_state_dict = self.pinned_state_dicts.get(id(model), None)
|
||||
total_size, pinned_state_dict, writers = async_save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=weights_name,
|
||||
is_master=control_saving,
|
||||
use_pp_format=True,
|
||||
n_write_entries=191,
|
||||
use_safetensors=use_safetensors,
|
||||
n_write_entries=self.N_WRITE_ENTRIES,
|
||||
)
|
||||
self.pinned_state_dicts[id(model)] = pinned_state_dict
|
||||
self.async_writers.extend(writers)
|
||||
else:
|
||||
total_size = save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=weights_name,
|
||||
is_master=control_saving,
|
||||
use_safetensors=use_safetensors,
|
||||
use_pp_format=True,
|
||||
)
|
||||
total_size = save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=weights_name,
|
||||
is_master=control_saving,
|
||||
use_safetensors=use_safetensors,
|
||||
use_pp_format=True,
|
||||
)
|
||||
|
||||
if control_saving:
|
||||
assert (
|
||||
@ -664,7 +680,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
model = model.unwrap()
|
||||
if self.dp_rank != 0:
|
||||
return
|
||||
|
||||
# The logic of collecting parameter shards along tp degree
|
||||
# has been implemented by _save_to_state_dict method of ParallelModule in Shardformer.
|
||||
state_dict = model.state_dict()
|
||||
@ -686,15 +701,12 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
for _state_dict in state_dict_list:
|
||||
complete_state_dict.update(_state_dict)
|
||||
if use_async:
|
||||
from tensornvme.async_file_io import AsyncFileWriter
|
||||
|
||||
from colossalai.utils.safetensors import move_and_save
|
||||
|
||||
writer = AsyncFileWriter(open(checkpoint, "wb"), self.N_WRITE_ENTRIES, backend="pthread")
|
||||
if id(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
|
||||
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(complete_state_dict)
|
||||
self.async_writers.append(writer)
|
||||
move_and_save(writer, state_dict, self.pinned_state_dicts[id(model)])
|
||||
move_and_save(writer, complete_state_dict, self.pinned_state_dicts[id(model)])
|
||||
else:
|
||||
save_state_dict(complete_state_dict, checkpoint, use_safetensors)
|
||||
|
||||
|
@ -34,6 +34,7 @@ if Version(torch.__version__) < Version("2.0.0"):
|
||||
else:
|
||||
TEST_CONFIGS = [
|
||||
# TODO(ver217): other configs lead to hang
|
||||
{"tp_size": 1, "pp_size": 1, "num_microbatches": 4, "zero_stage": 1, "precision": "fp16", "initial_scale": 1},
|
||||
{"tp_size": 1, "pp_size": 2, "num_microbatches": 4, "zero_stage": 1, "precision": "fp16", "initial_scale": 1},
|
||||
]
|
||||
|
||||
@ -42,8 +43,9 @@ else:
|
||||
@parameterize("model_name", ["transformers_llama_for_causal_lm"])
|
||||
@parameterize("size_per_shard", [32])
|
||||
@parameterize("test_config", TEST_CONFIGS)
|
||||
@parameterize("use_async", [True, False])
|
||||
@clear_cache_before_run()
|
||||
def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict):
|
||||
def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict, use_async: bool):
|
||||
(model_fn, data_gen_fn, output_transform_fn, loss_fn, _) = next(
|
||||
iter(model_zoo.get_sub_registry(model_name).values())
|
||||
)
|
||||
@ -85,8 +87,14 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf
|
||||
with shared_tempdir() as tempdir:
|
||||
model_ckpt_path = f"{tempdir}/model"
|
||||
optimizer_ckpt_path = f"{tempdir}/optimizer"
|
||||
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard)
|
||||
if not use_async:
|
||||
model_ckpt_path = f"{model_ckpt_path}.pt"
|
||||
if use_async:
|
||||
model_ckpt_path = f"{model_ckpt_path}.safetensors"
|
||||
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard, use_async=use_async)
|
||||
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard)
|
||||
booster.checkpoint_io._sync_d2h()
|
||||
booster.checkpoint_io._sync_io()
|
||||
dist.barrier()
|
||||
|
||||
new_model = model_fn().cuda()
|
||||
|
Loading…
Reference in New Issue
Block a user