[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -2,7 +2,6 @@
import functools
import itertools
from collections import OrderedDict
from copy import deepcopy
from typing import Any, Iterator, Optional, Tuple
import torch
@@ -40,7 +39,7 @@ from .zero_hook import ZeroHook
try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
except ImportError:
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
_EXTRA_STATE_KEY_SUFFIX = "_extra_state"
class ShardedModelV2(nn.Module):
@@ -78,20 +77,22 @@ class ShardedModelV2(nn.Module):
bf16 (bool, optional): Whether to use bfloat16 for param and grad. Defaults to False.
"""
def __init__(self,
module: nn.Module,
shard_strategy: BaseShardStrategy,
process_group: Optional[ProcessGroup] = None,
reduce_scatter_process_group: Optional[ProcessGroup] = None,
reduce_scatter_bucket_size_mb: int = 25,
fp32_reduce_scatter: bool = False,
tensor_placement_policy: str = 'cuda',
gradient_predivide_factor: Optional[float] = 1.0,
reuse_fp16_shard: bool = False,
bf16: bool = False,
*args,
**kwargs):
assert not isinstance(module, ShardedModelV2), 'Nested ShardedModelV2 is not supported.'
def __init__(
self,
module: nn.Module,
shard_strategy: BaseShardStrategy,
process_group: Optional[ProcessGroup] = None,
reduce_scatter_process_group: Optional[ProcessGroup] = None,
reduce_scatter_bucket_size_mb: int = 25,
fp32_reduce_scatter: bool = False,
tensor_placement_policy: str = "cuda",
gradient_predivide_factor: Optional[float] = 1.0,
reuse_fp16_shard: bool = False,
bf16: bool = False,
*args,
**kwargs,
):
assert not isinstance(module, ShardedModelV2), "Nested ShardedModelV2 is not supported."
super().__init__()
self.logger = get_dist_logger()
self.bf16 = bf16
@@ -101,13 +102,13 @@ class ShardedModelV2(nn.Module):
sharded_cnt = 0
unshard_cnt = 0
for param in submodule.parameters(recurse=False):
assert hasattr(param, 'colo_attr'), 'You must use ZeroInitContext to init your module first.'
assert hasattr(param, "colo_attr"), "You must use ZeroInitContext to init your module first."
if param.colo_attr.param_is_sharded:
sharded_cnt += 1
else:
unshard_cnt += 1
assert (not sharded_cnt) or (not unshard_cnt), 'nn.Module can not both have shard param and unshard param'
submodule.param_is_sharded = (sharded_cnt > 0)
assert (not sharded_cnt) or (not unshard_cnt), "nn.Module can not both have shard param and unshard param"
submodule.param_is_sharded = sharded_cnt > 0
self.sharded_params = []
self.unshard_params = []
@@ -124,7 +125,7 @@ class ShardedModelV2(nn.Module):
self.rank = dist.get_rank(self.process_group)
self.shard_strategy = shard_strategy
self._use_memory_tracer = tensor_placement_policy == 'auto'
self._use_memory_tracer = tensor_placement_policy == "auto"
if self._use_memory_tracer:
self._memstats_collector = MemStatsCollector()
self._start_collect_memstats = disposable(self._memstats_collector.start_collection)
@@ -132,18 +133,19 @@ class ShardedModelV2(nn.Module):
else:
self._memstats_collector = None
self._tensor_placement_policy: TensorPlacementPolicy = TensorPlacementPolicyFactory.create(
tensor_placement_policy)(mem_stats_collector=self._memstats_collector)
tensor_placement_policy
)(mem_stats_collector=self._memstats_collector)
if 'warmup_non_model_data_ratio' in kwargs:
if tensor_placement_policy != 'auto':
self.logger.warning('setting warmup_non_model_data_ratio is useless if not use auto placement')
if "warmup_non_model_data_ratio" in kwargs:
if tensor_placement_policy != "auto":
self.logger.warning("setting warmup_non_model_data_ratio is useless if not use auto placement")
else:
ratio = kwargs['warmup_non_model_data_ratio']
ratio = kwargs["warmup_non_model_data_ratio"]
self._tensor_placement_policy._warmup_non_model_data_ratio = ratio
self.logger.info(f'setting warmup_non_model_data_ratio as {ratio} for auto placement')
self.logger.info(f"setting warmup_non_model_data_ratio as {ratio} for auto placement")
self._stateful_tensor_mgr = StatefulTensorMgr(self._tensor_placement_policy)
param_tensor_list = [p.colo_attr.sharded_data_tensor for p in module.parameters() if hasattr(p, 'colo_attr')]
param_tensor_list = [p.colo_attr.sharded_data_tensor for p in module.parameters() if hasattr(p, "colo_attr")]
self._stateful_tensor_mgr.register_stateful_tensor_list(param_tensor_list)
# Register hooks
@@ -155,7 +157,7 @@ class ShardedModelV2(nn.Module):
self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook)
self.fp32_reduce_scatter = fp32_reduce_scatter
self._cpu_offload: bool = tensor_placement_policy != 'cuda'
self._cpu_offload: bool = tensor_placement_policy != "cuda"
for param in module.parameters():
# Init `offload_grad`
param.colo_attr.offload_grad = self._cpu_offload
@@ -164,9 +166,11 @@ class ShardedModelV2(nn.Module):
# So we use 1.0 as the default gradient_predivide_factor
# However, if you set gradient_predivide_factor to None, we will set
# gradient_predivide_factor to a value >= 1.0 automatically
self.gradient_predivide_factor: float = gradient_predivide_factor if \
gradient_predivide_factor is not None else \
get_gradient_predivide_factor(self.world_size)
self.gradient_predivide_factor: float = (
gradient_predivide_factor
if gradient_predivide_factor is not None
else get_gradient_predivide_factor(self.world_size)
)
self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor
self.comm_stream: torch.cuda.Stream = torch.cuda.Stream()
@@ -194,7 +198,7 @@ class ShardedModelV2(nn.Module):
def cpu_offload(self):
return self._cpu_offload
def dump_memory_stats(self, filename: Optional[str] = 'dump_mem_stats.log') -> None:
def dump_memory_stats(self, filename: Optional[str] = "dump_mem_stats.log") -> None:
"""
dummy memory tracer collected information to a file.
try:
@@ -205,18 +209,18 @@ class ShardedModelV2(nn.Module):
exit(0)
"""
if self._use_memory_tracer:
self.logger.error(f'dump memory tracer collected information to a {filename}', ranks=[0])
self.logger.error(f"dump memory tracer collected information to a {filename}", ranks=[0])
if gpc.get_global_rank() == 0:
with open(filename, 'w+') as f:
f.write(f'cuda reserved {torch.cuda.memory_reserved(get_current_device()) / 1e9} GB\n')
f.write(f'cuda max allocated {torch.cuda.max_memory_allocated(get_current_device()) / 1e9} GB\n')
f.write('CUDA model data (GB)\n')
f.write('\n')
f.write('CUDA non model data (GB)\n')
f.write(str(self._memstats_collector._memstats.non_model_data_list('cuda')))
f.write('CPU non model data (GB)\n')
f.write(str(self._memstats_collector._memstats.non_model_data_list('cpu')))
f.write('\n')
with open(filename, "w+") as f:
f.write(f"cuda reserved {torch.cuda.memory_reserved(get_current_device()) / 1e9} GB\n")
f.write(f"cuda max allocated {torch.cuda.max_memory_allocated(get_current_device()) / 1e9} GB\n")
f.write("CUDA model data (GB)\n")
f.write("\n")
f.write("CUDA non model data (GB)\n")
f.write(str(self._memstats_collector._memstats.non_model_data_list("cuda")))
f.write("CPU non model data (GB)\n")
f.write(str(self._memstats_collector._memstats.non_model_data_list("cpu")))
f.write("\n")
def _pre_forward_operations(self, *args):
# the operation will affect the memory tracer behavior in ZeroHook
@@ -224,14 +228,14 @@ class ShardedModelV2(nn.Module):
self._start_collect_memstats()
for p in self.module.parameters():
if hasattr(p, 'colo_attr'):
if hasattr(p, "colo_attr"):
p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD)
self._stateful_tensor_mgr.start_iter()
def _post_forward_operations(self):
for p in self.module.parameters():
if hasattr(p, 'colo_attr'):
if hasattr(p, "colo_attr"):
p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD)
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
@@ -261,8 +265,9 @@ class ShardedModelV2(nn.Module):
# the way to calculate margin space is based on the assumption that
# model data is fixed in cuda during training.
# cuda margin space can be used to store OS.
self._cuda_margin_space = colo_device_memory_capacity(
get_current_device()) - self._memstats_collector._memstats.max_overall_cuda
self._cuda_margin_space = (
colo_device_memory_capacity(get_current_device()) - self._memstats_collector._memstats.max_overall_cuda
)
@torch.no_grad()
def _post_backward_operations(self) -> None:
@@ -330,7 +335,7 @@ class ShardedModelV2(nn.Module):
"""
if grad is None:
return
assert not grad.requires_grad, 'ShardedModel only works with gradients that don\'t require gradients'
assert not grad.requires_grad, "ShardedModel only works with gradients that don't require gradients"
if not self._require_backward_grad_sync:
return
# used to cheat Pytorch, since we can't return None
@@ -354,16 +359,19 @@ class ShardedModelV2(nn.Module):
grad.data.div_(self.gradient_predivide_factor)
if self.world_size > 1:
grad_chunks = chunk_and_pad(grad, self.reduce_scatter_process_group.size())
self.reducer.reduce_scatter_async(grad_chunks,
group=self.reduce_scatter_process_group,
callback_fn=functools.partial(self._reduce_scatter_callback, param))
self.reducer.reduce_scatter_async(
grad_chunks,
group=self.reduce_scatter_process_group,
callback_fn=functools.partial(self._reduce_scatter_callback, param),
)
else:
self._reduce_scatter_callback(param, grad)
torch.cuda.current_stream().wait_stream(self.comm_stream)
def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None:
assert isinstance(reduced_grad,
torch.Tensor), f"_reduce_scatter_callback accept reduced_grad as {type(reduced_grad)}"
assert isinstance(
reduced_grad, torch.Tensor
), f"_reduce_scatter_callback accept reduced_grad as {type(reduced_grad)}"
reduced_grad.data = reduced_grad.data.contiguous().view(-1)
if self.gradient_postdivide_factor > 1:
# Average grad by world_size for consistency with PyTorch DDP.
@@ -372,7 +380,6 @@ class ShardedModelV2(nn.Module):
# FIXME(ver217): refactor the below line when impl eviction policy
def _save_grad(self, param: Parameter, grad: torch.Tensor):
# record whether we have overflow
self.overflow_counter += torch.isinf(grad).any().item()
self.overflow_counter += torch.isnan(grad).any().item()
@@ -384,8 +391,9 @@ class ShardedModelV2(nn.Module):
if self.reuse_fp16_shard:
# make parameters point to gradient
assert param.colo_attr.saved_grad.is_null(
), 'Gradient accumulation is not supported when reuse_fp16_shard=True'
assert (
param.colo_attr.saved_grad.is_null()
), "Gradient accumulation is not supported when reuse_fp16_shard=True"
param.colo_attr.grad_payload_reset(grad.data)
# release the memory of param
@@ -396,7 +404,6 @@ class ShardedModelV2(nn.Module):
if param.colo_attr.is_replicated:
param.colo_attr.sharded_data_tensor.is_sharded = True
else:
fp32_grad = cast_tensor_to_fp32(grad)
if param.colo_attr.saved_grad.is_null():
@@ -410,39 +417,44 @@ class ShardedModelV2(nn.Module):
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
return self.module.parameters(recurse=recurse)
def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]:
def named_parameters(self, prefix: str = "", recurse: bool = True) -> Iterator[Tuple[str, Parameter]]:
return self.module.named_parameters(prefix, recurse)
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
return self._colo_state_dict(destination,
prefix,
keep_vars,
shard_strategy=self.shard_strategy,
state_dict_func=nn.Module.state_dict,
module_to_load=self.module,
sharded_params=self.sharded_params,
process_group=self.process_group)
def state_dict(self, destination=None, prefix="", keep_vars=False) -> "OrderedDict[str, torch.Tensor]":
return self._colo_state_dict(
destination,
prefix,
keep_vars,
shard_strategy=self.shard_strategy,
state_dict_func=nn.Module.state_dict,
module_to_load=self.module,
sharded_params=self.sharded_params,
process_group=self.process_group,
)
def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True) -> None:
def load_state_dict(self, state_dict: "OrderedDict[str, torch.Tensor]", strict: bool = True) -> None:
for name, p in self.named_parameters():
if name in state_dict:
p.colo_attr.data_payload_reset(state_dict[name].to(dtype=p.colo_attr.data_payload.dtype,
device=p.colo_attr.data_payload.device))
p.colo_attr.data_payload_reset(
state_dict[name].to(dtype=p.colo_attr.data_payload.dtype, device=p.colo_attr.data_payload.device)
)
# Force re-shard
p.colo_attr.sharded_data_tensor.is_sharded = False
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor])
elif strict:
raise RuntimeError(f'Missing key in state_dict: {name}')
raise RuntimeError(f"Missing key in state_dict: {name}")
def _colo_state_dict(self,
destination=None,
prefix='',
keep_vars=False,
shard_strategy: Optional[BaseShardStrategy] = None,
state_dict_func=None,
module_to_load=None,
sharded_params=[],
process_group=None) -> 'OrderedDict[str, torch.Tensor]':
def _colo_state_dict(
self,
destination=None,
prefix="",
keep_vars=False,
shard_strategy: Optional[BaseShardStrategy] = None,
state_dict_func=None,
module_to_load=None,
sharded_params=[],
process_group=None,
) -> "OrderedDict[str, torch.Tensor]":
if len(sharded_params) == 0:
for param in self.parameters():
if param.colo_attr.param_is_sharded:
@@ -460,15 +472,9 @@ class ShardedModelV2(nn.Module):
p.colo_attr.set_data_none()
return gathered_state_dict
def _colo_load_from_state_dict(self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
shard_strategy=None):
def _colo_load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs, shard_strategy=None
):
r"""Copies parameters and buffers from :attr:`state_dict` into only
this module, but not its descendants. This is called on every submodule
in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
@@ -512,10 +518,12 @@ class ShardedModelV2(nn.Module):
key = prefix + name
if key in state_dict:
input_param = state_dict[key]
if hasattr(param, 'colo_attr'):
if hasattr(param, "colo_attr"):
param.colo_attr.data_payload_reset(
input_param.to(dtype=param.colo_attr.data_payload.dtype,
device=param.colo_attr.data_payload.device))
input_param.to(
dtype=param.colo_attr.data_payload.dtype, device=param.colo_attr.data_payload.device
)
)
if shard_strategy is not None:
# Force re-shard
param.colo_attr.sharded_data_tensor.is_sharded = False
@@ -531,19 +539,21 @@ class ShardedModelV2(nn.Module):
if not is_param_lazy and input_param.shape != param.shape:
# local shape should match the one in checkpoint
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
'the shape in current model is {}.'.format(
key, input_param.shape, param.shape))
error_msgs.append(
"size mismatch for {}: copying a param with shape {} from checkpoint, "
"the shape in current model is {}.".format(key, input_param.shape, param.shape)
)
continue
try:
with torch.no_grad():
param.copy_(input_param)
except Exception as ex:
error_msgs.append('While copying the parameter named "{}", '
'whose dimensions in the model are {} and '
'whose dimensions in the checkpoint are {}, '
'an exception occurred : {}.'.format(key, param.size(), input_param.size(),
ex.args))
error_msgs.append(
'While copying the parameter named "{}", '
"whose dimensions in the model are {} and "
"whose dimensions in the checkpoint are {}, "
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
)
elif strict:
missing_keys.append(key)
@@ -559,8 +569,8 @@ class ShardedModelV2(nn.Module):
if strict:
for key in state_dict.keys():
if key.startswith(prefix) and key != extra_state_key:
input_name = key[len(prefix):]
input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child
input_name = key[len(prefix) :]
input_name = input_name.split(".", 1)[0] # get the name of param/buffer/child
if input_name not in self._modules and input_name not in local_state:
unexpected_keys.append(key)