[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -11,7 +11,7 @@ from colossalai.legacy.core import global_context as gpc
try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
except ImportError:
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
_EXTRA_STATE_KEY_SUFFIX = "_extra_state"
from .common import is_using_pp
@@ -25,10 +25,9 @@ def broadcast_state_dict(state_dict, parallel_mode):
return state_dict[0]
def partition_tensor_parallel_state_dict(state_dict: OrderedDict,
parallel_mode: ParallelMode,
dims: dict = dict(),
partition_states: dict = dict()):
def partition_tensor_parallel_state_dict(
state_dict: OrderedDict, parallel_mode: ParallelMode, dims: dict = dict(), partition_states: dict = dict()
):
src_rank = gpc.get_ranks_in_group(parallel_mode)[0]
depth = gpc.get_world_size(parallel_mode)
group = gpc.get_cpu_group(parallel_mode)
@@ -65,11 +64,11 @@ def partition_tensor_parallel_state_dict(state_dict: OrderedDict,
def gather_tensor_parallel_state_dict(
state_dict: OrderedDict,
parallel_mode: ParallelMode,
dims: dict = dict(),
partition_states: dict = dict(),
keep_vars: bool = False,
state_dict: OrderedDict,
parallel_mode: ParallelMode,
dims: dict = dict(),
partition_states: dict = dict(),
keep_vars: bool = False,
):
dst_rank = gpc.get_ranks_in_group(parallel_mode)[0]
depth = gpc.get_world_size(parallel_mode)
@@ -138,8 +137,11 @@ def partition_pipeline_parallel_state_dict(model, state_dict):
def gather_pipeline_parallel_state_dict(state_dict):
gathered_states = ([None for _ in range(gpc.get_world_size(ParallelMode.PIPELINE))]
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0 else None)
gathered_states = (
[None for _ in range(gpc.get_world_size(ParallelMode.PIPELINE))]
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0
else None
)
dist.gather_object(
state_dict,
gathered_states,
@@ -147,18 +149,23 @@ def gather_pipeline_parallel_state_dict(state_dict):
group=gpc.get_cpu_group(ParallelMode.PIPELINE),
)
state_dict = (OrderedDict(chain.from_iterable(state.items() for state in gathered_states))
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0 else OrderedDict())
state_dict = (
OrderedDict(chain.from_iterable(state.items() for state in gathered_states))
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0
else OrderedDict()
)
return state_dict
def save_checkpoint(file,
epoch: int,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer = None,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
**kwargs):
def save_checkpoint(
file,
epoch: int,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer = None,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
**kwargs,
):
"""Stores the checkpoint to disk. Saves all the training components' parameters or buffers, such as model, optimizer,
lr_scheduler etc. into a checkpoint dictionary.
@@ -196,8 +203,11 @@ def broadcast_model(model: torch.nn.Module):
src_rank = gpc.get_ranks_in_group(ParallelMode.TENSOR)[0]
for p in model.parameters():
if not getattr(p, IS_TENSOR_PARALLEL, False) and p.storage().size() > 0:
group = gpc.get_group(ParallelMode.TENSOR) if p.device.type == 'cuda' else gpc.get_cpu_group(
ParallelMode.TENSOR)
group = (
gpc.get_group(ParallelMode.TENSOR)
if p.device.type == "cuda"
else gpc.get_cpu_group(ParallelMode.TENSOR)
)
dist.broadcast(p, src_rank, group=group)
@@ -226,8 +236,9 @@ def load_checkpoint(
Raises:
RuntimeError: Raise error if the model/optimizer cannot successfully be recuperated
"""
state_dict = (torch.load(file, map_location=torch.device("cpu"))
if gpc.get_local_rank(ParallelMode.MODEL) == 0 else None)
state_dict = (
torch.load(file, map_location=torch.device("cpu")) if gpc.get_local_rank(ParallelMode.MODEL) == 0 else None
)
# model states
model_state = state_dict.pop("model") if state_dict is not None else dict()
@@ -246,8 +257,11 @@ def load_checkpoint(
dist.gather_object(error_msgs, all_error_msgs, dst=dst_rank, group=gpc.get_cpu_group(ParallelMode.MODEL))
if gpc.get_global_rank() == 0:
all_error_msgs = list(chain.from_iterable(all_error_msgs))
raise RuntimeError("Error(s) in loading state_dict for {}:\n\t{}".format(
model.__class__.__name__, "\n\t".join(all_error_msgs)))
raise RuntimeError(
"Error(s) in loading state_dict for {}:\n\t{}".format(
model.__class__.__name__, "\n\t".join(all_error_msgs)
)
)
else:
raise e