[utils] Impl clip_grad_norm for ColoTensor and ZeroOptimizer (#1442)

This commit is contained in:
ver217
2022-08-11 22:58:58 +08:00
committed by GitHub
parent 74bee5f7e8
commit 821c6172e2
3 changed files with 232 additions and 5 deletions

View File

@@ -1,11 +1,13 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import os
from pprint import pp
import random
import socket
from pathlib import Path
from typing import Callable, List, Union
from typing import Callable, List, Union, Dict, Optional
import functools
import torch
from torch._six import inf
from torch.nn.parameter import Parameter
@@ -22,9 +24,11 @@ from colossalai.constants import (IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PAR
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env
from .multi_tensor_apply import multi_tensor_applier
from colossalai.tensor import ColoParameter, ProcessGroup
from collections import defaultdict
def print_rank_0(msg: str, logger=None):
"""Print messages and save logs(optional). This is executed only if you are the rank-0 gpu.
@@ -162,6 +166,121 @@ def _get_tensor_norm(norm: Union[float, torch.Tensor], move_to_cuda) -> torch.Te
# ======== Gradient Clipping =========
def _compute_local_lp(params: List[ColoParameter], norm_type: float) -> float:
if len(params) == 0:
return 0.0
grads = [p.grad for p in params]
use_cuda_kernel = grads[0].device.type == 'cuda'
if norm_type == inf:
local_lp = max([g.abs().max() for g in grads])
elif norm_type == 2.0 and use_cuda_kernel:
local_lp = _calc_l2_norm(grads)**norm_type
else:
local_lp = _calc_lp(grads, norm_type)
if isinstance(local_lp, torch.Tensor):
return local_lp.item()
return local_lp
def _compute_buckets_lp(params: List[ColoParameter], norm_type: float) -> float:
if len(params) == 0:
return 0.0
buckets: Dict[Optional[ProcessGroup], List[ColoParameter]] = defaultdict(list)
for p in params:
if p.is_replicate():
buckets[None].append(p)
else:
buckets[p.get_process_group().tp_process_group()].append(p)
total_lp = 0.0
for group, bucket in buckets.items():
local_lp = _compute_local_lp(bucket, norm_type)
if group is not None:
local_lp_tensor = torch.tensor([local_lp], device=torch.cuda.current_device())
if norm_type == inf:
dist.all_reduce(local_lp_tensor, op=dist.ReduceOp.MAX, group=group)
else:
dist.all_reduce(local_lp_tensor, group=group)
local_lp = local_lp_tensor.item()
if norm_type == inf:
total_lp = max(total_lp, local_lp)
else:
total_lp += local_lp
return total_lp
def _compute_pp_grad_lp(total_lp: float, norm_type: float) -> float:
if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
total_lp_tensor = torch.tensor([total_lp], device=torch.cuda.current_device())
if norm_type == inf:
dist.all_reduce(total_lp_tensor, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.PIPELINE))
else:
dist.all_reduce(total_lp_tensor, group=gpc.get_group(ParallelMode.PIPELINE))
total_lp = total_lp_tensor.item()
return total_lp
def _compute_grad_lp(parameters, norm_type: float = 2.0) -> float:
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
grad_dtype = None
cpu_grad_params: List[ColoParameter] = []
cuda_grad_params: List[ColoParameter] = []
for p in parameters:
if p.grad is None:
continue
assert isinstance(p, ColoParameter)
if grad_dtype is None:
grad_dtype = p.grad.dtype
assert p.grad.dtype == grad_dtype, f'Expected all grads are {grad_dtype}, got {p.grad.dtype}'
if p.grad.device.type == 'cuda':
cuda_grad_params.append(p)
else:
cpu_grad_params.append(p)
norm_type = float(norm_type)
cpu_lp = _compute_buckets_lp(cpu_grad_params, norm_type)
cuda_lp = _compute_buckets_lp(cuda_grad_params, norm_type)
if norm_type == inf:
total_lp = max(cpu_lp, cuda_lp)
else:
total_lp = cpu_lp + cuda_lp
return _compute_pp_grad_lp(total_lp, norm_type)
def compute_grad_norm(parameters, norm_type: float = 2.0) -> float:
norm_type = float(norm_type)
total_norm = _compute_grad_lp(parameters, norm_type)
if norm_type != inf:
total_norm = total_norm**(1 / norm_type)
return total_norm
def _clip_grad_norm(parameters, max_norm: float, total_norm: float) -> None:
clip_coef = max_norm / (total_norm + 1e-6)
if clip_coef < 1.0:
cuda_grads: List[torch.Tensor] = []
cpu_grads: List[torch.Tensor] = []
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
for p in parameters:
if p.grad is None:
continue
if p.grad.device.type == 'cuda':
cuda_grads.append(p.grad.detach())
else:
cpu_grads.append(p.grad.detach())
if len(cuda_grads) > 0:
dummy_overflow_buf = torch.cuda.IntTensor([0])
multi_tensor_applier(colossal_C.multi_tensor_scale, dummy_overflow_buf, [cuda_grads, cuda_grads], clip_coef)
for g in cpu_grads:
g.mul_(clip_coef)
def clip_grad_norm(parameters, max_norm: float, norm_type: float = 2.0) -> float:
total_norm = compute_grad_norm(parameters, norm_type)
_clip_grad_norm(parameters, max_norm, total_norm)
return total_norm
def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
"""Clips gradient norm of an iterable of parameters whose gradients are in fp32.