Added MoE parallel (#127)

This commit is contained in:
HELSON
2022-01-07 15:08:36 +08:00
committed by GitHub
parent 42741dd4a3
commit dceae85195
26 changed files with 858 additions and 18 deletions

View File

@@ -17,6 +17,7 @@ import torch.distributed as dist
from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARALLEL_ATTRIBUTES
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.global_variables import moe_env
from .multi_tensor_apply import multi_tensor_applier
@@ -91,6 +92,10 @@ def is_model_parallel_parameter(p):
return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL)
def is_moe_parallel_parameter(p):
return hasattr(p, 'moe_param') and moe_env.data_parallel_size > 1
def _calc_l2_norm(grads):
norm = 0.0
if len(grads) > 0:
@@ -165,26 +170,37 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
else:
tensor_parallel_grads = []
no_tensor_parallel_grads = []
moe_parallel_grads = [] # used to collect moe tensor parallel gradients
for p in params:
if is_model_parallel_parameter(p):
reductor = (gpc.get_world_size(ParallelMode.TENSOR) / getattr(p, NUM_PARTITIONS)) ** (1 / norm_type)
tensor_parallel_grads.append(p.grad.data / reductor)
elif is_moe_parallel_parameter(p):
moe_parallel_grads.append(p.grad.data)
else:
no_tensor_parallel_grads.append(p.grad.data)
if norm_type == 2.0:
tensor_parallel_norm = _calc_l2_norm(
tensor_parallel_grads) ** norm_type
no_tensor_parallel_norm = _calc_l2_norm(
no_tensor_parallel_grads) ** norm_type
moe_parallel_norm = _calc_l2_norm(
moe_parallel_grads) ** norm_type
else:
tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type)
no_tensor_parallel_grads = _calc_lp(
no_tensor_parallel_norm = _calc_lp(
no_tensor_parallel_grads, norm_type)
moe_parallel_norm = _calc_lp(moe_parallel_grads, norm_type)
# Sum across all model-parallel GPUs.
if gpc.is_initialized(ParallelMode.TENSOR) and len(tensor_parallel_grads) > 0:
dist.all_reduce(tensor_parallel_norm,
op=dist.ReduceOp.SUM,
group=gpc.get_group(ParallelMode.TENSOR))
# Sum across all moe-tensor-parallel GPUs
if len(moe_parallel_grads) > 0:
dist.all_reduce(moe_parallel_norm, group=gpc.get_group(ParallelMode.MOE_MODEL))
no_tensor_parallel_norm += moe_parallel_norm
total_norm = tensor_parallel_norm + no_tensor_parallel_norm
if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
dist.all_reduce(total_norm,