mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-22 21:49:08 +00:00
fix
This commit is contained in:
parent
603d06ad56
commit
0009a4dbbe
@ -9,6 +9,7 @@ from typing import Dict, Iterator, Optional, OrderedDict, Tuple
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from tensornvme.async_file_io import AsyncFileWriter
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||||
from torch.utils._pytree import tree_map
|
from torch.utils._pytree import tree_map
|
||||||
@ -22,6 +23,7 @@ from colossalai.tensor.padded_tensor import (
|
|||||||
to_unpadded_tensor,
|
to_unpadded_tensor,
|
||||||
)
|
)
|
||||||
from colossalai.utils import get_current_device, get_non_persistent_buffers_set
|
from colossalai.utils import get_current_device, get_non_persistent_buffers_set
|
||||||
|
from colossalai.utils.safetensors import move_and_save
|
||||||
|
|
||||||
from .general_checkpoint_io import GeneralCheckpointIO
|
from .general_checkpoint_io import GeneralCheckpointIO
|
||||||
from .index_file import CheckpointIndexFile
|
from .index_file import CheckpointIndexFile
|
||||||
@ -199,7 +201,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
|
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
|
||||||
use_async (bool, optional): Whether to save the state_dicts of model asynchronously. Defaults to False.
|
use_async (bool, optional): Whether to save the state_dicts of model asynchronously. Defaults to False.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
||||||
model._force_wait_all_gather()
|
model._force_wait_all_gather()
|
||||||
model = model.unwrap()
|
model = model.unwrap()
|
||||||
@ -224,7 +225,18 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
if self.pp_size == 1:
|
if self.pp_size == 1:
|
||||||
# When pipeline is not used, save the model shards as in general checkpointIO
|
# When pipeline is not used, save the model shards as in general checkpointIO
|
||||||
if use_async:
|
if use_async:
|
||||||
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async=use_async)
|
pinned_state_dict = self.pinned_state_dicts.get(id(model), None)
|
||||||
|
total_size, pinned_state_dict, writers = async_save_state_dict_shards(
|
||||||
|
sharded_state_dict=state_dict_shard,
|
||||||
|
checkpoint=checkpoint,
|
||||||
|
index_file=index_file,
|
||||||
|
base_filename=weights_name,
|
||||||
|
is_master=control_saving,
|
||||||
|
use_safetensors=use_safetensors,
|
||||||
|
n_write_entries=self.N_WRITE_ENTRIES,
|
||||||
|
)
|
||||||
|
self.pinned_state_dicts[id(model)] = pinned_state_dict
|
||||||
|
self.async_writers.extend(writers)
|
||||||
else:
|
else:
|
||||||
total_size = save_state_dict_shards(
|
total_size = save_state_dict_shards(
|
||||||
sharded_state_dict=state_dict_shard,
|
sharded_state_dict=state_dict_shard,
|
||||||
@ -234,16 +246,16 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
is_master=control_saving,
|
is_master=control_saving,
|
||||||
use_safetensors=use_safetensors,
|
use_safetensors=use_safetensors,
|
||||||
)
|
)
|
||||||
if control_saving:
|
if control_saving:
|
||||||
index_file.append_meta_data("total_size", total_size)
|
index_file.append_meta_data("total_size", total_size)
|
||||||
index_file.write_index_file(save_index_file)
|
index_file.write_index_file(save_index_file)
|
||||||
save_config_file(model, checkpoint)
|
save_config_file(model, checkpoint)
|
||||||
if self.verbose and self.coordinator.is_master():
|
if self.verbose and self.coordinator.is_master():
|
||||||
logging.info(
|
logging.info(
|
||||||
f"The model is split into checkpoint shards. "
|
f"The model is split into checkpoint shards. "
|
||||||
f"You can find where each parameters has been saved in the "
|
f"You can find where each parameters has been saved in the "
|
||||||
f"index located at {save_index_file}."
|
f"index located at {save_index_file}."
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# When pipeline is used, each stage produces its own shard files and index files.
|
# When pipeline is used, each stage produces its own shard files and index files.
|
||||||
@ -259,24 +271,28 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
|
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
|
||||||
save_index_file = os.path.join("tmp_index_files", save_index_file)
|
save_index_file = os.path.join("tmp_index_files", save_index_file)
|
||||||
if use_async:
|
if use_async:
|
||||||
total_size, returned_state_dict, writers = async_save_state_dict_shards(
|
pinned_state_dict = self.pinned_state_dicts.get(id(model), None)
|
||||||
|
total_size, pinned_state_dict, writers = async_save_state_dict_shards(
|
||||||
sharded_state_dict=state_dict_shard,
|
sharded_state_dict=state_dict_shard,
|
||||||
checkpoint=checkpoint,
|
checkpoint=checkpoint,
|
||||||
index_file=index_file,
|
index_file=index_file,
|
||||||
base_filename=weights_name,
|
base_filename=weights_name,
|
||||||
is_master=control_saving,
|
is_master=control_saving,
|
||||||
use_pp_format=True,
|
use_safetensors=use_safetensors,
|
||||||
n_write_entries=191,
|
n_write_entries=self.N_WRITE_ENTRIES,
|
||||||
|
)
|
||||||
|
self.pinned_state_dicts[id(model)] = pinned_state_dict
|
||||||
|
self.async_writers.extend(writers)
|
||||||
|
else:
|
||||||
|
total_size = save_state_dict_shards(
|
||||||
|
sharded_state_dict=state_dict_shard,
|
||||||
|
checkpoint=checkpoint,
|
||||||
|
index_file=index_file,
|
||||||
|
base_filename=weights_name,
|
||||||
|
is_master=control_saving,
|
||||||
|
use_safetensors=use_safetensors,
|
||||||
|
use_pp_format=True,
|
||||||
)
|
)
|
||||||
total_size = save_state_dict_shards(
|
|
||||||
sharded_state_dict=state_dict_shard,
|
|
||||||
checkpoint=checkpoint,
|
|
||||||
index_file=index_file,
|
|
||||||
base_filename=weights_name,
|
|
||||||
is_master=control_saving,
|
|
||||||
use_safetensors=use_safetensors,
|
|
||||||
use_pp_format=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if control_saving:
|
if control_saving:
|
||||||
assert (
|
assert (
|
||||||
@ -664,7 +680,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
model = model.unwrap()
|
model = model.unwrap()
|
||||||
if self.dp_rank != 0:
|
if self.dp_rank != 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
# The logic of collecting parameter shards along tp degree
|
# The logic of collecting parameter shards along tp degree
|
||||||
# has been implemented by _save_to_state_dict method of ParallelModule in Shardformer.
|
# has been implemented by _save_to_state_dict method of ParallelModule in Shardformer.
|
||||||
state_dict = model.state_dict()
|
state_dict = model.state_dict()
|
||||||
@ -686,15 +701,12 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
for _state_dict in state_dict_list:
|
for _state_dict in state_dict_list:
|
||||||
complete_state_dict.update(_state_dict)
|
complete_state_dict.update(_state_dict)
|
||||||
if use_async:
|
if use_async:
|
||||||
from tensornvme.async_file_io import AsyncFileWriter
|
|
||||||
|
|
||||||
from colossalai.utils.safetensors import move_and_save
|
|
||||||
|
|
||||||
writer = AsyncFileWriter(open(checkpoint, "wb"), self.N_WRITE_ENTRIES, backend="pthread")
|
writer = AsyncFileWriter(open(checkpoint, "wb"), self.N_WRITE_ENTRIES, backend="pthread")
|
||||||
if id(model) not in self.pinned_state_dicts:
|
if id(model) not in self.pinned_state_dicts:
|
||||||
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
|
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(complete_state_dict)
|
||||||
self.async_writers.append(writer)
|
self.async_writers.append(writer)
|
||||||
move_and_save(writer, state_dict, self.pinned_state_dicts[id(model)])
|
move_and_save(writer, complete_state_dict, self.pinned_state_dicts[id(model)])
|
||||||
else:
|
else:
|
||||||
save_state_dict(complete_state_dict, checkpoint, use_safetensors)
|
save_state_dict(complete_state_dict, checkpoint, use_safetensors)
|
||||||
|
|
||||||
|
@ -34,6 +34,7 @@ if Version(torch.__version__) < Version("2.0.0"):
|
|||||||
else:
|
else:
|
||||||
TEST_CONFIGS = [
|
TEST_CONFIGS = [
|
||||||
# TODO(ver217): other configs lead to hang
|
# TODO(ver217): other configs lead to hang
|
||||||
|
{"tp_size": 1, "pp_size": 1, "num_microbatches": 4, "zero_stage": 1, "precision": "fp16", "initial_scale": 1},
|
||||||
{"tp_size": 1, "pp_size": 2, "num_microbatches": 4, "zero_stage": 1, "precision": "fp16", "initial_scale": 1},
|
{"tp_size": 1, "pp_size": 2, "num_microbatches": 4, "zero_stage": 1, "precision": "fp16", "initial_scale": 1},
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -42,8 +43,9 @@ else:
|
|||||||
@parameterize("model_name", ["transformers_llama_for_causal_lm"])
|
@parameterize("model_name", ["transformers_llama_for_causal_lm"])
|
||||||
@parameterize("size_per_shard", [32])
|
@parameterize("size_per_shard", [32])
|
||||||
@parameterize("test_config", TEST_CONFIGS)
|
@parameterize("test_config", TEST_CONFIGS)
|
||||||
|
@parameterize("use_async", [True, False])
|
||||||
@clear_cache_before_run()
|
@clear_cache_before_run()
|
||||||
def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict):
|
def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict, use_async: bool):
|
||||||
(model_fn, data_gen_fn, output_transform_fn, loss_fn, _) = next(
|
(model_fn, data_gen_fn, output_transform_fn, loss_fn, _) = next(
|
||||||
iter(model_zoo.get_sub_registry(model_name).values())
|
iter(model_zoo.get_sub_registry(model_name).values())
|
||||||
)
|
)
|
||||||
@ -85,8 +87,14 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf
|
|||||||
with shared_tempdir() as tempdir:
|
with shared_tempdir() as tempdir:
|
||||||
model_ckpt_path = f"{tempdir}/model"
|
model_ckpt_path = f"{tempdir}/model"
|
||||||
optimizer_ckpt_path = f"{tempdir}/optimizer"
|
optimizer_ckpt_path = f"{tempdir}/optimizer"
|
||||||
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard)
|
if not use_async:
|
||||||
|
model_ckpt_path = f"{model_ckpt_path}.pt"
|
||||||
|
if use_async:
|
||||||
|
model_ckpt_path = f"{model_ckpt_path}.safetensors"
|
||||||
|
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard, use_async=use_async)
|
||||||
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard)
|
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard)
|
||||||
|
booster.checkpoint_io._sync_d2h()
|
||||||
|
booster.checkpoint_io._sync_io()
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
new_model = model_fn().cuda()
|
new_model = model_fn().cuda()
|
||||||
|
Loading…
Reference in New Issue
Block a user