[moe] implement tp

This commit is contained in:
botbw
2024-07-16 06:03:57 +00:00
committed by Hongxin Liu
parent 0b5bbe9ce4
commit dc583aa576
8 changed files with 79 additions and 40 deletions

View File

@@ -443,7 +443,7 @@ def all_to_all_uneven(
# ===========================================================
# This code section was modified from
# This code section was modified from
# https://github.com/microsoft/DeepSpeed/blob/3d347276ce80e1a29e777c839d1d7fabe8e5f034/deepspeed/moe/mappings.py
# Copyright (c) Microsoft Corporation.
@@ -492,8 +492,9 @@ def _drop_tokens(input_, dim: int, tp_group: ProcessGroup):
total_chunks = tp_group.size()
this_chunk = tp_group.rank()
assert input_.shape[
dim] % total_chunks == 0, f"input dimension {dim} ({input_.shape[dim]}) is not divisible by tensor parallel world size ({total_chunks})"
assert (
input_.shape[dim] % total_chunks == 0
), f"input dimension {dim} ({input_.shape[dim]}) is not divisible by tensor parallel world size ({total_chunks})"
chunk_size = input_.shape[dim] // total_chunks
return torch.narrow(input_, dim, this_chunk * chunk_size, chunk_size)
@@ -531,15 +532,20 @@ def gather_tokens(input_, dim: int, tp_group: ProcessGroup):
if tp_group.size() == 1:
# no tensor parallelism for non-experts
return input_
assert input_.requires_grad, "Input must require grad to assure that backward is executed, otherwise it might hang the program."
return _GatherTokens.apply(input_, dim)
assert (
input_.requires_grad
), "Input must require grad to assure that backward is executed, otherwise it might hang the program."
return _GatherTokens.apply(input_, dim, tp_group)
def drop_tokens(input_, dim: int, tp_group: ProcessGroup):
if tp_group.size() == 1:
# no tensor parallelism for non-experts
return input_
assert input_.requires_grad, "Input must require grad to assure that backward is executed, otherwise it might hang the program."
assert (
input_.requires_grad
), "Input must require grad to assure that backward is executed, otherwise it might hang the program."
return _DropTokens.apply(input_, dim, tp_group)
# ===========================================================