mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 21:09:18 +00:00
Feature/zero (#279)
* add zero1 (#209) * add zero1 * add test zero1 * update zero stage 1 develop (#212) * Implement naive zero3 (#240) * naive zero3 works well * add zero3 param manager * add TODOs in comments * add gather full param ctx * fix sub module streams * add offload * fix bugs of hook and add unit tests * fix bugs of hook and add unit tests (#252) * add gather full param ctx * fix sub module streams * add offload * fix bugs of hook and add unit tests * polish code and add state dict hook * fix bug * update unit test * refactor reconstructed zero code * clip_grad support zero3 and add unit test * add unit test for Zero3ParameterManager * [WIP] initialize the shard param class * [WIP] Yet another sharded model implementation (#274) * [WIP] initialize the shard param class * [WIP] Yes another implementation of shardModel. Using a better hook method. * torch.concat -> torch.cat * fix test_zero_level_1.py::test_zero_level_1 unitest * remove deepspeed implementation and refactor for the reconstructed zero module * polish zero dp unittests Co-authored-by: ver217 <lhx0217@gmail.com> Co-authored-by: Frank Lee <somerlee.9@gmail.com>
This commit is contained in:
@@ -2,9 +2,12 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
import random
|
||||
import socket
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
from torch._six import inf
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
|
||||
try:
|
||||
import colossal_C
|
||||
@@ -14,7 +17,8 @@ except:
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch.distributed as dist
|
||||
from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARALLEL_ATTRIBUTES
|
||||
from colossalai.constants import (IS_TENSOR_PARALLEL, NUM_PARTITIONS,
|
||||
TENSOR_PARALLEL_ATTRIBUTES)
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.global_variables import moe_env
|
||||
@@ -134,6 +138,10 @@ def _calc_lp(grads, norm_type):
|
||||
norm += grad_norm**norm_type
|
||||
return norm
|
||||
|
||||
def _move_norm_to_cuda(norm: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]:
|
||||
if torch.is_tensor(norm) and norm.device.type != 'cuda':
|
||||
norm = norm.to(torch.cuda.current_device())
|
||||
return norm
|
||||
|
||||
# ======== Gradient Clipping =========
|
||||
|
||||
@@ -163,17 +171,27 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
|
||||
# - grad should not be none
|
||||
# - parameter should not be shared
|
||||
# - should not be a replica due to tensor model parallelism
|
||||
params = []
|
||||
params: List[Parameter] = []
|
||||
has_zero_shared_param: bool = False
|
||||
for param in parameters:
|
||||
if param.grad is not None:
|
||||
# Make sure the grads are in fp32
|
||||
assert param.grad.type() == 'torch.cuda.FloatTensor', \
|
||||
f'expected gradient to be dtype torch.cuda.FloatTensor, but got {param.grad.type()}'
|
||||
assert param.grad.dtype == torch.float, \
|
||||
f'expected gradient to be dtype torch.float, but got {param.grad.type()}'
|
||||
if hasattr(param, 'zero_is_sharded'):
|
||||
has_zero_shared_param = True
|
||||
params.append(param)
|
||||
|
||||
if len(params) == 0:
|
||||
return 0.0
|
||||
# Norm parameters.
|
||||
max_norm = float(max_norm)
|
||||
norm_type = float(norm_type)
|
||||
|
||||
# Parameters can be on CPU or CUDA
|
||||
# If parameters are on CPU, disable CUDA kernerls
|
||||
enable_cuda_kernels = params[0].grad.device.type == 'cuda'
|
||||
|
||||
# Calculate norm.
|
||||
if norm_type == inf:
|
||||
total_norm = max(p.grad.data.abs().max() for p in params)
|
||||
@@ -184,28 +202,49 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
|
||||
op=dist.ReduceOp.MAX,
|
||||
group=gpc.get_group(ParallelMode.MODEL),
|
||||
async_op=False)
|
||||
if has_zero_shared_param:
|
||||
dist.all_reduce(total_norm_cuda,
|
||||
op=dist.ReduceOp.MAX,
|
||||
group=gpc.get_group(ParallelMode.DATA),
|
||||
async_op=False)
|
||||
total_norm = total_norm_cuda[0].item()
|
||||
else:
|
||||
tensor_parallel_grads = []
|
||||
no_tensor_parallel_grads = []
|
||||
moe_parallel_grads = [] # used to collect moe tensor parallel gradients
|
||||
zero_sharded_grads = []
|
||||
for p in params:
|
||||
if is_model_parallel_parameter(p):
|
||||
reductor = (gpc.get_world_size(ParallelMode.TENSOR) / getattr(p, NUM_PARTITIONS))**(1 / norm_type)
|
||||
tensor_parallel_grads.append(p.grad.data / reductor)
|
||||
elif is_moe_parallel_parameter(p):
|
||||
moe_parallel_grads.append(p.grad.data)
|
||||
elif hasattr(p, 'zero_is_sharded'):
|
||||
zero_sharded_grads.append(p.grad.data)
|
||||
else:
|
||||
no_tensor_parallel_grads.append(p.grad.data)
|
||||
|
||||
if norm_type == 2.0:
|
||||
tensor_parallel_norm = _calc_l2_norm(tensor_parallel_grads)**norm_type
|
||||
no_tensor_parallel_norm = _calc_l2_norm(no_tensor_parallel_grads)**norm_type
|
||||
moe_parallel_norm = _calc_l2_norm(moe_parallel_grads)**norm_type
|
||||
if norm_type == 2.0 and enable_cuda_kernels:
|
||||
tensor_parallel_norm = _calc_l2_norm(
|
||||
tensor_parallel_grads) ** norm_type
|
||||
no_tensor_parallel_norm = _calc_l2_norm(
|
||||
no_tensor_parallel_grads) ** norm_type
|
||||
moe_parallel_norm = _calc_l2_norm(
|
||||
moe_parallel_grads) ** norm_type
|
||||
zero_sharded_norm = _calc_l2_norm(zero_sharded_grads) ** norm_type
|
||||
else:
|
||||
tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type)
|
||||
no_tensor_parallel_norm = _calc_lp(no_tensor_parallel_grads, norm_type)
|
||||
moe_parallel_norm = _calc_lp(moe_parallel_grads, norm_type)
|
||||
zero_sharded_norm = _calc_lp(zero_sharded_grads, norm_type)
|
||||
|
||||
# If grads are on CPU, the norms is also on CPU. Cast them to CUDA tensors
|
||||
if not enable_cuda_kernels:
|
||||
tensor_parallel_norm = _move_norm_to_cuda(tensor_parallel_norm)
|
||||
no_tensor_parallel_norm = _move_norm_to_cuda(no_tensor_parallel_norm)
|
||||
moe_parallel_norm = _move_norm_to_cuda(moe_parallel_norm)
|
||||
zero_sharded_norm = _move_norm_to_cuda(zero_sharded_norm)
|
||||
|
||||
# Sum across all model-parallel GPUs.
|
||||
if gpc.is_initialized(ParallelMode.TENSOR) and len(tensor_parallel_grads) > 0:
|
||||
dist.all_reduce(tensor_parallel_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR))
|
||||
@@ -213,20 +252,32 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
|
||||
if len(moe_parallel_grads) > 0:
|
||||
dist.all_reduce(moe_parallel_norm, group=gpc.get_group(ParallelMode.MOE_MODEL))
|
||||
no_tensor_parallel_norm += moe_parallel_norm
|
||||
# Sum across all zero sharded GPUs
|
||||
if len(zero_sharded_grads) > 0:
|
||||
dist.all_reduce(zero_sharded_norm, group=gpc.get_group(ParallelMode.DATA))
|
||||
no_tensor_parallel_norm += zero_sharded_norm
|
||||
total_norm = tensor_parallel_norm + no_tensor_parallel_norm
|
||||
if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
|
||||
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.PIPELINE))
|
||||
total_norm = total_norm**(1.0 / norm_type)
|
||||
if type(total_norm) == 'torch.cuda.FloatTensor':
|
||||
dist.all_reduce(total_norm,
|
||||
op=dist.ReduceOp.SUM,
|
||||
group=gpc.get_group(ParallelMode.PIPELINE))
|
||||
total_norm = total_norm ** (1.0 / norm_type)
|
||||
if torch.is_tensor(total_norm):
|
||||
total_norm = total_norm.item()
|
||||
|
||||
# Scale.
|
||||
clip_coeff = max_norm / (total_norm + 1.0e-6)
|
||||
if clip_coeff < 1.0:
|
||||
grads = [p.grad.detach() for p in params]
|
||||
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
multi_tensor_applier(colossal_C.multi_tensor_scale, dummy_overflow_buf, [grads, grads], clip_coeff)
|
||||
|
||||
if enable_cuda_kernels:
|
||||
grads = [p.grad.detach() for p in params]
|
||||
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
multi_tensor_applier(colossal_C.multi_tensor_scale,
|
||||
dummy_overflow_buf,
|
||||
[grads, grads],
|
||||
clip_coeff)
|
||||
else:
|
||||
for p in params:
|
||||
p.grad.detach().mul_(clip_coeff)
|
||||
return total_norm
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user