[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -1,6 +1,5 @@
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
from enum import Enum
from os import stat
from typing import Dict, Optional, Tuple
import torch
@@ -74,22 +73,24 @@ class ShardedOptimizerV2(OptimizerWrapper):
https://arxiv.org/abs/2108.05818
"""
def __init__(self,
sharded_model: ShardedModelV2,
optimizer: Optimizer,
gpu_margin_mem_ratio: float = 0.0,
initial_scale: float = 2**32,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
dp_process_group: Optional[ProcessGroup] = None,
mp_process_group: Optional[ProcessGroup] = None,
verbose: bool = False) -> None:
assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel'
assert not isinstance(optimizer, ShardedOptimizerV2), 'Nested ShardedOptimizerV2 is not supported.'
def __init__(
self,
sharded_model: ShardedModelV2,
optimizer: Optimizer,
gpu_margin_mem_ratio: float = 0.0,
initial_scale: float = 2**32,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
dp_process_group: Optional[ProcessGroup] = None,
mp_process_group: Optional[ProcessGroup] = None,
verbose: bool = False,
) -> None:
assert isinstance(sharded_model, ShardedModelV2), "model must be wrapped with ShardedModel"
assert not isinstance(optimizer, ShardedOptimizerV2), "Nested ShardedOptimizerV2 is not supported."
super().__init__(optimizer)
self.shard_strategy = sharded_model.shard_strategy
@@ -97,39 +98,49 @@ class ShardedOptimizerV2(OptimizerWrapper):
self.bf16 = sharded_model.bf16
self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio)
assert 0.0 <= self.gpu_margin_mem_ratio <= 1.0, f'gpu_margin_mem_ratio must >=0.0 and <=1.0'
assert 0.0 <= self.gpu_margin_mem_ratio <= 1.0, f"gpu_margin_mem_ratio must >=0.0 and <=1.0"
# Only move fp32 shards from CPU to GPU when user allows and inner optimizer is valid
# Inner optimizer must support optimizing hybrid (CPU and CUDA) tensors,
# and it must set `num_fp32_shards_per_param` correctly
self._should_move_fp32_shards_h2d: bool = sharded_model.cpu_offload and self.gpu_margin_mem_ratio > 0.0 and getattr(
optimizer, 'num_fp32_shards_per_param', 0) >= 2
self.device = sharded_model._tensor_placement_policy.device or torch.device('cpu')
self._should_move_fp32_shards_h2d: bool = (
sharded_model.cpu_offload
and self.gpu_margin_mem_ratio > 0.0
and getattr(optimizer, "num_fp32_shards_per_param", 0) >= 2
)
self.device = sharded_model._tensor_placement_policy.device or torch.device("cpu")
self.optim_state: OptimState = OptimState.UNSCALED
self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA)
self.mp_process_group = mp_process_group or gpc.get_group(ParallelMode.MODEL)
# Grad scaler
self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale)
self.grad_scaler = DynamicGradScaler(
initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale,
)
self._found_overflow: Tensor = torch.IntTensor([0]).to(torch.cuda.current_device())
self._logger = get_dist_logger("ShardedOptimizerV2")
self._verbose = verbose
self._grad_prepared: bool = False # this should be set to true when _prepare_grads() and reset to false when backward
self._grad_prepared: bool = (
False # this should be set to true when _prepare_grads() and reset to false when backward
)
# Store fp32 param shards
self._register_master_weight()
if self.gpu_margin_mem_ratio != 0.0 and not isinstance(sharded_model._tensor_placement_policy,
AutoTensorPlacementPolicy):
self._logger.warning(f'gpu_margin_mem_ratio is meaningless when tensor_placement_policy is not "auto"',
ranks=[0])
if self.gpu_margin_mem_ratio != 0.0 and not isinstance(
sharded_model._tensor_placement_policy, AutoTensorPlacementPolicy
):
self._logger.warning(
f'gpu_margin_mem_ratio is meaningless when tensor_placement_policy is not "auto"', ranks=[0]
)
if self._verbose:
self._logger.debug(
f"After init ShardedOptimizerV2 consumes {self.get_memory_usage()[0] / 1e6} MB CUDA Memory!", ranks=[0])
f"After init ShardedOptimizerV2 consumes {self.get_memory_usage()[0] / 1e6} MB CUDA Memory!", ranks=[0]
)
self._use_memory_tracer = self.model.use_memory_tracer
@@ -138,7 +149,7 @@ class ShardedOptimizerV2(OptimizerWrapper):
return self.grad_scaler.scale.item()
def get_memory_usage(self) -> Tuple[int, int]:
""" Get the memory usage of the optimizer. Including master_params (param fp32),
"""Get the memory usage of the optimizer. Including master_params (param fp32),
momentum (``self.state[p]['exp_avg']``) variance (``self.state[p]['exp_avg_sq']``)
Returns:
@@ -157,7 +168,7 @@ class ShardedOptimizerV2(OptimizerWrapper):
for _, p_fp32 in self.master_params.items():
update_mem_use(p_fp32)
for group in self.optim.param_groups:
for p in group['params']:
for p in group["params"]:
state = self.optim.state[p]
for k, v in state.items():
update_mem_use(v)
@@ -191,7 +202,6 @@ class ShardedOptimizerV2(OptimizerWrapper):
return super().clip_grad_norm(model, max_norm)
def step(self, *args, **kwargs):
self._prepare_grads()
# unscale grads if scaled
if not self.bf16 and self.optim_state == OptimState.SCALED:
@@ -203,7 +213,7 @@ class ShardedOptimizerV2(OptimizerWrapper):
self.grad_scaler.update(found_inf)
if found_inf:
self._logger.warning('found inf during ShardedOptimV2 step')
self._logger.warning("found inf during ShardedOptimV2 step")
self._zero_grad(recover_data=True)
return
@@ -213,14 +223,16 @@ class ShardedOptimizerV2(OptimizerWrapper):
gpu_mem, cpu_mem = self.get_memory_usage()
self._logger.debug(
f"Before step ShardedOptimizerV2 consumes {gpu_mem / 1e6} MB CUDA Memory, {cpu_mem / 1e6} MB CUDA Memory!",
ranks=[0])
ranks=[0],
)
ret = self.optim.step(*args, **kwargs)
if self._verbose:
gpu_mem, cpu_mem = self.get_memory_usage()
self._logger.debug(
f"After step ShardedOptimizerV2 consumes {gpu_mem / 1e6} MB CUDA Memory, {cpu_mem / 1e6} MB CUDA Memory!",
ranks=[0])
ranks=[0],
)
self._copy_master_model_to_model_fp16()
return ret
@@ -240,7 +252,7 @@ class ShardedOptimizerV2(OptimizerWrapper):
def _unscale_grads(self):
assert self.optim_state == OptimState.SCALED
for group in self.optim.param_groups:
for p in group['params']:
for p in group["params"]:
if p.grad is not None:
p.grad.data.div_(self.loss_scale)
self.optim_state = OptimState.UNSCALED
@@ -260,16 +272,16 @@ class ShardedOptimizerV2(OptimizerWrapper):
# Which leads to wrong accumulation
self.optim.zero_grad(set_to_none=True)
for group in self.optim.param_groups:
for p in group['params']:
for p in group["params"]:
# p.colo_attr.sharded_data_tensor stores grad now
# we have to recover fp16 param
reuse_fp16_shard = (p.colo_attr.sharded_data_tensor.payload_size == 0)
reuse_fp16_shard = p.colo_attr.sharded_data_tensor.payload_size == 0
if recover_data and reuse_fp16_shard:
self._copy_master_param_to_param_fp16(p)
else:
# release saved gradient
p.colo_attr.saved_grad.set_null()
self.model.overflow_counter = 0 # set overflow counter to zero
self.model.overflow_counter = 0 # set overflow counter to zero
def sync_grad(self):
pass
@@ -277,8 +289,8 @@ class ShardedOptimizerV2(OptimizerWrapper):
def _register_master_weight(self):
self.master_params: Dict[Parameter, StatefulTensor] = {}
for group in self.optim.param_groups:
for p in group['params']:
assert hasattr(p, 'colo_attr'), 'The parameter must be wrapped with ShardedParam'
for p in group["params"]:
assert hasattr(p, "colo_attr"), "The parameter must be wrapped with ShardedParam"
shard_flag = not p.colo_attr.sharded_data_tensor.is_sharded and p.colo_attr.is_replicated
if shard_flag:
# we always shard replicated parameters
@@ -296,7 +308,7 @@ class ShardedOptimizerV2(OptimizerWrapper):
fp32_shards_available_cuda_margin_mem = available_cuda_margin_mem / self.optim.num_fp32_shards_per_param
fp32_shards_used_cuda_margin_mem = 0
for group in self.optim.param_groups:
for p in group['params']:
for p in group["params"]:
if p.colo_attr.saved_grad.is_null():
continue
shard_mem = self.master_params[p].payload.numel() * self.master_params[p].payload.element_size()
@@ -314,7 +326,7 @@ class ShardedOptimizerV2(OptimizerWrapper):
if self._grad_prepared:
return
for group in self.optim.param_groups:
for p in group['params']:
for p in group["params"]:
if p.colo_attr.saved_grad.is_null():
continue
p.colo_attr.saved_grad.trans_state(TensorState.COMPUTE)
@@ -335,7 +347,7 @@ class ShardedOptimizerV2(OptimizerWrapper):
# assign master param pointers to p.data.
# We will not trigger data copy here.
for group in self.optim.param_groups:
for p in group['params']:
for p in group["params"]:
self.master_params[p].trans_state(TensorState.COMPUTE)
p.data = self.master_params[p].payload
# Now p.data is sharded
@@ -346,7 +358,7 @@ class ShardedOptimizerV2(OptimizerWrapper):
# TODO() improve efficiency by gathering tensors into a chunk and transferring
# a chunk.
for group in self.optim.param_groups:
for p in group['params']:
for p in group["params"]:
self._copy_master_param_to_param_fp16(p)
def _copy_master_param_to_param_fp16(self, p):
@@ -364,7 +376,8 @@ class ShardedOptimizerV2(OptimizerWrapper):
# in order to use copy, otherwise, the sizes of tensor is not compatible
if p.colo_attr.data_payload.numel() != p.data.numel():
p.colo_attr.data_payload_reset(
torch.empty(p.data.shape, dtype=p.colo_attr.data_payload.dtype, device=p.colo_attr.data_payload.device))
torch.empty(p.data.shape, dtype=p.colo_attr.data_payload.dtype, device=p.colo_attr.data_payload.device)
)
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
half_dtype = torch.bfloat16 if self.bf16 else torch.float16
@@ -373,7 +386,7 @@ class ShardedOptimizerV2(OptimizerWrapper):
if p.colo_attr.keep_not_shard and p.colo_attr.is_replicated:
# We gather full fp16 param here
p.colo_attr.sharded_data_tensor.is_sharded = True # since only gradient is sharded, we should set to True
p.colo_attr.sharded_data_tensor.is_sharded = True # since only gradient is sharded, we should set to True
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group)
self.master_params[p].trans_state(TensorState.HOLD)
@@ -381,18 +394,18 @@ class ShardedOptimizerV2(OptimizerWrapper):
def state_dict(self):
optim_state_dict = super().state_dict()
scaler_state_dict = self.grad_scaler.state_dict()
optim_state_dict['scaler'] = scaler_state_dict
optim_state_dict["scaler"] = scaler_state_dict
return optim_state_dict
def load_state_dict(self, *args, **kwargs):
if 'scaler' not in args[0]:
self._logger.warning('Missing scaler when loading optimizer state dict', ranks=[0])
if "scaler" not in args[0]:
self._logger.warning("Missing scaler when loading optimizer state dict", ranks=[0])
else:
scaler_state_dict = args[0].pop('scaler')
scaler_state_dict = args[0].pop("scaler")
self.grad_scaler.load_state_dict(scaler_state_dict)
super().load_state_dict(*args, **kwargs)
for group in self.optim.param_groups:
for p in group['params']:
for p in group["params"]:
state = self.optim.state[p]
for k, v in state.items():
if isinstance(v, Tensor):