mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[utils] Impl clip_grad_norm for ColoTensor and ZeroOptimizer (#1442)
This commit is contained in:
@@ -1,11 +1,13 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
import os
|
||||
from pprint import pp
|
||||
import random
|
||||
import socket
|
||||
from pathlib import Path
|
||||
from typing import Callable, List, Union
|
||||
from typing import Callable, List, Union, Dict, Optional
|
||||
import functools
|
||||
|
||||
import torch
|
||||
from torch._six import inf
|
||||
from torch.nn.parameter import Parameter
|
||||
@@ -22,9 +24,11 @@ from colossalai.constants import (IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PAR
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.global_variables import tensor_parallel_env as env
|
||||
|
||||
from .multi_tensor_apply import multi_tensor_applier
|
||||
|
||||
from colossalai.tensor import ColoParameter, ProcessGroup
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
def print_rank_0(msg: str, logger=None):
|
||||
"""Print messages and save logs(optional). This is executed only if you are the rank-0 gpu.
|
||||
@@ -162,6 +166,121 @@ def _get_tensor_norm(norm: Union[float, torch.Tensor], move_to_cuda) -> torch.Te
|
||||
# ======== Gradient Clipping =========
|
||||
|
||||
|
||||
def _compute_local_lp(params: List[ColoParameter], norm_type: float) -> float:
|
||||
if len(params) == 0:
|
||||
return 0.0
|
||||
grads = [p.grad for p in params]
|
||||
use_cuda_kernel = grads[0].device.type == 'cuda'
|
||||
if norm_type == inf:
|
||||
local_lp = max([g.abs().max() for g in grads])
|
||||
elif norm_type == 2.0 and use_cuda_kernel:
|
||||
local_lp = _calc_l2_norm(grads)**norm_type
|
||||
else:
|
||||
local_lp = _calc_lp(grads, norm_type)
|
||||
if isinstance(local_lp, torch.Tensor):
|
||||
return local_lp.item()
|
||||
return local_lp
|
||||
|
||||
|
||||
def _compute_buckets_lp(params: List[ColoParameter], norm_type: float) -> float:
|
||||
if len(params) == 0:
|
||||
return 0.0
|
||||
buckets: Dict[Optional[ProcessGroup], List[ColoParameter]] = defaultdict(list)
|
||||
for p in params:
|
||||
if p.is_replicate():
|
||||
buckets[None].append(p)
|
||||
else:
|
||||
buckets[p.get_process_group().tp_process_group()].append(p)
|
||||
total_lp = 0.0
|
||||
for group, bucket in buckets.items():
|
||||
local_lp = _compute_local_lp(bucket, norm_type)
|
||||
if group is not None:
|
||||
local_lp_tensor = torch.tensor([local_lp], device=torch.cuda.current_device())
|
||||
if norm_type == inf:
|
||||
dist.all_reduce(local_lp_tensor, op=dist.ReduceOp.MAX, group=group)
|
||||
else:
|
||||
dist.all_reduce(local_lp_tensor, group=group)
|
||||
local_lp = local_lp_tensor.item()
|
||||
if norm_type == inf:
|
||||
total_lp = max(total_lp, local_lp)
|
||||
else:
|
||||
total_lp += local_lp
|
||||
return total_lp
|
||||
|
||||
|
||||
def _compute_pp_grad_lp(total_lp: float, norm_type: float) -> float:
|
||||
if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
|
||||
total_lp_tensor = torch.tensor([total_lp], device=torch.cuda.current_device())
|
||||
if norm_type == inf:
|
||||
dist.all_reduce(total_lp_tensor, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.PIPELINE))
|
||||
else:
|
||||
dist.all_reduce(total_lp_tensor, group=gpc.get_group(ParallelMode.PIPELINE))
|
||||
total_lp = total_lp_tensor.item()
|
||||
return total_lp
|
||||
|
||||
|
||||
def _compute_grad_lp(parameters, norm_type: float = 2.0) -> float:
|
||||
if isinstance(parameters, torch.Tensor):
|
||||
parameters = [parameters]
|
||||
grad_dtype = None
|
||||
cpu_grad_params: List[ColoParameter] = []
|
||||
cuda_grad_params: List[ColoParameter] = []
|
||||
for p in parameters:
|
||||
if p.grad is None:
|
||||
continue
|
||||
assert isinstance(p, ColoParameter)
|
||||
if grad_dtype is None:
|
||||
grad_dtype = p.grad.dtype
|
||||
assert p.grad.dtype == grad_dtype, f'Expected all grads are {grad_dtype}, got {p.grad.dtype}'
|
||||
if p.grad.device.type == 'cuda':
|
||||
cuda_grad_params.append(p)
|
||||
else:
|
||||
cpu_grad_params.append(p)
|
||||
norm_type = float(norm_type)
|
||||
cpu_lp = _compute_buckets_lp(cpu_grad_params, norm_type)
|
||||
cuda_lp = _compute_buckets_lp(cuda_grad_params, norm_type)
|
||||
if norm_type == inf:
|
||||
total_lp = max(cpu_lp, cuda_lp)
|
||||
else:
|
||||
total_lp = cpu_lp + cuda_lp
|
||||
return _compute_pp_grad_lp(total_lp, norm_type)
|
||||
|
||||
|
||||
def compute_grad_norm(parameters, norm_type: float = 2.0) -> float:
|
||||
norm_type = float(norm_type)
|
||||
total_norm = _compute_grad_lp(parameters, norm_type)
|
||||
if norm_type != inf:
|
||||
total_norm = total_norm**(1 / norm_type)
|
||||
return total_norm
|
||||
|
||||
|
||||
def _clip_grad_norm(parameters, max_norm: float, total_norm: float) -> None:
|
||||
clip_coef = max_norm / (total_norm + 1e-6)
|
||||
if clip_coef < 1.0:
|
||||
cuda_grads: List[torch.Tensor] = []
|
||||
cpu_grads: List[torch.Tensor] = []
|
||||
if isinstance(parameters, torch.Tensor):
|
||||
parameters = [parameters]
|
||||
for p in parameters:
|
||||
if p.grad is None:
|
||||
continue
|
||||
if p.grad.device.type == 'cuda':
|
||||
cuda_grads.append(p.grad.detach())
|
||||
else:
|
||||
cpu_grads.append(p.grad.detach())
|
||||
if len(cuda_grads) > 0:
|
||||
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
multi_tensor_applier(colossal_C.multi_tensor_scale, dummy_overflow_buf, [cuda_grads, cuda_grads], clip_coef)
|
||||
for g in cpu_grads:
|
||||
g.mul_(clip_coef)
|
||||
|
||||
|
||||
def clip_grad_norm(parameters, max_norm: float, norm_type: float = 2.0) -> float:
|
||||
total_norm = compute_grad_norm(parameters, norm_type)
|
||||
_clip_grad_norm(parameters, max_norm, total_norm)
|
||||
return total_norm
|
||||
|
||||
|
||||
def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
|
||||
"""Clips gradient norm of an iterable of parameters whose gradients are in fp32.
|
||||
|
||||
|
Reference in New Issue
Block a user