mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-28 11:45:23 +00:00
[plugin] support get_grad_norm (#6115)
This commit is contained in:
parent
13ffa08cfa
commit
a15ab139ad
@ -1,4 +1,4 @@
|
|||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, inf
|
from torch import Tensor, inf
|
||||||
@ -84,6 +84,7 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
|
|||||||
self.master_to_working_map[master_p] = p
|
self.master_to_working_map[master_p] = p
|
||||||
master_params.append(master_p)
|
master_params.append(master_p)
|
||||||
group["params"] = master_params
|
group["params"] = master_params
|
||||||
|
self._current_grad_norm: Optional[float] = None
|
||||||
|
|
||||||
def backward(self, loss: Tensor, *args, **kwargs):
|
def backward(self, loss: Tensor, *args, **kwargs):
|
||||||
loss = self.mixed_precision.pre_backward(loss)
|
loss = self.mixed_precision.pre_backward(loss)
|
||||||
@ -187,6 +188,7 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
|
|||||||
if p.grad is not None
|
if p.grad is not None
|
||||||
]
|
]
|
||||||
total_norm = self._compute_grad_norm(param_gradient_pairs)
|
total_norm = self._compute_grad_norm(param_gradient_pairs)
|
||||||
|
self._current_grad_norm = total_norm
|
||||||
self._unscale_and_clip_grads(total_norm)
|
self._unscale_and_clip_grads(total_norm)
|
||||||
|
|
||||||
self.optim.step(*args, **kwargs)
|
self.optim.step(*args, **kwargs)
|
||||||
@ -212,3 +214,6 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
|
|||||||
|
|
||||||
def get_master_to_working_map(self) -> Dict[int, torch.Tensor]:
|
def get_master_to_working_map(self) -> Dict[int, torch.Tensor]:
|
||||||
return {id(master_p): working_p for master_p, working_p in self.master_to_working_map.items()}
|
return {id(master_p): working_p for master_p, working_p in self.master_to_working_map.items()}
|
||||||
|
|
||||||
|
def get_grad_norm(self, norm_type=2, **kwargs):
|
||||||
|
return self._current_grad_norm
|
||||||
|
@ -293,6 +293,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
|||||||
self.pp_pg = pp_process_group
|
self.pp_pg = pp_process_group
|
||||||
self.tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
|
self.tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
|
||||||
self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
|
self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
|
||||||
|
self._current_grad_norm: Optional[float] = None
|
||||||
super().__init__(optim)
|
super().__init__(optim)
|
||||||
|
|
||||||
def backward(self, loss: Tensor, *args, **kwargs):
|
def backward(self, loss: Tensor, *args, **kwargs):
|
||||||
@ -364,6 +365,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
|||||||
(p, p.grad) for group in self.optim.param_groups for p in group["params"] if p.grad is not None
|
(p, p.grad) for group in self.optim.param_groups for p in group["params"] if p.grad is not None
|
||||||
]
|
]
|
||||||
total_norm = self._compute_grad_norm(param_gradient_pairs)
|
total_norm = self._compute_grad_norm(param_gradient_pairs)
|
||||||
|
self._current_grad_norm = total_norm
|
||||||
|
|
||||||
# Clip the gradients to prevent exploding gradients.
|
# Clip the gradients to prevent exploding gradients.
|
||||||
self._clip_grad_norm(total_norm)
|
self._clip_grad_norm(total_norm)
|
||||||
@ -477,6 +479,9 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
|||||||
def get_master_to_working_map(self):
|
def get_master_to_working_map(self):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_grad_norm(self, norm_type=2, **kwargs):
|
||||||
|
return self._current_grad_norm
|
||||||
|
|
||||||
|
|
||||||
class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -135,6 +135,18 @@ class OptimizerWrapper:
|
|||||||
"""
|
"""
|
||||||
return self.optim
|
return self.optim
|
||||||
|
|
||||||
|
def get_grad_norm(self, norm_type: Union[float, int] = 2.0, **kwargs) -> Optional[float]:
|
||||||
|
"""
|
||||||
|
Returns the gradient norm of an iterable of parameters. This method should be called after optimizer.step().
|
||||||
|
|
||||||
|
Args:
|
||||||
|
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[float]: Total norm of the gradients (viewed as a single vector). If there are no valid gradients, returns None.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("The method get_grad_norm is not implemented yet.")
|
||||||
|
|
||||||
|
|
||||||
class DistributedOptim(Optimizer):
|
class DistributedOptim(Optimizer):
|
||||||
def setup_distributed(
|
def setup_distributed(
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
|
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
|
||||||
import copy
|
import copy
|
||||||
import math
|
import math
|
||||||
from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple, Union
|
from typing import Any, Dict, Iterator, Optional, OrderedDict, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@ -195,6 +195,7 @@ class GeminiOptimizer(OptimizerWrapper):
|
|||||||
self._logger.warning(f'gpu_margin_mem_ratio is meaningless when placement_policy is not "auto"', ranks=[0])
|
self._logger.warning(f'gpu_margin_mem_ratio is meaningless when placement_policy is not "auto"', ranks=[0])
|
||||||
|
|
||||||
self._register_states = disposable(self._register_states_)
|
self._register_states = disposable(self._register_states_)
|
||||||
|
self._current_grad_norm: Optional[float] = None
|
||||||
|
|
||||||
def _set_grad_ptr(self):
|
def _set_grad_ptr(self):
|
||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
@ -255,6 +256,7 @@ class GeminiOptimizer(OptimizerWrapper):
|
|||||||
|
|
||||||
if self.clipping_flag:
|
if self.clipping_flag:
|
||||||
total_norm = self._calc_global_norm()
|
total_norm = self._calc_global_norm()
|
||||||
|
self._current_grad_norm = total_norm
|
||||||
clip = ((total_norm / div_scale) + 1e-6) / self.max_norm
|
clip = ((total_norm / div_scale) + 1e-6) / self.max_norm
|
||||||
if clip > 1:
|
if clip > 1:
|
||||||
div_scale = clip * div_scale
|
div_scale = clip * div_scale
|
||||||
@ -846,6 +848,9 @@ class GeminiOptimizer(OptimizerWrapper):
|
|||||||
f"Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm", ranks=[0]
|
f"Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm", ranks=[0]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_grad_norm(self, norm_type=2, **kwargs):
|
||||||
|
return self._current_grad_norm
|
||||||
|
|
||||||
|
|
||||||
class GeminiAdamOptimizer(GeminiOptimizer):
|
class GeminiAdamOptimizer(GeminiOptimizer):
|
||||||
def __init__(self, model: torch.nn.Module, **defaults: Any) -> None:
|
def __init__(self, model: torch.nn.Module, **defaults: Any) -> None:
|
||||||
|
@ -218,6 +218,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||||||
)
|
)
|
||||||
elif self._dtype is torch.bfloat16:
|
elif self._dtype is torch.bfloat16:
|
||||||
self.mixed_precision_mixin = BF16MixedPrecisionMixin()
|
self.mixed_precision_mixin = BF16MixedPrecisionMixin()
|
||||||
|
self._current_grad_norm: Optional[float] = None
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
for hook in self.grad_handles:
|
for hook in self.grad_handles:
|
||||||
@ -551,6 +552,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||||||
|
|
||||||
# unscale and clip grads
|
# unscale and clip grads
|
||||||
global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
|
global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
|
||||||
|
self._current_grad_norm = global_norm
|
||||||
self._unscale_and_clip_grads(grad_partition_groups, global_norm)
|
self._unscale_and_clip_grads(grad_partition_groups, global_norm)
|
||||||
|
|
||||||
# update the parameters
|
# update the parameters
|
||||||
@ -934,3 +936,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||||||
def _force_wait_all_gather(self):
|
def _force_wait_all_gather(self):
|
||||||
for param in self._working_param_to_padded_working_param.keys():
|
for param in self._working_param_to_padded_working_param.keys():
|
||||||
wait_all_gather_handle(param)
|
wait_all_gather_handle(param)
|
||||||
|
|
||||||
|
def get_grad_norm(self, norm_type=2, **kwargs):
|
||||||
|
return self._current_grad_norm
|
||||||
|
@ -76,6 +76,8 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[
|
|||||||
|
|
||||||
booster.execute_pipeline(data_iter, model, _criterion, optimizer, return_loss=True)
|
booster.execute_pipeline(data_iter, model, _criterion, optimizer, return_loss=True)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
grad_norm = optimizer.get_grad_norm()
|
||||||
|
assert grad_norm is None or isinstance(grad_norm, float)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return repr(e)
|
return repr(e)
|
||||||
|
@ -54,6 +54,8 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, zero_size, t
|
|||||||
|
|
||||||
booster.backward(loss, optimizer)
|
booster.backward(loss, optimizer)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
grad_norm = optimizer.get_grad_norm()
|
||||||
|
assert grad_norm is None or isinstance(grad_norm, float)
|
||||||
|
|
||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
print(f"Tensor Parallelism policy for {model.__class__} is not implemented yet\n.")
|
print(f"Tensor Parallelism policy for {model.__class__} is not implemented yet\n.")
|
||||||
|
@ -50,6 +50,8 @@ def run_fn(stage, model_fn, data_gen_fn, output_transform_fn, lora_config=None)
|
|||||||
|
|
||||||
booster.backward(loss, optimizer)
|
booster.backward(loss, optimizer)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
grad_norm = optimizer.get_grad_norm()
|
||||||
|
assert grad_norm is None or isinstance(grad_norm, float)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return repr(e)
|
return repr(e)
|
||||||
|
Loading…
Reference in New Issue
Block a user