mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-19 04:02:17 +00:00
[gemini]remove registered gradients hooks (#5696)
* fix gemini fix gemini * fix fix
This commit is contained in:
parent
22297789ab
commit
d4c5ef441e
@ -20,7 +20,12 @@ class ChunkManager:
|
|||||||
init_device (torch.device): optional, the device on which the chunk is initialized. The default is None.
|
init_device (torch.device): optional, the device on which the chunk is initialized. The default is None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, chunk_configuration, init_device: Optional[torch.device] = None) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
chunk_configuration,
|
||||||
|
init_device: Optional[torch.device] = None,
|
||||||
|
reuse_fp16_chunk: bool = True,
|
||||||
|
) -> None:
|
||||||
self.device = init_device or get_accelerator().get_current_device()
|
self.device = init_device or get_accelerator().get_current_device()
|
||||||
self.dp_degree_chunk_size_dict: Dict[int, int] = dict()
|
self.dp_degree_chunk_size_dict: Dict[int, int] = dict()
|
||||||
self.kwargs_config = chunk_configuration
|
self.kwargs_config = chunk_configuration
|
||||||
@ -33,6 +38,10 @@ class ChunkManager:
|
|||||||
self.accessed_chunks: Set[Chunk] = set()
|
self.accessed_chunks: Set[Chunk] = set()
|
||||||
self.accessed_mem: int = 0
|
self.accessed_mem: int = 0
|
||||||
self.total_mem: Dict[str, int] = {"cpu": 0, "cuda": 0}
|
self.total_mem: Dict[str, int] = {"cpu": 0, "cuda": 0}
|
||||||
|
self.reuse_fp16_chunk = reuse_fp16_chunk
|
||||||
|
# Whether model is accumulating gradients,
|
||||||
|
self.accumulating_grads = False
|
||||||
|
self.overflow_counter = 0
|
||||||
|
|
||||||
def register_tensor(
|
def register_tensor(
|
||||||
self,
|
self,
|
||||||
|
@ -19,6 +19,7 @@ def init_chunk_manager(
|
|||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
init_device: Optional[torch.device] = None,
|
init_device: Optional[torch.device] = None,
|
||||||
hidden_dim: Optional[int] = None,
|
hidden_dim: Optional[int] = None,
|
||||||
|
reuse_fp16_chunk: bool = True,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> ChunkManager:
|
) -> ChunkManager:
|
||||||
@ -50,5 +51,9 @@ def init_chunk_manager(
|
|||||||
)
|
)
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
chunk_manager = ChunkManager(config_dict, init_device)
|
chunk_manager = ChunkManager(
|
||||||
|
config_dict,
|
||||||
|
init_device,
|
||||||
|
reuse_fp16_chunk=reuse_fp16_chunk,
|
||||||
|
)
|
||||||
return chunk_manager
|
return chunk_manager
|
||||||
|
@ -98,8 +98,14 @@ class GeminiDDP(ModelWrapper):
|
|||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert mixed_precision in (torch.float16, torch.bfloat16)
|
assert mixed_precision in (torch.float16, torch.bfloat16)
|
||||||
|
reuse_fp16_chunk = master_weights if not enable_gradient_accumulation else False
|
||||||
|
self.enable_gradient_accumulation = enable_gradient_accumulation
|
||||||
if chunk_config_dict is not None:
|
if chunk_config_dict is not None:
|
||||||
self.chunk_manager = ChunkManager(chunk_config_dict, chunk_init_device)
|
self.chunk_manager = ChunkManager(
|
||||||
|
chunk_config_dict,
|
||||||
|
chunk_init_device,
|
||||||
|
reuse_fp16_chunk=reuse_fp16_chunk,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# some ugly hotfix for the compatibility with Lightning
|
# some ugly hotfix for the compatibility with Lightning
|
||||||
if search_range_m is None:
|
if search_range_m is None:
|
||||||
@ -112,6 +118,7 @@ class GeminiDDP(ModelWrapper):
|
|||||||
min_chunk_size_m=min_chunk_size_m,
|
min_chunk_size_m=min_chunk_size_m,
|
||||||
strict_ddp_flag=strict_ddp_mode,
|
strict_ddp_flag=strict_ddp_mode,
|
||||||
process_group=zero_group,
|
process_group=zero_group,
|
||||||
|
reuse_fp16_chunk=reuse_fp16_chunk,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
)
|
)
|
||||||
self.gemini_manager = GeminiManager(
|
self.gemini_manager = GeminiManager(
|
||||||
@ -128,7 +135,6 @@ class GeminiDDP(ModelWrapper):
|
|||||||
self.param_op_hook = GeminiZeROHook(self.gemini_manager)
|
self.param_op_hook = GeminiZeROHook(self.gemini_manager)
|
||||||
self.fp32_params: List[torch.Tensor] = list()
|
self.fp32_params: List[torch.Tensor] = list()
|
||||||
self.fp16_params: List[ColoParameter] = list()
|
self.fp16_params: List[ColoParameter] = list()
|
||||||
self.overflow_counter = 0
|
|
||||||
self.grads_device: Dict[torch.Tensor, torch.device] = dict()
|
self.grads_device: Dict[torch.Tensor, torch.device] = dict()
|
||||||
self.param2name: Dict[nn.Parameter, str] = dict()
|
self.param2name: Dict[nn.Parameter, str] = dict()
|
||||||
self.name2param: Dict[str, nn.Parameter] = dict()
|
self.name2param: Dict[str, nn.Parameter] = dict()
|
||||||
@ -137,14 +143,8 @@ class GeminiDDP(ModelWrapper):
|
|||||||
self.zero_group = zero_group or _get_default_group()
|
self.zero_group = zero_group or _get_default_group()
|
||||||
self.extra_dp_group = extra_dp_group
|
self.extra_dp_group = extra_dp_group
|
||||||
|
|
||||||
self.reuse_fp16_chunk = master_weights
|
|
||||||
self.master_weights = master_weights
|
self.master_weights = master_weights
|
||||||
|
|
||||||
self.enable_gradient_accumulation = enable_gradient_accumulation
|
|
||||||
if self.enable_gradient_accumulation:
|
|
||||||
self.reuse_fp16_chunk = False
|
|
||||||
self.accumulating_grads = False # Whether model is accumulating gradients
|
|
||||||
|
|
||||||
self._logger = get_dist_logger()
|
self._logger = get_dist_logger()
|
||||||
|
|
||||||
if self.gemini_manager._premade_memstats_:
|
if self.gemini_manager._premade_memstats_:
|
||||||
@ -178,7 +178,29 @@ class GeminiDDP(ModelWrapper):
|
|||||||
if is_ddp_ignored(p):
|
if is_ddp_ignored(p):
|
||||||
continue
|
continue
|
||||||
if p.requires_grad:
|
if p.requires_grad:
|
||||||
p.register_hook(partial(self.grad_handle, p))
|
p._grad_handle = p.register_hook(
|
||||||
|
partial(
|
||||||
|
GeminiDDP.grad_handle,
|
||||||
|
chunk_manager=self.chunk_manager,
|
||||||
|
param2name=self.param2name,
|
||||||
|
grads_device=self.grads_device,
|
||||||
|
master_weights=self.master_weights,
|
||||||
|
enable_gradient_accumulation=self.enable_gradient_accumulation,
|
||||||
|
p=p,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def remove_hooks(self):
|
||||||
|
for p in self.module.parameters():
|
||||||
|
if is_ddp_ignored(p):
|
||||||
|
continue
|
||||||
|
if p.requires_grad:
|
||||||
|
assert hasattr(p, "_grad_handle")
|
||||||
|
p._grad_handle.remove()
|
||||||
|
delattr(p, "_grad_handle")
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
self.remove_hooks()
|
||||||
|
|
||||||
def parameters(self, recurse: bool = True):
|
def parameters(self, recurse: bool = True):
|
||||||
return self.module.parameters(recurse)
|
return self.module.parameters(recurse)
|
||||||
@ -324,8 +346,8 @@ class GeminiDDP(ModelWrapper):
|
|||||||
f"{error_str}",
|
f"{error_str}",
|
||||||
)
|
)
|
||||||
self._setup_grads_ptr()
|
self._setup_grads_ptr()
|
||||||
if self.enable_gradient_accumulation and not self.accumulating_grads:
|
if self.enable_gradient_accumulation and not self.chunk_manager.accumulating_grads:
|
||||||
self.accumulating_grads = True # Turn on the state of gradient accumulation.
|
self.chunk_manager.accumulating_grads = True # Turn on the state of gradient accumulation.
|
||||||
self._logger.debug(
|
self._logger.debug(
|
||||||
f"comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}"
|
f"comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}"
|
||||||
)
|
)
|
||||||
@ -340,25 +362,34 @@ class GeminiDDP(ModelWrapper):
|
|||||||
def backward_by_grad(self, tensor, grad):
|
def backward_by_grad(self, tensor, grad):
|
||||||
raise RuntimeError("Gemini is not compatible with pipeline. backward_by_grad shoudn't be called in Gemini.")
|
raise RuntimeError("Gemini is not compatible with pipeline. backward_by_grad shoudn't be called in Gemini.")
|
||||||
|
|
||||||
def grad_handle(self, p, grad):
|
@staticmethod
|
||||||
|
def grad_handle(
|
||||||
|
grad,
|
||||||
|
chunk_manager: ChunkManager,
|
||||||
|
param2name: Dict,
|
||||||
|
grads_device: Dict,
|
||||||
|
master_weights: bool,
|
||||||
|
enable_gradient_accumulation: bool,
|
||||||
|
p: nn.Parameter,
|
||||||
|
):
|
||||||
setattr(p, "_gemini_reduced", True)
|
setattr(p, "_gemini_reduced", True)
|
||||||
empty_grad = torch.empty_like(grad)
|
empty_grad = torch.empty_like(grad)
|
||||||
free_storage(empty_grad)
|
free_storage(empty_grad)
|
||||||
with torch._C.DisableTorchFunction():
|
with torch._C.DisableTorchFunction():
|
||||||
chunk = self.chunk_manager.get_chunk(p)
|
chunk = chunk_manager.get_chunk(p)
|
||||||
if chunk.tensors_info[p].state != TensorState.HOLD_AFTER_BWD:
|
if chunk.tensors_info[p].state != TensorState.HOLD_AFTER_BWD:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Parameter `{self.param2name[p]}` failed at the gradient reduction. "
|
f"Parameter `{param2name[p]}` failed at the gradient reduction. "
|
||||||
"Some unsupported torch function is operated upon this parameter."
|
"Some unsupported torch function is operated upon this parameter."
|
||||||
)
|
)
|
||||||
grad_chunk = chunk
|
grad_chunk = chunk
|
||||||
if not self.reuse_fp16_chunk:
|
if not chunk_manager.reuse_fp16_chunk:
|
||||||
if not self.accumulating_grads:
|
if not chunk_manager.accumulating_grads:
|
||||||
grad_chunk = self.chunk_manager.init_grad_chunk(chunk)
|
grad_chunk = chunk_manager.init_grad_chunk(chunk)
|
||||||
else:
|
else:
|
||||||
assert chunk.grad_chunk is not None
|
assert chunk.grad_chunk is not None
|
||||||
if chunk.grad_chunk not in self.chunk_manager.accessed_chunks:
|
if chunk.grad_chunk not in chunk_manager.accessed_chunks:
|
||||||
grad_chunk = self.chunk_manager.rearrange_accumulated_grad_chunk(chunk)
|
grad_chunk = chunk_manager.rearrange_accumulated_grad_chunk(chunk)
|
||||||
else:
|
else:
|
||||||
grad_chunk = chunk.grad_chunk
|
grad_chunk = chunk.grad_chunk
|
||||||
chunk.grad_chunk.l2_norm = None
|
chunk.grad_chunk.l2_norm = None
|
||||||
@ -371,33 +402,33 @@ class GeminiDDP(ModelWrapper):
|
|||||||
chunk.tensor_trans_state(p, TensorState.HOLD)
|
chunk.tensor_trans_state(p, TensorState.HOLD)
|
||||||
|
|
||||||
grad_chunk.tensor_trans_state(p, TensorState.READY_FOR_REDUCE)
|
grad_chunk.tensor_trans_state(p, TensorState.READY_FOR_REDUCE)
|
||||||
if not self.accumulating_grads:
|
if not chunk_manager.accumulating_grads:
|
||||||
grad_chunk.copy_tensor_to_chunk_slice(p, grad, update_ptr=self.reuse_fp16_chunk)
|
grad_chunk.copy_tensor_to_chunk_slice(p, grad, update_ptr=chunk_manager.reuse_fp16_chunk)
|
||||||
else:
|
else:
|
||||||
grad_chunk.add_tensor_to_chunk_slice(p, grad)
|
grad_chunk.add_tensor_to_chunk_slice(p, grad)
|
||||||
reduced = self.chunk_manager.reduce_chunk(grad_chunk)
|
reduced = chunk_manager.reduce_chunk(grad_chunk)
|
||||||
if reduced:
|
if reduced:
|
||||||
if not self.reuse_fp16_chunk:
|
if not chunk_manager.reuse_fp16_chunk:
|
||||||
if chunk.keep_gathered:
|
if chunk.keep_gathered:
|
||||||
self.chunk_manager.fake_release_chunk(chunk)
|
chunk_manager.fake_release_chunk(chunk)
|
||||||
else:
|
else:
|
||||||
self.chunk_manager.release_chunk(chunk)
|
chunk_manager.release_chunk(chunk)
|
||||||
if grad_chunk.is_gathered:
|
if grad_chunk.is_gathered:
|
||||||
grad_chunk.cuda_global_chunk.div_(chunk.pg_size)
|
grad_chunk.cuda_global_chunk.div_(chunk.pg_size)
|
||||||
if self.extra_dp_group is not None:
|
if chunk.extra_dp_group is not None:
|
||||||
grad_chunk.cuda_global_chunk.div_(chunk.extra_dp_size)
|
grad_chunk.cuda_global_chunk.div_(chunk.extra_dp_size)
|
||||||
else:
|
else:
|
||||||
grad_chunk.cuda_shard.div_(chunk.pg_size)
|
grad_chunk.cuda_shard.div_(chunk.pg_size)
|
||||||
if self.extra_dp_group is not None:
|
if chunk.extra_dp_group is not None:
|
||||||
grad_chunk.cuda_shard.div_(chunk.extra_dp_size)
|
grad_chunk.cuda_shard.div_(chunk.extra_dp_size)
|
||||||
# check overflow elements
|
# check overflow elements
|
||||||
self.overflow_counter += grad_chunk.has_inf_or_nan
|
chunk_manager.overflow_counter += grad_chunk.has_inf_or_nan
|
||||||
# record l2 norm for gradient clipping. flag is bound to fp16 chunk
|
# record l2 norm for gradient clipping. flag is bound to fp16 chunk
|
||||||
if chunk.l2_norm_flag:
|
if chunk.l2_norm_flag:
|
||||||
grad_chunk.set_l2_norm()
|
grad_chunk.set_l2_norm()
|
||||||
self.chunk_manager.move_chunk(grad_chunk, self.grads_device[p], force_copy=True)
|
chunk_manager.move_chunk(grad_chunk, grads_device[p], force_copy=True)
|
||||||
if not (self.master_weights) or (self.enable_gradient_accumulation):
|
if not (master_weights) or (enable_gradient_accumulation):
|
||||||
self.chunk_manager.move_chunk(chunk, self.grads_device[p], force_copy=True)
|
chunk_manager.move_chunk(chunk, grads_device[p], force_copy=True)
|
||||||
return empty_grad
|
return empty_grad
|
||||||
|
|
||||||
def zero_grad(self, set_to_none: bool = False) -> None:
|
def zero_grad(self, set_to_none: bool = False) -> None:
|
||||||
@ -513,11 +544,11 @@ class GeminiDDP(ModelWrapper):
|
|||||||
|
|
||||||
# get copies of fp32 parameters in CPU
|
# get copies of fp32 parameters in CPU
|
||||||
# as memory of fp16_params may be reused by grad, it's not reliable, we should use fp32_params and convert to fp16
|
# as memory of fp16_params may be reused by grad, it's not reliable, we should use fp32_params and convert to fp16
|
||||||
params = self.fp32_params if self.reuse_fp16_chunk else self.fp16_params
|
params = self.fp32_params if self.chunk_manager.reuse_fp16_chunk else self.fp16_params
|
||||||
param_to_save_data = self._get_param_to_save_data(params, only_rank_0)
|
param_to_save_data = self._get_param_to_save_data(params, only_rank_0)
|
||||||
# get the mapping between copies and fp16 parameters
|
# get the mapping between copies and fp16 parameters
|
||||||
p_mapping = dict()
|
p_mapping = dict()
|
||||||
if self.reuse_fp16_chunk:
|
if self.chunk_manager.reuse_fp16_chunk:
|
||||||
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
|
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
|
||||||
name = self.param2name[p]
|
name = self.param2name[p]
|
||||||
assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
|
assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
|
||||||
@ -713,7 +744,7 @@ class GeminiDDP(ModelWrapper):
|
|||||||
name = self.param2name[p]
|
name = self.param2name[p]
|
||||||
fp32_to_name[fp32_p] = name
|
fp32_to_name[fp32_p] = name
|
||||||
|
|
||||||
params_to_load = self.fp32_params if self.reuse_fp16_chunk else self.fp16_params
|
params_to_load = self.fp32_params if self.chunk_manager.reuse_fp16_chunk else self.fp16_params
|
||||||
chunk_list = self.chunk_manager.get_chunks(params_to_load)
|
chunk_list = self.chunk_manager.get_chunks(params_to_load)
|
||||||
for chunk in chunk_list:
|
for chunk in chunk_list:
|
||||||
temp_chunk = get_temp_total_chunk_on_cuda(chunk, self.mixed_precision)
|
temp_chunk = get_temp_total_chunk_on_cuda(chunk, self.mixed_precision)
|
||||||
@ -728,7 +759,9 @@ class GeminiDDP(ModelWrapper):
|
|||||||
shard_fn = tensor.shard_fn
|
shard_fn = tensor.shard_fn
|
||||||
gather_fn = tensor.gather_fn
|
gather_fn = tensor.gather_fn
|
||||||
|
|
||||||
parameter_name = fp32_to_name[tensor] if self.reuse_fp16_chunk else self.param2name[tensor]
|
parameter_name = (
|
||||||
|
fp32_to_name[tensor] if self.chunk_manager.reuse_fp16_chunk else self.param2name[tensor]
|
||||||
|
)
|
||||||
parameter_slice = temp_chunk[tensor_info.offset : tensor_info.end]
|
parameter_slice = temp_chunk[tensor_info.offset : tensor_info.end]
|
||||||
load(
|
load(
|
||||||
parameter_name,
|
parameter_name,
|
||||||
@ -900,7 +933,7 @@ class GeminiDDP(ModelWrapper):
|
|||||||
gathered_param = param if keep_vars else param.detach()
|
gathered_param = param if keep_vars else param.detach()
|
||||||
else:
|
else:
|
||||||
# as memory of fp16 param may be reused, we should use fp32 param and then convert to fp16
|
# as memory of fp16 param may be reused, we should use fp32 param and then convert to fp16
|
||||||
param_to_save = fp16_to_fp32[param] if self.reuse_fp16_chunk else param
|
param_to_save = fp16_to_fp32[param] if self.chunk_manager.reuse_fp16_chunk else param
|
||||||
if param_to_save not in gathered_param_buffer:
|
if param_to_save not in gathered_param_buffer:
|
||||||
chunk = self.chunk_manager.get_chunk(param_to_save)
|
chunk = self.chunk_manager.get_chunk(param_to_save)
|
||||||
gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0))
|
gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0))
|
||||||
|
@ -62,10 +62,10 @@ class GeminiFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
|
|||||||
self.module = module
|
self.module = module
|
||||||
|
|
||||||
def check_local_overflow(self) -> bool:
|
def check_local_overflow(self) -> bool:
|
||||||
return self.module.overflow_counter > 0
|
return self.module.chunk_manager.overflow_counter > 0
|
||||||
|
|
||||||
def pre_zero_grad(self) -> None:
|
def pre_zero_grad(self) -> None:
|
||||||
self.module.overflow_counter = 0
|
self.module.chunk_manager.overflow_counter = 0
|
||||||
|
|
||||||
|
|
||||||
class GeminiOptimizer(OptimizerWrapper):
|
class GeminiOptimizer(OptimizerWrapper):
|
||||||
@ -202,7 +202,7 @@ class GeminiOptimizer(OptimizerWrapper):
|
|||||||
chunk16 = self.param_to_chunk16[fake_param]
|
chunk16 = self.param_to_chunk16[fake_param]
|
||||||
begin, end = self.param_to_range[fake_param]
|
begin, end = self.param_to_range[fake_param]
|
||||||
|
|
||||||
grad_chunk16 = chunk16 if self.module.reuse_fp16_chunk else chunk16.grad_chunk
|
grad_chunk16 = chunk16 if self.module.chunk_manager.reuse_fp16_chunk else chunk16.grad_chunk
|
||||||
fake_param.data = grad_chunk16.payload[begin:end]
|
fake_param.data = grad_chunk16.payload[begin:end]
|
||||||
fake_param.grad = fake_param.data
|
fake_param.grad = fake_param.data
|
||||||
|
|
||||||
@ -221,14 +221,14 @@ class GeminiOptimizer(OptimizerWrapper):
|
|||||||
|
|
||||||
def _clear_global_norm(self) -> None:
|
def _clear_global_norm(self) -> None:
|
||||||
for c16 in self.chunk16_set:
|
for c16 in self.chunk16_set:
|
||||||
grad_chunk = c16 if self.module.reuse_fp16_chunk else c16.grad_chunk
|
grad_chunk = c16 if self.module.chunk_manager.reuse_fp16_chunk else c16.grad_chunk
|
||||||
grad_chunk.l2_norm = None
|
grad_chunk.l2_norm = None
|
||||||
|
|
||||||
def _calc_global_norm(self) -> float:
|
def _calc_global_norm(self) -> float:
|
||||||
norm_sqr: float = 0.0
|
norm_sqr: float = 0.0
|
||||||
group_to_norm = dict()
|
group_to_norm = dict()
|
||||||
for c16 in self.chunk16_set:
|
for c16 in self.chunk16_set:
|
||||||
grad_chunk = c16 if self.module.reuse_fp16_chunk else c16.grad_chunk
|
grad_chunk = c16 if self.module.chunk_manager.reuse_fp16_chunk else c16.grad_chunk
|
||||||
assert grad_chunk.l2_norm is not None
|
assert grad_chunk.l2_norm is not None
|
||||||
|
|
||||||
if grad_chunk.is_gathered:
|
if grad_chunk.is_gathered:
|
||||||
@ -275,7 +275,7 @@ class GeminiOptimizer(OptimizerWrapper):
|
|||||||
self._logger.info(f"Found overflow. Skip step")
|
self._logger.info(f"Found overflow. Skip step")
|
||||||
self._clear_global_norm() # clear recorded norm
|
self._clear_global_norm() # clear recorded norm
|
||||||
self.zero_grad() # reset all gradients
|
self.zero_grad() # reset all gradients
|
||||||
if self.module.reuse_fp16_chunk:
|
if self.module.chunk_manager.reuse_fp16_chunk:
|
||||||
self._update_fp16_params()
|
self._update_fp16_params()
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -288,7 +288,7 @@ class GeminiOptimizer(OptimizerWrapper):
|
|||||||
self.zero_grad()
|
self.zero_grad()
|
||||||
if self.module.master_weights:
|
if self.module.master_weights:
|
||||||
self._update_fp16_params()
|
self._update_fp16_params()
|
||||||
self.module.accumulating_grads = False
|
self.module.chunk_manager.accumulating_grads = False
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def clip_grad_norm(self, model: torch.nn.Module, max_norm: float, norm_type: float = 2.0):
|
def clip_grad_norm(self, model: torch.nn.Module, max_norm: float, norm_type: float = 2.0):
|
||||||
|
@ -26,7 +26,7 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
|
|||||||
chunk_manager = model.chunk_manager
|
chunk_manager = model.chunk_manager
|
||||||
param_list = [p for p in model.parameters()]
|
param_list = [p for p in model.parameters()]
|
||||||
chunk_list = chunk_manager.get_chunks(param_list)
|
chunk_list = chunk_manager.get_chunks(param_list)
|
||||||
if not model.reuse_fp16_chunk:
|
if not model.chunk_manager.reuse_fp16_chunk:
|
||||||
chunk_list = [chunk.grad_chunk for chunk in chunk_list]
|
chunk_list = [chunk.grad_chunk for chunk in chunk_list]
|
||||||
for chunk in chunk_list:
|
for chunk in chunk_list:
|
||||||
chunk_manager.access_chunk(chunk)
|
chunk_manager.access_chunk(chunk)
|
||||||
|
Loading…
Reference in New Issue
Block a user