mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 13:30:19 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -37,19 +37,20 @@ class AMPOptimizer(OptimizerWrapper):
|
||||
norm_type (float, optional): norm_type used for `clip_grad_norm`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optimizer: Optimizer,
|
||||
module: BaseOffloadModule,
|
||||
initial_scale: float = 2**16,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
min_scale: float = 1,
|
||||
max_scale: float = 2**32,
|
||||
clipping_norm: float = 0.0,
|
||||
norm_type: float = 2.0):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
module: BaseOffloadModule,
|
||||
initial_scale: float = 2**16,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
min_scale: float = 1,
|
||||
max_scale: float = 2**32,
|
||||
clipping_norm: float = 0.0,
|
||||
norm_type: float = 2.0,
|
||||
):
|
||||
super().__init__(optimizer)
|
||||
|
||||
self.module = module
|
||||
@@ -69,19 +70,21 @@ class AMPOptimizer(OptimizerWrapper):
|
||||
self.__init__optimizer()
|
||||
|
||||
# Grad scaler
|
||||
self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
|
||||
min_scale=min_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
max_scale=max_scale)
|
||||
self.grad_scaler = DynamicGradScaler(
|
||||
initial_scale=initial_scale,
|
||||
min_scale=min_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
max_scale=max_scale,
|
||||
)
|
||||
self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=get_current_device())
|
||||
self._logger = get_dist_logger()
|
||||
|
||||
def _set_grad_ptr(self):
|
||||
for group in self.param_groups:
|
||||
for fake_param in group['params']:
|
||||
for fake_param in group["params"]:
|
||||
region = self.param_to_region[fake_param]
|
||||
begin, end = self.param_to_range[fake_param]
|
||||
|
||||
@@ -92,7 +95,7 @@ class AMPOptimizer(OptimizerWrapper):
|
||||
def _update_fp16_params(self):
|
||||
none_tensor = torch.empty([0])
|
||||
for group in self.param_groups:
|
||||
for fake_param in group['params']:
|
||||
for fake_param in group["params"]:
|
||||
assert fake_param.grad is None
|
||||
fake_param.data = none_tensor
|
||||
self.param_to_region[fake_param].cpu_grad = None
|
||||
@@ -130,10 +133,10 @@ class AMPOptimizer(OptimizerWrapper):
|
||||
|
||||
found_inf = self._check_overflow()
|
||||
if found_inf:
|
||||
self.optim_state = OptimState.UNSCALED # no need to unscale grad
|
||||
self.grad_scaler.update(found_inf) # update gradient scaler
|
||||
self._logger.info(f'Found overflow. Skip step')
|
||||
self.zero_grad() # reset all gradients
|
||||
self.optim_state = OptimState.UNSCALED # no need to unscale grad
|
||||
self.grad_scaler.update(found_inf) # update gradient scaler
|
||||
self._logger.info(f"Found overflow. Skip step")
|
||||
self.zero_grad() # reset all gradients
|
||||
self._update_fp16_params()
|
||||
return
|
||||
|
||||
@@ -156,11 +159,10 @@ class AMPOptimizer(OptimizerWrapper):
|
||||
self.module.backward(loss)
|
||||
|
||||
def __init__optimizer(self):
|
||||
|
||||
for group in self.optim.param_groups:
|
||||
fake_params_list = list()
|
||||
|
||||
for param in group['params']:
|
||||
for param in group["params"]:
|
||||
region = self.region_manager.get_region(param)
|
||||
fake_param = torch.nn.Parameter(torch.empty([0]))
|
||||
self.param_to_range[fake_param] = region.param_to_range[param]
|
||||
@@ -171,7 +173,7 @@ class AMPOptimizer(OptimizerWrapper):
|
||||
if param in self.optim.state:
|
||||
self.optim.state[fake_param] = self.optim.state.pop(param)
|
||||
|
||||
group['params'] = fake_params_list
|
||||
group["params"] = fake_params_list
|
||||
|
||||
# Leverage state_dict() and load_state_dict() to
|
||||
# recast preexisting per-param state tensors
|
||||
|
Reference in New Issue
Block a user