mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 03:20:52 +00:00
[moe] merge moe into main (#4978)
* update moe module * support openmoe
This commit is contained in:
275
colossalai/moe/_operation.py
Normal file
275
colossalai/moe/_operation.py
Normal file
@@ -0,0 +1,275 @@
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
|
||||
MOE_KERNEL = None
|
||||
|
||||
|
||||
def load_moe():
|
||||
global MOE_KERNEL
|
||||
from colossalai.kernel.op_builder import MOEBuilder
|
||||
|
||||
MOE_KERNEL = MOEBuilder().load()
|
||||
|
||||
|
||||
class AllGather(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx: Any,
|
||||
inputs: Tensor,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
overlap: bool = False,
|
||||
) -> Tuple[Tensor, Any]:
|
||||
"""
|
||||
Returns:
|
||||
outputs: Tensor
|
||||
handle: Optional[Work], if overlap is True
|
||||
"""
|
||||
assert ctx is not None or not overlap
|
||||
|
||||
if ctx is not None:
|
||||
ctx.comm_grp = group
|
||||
|
||||
comm_size = dist.get_world_size(group)
|
||||
if comm_size == 1:
|
||||
return inputs.unsqueeze(0), None
|
||||
|
||||
buffer_shape = (comm_size,) + inputs.shape
|
||||
outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device)
|
||||
buffer_list = list(torch.chunk(outputs, comm_size, dim=0))
|
||||
if not overlap:
|
||||
dist.all_gather(buffer_list, inputs, group=group)
|
||||
return outputs, None
|
||||
else:
|
||||
handle = dist.all_gather(buffer_list, inputs, group=group, async_op=True)
|
||||
return outputs, handle
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
|
||||
return (
|
||||
ReduceScatter.forward(None, grad_outputs[0], ctx.comm_grp, False)[0],
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
class ReduceScatter(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx: Any,
|
||||
inputs: Tensor,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
overlap: bool = False,
|
||||
) -> Tuple[Tensor, Any]:
|
||||
"""
|
||||
Returns:
|
||||
outputs: Tensor
|
||||
handle: Optional[Work], if overlap is True
|
||||
"""
|
||||
assert ctx is not None or not overlap
|
||||
|
||||
if ctx is not None:
|
||||
ctx.comm_grp = group
|
||||
|
||||
comm_size = dist.get_world_size(group)
|
||||
if comm_size == 1:
|
||||
return inputs.squeeze(0), None
|
||||
|
||||
if not inputs.is_contiguous():
|
||||
inputs = inputs.contiguous()
|
||||
|
||||
output_shape = inputs.shape[1:]
|
||||
outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device)
|
||||
buffer_list = list(torch.chunk(inputs, comm_size, dim=0))
|
||||
if not overlap:
|
||||
dist.reduce_scatter(outputs, buffer_list, group=group)
|
||||
return outputs, None
|
||||
else:
|
||||
handle = dist.reduce_scatter(outputs, buffer_list, group=group, async_op=True)
|
||||
return outputs, handle
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
|
||||
# TODO: support async backward
|
||||
return (
|
||||
AllGather.forward(None, grad_outputs[0], ctx.comm_grp, False)[0],
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
class AllToAll(torch.autograd.Function):
|
||||
"""Dispatches input tensor [e, c, h] to all experts by all_to_all_single
|
||||
operation in torch.distributed.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx: Any,
|
||||
inputs: Tensor,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
overlap: bool = False,
|
||||
) -> Tuple[Tensor, Any]:
|
||||
"""
|
||||
Returns:
|
||||
outputs: Tensor
|
||||
handle: Optional[Work], if overlap is True
|
||||
"""
|
||||
if ctx is not None:
|
||||
ctx.comm_grp = group
|
||||
if not inputs.is_contiguous():
|
||||
inputs = inputs.contiguous()
|
||||
if dist.get_world_size(group) == 1:
|
||||
return inputs, None
|
||||
output = torch.empty_like(inputs)
|
||||
if not overlap:
|
||||
dist.all_to_all_single(output, inputs, group=group)
|
||||
return output, None
|
||||
else:
|
||||
handle = dist.all_to_all_single(output, inputs, group=group, async_op=True)
|
||||
return output, handle
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
|
||||
return (
|
||||
AllToAll.forward(None, grad_outputs[0], ctx.comm_grp)[0],
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
class MoeDispatch(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx, tokens, mask, dest_idx, ec):
|
||||
s = tokens.size(0)
|
||||
h = tokens.size(1)
|
||||
dtype = tokens.dtype
|
||||
|
||||
if MOE_KERNEL is None:
|
||||
load_moe()
|
||||
if tokens.dtype != torch.float32:
|
||||
tokens = tokens.to(torch.float32)
|
||||
expert_input = MOE_KERNEL.dispatch_forward(s, ec, h, tokens, mask, dest_idx)
|
||||
if expert_input.dtype != dtype:
|
||||
expert_input = expert_input.to(dtype)
|
||||
ctx.save_for_backward(mask, dest_idx)
|
||||
ctx.s = s
|
||||
ctx.h = h
|
||||
ctx.ec = ec
|
||||
ctx.dtype = dtype
|
||||
|
||||
return expert_input
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, output_grad):
|
||||
mask, dest_idx = ctx.saved_tensors
|
||||
if output_grad.dtype != torch.float32:
|
||||
output_grad = output_grad.to(torch.float32)
|
||||
d_tokens = MOE_KERNEL.dispatch_backward(ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx)
|
||||
if d_tokens.dtype != ctx.dtype:
|
||||
d_tokens = d_tokens.to(ctx.dtype)
|
||||
return d_tokens, None, None, None
|
||||
|
||||
|
||||
class MoeCombine(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx, expert_tokens, logits, mask, dest_idx, ec):
|
||||
assert logits.dtype == torch.float32
|
||||
|
||||
s = logits.size(0)
|
||||
e = logits.size(1)
|
||||
c = ec // e
|
||||
h = expert_tokens.size(-1)
|
||||
dtype = expert_tokens.dtype
|
||||
|
||||
if expert_tokens.dtype != torch.float32:
|
||||
expert_tokens = expert_tokens.to(torch.float32)
|
||||
if MOE_KERNEL is None:
|
||||
load_moe()
|
||||
output = MOE_KERNEL.combine_forward(s, e, c, h, expert_tokens, logits, mask, dest_idx)
|
||||
if output.dtype != dtype:
|
||||
output = output.to(dtype)
|
||||
|
||||
ctx.save_for_backward(expert_tokens, logits, mask, dest_idx)
|
||||
ctx.s = s
|
||||
ctx.e = e
|
||||
ctx.c = c
|
||||
ctx.h = h
|
||||
ctx.dtype = dtype
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, tokens_grad):
|
||||
expert_tokens, logits, mask, dest_idx = ctx.saved_tensors
|
||||
if tokens_grad.dtype != torch.float32:
|
||||
tokens_grad = tokens_grad.to(torch.float32)
|
||||
|
||||
d_expert, d_logits = MOE_KERNEL.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, tokens_grad, expert_tokens, logits,
|
||||
mask, dest_idx)
|
||||
if d_expert.dtype != ctx.dtype:
|
||||
d_expert = d_expert.to(ctx.dtype)
|
||||
|
||||
return d_expert, d_logits, None, None, None
|
||||
|
||||
|
||||
def moe_cumsum(inputs: Tensor, use_kernel: bool = False):
|
||||
dim0 = inputs.size(0)
|
||||
flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0)
|
||||
if flag and use_kernel:
|
||||
if MOE_KERNEL is None:
|
||||
load_moe()
|
||||
return MOE_KERNEL.cumsum_sub_one(inputs)
|
||||
else:
|
||||
return torch.cumsum(inputs, dim=0) - 1
|
||||
|
||||
|
||||
class MoeInGradScaler(torch.autograd.Function):
|
||||
"""
|
||||
Scale the gradient back by the number of experts
|
||||
because the batch size increases in the moe stage
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx: Any, inputs: Tensor, ep_size: int) -> Tensor:
|
||||
if ctx is not None:
|
||||
ctx.ep_size = ep_size
|
||||
return inputs
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]:
|
||||
assert len(grad_outputs) == 1
|
||||
grad = grad_outputs[0]
|
||||
if ctx.ep_size != 1:
|
||||
grad = grad * ctx.ep_size
|
||||
return grad, None
|
||||
|
||||
|
||||
class MoeOutGradScaler(torch.autograd.Function):
|
||||
"""
|
||||
Scale the gradient by the number of experts
|
||||
because the batch size increases in the moe stage
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx: Any, inputs: Tensor, ep_size: int) -> Tensor:
|
||||
ctx.ep_size = ep_size
|
||||
return inputs
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]:
|
||||
assert len(grad_outputs) == 1
|
||||
grad = grad_outputs[0]
|
||||
if ctx.ep_size != 1:
|
||||
grad = grad / ctx.ep_size
|
||||
return grad, None
|
Reference in New Issue
Block a user