mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 13:00:52 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -11,7 +11,7 @@ from colossalai.legacy.core import global_context as gpc
|
||||
try:
|
||||
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
|
||||
except ImportError:
|
||||
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
|
||||
_EXTRA_STATE_KEY_SUFFIX = "_extra_state"
|
||||
|
||||
from .common import is_using_pp
|
||||
|
||||
@@ -25,10 +25,9 @@ def broadcast_state_dict(state_dict, parallel_mode):
|
||||
return state_dict[0]
|
||||
|
||||
|
||||
def partition_tensor_parallel_state_dict(state_dict: OrderedDict,
|
||||
parallel_mode: ParallelMode,
|
||||
dims: dict = dict(),
|
||||
partition_states: dict = dict()):
|
||||
def partition_tensor_parallel_state_dict(
|
||||
state_dict: OrderedDict, parallel_mode: ParallelMode, dims: dict = dict(), partition_states: dict = dict()
|
||||
):
|
||||
src_rank = gpc.get_ranks_in_group(parallel_mode)[0]
|
||||
depth = gpc.get_world_size(parallel_mode)
|
||||
group = gpc.get_cpu_group(parallel_mode)
|
||||
@@ -65,11 +64,11 @@ def partition_tensor_parallel_state_dict(state_dict: OrderedDict,
|
||||
|
||||
|
||||
def gather_tensor_parallel_state_dict(
|
||||
state_dict: OrderedDict,
|
||||
parallel_mode: ParallelMode,
|
||||
dims: dict = dict(),
|
||||
partition_states: dict = dict(),
|
||||
keep_vars: bool = False,
|
||||
state_dict: OrderedDict,
|
||||
parallel_mode: ParallelMode,
|
||||
dims: dict = dict(),
|
||||
partition_states: dict = dict(),
|
||||
keep_vars: bool = False,
|
||||
):
|
||||
dst_rank = gpc.get_ranks_in_group(parallel_mode)[0]
|
||||
depth = gpc.get_world_size(parallel_mode)
|
||||
@@ -138,8 +137,11 @@ def partition_pipeline_parallel_state_dict(model, state_dict):
|
||||
|
||||
|
||||
def gather_pipeline_parallel_state_dict(state_dict):
|
||||
gathered_states = ([None for _ in range(gpc.get_world_size(ParallelMode.PIPELINE))]
|
||||
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0 else None)
|
||||
gathered_states = (
|
||||
[None for _ in range(gpc.get_world_size(ParallelMode.PIPELINE))]
|
||||
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0
|
||||
else None
|
||||
)
|
||||
dist.gather_object(
|
||||
state_dict,
|
||||
gathered_states,
|
||||
@@ -147,18 +149,23 @@ def gather_pipeline_parallel_state_dict(state_dict):
|
||||
group=gpc.get_cpu_group(ParallelMode.PIPELINE),
|
||||
)
|
||||
|
||||
state_dict = (OrderedDict(chain.from_iterable(state.items() for state in gathered_states))
|
||||
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0 else OrderedDict())
|
||||
state_dict = (
|
||||
OrderedDict(chain.from_iterable(state.items() for state in gathered_states))
|
||||
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0
|
||||
else OrderedDict()
|
||||
)
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def save_checkpoint(file,
|
||||
epoch: int,
|
||||
model: torch.nn.Module,
|
||||
optimizer: torch.optim.Optimizer = None,
|
||||
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
|
||||
**kwargs):
|
||||
def save_checkpoint(
|
||||
file,
|
||||
epoch: int,
|
||||
model: torch.nn.Module,
|
||||
optimizer: torch.optim.Optimizer = None,
|
||||
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Stores the checkpoint to disk. Saves all the training components' parameters or buffers, such as model, optimizer,
|
||||
lr_scheduler etc. into a checkpoint dictionary.
|
||||
|
||||
@@ -196,8 +203,11 @@ def broadcast_model(model: torch.nn.Module):
|
||||
src_rank = gpc.get_ranks_in_group(ParallelMode.TENSOR)[0]
|
||||
for p in model.parameters():
|
||||
if not getattr(p, IS_TENSOR_PARALLEL, False) and p.storage().size() > 0:
|
||||
group = gpc.get_group(ParallelMode.TENSOR) if p.device.type == 'cuda' else gpc.get_cpu_group(
|
||||
ParallelMode.TENSOR)
|
||||
group = (
|
||||
gpc.get_group(ParallelMode.TENSOR)
|
||||
if p.device.type == "cuda"
|
||||
else gpc.get_cpu_group(ParallelMode.TENSOR)
|
||||
)
|
||||
dist.broadcast(p, src_rank, group=group)
|
||||
|
||||
|
||||
@@ -226,8 +236,9 @@ def load_checkpoint(
|
||||
Raises:
|
||||
RuntimeError: Raise error if the model/optimizer cannot successfully be recuperated
|
||||
"""
|
||||
state_dict = (torch.load(file, map_location=torch.device("cpu"))
|
||||
if gpc.get_local_rank(ParallelMode.MODEL) == 0 else None)
|
||||
state_dict = (
|
||||
torch.load(file, map_location=torch.device("cpu")) if gpc.get_local_rank(ParallelMode.MODEL) == 0 else None
|
||||
)
|
||||
|
||||
# model states
|
||||
model_state = state_dict.pop("model") if state_dict is not None else dict()
|
||||
@@ -246,8 +257,11 @@ def load_checkpoint(
|
||||
dist.gather_object(error_msgs, all_error_msgs, dst=dst_rank, group=gpc.get_cpu_group(ParallelMode.MODEL))
|
||||
if gpc.get_global_rank() == 0:
|
||||
all_error_msgs = list(chain.from_iterable(all_error_msgs))
|
||||
raise RuntimeError("Error(s) in loading state_dict for {}:\n\t{}".format(
|
||||
model.__class__.__name__, "\n\t".join(all_error_msgs)))
|
||||
raise RuntimeError(
|
||||
"Error(s) in loading state_dict for {}:\n\t{}".format(
|
||||
model.__class__.__name__, "\n\t".join(all_error_msgs)
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
Reference in New Issue
Block a user