[MOE] add unitest for MOE experts layout, gradient handler and kernel (#469)

This commit is contained in:
HELSON
2022-03-21 13:35:04 +08:00
committed by GitHub
parent 1559c0df41
commit 7544347145
13 changed files with 263 additions and 499 deletions

View File

@@ -8,7 +8,6 @@ import torch
from torch._six import inf
from torch.nn.parameter import Parameter
try:
import colossal_C
except:
@@ -17,11 +16,9 @@ except:
from contextlib import contextmanager
import torch.distributed as dist
from colossalai.constants import (IS_TENSOR_PARALLEL, NUM_PARTITIONS,
TENSOR_PARALLEL_ATTRIBUTES)
from colossalai.constants import (IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARALLEL_ATTRIBUTES)
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.global_variables import moe_env
from colossalai.global_variables import tensor_parallel_env as env
from .multi_tensor_apply import multi_tensor_applier
@@ -116,7 +113,10 @@ def is_model_parallel_parameter(p):
def is_moe_parallel_parameter(p):
return hasattr(p, 'moe_param') and moe_env.data_parallel_size > 1
# FIXME(HHC): clip_grad need to changed to adapted for MoE
# This return value must set to False, otherwise it will raise
# an error in training
return False
def _calc_l2_norm(grads):
@@ -127,7 +127,7 @@ def _calc_l2_norm(grads):
colossal_C.multi_tensor_l2norm,
dummy_overflow_buf,
[grads],
False # no per-parameter norm
False # no per-parameter norm
)
return norm
@@ -139,11 +139,13 @@ def _calc_lp(grads, norm_type):
norm += grad_norm**norm_type
return norm
def _move_norm_to_cuda(norm: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]:
if torch.is_tensor(norm) and norm.device.type != 'cuda':
norm = norm.to(torch.cuda.current_device())
return norm
# ======== Gradient Clipping =========
@@ -212,7 +214,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
else:
tensor_parallel_grads = []
no_tensor_parallel_grads = []
moe_parallel_grads = [] # used to collect moe tensor parallel gradients
moe_parallel_grads = [] # used to collect moe tensor parallel gradients
zero_sharded_grads = []
for p in params:
if is_model_parallel_parameter(p):
@@ -226,13 +228,10 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
no_tensor_parallel_grads.append(p.grad.data)
if norm_type == 2.0 and enable_cuda_kernels:
tensor_parallel_norm = _calc_l2_norm(
tensor_parallel_grads) ** norm_type
no_tensor_parallel_norm = _calc_l2_norm(
no_tensor_parallel_grads) ** norm_type
moe_parallel_norm = _calc_l2_norm(
moe_parallel_grads) ** norm_type
zero_sharded_norm = _calc_l2_norm(zero_sharded_grads) ** norm_type
tensor_parallel_norm = _calc_l2_norm(tensor_parallel_grads)**norm_type
no_tensor_parallel_norm = _calc_l2_norm(no_tensor_parallel_grads)**norm_type
moe_parallel_norm = _calc_l2_norm(moe_parallel_grads)**norm_type
zero_sharded_norm = _calc_l2_norm(zero_sharded_grads)**norm_type
else:
tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type)
no_tensor_parallel_norm = _calc_lp(no_tensor_parallel_grads, norm_type)
@@ -259,10 +258,8 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
no_tensor_parallel_norm += zero_sharded_norm
total_norm = tensor_parallel_norm + no_tensor_parallel_norm
if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
dist.all_reduce(total_norm,
op=dist.ReduceOp.SUM,
group=gpc.get_group(ParallelMode.PIPELINE))
total_norm = total_norm ** (1.0 / norm_type)
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.PIPELINE))
total_norm = total_norm**(1.0 / norm_type)
if torch.is_tensor(total_norm):
total_norm = total_norm.item()
@@ -272,10 +269,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
if enable_cuda_kernels:
grads = [p.grad.detach() for p in params]
dummy_overflow_buf = torch.cuda.IntTensor([0])
multi_tensor_applier(colossal_C.multi_tensor_scale,
dummy_overflow_buf,
[grads, grads],
clip_coeff)
multi_tensor_applier(colossal_C.multi_tensor_scale, dummy_overflow_buf, [grads, grads], clip_coeff)
else:
for p in params:
p.grad.detach().mul_(clip_coeff)