[gemini]remove registered gradients hooks (#5696)

* fix gemini

fix gemini

* fix

fix
This commit is contained in:
flybird11111
2024-05-09 10:29:49 +08:00
committed by GitHub
parent 22297789ab
commit d4c5ef441e
5 changed files with 93 additions and 46 deletions

View File

@@ -98,8 +98,14 @@ class GeminiDDP(ModelWrapper):
verbose: bool = False,
) -> None:
assert mixed_precision in (torch.float16, torch.bfloat16)
reuse_fp16_chunk = master_weights if not enable_gradient_accumulation else False
self.enable_gradient_accumulation = enable_gradient_accumulation
if chunk_config_dict is not None:
self.chunk_manager = ChunkManager(chunk_config_dict, chunk_init_device)
self.chunk_manager = ChunkManager(
chunk_config_dict,
chunk_init_device,
reuse_fp16_chunk=reuse_fp16_chunk,
)
else:
# some ugly hotfix for the compatibility with Lightning
if search_range_m is None:
@@ -112,6 +118,7 @@ class GeminiDDP(ModelWrapper):
min_chunk_size_m=min_chunk_size_m,
strict_ddp_flag=strict_ddp_mode,
process_group=zero_group,
reuse_fp16_chunk=reuse_fp16_chunk,
verbose=verbose,
)
self.gemini_manager = GeminiManager(
@@ -128,7 +135,6 @@ class GeminiDDP(ModelWrapper):
self.param_op_hook = GeminiZeROHook(self.gemini_manager)
self.fp32_params: List[torch.Tensor] = list()
self.fp16_params: List[ColoParameter] = list()
self.overflow_counter = 0
self.grads_device: Dict[torch.Tensor, torch.device] = dict()
self.param2name: Dict[nn.Parameter, str] = dict()
self.name2param: Dict[str, nn.Parameter] = dict()
@@ -137,14 +143,8 @@ class GeminiDDP(ModelWrapper):
self.zero_group = zero_group or _get_default_group()
self.extra_dp_group = extra_dp_group
self.reuse_fp16_chunk = master_weights
self.master_weights = master_weights
self.enable_gradient_accumulation = enable_gradient_accumulation
if self.enable_gradient_accumulation:
self.reuse_fp16_chunk = False
self.accumulating_grads = False # Whether model is accumulating gradients
self._logger = get_dist_logger()
if self.gemini_manager._premade_memstats_:
@@ -178,7 +178,29 @@ class GeminiDDP(ModelWrapper):
if is_ddp_ignored(p):
continue
if p.requires_grad:
p.register_hook(partial(self.grad_handle, p))
p._grad_handle = p.register_hook(
partial(
GeminiDDP.grad_handle,
chunk_manager=self.chunk_manager,
param2name=self.param2name,
grads_device=self.grads_device,
master_weights=self.master_weights,
enable_gradient_accumulation=self.enable_gradient_accumulation,
p=p,
)
)
def remove_hooks(self):
for p in self.module.parameters():
if is_ddp_ignored(p):
continue
if p.requires_grad:
assert hasattr(p, "_grad_handle")
p._grad_handle.remove()
delattr(p, "_grad_handle")
def __del__(self):
self.remove_hooks()
def parameters(self, recurse: bool = True):
return self.module.parameters(recurse)
@@ -324,8 +346,8 @@ class GeminiDDP(ModelWrapper):
f"{error_str}",
)
self._setup_grads_ptr()
if self.enable_gradient_accumulation and not self.accumulating_grads:
self.accumulating_grads = True # Turn on the state of gradient accumulation.
if self.enable_gradient_accumulation and not self.chunk_manager.accumulating_grads:
self.chunk_manager.accumulating_grads = True # Turn on the state of gradient accumulation.
self._logger.debug(
f"comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}"
)
@@ -340,25 +362,34 @@ class GeminiDDP(ModelWrapper):
def backward_by_grad(self, tensor, grad):
raise RuntimeError("Gemini is not compatible with pipeline. backward_by_grad shoudn't be called in Gemini.")
def grad_handle(self, p, grad):
@staticmethod
def grad_handle(
grad,
chunk_manager: ChunkManager,
param2name: Dict,
grads_device: Dict,
master_weights: bool,
enable_gradient_accumulation: bool,
p: nn.Parameter,
):
setattr(p, "_gemini_reduced", True)
empty_grad = torch.empty_like(grad)
free_storage(empty_grad)
with torch._C.DisableTorchFunction():
chunk = self.chunk_manager.get_chunk(p)
chunk = chunk_manager.get_chunk(p)
if chunk.tensors_info[p].state != TensorState.HOLD_AFTER_BWD:
raise RuntimeError(
f"Parameter `{self.param2name[p]}` failed at the gradient reduction. "
f"Parameter `{param2name[p]}` failed at the gradient reduction. "
"Some unsupported torch function is operated upon this parameter."
)
grad_chunk = chunk
if not self.reuse_fp16_chunk:
if not self.accumulating_grads:
grad_chunk = self.chunk_manager.init_grad_chunk(chunk)
if not chunk_manager.reuse_fp16_chunk:
if not chunk_manager.accumulating_grads:
grad_chunk = chunk_manager.init_grad_chunk(chunk)
else:
assert chunk.grad_chunk is not None
if chunk.grad_chunk not in self.chunk_manager.accessed_chunks:
grad_chunk = self.chunk_manager.rearrange_accumulated_grad_chunk(chunk)
if chunk.grad_chunk not in chunk_manager.accessed_chunks:
grad_chunk = chunk_manager.rearrange_accumulated_grad_chunk(chunk)
else:
grad_chunk = chunk.grad_chunk
chunk.grad_chunk.l2_norm = None
@@ -371,33 +402,33 @@ class GeminiDDP(ModelWrapper):
chunk.tensor_trans_state(p, TensorState.HOLD)
grad_chunk.tensor_trans_state(p, TensorState.READY_FOR_REDUCE)
if not self.accumulating_grads:
grad_chunk.copy_tensor_to_chunk_slice(p, grad, update_ptr=self.reuse_fp16_chunk)
if not chunk_manager.accumulating_grads:
grad_chunk.copy_tensor_to_chunk_slice(p, grad, update_ptr=chunk_manager.reuse_fp16_chunk)
else:
grad_chunk.add_tensor_to_chunk_slice(p, grad)
reduced = self.chunk_manager.reduce_chunk(grad_chunk)
reduced = chunk_manager.reduce_chunk(grad_chunk)
if reduced:
if not self.reuse_fp16_chunk:
if not chunk_manager.reuse_fp16_chunk:
if chunk.keep_gathered:
self.chunk_manager.fake_release_chunk(chunk)
chunk_manager.fake_release_chunk(chunk)
else:
self.chunk_manager.release_chunk(chunk)
chunk_manager.release_chunk(chunk)
if grad_chunk.is_gathered:
grad_chunk.cuda_global_chunk.div_(chunk.pg_size)
if self.extra_dp_group is not None:
if chunk.extra_dp_group is not None:
grad_chunk.cuda_global_chunk.div_(chunk.extra_dp_size)
else:
grad_chunk.cuda_shard.div_(chunk.pg_size)
if self.extra_dp_group is not None:
if chunk.extra_dp_group is not None:
grad_chunk.cuda_shard.div_(chunk.extra_dp_size)
# check overflow elements
self.overflow_counter += grad_chunk.has_inf_or_nan
chunk_manager.overflow_counter += grad_chunk.has_inf_or_nan
# record l2 norm for gradient clipping. flag is bound to fp16 chunk
if chunk.l2_norm_flag:
grad_chunk.set_l2_norm()
self.chunk_manager.move_chunk(grad_chunk, self.grads_device[p], force_copy=True)
if not (self.master_weights) or (self.enable_gradient_accumulation):
self.chunk_manager.move_chunk(chunk, self.grads_device[p], force_copy=True)
chunk_manager.move_chunk(grad_chunk, grads_device[p], force_copy=True)
if not (master_weights) or (enable_gradient_accumulation):
chunk_manager.move_chunk(chunk, grads_device[p], force_copy=True)
return empty_grad
def zero_grad(self, set_to_none: bool = False) -> None:
@@ -513,11 +544,11 @@ class GeminiDDP(ModelWrapper):
# get copies of fp32 parameters in CPU
# as memory of fp16_params may be reused by grad, it's not reliable, we should use fp32_params and convert to fp16
params = self.fp32_params if self.reuse_fp16_chunk else self.fp16_params
params = self.fp32_params if self.chunk_manager.reuse_fp16_chunk else self.fp16_params
param_to_save_data = self._get_param_to_save_data(params, only_rank_0)
# get the mapping between copies and fp16 parameters
p_mapping = dict()
if self.reuse_fp16_chunk:
if self.chunk_manager.reuse_fp16_chunk:
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
name = self.param2name[p]
assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
@@ -713,7 +744,7 @@ class GeminiDDP(ModelWrapper):
name = self.param2name[p]
fp32_to_name[fp32_p] = name
params_to_load = self.fp32_params if self.reuse_fp16_chunk else self.fp16_params
params_to_load = self.fp32_params if self.chunk_manager.reuse_fp16_chunk else self.fp16_params
chunk_list = self.chunk_manager.get_chunks(params_to_load)
for chunk in chunk_list:
temp_chunk = get_temp_total_chunk_on_cuda(chunk, self.mixed_precision)
@@ -728,7 +759,9 @@ class GeminiDDP(ModelWrapper):
shard_fn = tensor.shard_fn
gather_fn = tensor.gather_fn
parameter_name = fp32_to_name[tensor] if self.reuse_fp16_chunk else self.param2name[tensor]
parameter_name = (
fp32_to_name[tensor] if self.chunk_manager.reuse_fp16_chunk else self.param2name[tensor]
)
parameter_slice = temp_chunk[tensor_info.offset : tensor_info.end]
load(
parameter_name,
@@ -900,7 +933,7 @@ class GeminiDDP(ModelWrapper):
gathered_param = param if keep_vars else param.detach()
else:
# as memory of fp16 param may be reused, we should use fp32 param and then convert to fp16
param_to_save = fp16_to_fp32[param] if self.reuse_fp16_chunk else param
param_to_save = fp16_to_fp32[param] if self.chunk_manager.reuse_fp16_chunk else param
if param_to_save not in gathered_param_buffer:
chunk = self.chunk_manager.get_chunk(param_to_save)
gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0))