[zero] improve adaptability for not-shard parameters (#708)

* adapt post grad hooks for not-shard parameters
* adapt optimizer for not-shard parameters
* offload gradients for not-replicated parameters
This commit is contained in:
HELSON
2022-04-11 13:38:51 +08:00
committed by GitHub
parent ab8c6b4a0e
commit a9b8300d54
9 changed files with 114 additions and 111 deletions

View File

@@ -8,7 +8,7 @@ from .experts import FFNExperts, TPExperts
class ForceFP32Parameter(torch.nn.Parameter):
def half(self, memory_format=None):
return self
return self.data
class NormalNoiseGenerator:

View File

@@ -142,6 +142,7 @@ class CPUAdam(torch.optim.Optimizer):
beta1, beta2 = group['betas']
if target_device.type == 'cpu':
assert p.data.numel() == p.grad.data.numel(), "parameter and gradient should have the same size"
assert state['exp_avg'].device.type == 'cpu', "exp_avg should stay on cpu"
assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu"
self.cpu_adam_op.adam_update(self.opt_id, state['step'], group['lr'], beta1, beta2, group['eps'],
@@ -151,8 +152,8 @@ class CPUAdam(torch.optim.Optimizer):
assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda"
assert state['exp_avg_sq'].device.type == 'cuda', "exp_avg should stay on cuda"
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
bias_correction1 = 1 - beta1**state['step']
bias_correction2 = 1 - beta2**state['step']
# adam on cuda
self.torch_adam_update(p.data, p.grad.data, state['exp_avg'], state['exp_avg_sq'], group['lr'],

View File

@@ -213,7 +213,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[0]
for param in self.param_list:
assert hasattr(param, 'colo_attr')
if not param.colo_attr.param_is_sharded and param.is_replicated:
if not param.colo_attr.param_is_sharded and param.colo_attr.is_replicated:
dist.broadcast(tensor=param.data, src=src_rank, group=self.dp_process_group)
param.colo_attr.remove_torch_payload()
@@ -239,9 +239,6 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
self.model_numel_tensor += param.numel()
# mark whether the param is replicated
param.is_replicated = self.is_replicated
# convert parameters to half
param_half = half_fn(param)
param.data = param_half
@@ -261,6 +258,13 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
self.shard_strategy.shard([param.colo_attr.sharded_data_tensor], self.dp_process_group)
param.data = param.colo_attr.sharded_data_tensor.payload # set param.data to payload
# mark whether the param is replicated
param.colo_attr.is_replicated = self.is_replicated
# mark whether the param should keep not sharded
# if True, the param is used as Zero stage 2
param.colo_attr.keep_not_shard = not self.shard_param
self.param_list.append(param)
# We must cast buffers

View File

@@ -123,7 +123,7 @@ class ShardedModelV2(nn.Module):
ZeroHook(self.shard_strategy, self._memstats_collector, self._stateful_tensor_mgr, self.process_group)
]
register_ophooks_recursively(self.module, self._ophook_list)
self.param_hook_mgr = BaseParamHookMgr(self.sharded_params)
self.param_hook_mgr = BaseParamHookMgr(list(self.module.parameters()))
self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook)
self.fp32_reduce_scatter = fp32_reduce_scatter
@@ -177,8 +177,8 @@ class ShardedModelV2(nn.Module):
self.logger.error(f'dump memort tracer collected infomation to a {filename}', ranks=[0])
if gpc.get_global_rank() == 0:
with open(filename, 'w+') as f:
f.write(f'cuda reserved {torch.cuda.memory_reserved(get_current_device())/1e9} GB\n')
f.write(f'cuda max allocated {torch.cuda.max_memory_allocated(get_current_device())/1e9} GB\n')
f.write(f'cuda reserved {torch.cuda.memory_reserved(get_current_device()) / 1e9} GB\n')
f.write(f'cuda max allocated {torch.cuda.max_memory_allocated(get_current_device()) / 1e9} GB\n')
f.write('CUDA model data (GB)\n')
f.write(str(self._memstats_collector.model_data_list('cuda', 'GB')))
f.write('\n')
@@ -254,10 +254,6 @@ class ShardedModelV2(nn.Module):
torch.cuda.current_stream().synchronize()
self.reducer.free()
# all reduce gradients for unsharded parameters
reduce_list = [p for p in self.unshard_params if p.is_replicated]
bucket_allreduce(reduce_list, self.process_group)
# 3. shard tensors not dealed in the zero hook
tensor_list = []
for p in self.sharded_params:
@@ -279,15 +275,6 @@ class ShardedModelV2(nn.Module):
if not self._require_backward_grad_sync:
continue
# move unsharded param grad to saved_grad
if not p.colo_attr.param_is_sharded:
if p.colo_attr.offload_grad:
colo_model_data_move_to_cpu(p.grad)
if p.colo_attr.saved_grad.is_null():
p.colo_attr.saved_grad.reset_payload(p.grad.data)
else:
p.colo_attr.saved_grad.payload.add_(p.grad.data)
p.grad = None
@torch.no_grad()
@@ -316,6 +303,18 @@ class ShardedModelV2(nn.Module):
assert not grad.requires_grad, 'ShardedModel only works with gradients that don\'t require gradients'
if not self._require_backward_grad_sync:
return
if param.colo_attr.is_replicated:
self._reduce_scatter_handler(param, grad)
else:
self._save_grad(param, grad)
# used to cheat Pytorch, since we can't return None
empty_grad = torch.empty_like(grad)
free_storage(empty_grad)
return empty_grad
def _reduce_scatter_handler(self, param: Parameter, grad: torch.Tensor) -> None:
self.comm_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.comm_stream):
new_grad = grad.clone()
@@ -334,9 +333,6 @@ class ShardedModelV2(nn.Module):
self._reduce_scatter_callback(param, new_grad)
orig_grad_data.record_stream(self.comm_stream)
torch.cuda.current_stream().wait_stream(self.comm_stream)
empty_grad = torch.empty_like(grad)
free_storage(empty_grad)
return empty_grad
def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None:
assert isinstance(reduced_grad,
@@ -345,21 +341,35 @@ class ShardedModelV2(nn.Module):
if self.gradient_postdivide_factor > 1:
# Average grad by world_size for consistency with PyTorch DDP.
reduced_grad.data.div_(self.gradient_postdivide_factor)
# FIXME(ver217): remove the below line when impl eviction policy
self._save_grad(param, reduced_grad)
# FIXME(ver217): refactor the below line when impl eviction policy
def _save_grad(self, param: Parameter, grad: torch.Tensor):
# move gradient to cpu
if param.colo_attr.offload_grad:
colo_model_data_move_to_cpu(reduced_grad)
colo_model_data_move_to_cpu(grad)
if self.reuse_fp16_shard:
# make parameters point to gradient
assert param.colo_attr.saved_grad.is_null(
), 'Gradien accumulation is not supported when reuse_fp16_shard=True'
param.colo_attr.sharded_data_tensor.reset_payload(reduced_grad)
param.colo_attr.sharded_data_tensor.is_sharded = True
param.colo_attr.saved_grad.reset_payload(param.colo_attr.sharded_data_tensor.payload)
param.colo_attr.saved_grad.reset_payload(grad)
param.colo_attr.sharded_data_tensor.reset_payload(grad) # release the memory of param
if param.colo_attr.is_replicated:
param.colo_attr.sharded_data_tensor.is_sharded = True
else:
reduced_grad = cast_tensor_to_fp32(reduced_grad)
fp32_grad = cast_tensor_to_fp32(grad)
if param.colo_attr.saved_grad.is_null():
param.colo_attr.saved_grad.reset_payload(reduced_grad)
param.colo_attr.saved_grad.reset_payload(fp32_grad)
else:
param.colo_attr.saved_grad.payload.add_(reduced_grad.view_as(param.colo_attr.saved_grad.payload))
param.colo_attr.saved_grad.payload.add_(fp32_grad.view_as(param.colo_attr.saved_grad.payload))
# keep saved_grad in HOLD state
param.colo_attr.saved_grad.trans_state(TensorState.HOLD)
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':

View File

@@ -68,9 +68,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5.
growth_interval (float, optional): growth_interval used by DynamicGradScaler. Defaults to 1000.
hysteresis (float, optional): hysteresis used by DynamicGradScaler. Defaults to 2.
keep_unsharded (bool, optional): if True, optimizer won't shard unsharded parameters.
In Zero-2, set keep_unsharded to False.
In Zero-3, set keep_unsharded to True.
max_scale (int, optional): max_scale used by DynamicGradScaler. Defaults to 2**32.
dp_process_group (Optional[ProcessGroup], optional): data paralle process group. Defaults to None.
mp_process_group (Optional[ProcessGroup], optional): model paralle process group. Defaults to None.
@@ -91,7 +88,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
growth_interval: float = 1000,
hysteresis: float = 2,
max_scale: int = 2**32,
keep_unsharded: bool = False,
dp_process_group: Optional[ProcessGroup] = None,
mp_process_group: Optional[ProcessGroup] = None) -> None:
assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel'
@@ -125,10 +121,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self._found_overflow: Tensor = torch.FloatTensor([0]).to(torch.cuda.current_device())
self._logger = get_dist_logger("ShardedOptimizerV2")
assert not (keep_unsharded and self._should_move_fp32_shards_h2d), \
"Keeping unsharded parameters can't be used with hybrid OS placement right now."
self.keep_unshard = keep_unsharded
# Store fp32 param shards
self._register_master_weight()
@@ -139,6 +131,10 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
if self._use_memory_tracer:
GLOBAL_MODEL_DATA_TRACER.register_optimizer(self)
@property
def loss_scale(self):
return self.grad_scaler.scale.item()
def get_memory_usage(self) -> Tuple[int, int]:
""" Get the memory usage of the optimizer. Including master_params (param fp32),
momentum (``self.state[p]['exp_avg']``) variance (``self.state[p]['exp_avg_sq']``)
@@ -166,6 +162,22 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
return cuda_use, cpu_use
def zero_grad(self, *args, **kwargs):
self._zero_grad()
def backward(self, loss: Tensor) -> None:
loss = self.loss_scale * loss
self.optim_state = OptimState.SCALED
self.model.backward(loss)
def backward_by_grad(self, tensor: Tensor, grad: Tensor) -> None:
self.model.backward_by_grad(tensor, grad)
def clip_grad_norm(self, model: nn.Module, max_norm: float):
if self.optim_state == OptimState.SCALED:
self._unscale_grads()
return super().clip_grad_norm(model, max_norm)
def step(self, *args, **kwargs):
self._prepare_grads()
self._maybe_move_fp32_shards()
@@ -193,26 +205,9 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self._logger.debug(
f"After step ShardedOptimizerV2 consumes {self.get_memory_usage()[0] / 1e6} MB CUDA Memory, {self.get_memory_usage()[1] / 1e6} MB CUDA Memory!",
ranks=[0])
self._copy_master_param_to_param_fp16()
self._copy_master_model_to_model_fp16()
return ret
def backward(self, loss: Tensor) -> None:
loss = self.loss_scale * loss
self.optim_state = OptimState.SCALED
self.model.backward(loss)
def backward_by_grad(self, tensor: Tensor, grad: Tensor) -> None:
self.model.backward_by_grad(tensor, grad)
def clip_grad_norm(self, model: nn.Module, max_norm: float):
if self.optim_state == OptimState.SCALED:
self._unscale_grads()
return super().clip_grad_norm(model, max_norm)
@property
def loss_scale(self):
return self.grad_scaler.scale.item()
def _check_overflow(self):
# clear previous overflow record
self._found_overflow.fill_(0.0)
@@ -240,9 +235,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
p.grad.data.div_(self.loss_scale)
self.optim_state = OptimState.UNSCALED
def zero_grad(self, *args, **kwargs):
self._zero_grad()
def _zero_grad(self, recover_data: bool = False):
"""zero grad and maybe recover fp16 params
When `reuse_fp16_shard` is enabled,
@@ -262,13 +254,11 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# p.colo_attr.sharded_data_tensor stores grad now
# we have to recover fp16 param
reuse_fp16_shard = p.colo_attr.saved_grad.data_ptr() == p.colo_attr.sharded_data_tensor.data_ptr()
p.colo_attr.saved_grad.set_null()
if recover_data and reuse_fp16_shard:
# We should write like this to trigger ForceFP32Paramter's half method
p.data = self.master_params[p].payload
p.colo_attr.sharded_data_tensor.reset_payload(
colo_model_tensor_clone(p.half(), torch.cuda.current_device()))
p.colo_attr.remove_torch_payload()
self._copy_master_param_to_param_fp16(p)
else:
# release saved gradient
p.colo_attr.saved_grad.set_null()
def sync_grad(self):
pass
@@ -278,14 +268,13 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
for group in self.optim.param_groups:
for p in group['params']:
assert hasattr(p, 'colo_attr'), 'The parameter must be wrapped with ShardedParam'
is_param_sharded = p.colo_attr.sharded_data_tensor.is_sharded
if not is_param_sharded and not self.keep_unshard:
# Please use keep_unsharded to control whether shard unsharded paramters
# As we only store param shard, we shard it here
shard_flag = not p.colo_attr.sharded_data_tensor.is_sharded and p.colo_attr.is_replicated
if shard_flag:
# we always shard replicated paramters
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group)
self.master_params[p] = StatefulTensor(
cast_tensor_to_fp32(p.colo_attr.sharded_data_tensor.payload.to(self.device)))
if not is_param_sharded and not self.keep_unshard:
if shard_flag:
# In this branch, there's no need to shard param
# So we gather here
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group)
@@ -328,31 +317,27 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# Now p.data is sharded
# So optimizer states are sharded naturally
def _copy_master_param_to_param_fp16(self):
def _copy_master_model_to_model_fp16(self):
# Copy master param data (fp32) to payload of colo_attr (fp16)
# TODO() improve efficiency by gathering tensors into a chunk and transfering
# a chunk.
for group in self.optim.param_groups:
for p in group['params']:
is_param_sharded = p.colo_attr.sharded_data_tensor.is_sharded
if not is_param_sharded and not self.keep_unshard:
# We use ZeRO-2 here
# The `p.colo_attr.sharded_data_tensor` saves full fp16 param
# But we only have updated fp32 param shard here
# So we first shard full fp16 param and copy fp32 param shard to it
# Then we will gather them
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group)
# We have to use `copy_payload` instead of `reset_payload`
# Since p.data is fp32 and p.colo_attr.sharded_data_tensor is fp16
self._copy_master_param_to_param_fp16(p)
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
p.colo_attr.sharded_data_tensor.reset_payload(
colo_model_tensor_clone(p.half(), p.colo_attr.sharded_data_tensor.device))
p.colo_attr.remove_torch_payload()
def _copy_master_param_to_param_fp16(self, p):
# flush gradient
p.colo_attr.saved_grad.set_null()
if not is_param_sharded and not self.keep_unshard:
# We gather full fp16 param here
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group)
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
p.data = self.master_params[p].payload
p.colo_attr.sharded_data_tensor.reset_payload(
colo_model_tensor_clone(p.half(), p.colo_attr.sharded_data_tensor.device))
p.colo_attr.remove_torch_payload()
self.master_params[p].trans_state(TensorState.HOLD)
p.colo_attr.saved_grad.set_null()
if p.colo_attr.keep_not_shard and p.colo_attr.is_replicated:
# We gather full fp16 param here
p.colo_attr.sharded_data_tensor.is_sharded = True # since only gradient is sharded, we should set to True
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group)
self.master_params[p].trans_state(TensorState.HOLD)