mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 10:06:44 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
|
||||
from enum import Enum
|
||||
from os import stat
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@@ -74,22 +73,24 @@ class ShardedOptimizerV2(OptimizerWrapper):
|
||||
https://arxiv.org/abs/2108.05818
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
sharded_model: ShardedModelV2,
|
||||
optimizer: Optimizer,
|
||||
gpu_margin_mem_ratio: float = 0.0,
|
||||
initial_scale: float = 2**32,
|
||||
min_scale: float = 1,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
max_scale: float = 2**32,
|
||||
dp_process_group: Optional[ProcessGroup] = None,
|
||||
mp_process_group: Optional[ProcessGroup] = None,
|
||||
verbose: bool = False) -> None:
|
||||
assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel'
|
||||
assert not isinstance(optimizer, ShardedOptimizerV2), 'Nested ShardedOptimizerV2 is not supported.'
|
||||
def __init__(
|
||||
self,
|
||||
sharded_model: ShardedModelV2,
|
||||
optimizer: Optimizer,
|
||||
gpu_margin_mem_ratio: float = 0.0,
|
||||
initial_scale: float = 2**32,
|
||||
min_scale: float = 1,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
max_scale: float = 2**32,
|
||||
dp_process_group: Optional[ProcessGroup] = None,
|
||||
mp_process_group: Optional[ProcessGroup] = None,
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
assert isinstance(sharded_model, ShardedModelV2), "model must be wrapped with ShardedModel"
|
||||
assert not isinstance(optimizer, ShardedOptimizerV2), "Nested ShardedOptimizerV2 is not supported."
|
||||
|
||||
super().__init__(optimizer)
|
||||
self.shard_strategy = sharded_model.shard_strategy
|
||||
@@ -97,39 +98,49 @@ class ShardedOptimizerV2(OptimizerWrapper):
|
||||
self.bf16 = sharded_model.bf16
|
||||
|
||||
self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio)
|
||||
assert 0.0 <= self.gpu_margin_mem_ratio <= 1.0, f'gpu_margin_mem_ratio must >=0.0 and <=1.0'
|
||||
assert 0.0 <= self.gpu_margin_mem_ratio <= 1.0, f"gpu_margin_mem_ratio must >=0.0 and <=1.0"
|
||||
# Only move fp32 shards from CPU to GPU when user allows and inner optimizer is valid
|
||||
# Inner optimizer must support optimizing hybrid (CPU and CUDA) tensors,
|
||||
# and it must set `num_fp32_shards_per_param` correctly
|
||||
self._should_move_fp32_shards_h2d: bool = sharded_model.cpu_offload and self.gpu_margin_mem_ratio > 0.0 and getattr(
|
||||
optimizer, 'num_fp32_shards_per_param', 0) >= 2
|
||||
self.device = sharded_model._tensor_placement_policy.device or torch.device('cpu')
|
||||
self._should_move_fp32_shards_h2d: bool = (
|
||||
sharded_model.cpu_offload
|
||||
and self.gpu_margin_mem_ratio > 0.0
|
||||
and getattr(optimizer, "num_fp32_shards_per_param", 0) >= 2
|
||||
)
|
||||
self.device = sharded_model._tensor_placement_policy.device or torch.device("cpu")
|
||||
self.optim_state: OptimState = OptimState.UNSCALED
|
||||
self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA)
|
||||
self.mp_process_group = mp_process_group or gpc.get_group(ParallelMode.MODEL)
|
||||
# Grad scaler
|
||||
self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
|
||||
min_scale=min_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
max_scale=max_scale)
|
||||
self.grad_scaler = DynamicGradScaler(
|
||||
initial_scale=initial_scale,
|
||||
min_scale=min_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
max_scale=max_scale,
|
||||
)
|
||||
self._found_overflow: Tensor = torch.IntTensor([0]).to(torch.cuda.current_device())
|
||||
self._logger = get_dist_logger("ShardedOptimizerV2")
|
||||
self._verbose = verbose
|
||||
self._grad_prepared: bool = False # this should be set to true when _prepare_grads() and reset to false when backward
|
||||
self._grad_prepared: bool = (
|
||||
False # this should be set to true when _prepare_grads() and reset to false when backward
|
||||
)
|
||||
|
||||
# Store fp32 param shards
|
||||
self._register_master_weight()
|
||||
if self.gpu_margin_mem_ratio != 0.0 and not isinstance(sharded_model._tensor_placement_policy,
|
||||
AutoTensorPlacementPolicy):
|
||||
self._logger.warning(f'gpu_margin_mem_ratio is meaningless when tensor_placement_policy is not "auto"',
|
||||
ranks=[0])
|
||||
if self.gpu_margin_mem_ratio != 0.0 and not isinstance(
|
||||
sharded_model._tensor_placement_policy, AutoTensorPlacementPolicy
|
||||
):
|
||||
self._logger.warning(
|
||||
f'gpu_margin_mem_ratio is meaningless when tensor_placement_policy is not "auto"', ranks=[0]
|
||||
)
|
||||
|
||||
if self._verbose:
|
||||
self._logger.debug(
|
||||
f"After init ShardedOptimizerV2 consumes {self.get_memory_usage()[0] / 1e6} MB CUDA Memory!", ranks=[0])
|
||||
f"After init ShardedOptimizerV2 consumes {self.get_memory_usage()[0] / 1e6} MB CUDA Memory!", ranks=[0]
|
||||
)
|
||||
|
||||
self._use_memory_tracer = self.model.use_memory_tracer
|
||||
|
||||
@@ -138,7 +149,7 @@ class ShardedOptimizerV2(OptimizerWrapper):
|
||||
return self.grad_scaler.scale.item()
|
||||
|
||||
def get_memory_usage(self) -> Tuple[int, int]:
|
||||
""" Get the memory usage of the optimizer. Including master_params (param fp32),
|
||||
"""Get the memory usage of the optimizer. Including master_params (param fp32),
|
||||
momentum (``self.state[p]['exp_avg']``) variance (``self.state[p]['exp_avg_sq']``)
|
||||
|
||||
Returns:
|
||||
@@ -157,7 +168,7 @@ class ShardedOptimizerV2(OptimizerWrapper):
|
||||
for _, p_fp32 in self.master_params.items():
|
||||
update_mem_use(p_fp32)
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
for p in group["params"]:
|
||||
state = self.optim.state[p]
|
||||
for k, v in state.items():
|
||||
update_mem_use(v)
|
||||
@@ -191,7 +202,6 @@ class ShardedOptimizerV2(OptimizerWrapper):
|
||||
return super().clip_grad_norm(model, max_norm)
|
||||
|
||||
def step(self, *args, **kwargs):
|
||||
|
||||
self._prepare_grads()
|
||||
# unscale grads if scaled
|
||||
if not self.bf16 and self.optim_state == OptimState.SCALED:
|
||||
@@ -203,7 +213,7 @@ class ShardedOptimizerV2(OptimizerWrapper):
|
||||
self.grad_scaler.update(found_inf)
|
||||
|
||||
if found_inf:
|
||||
self._logger.warning('found inf during ShardedOptimV2 step')
|
||||
self._logger.warning("found inf during ShardedOptimV2 step")
|
||||
self._zero_grad(recover_data=True)
|
||||
return
|
||||
|
||||
@@ -213,14 +223,16 @@ class ShardedOptimizerV2(OptimizerWrapper):
|
||||
gpu_mem, cpu_mem = self.get_memory_usage()
|
||||
self._logger.debug(
|
||||
f"Before step ShardedOptimizerV2 consumes {gpu_mem / 1e6} MB CUDA Memory, {cpu_mem / 1e6} MB CUDA Memory!",
|
||||
ranks=[0])
|
||||
ranks=[0],
|
||||
)
|
||||
ret = self.optim.step(*args, **kwargs)
|
||||
|
||||
if self._verbose:
|
||||
gpu_mem, cpu_mem = self.get_memory_usage()
|
||||
self._logger.debug(
|
||||
f"After step ShardedOptimizerV2 consumes {gpu_mem / 1e6} MB CUDA Memory, {cpu_mem / 1e6} MB CUDA Memory!",
|
||||
ranks=[0])
|
||||
ranks=[0],
|
||||
)
|
||||
|
||||
self._copy_master_model_to_model_fp16()
|
||||
return ret
|
||||
@@ -240,7 +252,7 @@ class ShardedOptimizerV2(OptimizerWrapper):
|
||||
def _unscale_grads(self):
|
||||
assert self.optim_state == OptimState.SCALED
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
for p in group["params"]:
|
||||
if p.grad is not None:
|
||||
p.grad.data.div_(self.loss_scale)
|
||||
self.optim_state = OptimState.UNSCALED
|
||||
@@ -260,16 +272,16 @@ class ShardedOptimizerV2(OptimizerWrapper):
|
||||
# Which leads to wrong accumulation
|
||||
self.optim.zero_grad(set_to_none=True)
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
for p in group["params"]:
|
||||
# p.colo_attr.sharded_data_tensor stores grad now
|
||||
# we have to recover fp16 param
|
||||
reuse_fp16_shard = (p.colo_attr.sharded_data_tensor.payload_size == 0)
|
||||
reuse_fp16_shard = p.colo_attr.sharded_data_tensor.payload_size == 0
|
||||
if recover_data and reuse_fp16_shard:
|
||||
self._copy_master_param_to_param_fp16(p)
|
||||
else:
|
||||
# release saved gradient
|
||||
p.colo_attr.saved_grad.set_null()
|
||||
self.model.overflow_counter = 0 # set overflow counter to zero
|
||||
self.model.overflow_counter = 0 # set overflow counter to zero
|
||||
|
||||
def sync_grad(self):
|
||||
pass
|
||||
@@ -277,8 +289,8 @@ class ShardedOptimizerV2(OptimizerWrapper):
|
||||
def _register_master_weight(self):
|
||||
self.master_params: Dict[Parameter, StatefulTensor] = {}
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
assert hasattr(p, 'colo_attr'), 'The parameter must be wrapped with ShardedParam'
|
||||
for p in group["params"]:
|
||||
assert hasattr(p, "colo_attr"), "The parameter must be wrapped with ShardedParam"
|
||||
shard_flag = not p.colo_attr.sharded_data_tensor.is_sharded and p.colo_attr.is_replicated
|
||||
if shard_flag:
|
||||
# we always shard replicated parameters
|
||||
@@ -296,7 +308,7 @@ class ShardedOptimizerV2(OptimizerWrapper):
|
||||
fp32_shards_available_cuda_margin_mem = available_cuda_margin_mem / self.optim.num_fp32_shards_per_param
|
||||
fp32_shards_used_cuda_margin_mem = 0
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
for p in group["params"]:
|
||||
if p.colo_attr.saved_grad.is_null():
|
||||
continue
|
||||
shard_mem = self.master_params[p].payload.numel() * self.master_params[p].payload.element_size()
|
||||
@@ -314,7 +326,7 @@ class ShardedOptimizerV2(OptimizerWrapper):
|
||||
if self._grad_prepared:
|
||||
return
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
for p in group["params"]:
|
||||
if p.colo_attr.saved_grad.is_null():
|
||||
continue
|
||||
p.colo_attr.saved_grad.trans_state(TensorState.COMPUTE)
|
||||
@@ -335,7 +347,7 @@ class ShardedOptimizerV2(OptimizerWrapper):
|
||||
# assign master param pointers to p.data.
|
||||
# We will not trigger data copy here.
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
for p in group["params"]:
|
||||
self.master_params[p].trans_state(TensorState.COMPUTE)
|
||||
p.data = self.master_params[p].payload
|
||||
# Now p.data is sharded
|
||||
@@ -346,7 +358,7 @@ class ShardedOptimizerV2(OptimizerWrapper):
|
||||
# TODO() improve efficiency by gathering tensors into a chunk and transferring
|
||||
# a chunk.
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
for p in group["params"]:
|
||||
self._copy_master_param_to_param_fp16(p)
|
||||
|
||||
def _copy_master_param_to_param_fp16(self, p):
|
||||
@@ -364,7 +376,8 @@ class ShardedOptimizerV2(OptimizerWrapper):
|
||||
# in order to use copy, otherwise, the sizes of tensor is not compatible
|
||||
if p.colo_attr.data_payload.numel() != p.data.numel():
|
||||
p.colo_attr.data_payload_reset(
|
||||
torch.empty(p.data.shape, dtype=p.colo_attr.data_payload.dtype, device=p.colo_attr.data_payload.device))
|
||||
torch.empty(p.data.shape, dtype=p.colo_attr.data_payload.dtype, device=p.colo_attr.data_payload.device)
|
||||
)
|
||||
|
||||
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
|
||||
half_dtype = torch.bfloat16 if self.bf16 else torch.float16
|
||||
@@ -373,7 +386,7 @@ class ShardedOptimizerV2(OptimizerWrapper):
|
||||
|
||||
if p.colo_attr.keep_not_shard and p.colo_attr.is_replicated:
|
||||
# We gather full fp16 param here
|
||||
p.colo_attr.sharded_data_tensor.is_sharded = True # since only gradient is sharded, we should set to True
|
||||
p.colo_attr.sharded_data_tensor.is_sharded = True # since only gradient is sharded, we should set to True
|
||||
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group)
|
||||
|
||||
self.master_params[p].trans_state(TensorState.HOLD)
|
||||
@@ -381,18 +394,18 @@ class ShardedOptimizerV2(OptimizerWrapper):
|
||||
def state_dict(self):
|
||||
optim_state_dict = super().state_dict()
|
||||
scaler_state_dict = self.grad_scaler.state_dict()
|
||||
optim_state_dict['scaler'] = scaler_state_dict
|
||||
optim_state_dict["scaler"] = scaler_state_dict
|
||||
return optim_state_dict
|
||||
|
||||
def load_state_dict(self, *args, **kwargs):
|
||||
if 'scaler' not in args[0]:
|
||||
self._logger.warning('Missing scaler when loading optimizer state dict', ranks=[0])
|
||||
if "scaler" not in args[0]:
|
||||
self._logger.warning("Missing scaler when loading optimizer state dict", ranks=[0])
|
||||
else:
|
||||
scaler_state_dict = args[0].pop('scaler')
|
||||
scaler_state_dict = args[0].pop("scaler")
|
||||
self.grad_scaler.load_state_dict(scaler_state_dict)
|
||||
super().load_state_dict(*args, **kwargs)
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
for p in group["params"]:
|
||||
state = self.optim.state[p]
|
||||
for k, v in state.items():
|
||||
if isinstance(v, Tensor):
|
||||
|
Reference in New Issue
Block a user