From 821c6172e235935fdfacee47ad5e7a8c67893038 Mon Sep 17 00:00:00 2001 From: ver217 Date: Thu, 11 Aug 2022 22:58:58 +0800 Subject: [PATCH] [utils] Impl clip_grad_norm for ColoTensor and ZeroOptimizer (#1442) --- colossalai/utils/common.py | 123 +++++++++++++++++- colossalai/zero/zero_optimizer.py | 35 ++++- .../test_utils/test_norm_gradient_clipping.py | 79 +++++++++++ 3 files changed, 232 insertions(+), 5 deletions(-) create mode 100644 tests/test_utils/test_norm_gradient_clipping.py diff --git a/colossalai/utils/common.py b/colossalai/utils/common.py index 19748770d..ccc136858 100644 --- a/colossalai/utils/common.py +++ b/colossalai/utils/common.py @@ -1,11 +1,13 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- import os +from pprint import pp import random import socket from pathlib import Path -from typing import Callable, List, Union +from typing import Callable, List, Union, Dict, Optional import functools + import torch from torch._six import inf from torch.nn.parameter import Parameter @@ -22,9 +24,11 @@ from colossalai.constants import (IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PAR from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.global_variables import tensor_parallel_env as env - from .multi_tensor_apply import multi_tensor_applier +from colossalai.tensor import ColoParameter, ProcessGroup +from collections import defaultdict + def print_rank_0(msg: str, logger=None): """Print messages and save logs(optional). This is executed only if you are the rank-0 gpu. @@ -162,6 +166,121 @@ def _get_tensor_norm(norm: Union[float, torch.Tensor], move_to_cuda) -> torch.Te # ======== Gradient Clipping ========= +def _compute_local_lp(params: List[ColoParameter], norm_type: float) -> float: + if len(params) == 0: + return 0.0 + grads = [p.grad for p in params] + use_cuda_kernel = grads[0].device.type == 'cuda' + if norm_type == inf: + local_lp = max([g.abs().max() for g in grads]) + elif norm_type == 2.0 and use_cuda_kernel: + local_lp = _calc_l2_norm(grads)**norm_type + else: + local_lp = _calc_lp(grads, norm_type) + if isinstance(local_lp, torch.Tensor): + return local_lp.item() + return local_lp + + +def _compute_buckets_lp(params: List[ColoParameter], norm_type: float) -> float: + if len(params) == 0: + return 0.0 + buckets: Dict[Optional[ProcessGroup], List[ColoParameter]] = defaultdict(list) + for p in params: + if p.is_replicate(): + buckets[None].append(p) + else: + buckets[p.get_process_group().tp_process_group()].append(p) + total_lp = 0.0 + for group, bucket in buckets.items(): + local_lp = _compute_local_lp(bucket, norm_type) + if group is not None: + local_lp_tensor = torch.tensor([local_lp], device=torch.cuda.current_device()) + if norm_type == inf: + dist.all_reduce(local_lp_tensor, op=dist.ReduceOp.MAX, group=group) + else: + dist.all_reduce(local_lp_tensor, group=group) + local_lp = local_lp_tensor.item() + if norm_type == inf: + total_lp = max(total_lp, local_lp) + else: + total_lp += local_lp + return total_lp + + +def _compute_pp_grad_lp(total_lp: float, norm_type: float) -> float: + if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1: + total_lp_tensor = torch.tensor([total_lp], device=torch.cuda.current_device()) + if norm_type == inf: + dist.all_reduce(total_lp_tensor, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.PIPELINE)) + else: + dist.all_reduce(total_lp_tensor, group=gpc.get_group(ParallelMode.PIPELINE)) + total_lp = total_lp_tensor.item() + return total_lp + + +def _compute_grad_lp(parameters, norm_type: float = 2.0) -> float: + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + grad_dtype = None + cpu_grad_params: List[ColoParameter] = [] + cuda_grad_params: List[ColoParameter] = [] + for p in parameters: + if p.grad is None: + continue + assert isinstance(p, ColoParameter) + if grad_dtype is None: + grad_dtype = p.grad.dtype + assert p.grad.dtype == grad_dtype, f'Expected all grads are {grad_dtype}, got {p.grad.dtype}' + if p.grad.device.type == 'cuda': + cuda_grad_params.append(p) + else: + cpu_grad_params.append(p) + norm_type = float(norm_type) + cpu_lp = _compute_buckets_lp(cpu_grad_params, norm_type) + cuda_lp = _compute_buckets_lp(cuda_grad_params, norm_type) + if norm_type == inf: + total_lp = max(cpu_lp, cuda_lp) + else: + total_lp = cpu_lp + cuda_lp + return _compute_pp_grad_lp(total_lp, norm_type) + + +def compute_grad_norm(parameters, norm_type: float = 2.0) -> float: + norm_type = float(norm_type) + total_norm = _compute_grad_lp(parameters, norm_type) + if norm_type != inf: + total_norm = total_norm**(1 / norm_type) + return total_norm + + +def _clip_grad_norm(parameters, max_norm: float, total_norm: float) -> None: + clip_coef = max_norm / (total_norm + 1e-6) + if clip_coef < 1.0: + cuda_grads: List[torch.Tensor] = [] + cpu_grads: List[torch.Tensor] = [] + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + for p in parameters: + if p.grad is None: + continue + if p.grad.device.type == 'cuda': + cuda_grads.append(p.grad.detach()) + else: + cpu_grads.append(p.grad.detach()) + if len(cuda_grads) > 0: + dummy_overflow_buf = torch.cuda.IntTensor([0]) + multi_tensor_applier(colossal_C.multi_tensor_scale, dummy_overflow_buf, [cuda_grads, cuda_grads], clip_coef) + for g in cpu_grads: + g.mul_(clip_coef) + + +def clip_grad_norm(parameters, max_norm: float, norm_type: float = 2.0) -> float: + total_norm = compute_grad_norm(parameters, norm_type) + _clip_grad_norm(parameters, max_norm, total_norm) + return total_norm + + def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): """Clips gradient norm of an iterable of parameters whose gradients are in fp32. diff --git a/colossalai/zero/zero_optimizer.py b/colossalai/zero/zero_optimizer.py index bc23123b0..55b4d7ee9 100644 --- a/colossalai/zero/zero_optimizer.py +++ b/colossalai/zero/zero_optimizer.py @@ -8,9 +8,11 @@ from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.utils import get_current_device, disposable +from colossalai.utils.common import _compute_grad_lp, compute_grad_norm, _clip_grad_norm from collections import defaultdict, abc as container_abcs from copy import deepcopy from itertools import chain +from torch._six import inf class OptimState(Enum): @@ -143,11 +145,38 @@ class ZeroOptimizer(ColossalaiOptimizer): self._update_fp16_params() return ret - def clip_grad_norm(self, model: torch.nn.Module, max_norm: float): + def compute_grad_norm(self, norm_type: float = 2.0) -> float: + norm_type = float(norm_type) + if not self.chunk_manager.enable_distributed_storage: + return compute_grad_norm(self.module.parameters(), norm_type) + + non_distributed_params = [] + distributed_params = [] + for p in self.module.parameters(): + if getattr(p, '_ddp_to_ignore', False): + non_distributed_params.append(p) + else: + distributed_params.append(p) + non_distributed_norm = _compute_grad_lp(non_distributed_params, norm_type) + distributed_norm_tensor = torch.tensor([_compute_grad_lp(distributed_params, norm_type)], + device=get_current_device()) + if norm_type == inf: + dist.all_reduce(distributed_norm_tensor, + op=dist.ReduceOp.MAX, + group=self.chunk_manager.process_group.dp_process_group()) + total_norm = max(non_distributed_norm, distributed_norm_tensor.item()) + else: + dist.all_reduce(distributed_norm_tensor, group=self.chunk_manager.process_group.dp_process_group()) + total_norm = non_distributed_norm + distributed_norm_tensor.item() + total_norm = total_norm**(1 / norm_type) + return total_norm + + def clip_grad_norm(self, model: torch.nn.Module, max_norm: float, norm_type: float = 2.0): if self.optim_state == OptimState.SCALED: self._unscale_grads() - # TODO(ver217): fix zero clip grad norm - return super().clip_grad_norm(model, max_norm) + total_norm = self.compute_grad_norm(norm_type) + _clip_grad_norm(self.module.parameters(), max_norm, total_norm) + return total_norm def backward(self, loss: torch.Tensor): loss = self.loss_scale * loss diff --git a/tests/test_utils/test_norm_gradient_clipping.py b/tests/test_utils/test_norm_gradient_clipping.py new file mode 100644 index 000000000..7690dbb38 --- /dev/null +++ b/tests/test_utils/test_norm_gradient_clipping.py @@ -0,0 +1,79 @@ +from colossalai.tensor import distspec, ColoTensorSpec, ProcessGroup +from colossalai.tensor.colo_parameter import ColoParameter +import colossalai +import pytest +import torch +import torch.multiprocessing as mp +from colossalai.logging import disable_existing_loggers +from colossalai.utils import free_port, get_current_device +from torch.nn.utils import clip_grad_norm_ +from functools import partial +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils.common import clip_grad_norm +from torch.nn.parameter import Parameter + + +def close(num: float, other: float, rtol: float = 1e-5, atol: float = 1e-8): + return abs(num - other) <= atol + rtol * other + + +def shard_param(p: ColoParameter) -> None: + pg = p.get_process_group() + p._redistribute(distspec.shard([0], [pg.tp_world_size()])) + p.grad = p.grad.chunk(pg.tp_world_size(), 0)[pg.tp_local_rank()].clone().detach() + + +def check_grad_equal(p: Parameter, colo_p: ColoParameter) -> None: + pg = colo_p.get_process_group() + if p.shape != colo_p.shape: + grad = p.grad.chunk(pg.tp_world_size(), 0)[pg.tp_local_rank()] + else: + grad = p.grad + assert torch.allclose(grad, colo_p.grad), f'diff: {torch.abs(grad - colo_p.grad)}' + + +@parameterize('dtype', [torch.float]) +@parameterize('device', ['mixed', 'cuda', 'cpu']) +@parameterize('norm_type', [2.0, 3.0, float('inf')]) +def run_grad_clip_norm(world_size: int, dtype: torch.dtype, device: str, norm_type: float): + print(f'{world_size}, {dtype}, {device}, {norm_type}') + cuda_device = get_current_device() + devices = [cuda_device] * 4 + if device == 'cpu': + devices = [torch.device('cpu')] * 4 + elif device == 'mixed': + devices = [cuda_device] * 2 + [torch.device('cpu')] * 2 + pg = ProcessGroup(tp_degree=world_size) + params = [Parameter(torch.empty(4, 4, dtype=dtype, device=devices[i])) for i in range(4)] + colo_params = [ + ColoParameter(torch.empty(4, 4, dtype=dtype, device=devices[i]), spec=ColoTensorSpec(pg)) for i in range(4) + ] + for p, colo_p in zip(params, colo_params): + grad = torch.rand_like(p) + p.grad = grad + colo_p.grad = grad.clone().detach() + shard_param(colo_params[0]) + shard_param(colo_params[2]) + torch_norm = clip_grad_norm_(params, 1.0, norm_type=norm_type) + colo_norm = clip_grad_norm(colo_params, 1.0, norm_type=norm_type) + assert close(torch_norm, colo_norm), f'diff: {abs(torch_norm-colo_norm)}' + for p, colo_p in zip(params, colo_params): + check_grad_equal(p, colo_p) + + +def run_dist(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_grad_clip_norm(world_size=world_size) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 2]) +@rerun_if_address_is_in_use() +def test_zero_clip_grad(world_size: int): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_zero_clip_grad(2)