mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -19,18 +19,16 @@ from colossalai.tensor.d_tensor import (
|
||||
is_customized_distributed_tensor,
|
||||
is_distributed_tensor,
|
||||
sharded_tensor_to_param,
|
||||
to_global,
|
||||
to_global_for_customized_distributed_tensor,
|
||||
)
|
||||
|
||||
__all__ = ['ParallelModule']
|
||||
__all__ = ["ParallelModule"]
|
||||
|
||||
|
||||
class ParallelModule(nn.Module, ABC):
|
||||
|
||||
@abstractmethod
|
||||
def from_native_module(module: nn.Module,
|
||||
process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "ParallelModule":
|
||||
def from_native_module(
|
||||
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]] = None
|
||||
) -> "ParallelModule":
|
||||
"""
|
||||
Convert a native PyTorch module to a parallelized module.
|
||||
|
||||
@@ -40,7 +38,6 @@ class ParallelModule(nn.Module, ABC):
|
||||
If this is a list, the process group at the ith index of the list will correspond to the process group
|
||||
in the ith axis of the device mesh. Defaults to None, which means the global process group.
|
||||
"""
|
||||
pass
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
r"""Saves module state to `destination` dictionary, containing a state
|
||||
@@ -66,8 +63,9 @@ class ParallelModule(nn.Module, ABC):
|
||||
if getattr(self.__class__, "get_extra_state", Module.get_extra_state) is not Module.get_extra_state:
|
||||
destination[extra_state_key] = self.get_extra_state()
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
|
||||
error_msgs):
|
||||
def _load_from_state_dict(
|
||||
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||
):
|
||||
r"""Copies parameters and buffers from :attr:`state_dict` into only
|
||||
this module, but not its descendants. This is called on every submodule
|
||||
in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
|
||||
@@ -112,9 +110,11 @@ class ParallelModule(nn.Module, ABC):
|
||||
if key in state_dict:
|
||||
input_param = state_dict[key]
|
||||
if not torch.overrides.is_tensor_like(input_param):
|
||||
error_msgs.append('While copying the parameter named "{}", '
|
||||
'expected torch.Tensor or Tensor-like object from checkpoint but '
|
||||
'received {}'.format(key, type(input_param)))
|
||||
error_msgs.append(
|
||||
'While copying the parameter named "{}", '
|
||||
"expected torch.Tensor or Tensor-like object from checkpoint but "
|
||||
"received {}".format(key, type(input_param))
|
||||
)
|
||||
continue
|
||||
|
||||
if is_distributed_tensor(param):
|
||||
@@ -136,19 +136,22 @@ class ParallelModule(nn.Module, ABC):
|
||||
|
||||
if not is_param_lazy and input_param.shape != param.shape:
|
||||
# local shape should match the one in checkpoint
|
||||
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
|
||||
'the shape in current model is {}.'.format(key, input_param.shape, param.shape))
|
||||
error_msgs.append(
|
||||
"size mismatch for {}: copying a param with shape {} from checkpoint, "
|
||||
"the shape in current model is {}.".format(key, input_param.shape, param.shape)
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
with torch.no_grad():
|
||||
param.copy_(input_param)
|
||||
except Exception as ex:
|
||||
error_msgs.append('While copying the parameter named "{}", '
|
||||
'whose dimensions in the model are {} and '
|
||||
'whose dimensions in the checkpoint are {}, '
|
||||
'an exception occurred : {}.'.format(key, param.size(), input_param.size(),
|
||||
ex.args))
|
||||
error_msgs.append(
|
||||
'While copying the parameter named "{}", '
|
||||
"whose dimensions in the model are {} and "
|
||||
"whose dimensions in the checkpoint are {}, "
|
||||
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
|
||||
)
|
||||
elif strict:
|
||||
missing_keys.append(key)
|
||||
|
||||
@@ -164,7 +167,7 @@ class ParallelModule(nn.Module, ABC):
|
||||
if strict:
|
||||
for key in state_dict.keys():
|
||||
if key.startswith(prefix) and key != extra_state_key:
|
||||
input_name = key[len(prefix):]
|
||||
input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child
|
||||
input_name = key[len(prefix) :]
|
||||
input_name = input_name.split(".", 1)[0] # get the name of param/buffer/child
|
||||
if input_name not in self._modules and input_name not in local_state:
|
||||
unexpected_keys.append(key)
|
||||
|
Reference in New Issue
Block a user