[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -19,18 +19,16 @@ from colossalai.tensor.d_tensor import (
is_customized_distributed_tensor,
is_distributed_tensor,
sharded_tensor_to_param,
to_global,
to_global_for_customized_distributed_tensor,
)
__all__ = ['ParallelModule']
__all__ = ["ParallelModule"]
class ParallelModule(nn.Module, ABC):
@abstractmethod
def from_native_module(module: nn.Module,
process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "ParallelModule":
def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]] = None
) -> "ParallelModule":
"""
Convert a native PyTorch module to a parallelized module.
@@ -40,7 +38,6 @@ class ParallelModule(nn.Module, ABC):
If this is a list, the process group at the ith index of the list will correspond to the process group
in the ith axis of the device mesh. Defaults to None, which means the global process group.
"""
pass
def _save_to_state_dict(self, destination, prefix, keep_vars):
r"""Saves module state to `destination` dictionary, containing a state
@@ -66,8 +63,9 @@ class ParallelModule(nn.Module, ABC):
if getattr(self.__class__, "get_extra_state", Module.get_extra_state) is not Module.get_extra_state:
destination[extra_state_key] = self.get_extra_state()
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs):
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
r"""Copies parameters and buffers from :attr:`state_dict` into only
this module, but not its descendants. This is called on every submodule
in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
@@ -112,9 +110,11 @@ class ParallelModule(nn.Module, ABC):
if key in state_dict:
input_param = state_dict[key]
if not torch.overrides.is_tensor_like(input_param):
error_msgs.append('While copying the parameter named "{}", '
'expected torch.Tensor or Tensor-like object from checkpoint but '
'received {}'.format(key, type(input_param)))
error_msgs.append(
'While copying the parameter named "{}", '
"expected torch.Tensor or Tensor-like object from checkpoint but "
"received {}".format(key, type(input_param))
)
continue
if is_distributed_tensor(param):
@@ -136,19 +136,22 @@ class ParallelModule(nn.Module, ABC):
if not is_param_lazy and input_param.shape != param.shape:
# local shape should match the one in checkpoint
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
'the shape in current model is {}.'.format(key, input_param.shape, param.shape))
error_msgs.append(
"size mismatch for {}: copying a param with shape {} from checkpoint, "
"the shape in current model is {}.".format(key, input_param.shape, param.shape)
)
continue
try:
with torch.no_grad():
param.copy_(input_param)
except Exception as ex:
error_msgs.append('While copying the parameter named "{}", '
'whose dimensions in the model are {} and '
'whose dimensions in the checkpoint are {}, '
'an exception occurred : {}.'.format(key, param.size(), input_param.size(),
ex.args))
error_msgs.append(
'While copying the parameter named "{}", '
"whose dimensions in the model are {} and "
"whose dimensions in the checkpoint are {}, "
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
)
elif strict:
missing_keys.append(key)
@@ -164,7 +167,7 @@ class ParallelModule(nn.Module, ABC):
if strict:
for key in state_dict.keys():
if key.startswith(prefix) and key != extra_state_key:
input_name = key[len(prefix):]
input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child
input_name = key[len(prefix) :]
input_name = input_name.split(".", 1)[0] # get the name of param/buffer/child
if input_name not in self._modules and input_name not in local_state:
unexpected_keys.append(key)