mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
import copy
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
@@ -35,9 +34,9 @@ from .utils import (
|
||||
)
|
||||
|
||||
try:
|
||||
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
|
||||
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
|
||||
except ImportError:
|
||||
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
|
||||
_EXTRA_STATE_KEY_SUFFIX = "_extra_state"
|
||||
|
||||
|
||||
class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
@@ -52,12 +51,14 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
verbose (bool, optional): Whether to print logging massage when saving/loading has been succesfully executed. Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dp_group: ProcessGroup,
|
||||
pp_group: ProcessGroup,
|
||||
tp_group: ProcessGroup,
|
||||
zero_stage: int,
|
||||
verbose: bool = True) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
dp_group: ProcessGroup,
|
||||
pp_group: ProcessGroup,
|
||||
tp_group: ProcessGroup,
|
||||
zero_stage: int,
|
||||
verbose: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.dp_group = dp_group
|
||||
self.pp_group = pp_group
|
||||
@@ -68,17 +69,16 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
self.dp_size = dist.get_world_size(dp_group)
|
||||
self.pp_size = dist.get_world_size(pp_group)
|
||||
self.tp_size = dist.get_world_size(tp_group)
|
||||
self.use_zero = (zero_stage > 0)
|
||||
self.use_zero = zero_stage > 0
|
||||
self.verbose = verbose
|
||||
self.working_to_master_map = None
|
||||
self.master_to_working_map = None
|
||||
self.coordinator = DistCoordinator()
|
||||
|
||||
@staticmethod
|
||||
def _model_sharder(model: nn.Module,
|
||||
prefix: str = '',
|
||||
keep_vars: bool = False,
|
||||
size_per_shard: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
|
||||
def _model_sharder(
|
||||
model: nn.Module, prefix: str = "", keep_vars: bool = False, size_per_shard: int = 1024
|
||||
) -> Iterator[Tuple[OrderedDict, int]]:
|
||||
# An internel method that breaks state_dict of model into shards within limited size.
|
||||
|
||||
state_dict_sharder = StateDictSharder(size_per_shard)
|
||||
@@ -103,8 +103,10 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
|
||||
# Save extra states.
|
||||
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
|
||||
if getattr(model.__class__, "get_extra_state",
|
||||
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
|
||||
if (
|
||||
getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state)
|
||||
is not torch.nn.Module.get_extra_state
|
||||
):
|
||||
extra_state = model.get_extra_state()
|
||||
block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state)
|
||||
if block is not None:
|
||||
@@ -114,20 +116,20 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
|
||||
|
||||
@staticmethod
|
||||
def _optimizer_sharder(optimizer: OptimizerWrapper,
|
||||
use_zero: bool,
|
||||
dp_group: ProcessGroup,
|
||||
tp_group: ProcessGroup,
|
||||
master_to_working_map: Optional[Dict[int, torch.Tensor]] = None,
|
||||
size_per_shard: int = 1024):
|
||||
|
||||
def _optimizer_sharder(
|
||||
optimizer: OptimizerWrapper,
|
||||
use_zero: bool,
|
||||
dp_group: ProcessGroup,
|
||||
tp_group: ProcessGroup,
|
||||
master_to_working_map: Optional[Dict[int, torch.Tensor]] = None,
|
||||
size_per_shard: int = 1024,
|
||||
):
|
||||
# An internel method that breaks state_dict of optimizer into shards within limited size.
|
||||
|
||||
state_dict_sharder = StateDictSharder(size_per_shard)
|
||||
param_info = optimizer.param_info
|
||||
|
||||
for param, state in optimizer.optim.state.items():
|
||||
|
||||
if param is None:
|
||||
continue
|
||||
|
||||
@@ -136,15 +138,17 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
else:
|
||||
working_param = param
|
||||
|
||||
param_id = param_info['param2id'][id(working_param)]
|
||||
original_shape = param_info['param2shape'][id(working_param)]
|
||||
state_ = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state(state,
|
||||
working_param,
|
||||
original_shape=original_shape,
|
||||
dp_group=dp_group,
|
||||
tp_group=tp_group,
|
||||
use_zero=use_zero,
|
||||
inplace=False)
|
||||
param_id = param_info["param2id"][id(working_param)]
|
||||
original_shape = param_info["param2shape"][id(working_param)]
|
||||
state_ = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state(
|
||||
state,
|
||||
working_param,
|
||||
original_shape=original_shape,
|
||||
dp_group=dp_group,
|
||||
tp_group=tp_group,
|
||||
use_zero=use_zero,
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
block, block_size = state_dict_sharder.append_optim_state(param_id, state_)
|
||||
if block is not None:
|
||||
@@ -153,13 +157,15 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
# Return the last block in sharder.
|
||||
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
|
||||
|
||||
def save_sharded_model(self,
|
||||
model: nn.Module,
|
||||
checkpoint: str,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
size_per_shard: int = 1024,
|
||||
use_safetensors: bool = False) -> None:
|
||||
def save_sharded_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
checkpoint: str,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
size_per_shard: int = 1024,
|
||||
use_safetensors: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Save sharded model checkpoint under the given checkpointing path.
|
||||
The following files will be created under the path:
|
||||
@@ -194,24 +200,28 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
state_dict_shard = HybridParallelCheckpointIO._model_sharder(model, size_per_shard=size_per_shard)
|
||||
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
|
||||
index_file = CheckpointIndexFile(checkpoint)
|
||||
control_saving = (self.tp_rank == 0)
|
||||
control_saving = self.tp_rank == 0
|
||||
|
||||
if self.pp_size == 1:
|
||||
# When pipeline is not used, save the model shards as in general checkpointIO
|
||||
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=weights_name,
|
||||
is_master=control_saving,
|
||||
use_safetensors=use_safetensors)
|
||||
total_size = save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=weights_name,
|
||||
is_master=control_saving,
|
||||
use_safetensors=use_safetensors,
|
||||
)
|
||||
if control_saving:
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
save_config_file(model, checkpoint)
|
||||
if self.verbose:
|
||||
logging.info(f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}.")
|
||||
logging.info(
|
||||
f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}."
|
||||
)
|
||||
|
||||
else:
|
||||
# When pipeline is used, each stage produces its own shard files and index files.
|
||||
@@ -228,15 +238,19 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
|
||||
save_index_file = os.path.join("tmp_index_files", save_index_file)
|
||||
|
||||
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=weights_name,
|
||||
is_master=control_saving,
|
||||
use_safetensors=use_safetensors,
|
||||
use_pp_format=True)
|
||||
total_size = save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=weights_name,
|
||||
is_master=control_saving,
|
||||
use_safetensors=use_safetensors,
|
||||
use_pp_format=True,
|
||||
)
|
||||
if control_saving:
|
||||
assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0."
|
||||
assert (
|
||||
self.dp_rank == 0 and self.tp_rank == 0
|
||||
), "The saving process should have both dp_rank and tp_rank as 0."
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
else:
|
||||
@@ -259,9 +273,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
save_config_file(model, checkpoint)
|
||||
rmtree(tmp_index_file_folder)
|
||||
if self.verbose:
|
||||
logging.info(f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {final_index_file_path}.")
|
||||
logging.info(
|
||||
f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {final_index_file_path}."
|
||||
)
|
||||
|
||||
def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False):
|
||||
"""
|
||||
@@ -305,11 +321,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
state_dict = load_shard_state_dict(Path(file_path), use_safetensors)
|
||||
missing_keys = []
|
||||
|
||||
load_state_dict_into_model(model,
|
||||
state_dict,
|
||||
missing_keys=missing_keys,
|
||||
strict=strict,
|
||||
load_sub_module=True)
|
||||
load_state_dict_into_model(
|
||||
model, state_dict, missing_keys=missing_keys, strict=strict, load_sub_module=True
|
||||
)
|
||||
loaded_file.add(filename)
|
||||
|
||||
# Load parameters.
|
||||
@@ -319,15 +333,17 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
# Load buffers.
|
||||
non_persistent_buffers = set()
|
||||
for n, m in model.named_modules():
|
||||
non_persistent_buffers |= set('.'.join((n, b)) for b in m._non_persistent_buffers_set)
|
||||
non_persistent_buffers |= set(".".join((n, b)) for b in m._non_persistent_buffers_set)
|
||||
for name, buf in model.named_buffers():
|
||||
if buf is not None and name not in non_persistent_buffers:
|
||||
_load(name)
|
||||
|
||||
# Load extra states.
|
||||
extra_state_key = _EXTRA_STATE_KEY_SUFFIX
|
||||
if getattr(model.__class__, "get_extra_state",
|
||||
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
|
||||
if (
|
||||
getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state)
|
||||
is not torch.nn.Module.get_extra_state
|
||||
):
|
||||
_load(extra_state_key)
|
||||
|
||||
# Update master params if mixed-precision training is enabled.
|
||||
@@ -352,12 +368,14 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
if self.verbose:
|
||||
logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
|
||||
|
||||
def save_sharded_optimizer(self,
|
||||
optimizer: OptimizerWrapper,
|
||||
checkpoint: str,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
size_per_shard: int = 1024):
|
||||
def save_sharded_optimizer(
|
||||
self,
|
||||
optimizer: OptimizerWrapper,
|
||||
checkpoint: str,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
size_per_shard: int = 1024,
|
||||
):
|
||||
"""
|
||||
Save sharded optimizer checkpoint under the given checkpointing path.
|
||||
The following files will be created under the path:
|
||||
@@ -393,18 +411,21 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
dp_group=self.dp_group,
|
||||
tp_group=self.tp_group,
|
||||
master_to_working_map=self.master_to_working_map,
|
||||
size_per_shard=size_per_shard)
|
||||
size_per_shard=size_per_shard,
|
||||
)
|
||||
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
|
||||
index_file = CheckpointIndexFile(checkpoint)
|
||||
control_saving = (self.dp_rank == 0 and self.tp_rank == 0)
|
||||
control_saving = self.dp_rank == 0 and self.tp_rank == 0
|
||||
|
||||
if self.pp_size == 1:
|
||||
# When pipeline is not used, save the optimizer shards as in general checkpointIO
|
||||
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=states_name,
|
||||
is_master=control_saving)
|
||||
total_size = save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=states_name,
|
||||
is_master=control_saving,
|
||||
)
|
||||
|
||||
if control_saving:
|
||||
# Store param groups.
|
||||
@@ -415,9 +436,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
if self.verbose:
|
||||
logging.info(f"The optimizer is going to be split to checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}.")
|
||||
logging.info(
|
||||
f"The optimizer is going to be split to checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}."
|
||||
)
|
||||
|
||||
else:
|
||||
# When pipeline is used, each stage produces its own shard files and index files.
|
||||
@@ -433,15 +456,19 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
|
||||
save_index_file = os.path.join("tmp_index_files", save_index_file)
|
||||
|
||||
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=states_name,
|
||||
is_master=control_saving,
|
||||
use_pp_format=True)
|
||||
total_size = save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=states_name,
|
||||
is_master=control_saving,
|
||||
use_pp_format=True,
|
||||
)
|
||||
|
||||
if control_saving:
|
||||
assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0."
|
||||
assert (
|
||||
self.dp_rank == 0 and self.tp_rank == 0
|
||||
), "The saving process should have both dp_rank and tp_rank as 0."
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
else:
|
||||
@@ -451,7 +478,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
|
||||
# The global master rank integrates the index files and clean the folder.
|
||||
if self.pp_rank == 0:
|
||||
|
||||
final_index_file = CheckpointIndexFile(checkpoint)
|
||||
final_index_file.append_meta_data("total_size", 0)
|
||||
|
||||
@@ -470,9 +496,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
rmtree(tmp_index_file_folder)
|
||||
|
||||
if self.verbose:
|
||||
logging.info(f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {final_index_file_path}.")
|
||||
logging.info(
|
||||
f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {final_index_file_path}."
|
||||
)
|
||||
|
||||
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""):
|
||||
"""
|
||||
@@ -484,20 +512,21 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
prefix (str): Not used.
|
||||
"""
|
||||
|
||||
def _get_param_id_from_optimizer_param(param: torch.Tensor,
|
||||
master_to_working_map: Optional[Dict[int, torch.Tensor]] = None):
|
||||
def _get_param_id_from_optimizer_param(
|
||||
param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None
|
||||
):
|
||||
if master_to_working_map is not None:
|
||||
working_param = master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
return optimizer.param_info['param2id'][id(working_param)]
|
||||
return optimizer.param_info["param2id"][id(working_param)]
|
||||
|
||||
# id_map is a mapping from param ids kept by current pipeline, to their corresponding parameter objects.
|
||||
# When Zero is used, the mapped parameter objects should be fp32 master parameters.
|
||||
# IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info.
|
||||
id_map = {}
|
||||
for pg in optimizer.optim.param_groups:
|
||||
for param in pg['params']:
|
||||
for param in pg["params"]:
|
||||
param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map)
|
||||
id_map[param_id] = param
|
||||
|
||||
@@ -505,28 +534,30 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
|
||||
ckpt_root_path = ckpt_index_file.root_path
|
||||
weight_map = ckpt_index_file.weight_map
|
||||
weight_map = {int(k): v for k, v in weight_map.items()} # convert saved id from str to int
|
||||
weight_map = {int(k): v for k, v in weight_map.items()} # convert saved id from str to int
|
||||
|
||||
# Load param_groups
|
||||
param_group_path = ckpt_index_file.get_param_group_filename()
|
||||
if param_group_path is None:
|
||||
raise RuntimeError(f'Invalid index file path {checkpoint_index_file} for an optimizer. \
|
||||
Lacking param group file under current directory.')
|
||||
raise RuntimeError(
|
||||
f"Invalid index file path {checkpoint_index_file} for an optimizer. \
|
||||
Lacking param group file under current directory."
|
||||
)
|
||||
saved_groups = torch.load(param_group_path)
|
||||
|
||||
updated_groups = []
|
||||
for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
|
||||
# obtain updated param group
|
||||
new_pg = copy.deepcopy(saved_pg)
|
||||
new_pg['params'] = old_pg['params'] # The parameters in the same group shouln't change.
|
||||
new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change.
|
||||
updated_groups.append(new_pg)
|
||||
optimizer.optim.__dict__.update({'param_groups': updated_groups})
|
||||
optimizer.optim.__dict__.update({"param_groups": updated_groups})
|
||||
|
||||
# Load saved states to optimizer.
|
||||
# Keep a record of loaded files so that file will not be repeatedly loaded.
|
||||
loaded_file = set()
|
||||
for pg in optimizer.optim.param_groups:
|
||||
for param in pg['params']:
|
||||
for param in pg["params"]:
|
||||
if param is None:
|
||||
continue
|
||||
param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map)
|
||||
@@ -550,12 +581,10 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
working_param = self.master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
original_shape = optimizer.param_info['param2shape'][id(working_param)]
|
||||
sharded_state = self.shard_from_complete_optimizer_state(state,
|
||||
current_shape=working_param.shape,
|
||||
original_shape=original_shape,
|
||||
device=device,
|
||||
inplace=True)
|
||||
original_shape = optimizer.param_info["param2shape"][id(working_param)]
|
||||
sharded_state = self.shard_from_complete_optimizer_state(
|
||||
state, current_shape=working_param.shape, original_shape=original_shape, device=device, inplace=True
|
||||
)
|
||||
optimizer.optim.state[param] = sharded_state
|
||||
|
||||
sharded_optimizer_loading_epilogue(optimizer.optim)
|
||||
@@ -585,8 +614,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
if self.coordinator.is_master():
|
||||
super().save_lr_scheduler(lr_scheduler, checkpoint)
|
||||
|
||||
def link_master_and_working_param(self, working_to_master_map: Dict[Union[int, torch.Tensor], torch.Tensor],
|
||||
master_to_working_map: Dict[Union[int, torch.Tensor], torch.Tensor]):
|
||||
def link_master_and_working_param(
|
||||
self,
|
||||
working_to_master_map: Dict[Union[int, torch.Tensor], torch.Tensor],
|
||||
master_to_working_map: Dict[Union[int, torch.Tensor], torch.Tensor],
|
||||
):
|
||||
"""
|
||||
Create mappings between working params (for forward/backward) and master params (for optimizer update) with passed in mappings.
|
||||
This mapping can only be created when mixied precision is used.
|
||||
@@ -604,7 +636,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
self.working_to_master_map[k] = v
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!")
|
||||
f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!"
|
||||
)
|
||||
|
||||
self.master_to_working_map = dict()
|
||||
for k, v in master_to_working_map.items():
|
||||
@@ -614,12 +647,19 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
self.master_to_working_map[k] = v
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!")
|
||||
f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def gather_from_sharded_optimizer_state(state: OrderedDict, param: torch.Tensor, original_shape: torch.Size,
|
||||
dp_group: ProcessGroup, tp_group: ProcessGroup, use_zero: bool,
|
||||
inplace: bool) -> OrderedDict:
|
||||
def gather_from_sharded_optimizer_state(
|
||||
state: OrderedDict,
|
||||
param: torch.Tensor,
|
||||
original_shape: torch.Size,
|
||||
dp_group: ProcessGroup,
|
||||
tp_group: ProcessGroup,
|
||||
use_zero: bool,
|
||||
inplace: bool,
|
||||
) -> OrderedDict:
|
||||
"""
|
||||
With given parameter and its optimizer states, gather the complete optimizer state for saving.
|
||||
|
||||
@@ -641,14 +681,13 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
state_ = state if inplace else copy.deepcopy(state)
|
||||
|
||||
for k, v in state_.items():
|
||||
if isinstance(v, torch.Tensor) and k != 'step':
|
||||
|
||||
if isinstance(v, torch.Tensor) and k != "step":
|
||||
# First gather Zero shards.
|
||||
if use_zero:
|
||||
v = v.cuda()
|
||||
gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)]
|
||||
dist.all_gather(gather_tensor, v, group=dp_group)
|
||||
v = torch.stack(gather_tensor).view(-1)[:param.numel()].reshape_as(param)
|
||||
v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
|
||||
|
||||
# Then gather TP shards.
|
||||
partition_dim = search_tp_partition_dim(current_shape, original_shape, tp_size)
|
||||
@@ -661,9 +700,14 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
|
||||
return state_
|
||||
|
||||
def shard_from_complete_optimizer_state(self, state: OrderedDict, current_shape: torch.Size,
|
||||
original_shape: torch.Size, device: torch.device,
|
||||
inplace: bool) -> OrderedDict:
|
||||
def shard_from_complete_optimizer_state(
|
||||
self,
|
||||
state: OrderedDict,
|
||||
current_shape: torch.Size,
|
||||
original_shape: torch.Size,
|
||||
device: torch.device,
|
||||
inplace: bool,
|
||||
) -> OrderedDict:
|
||||
"""
|
||||
With complete optimizer states of a specific parameter loaded from checkpoint,
|
||||
slice out the sharded optimizer states kept by current device.
|
||||
@@ -681,8 +725,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
state_ = state if inplace else copy.deepcopy(state)
|
||||
|
||||
for k, v in state_.items():
|
||||
if isinstance(v, torch.Tensor) and k != 'step':
|
||||
|
||||
if isinstance(v, torch.Tensor) and k != "step":
|
||||
# Shard state along tensor parallel group.
|
||||
partition_dim = search_tp_partition_dim(current_shape, original_shape, self.tp_size)
|
||||
if partition_dim is not None:
|
||||
|
Reference in New Issue
Block a user