mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 21:09:18 +00:00
[MOE] add unitest for MOE experts layout, gradient handler and kernel (#469)
This commit is contained in:
@@ -8,7 +8,6 @@ import torch
|
||||
from torch._six import inf
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
|
||||
try:
|
||||
import colossal_C
|
||||
except:
|
||||
@@ -17,11 +16,9 @@ except:
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch.distributed as dist
|
||||
from colossalai.constants import (IS_TENSOR_PARALLEL, NUM_PARTITIONS,
|
||||
TENSOR_PARALLEL_ATTRIBUTES)
|
||||
from colossalai.constants import (IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARALLEL_ATTRIBUTES)
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.global_variables import moe_env
|
||||
from colossalai.global_variables import tensor_parallel_env as env
|
||||
|
||||
from .multi_tensor_apply import multi_tensor_applier
|
||||
@@ -116,7 +113,10 @@ def is_model_parallel_parameter(p):
|
||||
|
||||
|
||||
def is_moe_parallel_parameter(p):
|
||||
return hasattr(p, 'moe_param') and moe_env.data_parallel_size > 1
|
||||
# FIXME(HHC): clip_grad need to changed to adapted for MoE
|
||||
# This return value must set to False, otherwise it will raise
|
||||
# an error in training
|
||||
return False
|
||||
|
||||
|
||||
def _calc_l2_norm(grads):
|
||||
@@ -127,7 +127,7 @@ def _calc_l2_norm(grads):
|
||||
colossal_C.multi_tensor_l2norm,
|
||||
dummy_overflow_buf,
|
||||
[grads],
|
||||
False # no per-parameter norm
|
||||
False # no per-parameter norm
|
||||
)
|
||||
return norm
|
||||
|
||||
@@ -139,11 +139,13 @@ def _calc_lp(grads, norm_type):
|
||||
norm += grad_norm**norm_type
|
||||
return norm
|
||||
|
||||
|
||||
def _move_norm_to_cuda(norm: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]:
|
||||
if torch.is_tensor(norm) and norm.device.type != 'cuda':
|
||||
norm = norm.to(torch.cuda.current_device())
|
||||
return norm
|
||||
|
||||
|
||||
# ======== Gradient Clipping =========
|
||||
|
||||
|
||||
@@ -212,7 +214,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
|
||||
else:
|
||||
tensor_parallel_grads = []
|
||||
no_tensor_parallel_grads = []
|
||||
moe_parallel_grads = [] # used to collect moe tensor parallel gradients
|
||||
moe_parallel_grads = [] # used to collect moe tensor parallel gradients
|
||||
zero_sharded_grads = []
|
||||
for p in params:
|
||||
if is_model_parallel_parameter(p):
|
||||
@@ -226,13 +228,10 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
|
||||
no_tensor_parallel_grads.append(p.grad.data)
|
||||
|
||||
if norm_type == 2.0 and enable_cuda_kernels:
|
||||
tensor_parallel_norm = _calc_l2_norm(
|
||||
tensor_parallel_grads) ** norm_type
|
||||
no_tensor_parallel_norm = _calc_l2_norm(
|
||||
no_tensor_parallel_grads) ** norm_type
|
||||
moe_parallel_norm = _calc_l2_norm(
|
||||
moe_parallel_grads) ** norm_type
|
||||
zero_sharded_norm = _calc_l2_norm(zero_sharded_grads) ** norm_type
|
||||
tensor_parallel_norm = _calc_l2_norm(tensor_parallel_grads)**norm_type
|
||||
no_tensor_parallel_norm = _calc_l2_norm(no_tensor_parallel_grads)**norm_type
|
||||
moe_parallel_norm = _calc_l2_norm(moe_parallel_grads)**norm_type
|
||||
zero_sharded_norm = _calc_l2_norm(zero_sharded_grads)**norm_type
|
||||
else:
|
||||
tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type)
|
||||
no_tensor_parallel_norm = _calc_lp(no_tensor_parallel_grads, norm_type)
|
||||
@@ -259,10 +258,8 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
|
||||
no_tensor_parallel_norm += zero_sharded_norm
|
||||
total_norm = tensor_parallel_norm + no_tensor_parallel_norm
|
||||
if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
|
||||
dist.all_reduce(total_norm,
|
||||
op=dist.ReduceOp.SUM,
|
||||
group=gpc.get_group(ParallelMode.PIPELINE))
|
||||
total_norm = total_norm ** (1.0 / norm_type)
|
||||
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.PIPELINE))
|
||||
total_norm = total_norm**(1.0 / norm_type)
|
||||
if torch.is_tensor(total_norm):
|
||||
total_norm = total_norm.item()
|
||||
|
||||
@@ -272,10 +269,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
|
||||
if enable_cuda_kernels:
|
||||
grads = [p.grad.detach() for p in params]
|
||||
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
multi_tensor_applier(colossal_C.multi_tensor_scale,
|
||||
dummy_overflow_buf,
|
||||
[grads, grads],
|
||||
clip_coeff)
|
||||
multi_tensor_applier(colossal_C.multi_tensor_scale, dummy_overflow_buf, [grads, grads], clip_coeff)
|
||||
else:
|
||||
for p in params:
|
||||
p.grad.detach().mul_(clip_coeff)
|
||||
|
Reference in New Issue
Block a user