mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[moe] implement tp
This commit is contained in:
@@ -443,7 +443,7 @@ def all_to_all_uneven(
|
||||
|
||||
|
||||
# ===========================================================
|
||||
# This code section was modified from
|
||||
# This code section was modified from
|
||||
# https://github.com/microsoft/DeepSpeed/blob/3d347276ce80e1a29e777c839d1d7fabe8e5f034/deepspeed/moe/mappings.py
|
||||
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
@@ -492,8 +492,9 @@ def _drop_tokens(input_, dim: int, tp_group: ProcessGroup):
|
||||
|
||||
total_chunks = tp_group.size()
|
||||
this_chunk = tp_group.rank()
|
||||
assert input_.shape[
|
||||
dim] % total_chunks == 0, f"input dimension {dim} ({input_.shape[dim]}) is not divisible by tensor parallel world size ({total_chunks})"
|
||||
assert (
|
||||
input_.shape[dim] % total_chunks == 0
|
||||
), f"input dimension {dim} ({input_.shape[dim]}) is not divisible by tensor parallel world size ({total_chunks})"
|
||||
chunk_size = input_.shape[dim] // total_chunks
|
||||
|
||||
return torch.narrow(input_, dim, this_chunk * chunk_size, chunk_size)
|
||||
@@ -531,15 +532,20 @@ def gather_tokens(input_, dim: int, tp_group: ProcessGroup):
|
||||
if tp_group.size() == 1:
|
||||
# no tensor parallelism for non-experts
|
||||
return input_
|
||||
assert input_.requires_grad, "Input must require grad to assure that backward is executed, otherwise it might hang the program."
|
||||
return _GatherTokens.apply(input_, dim)
|
||||
assert (
|
||||
input_.requires_grad
|
||||
), "Input must require grad to assure that backward is executed, otherwise it might hang the program."
|
||||
return _GatherTokens.apply(input_, dim, tp_group)
|
||||
|
||||
|
||||
def drop_tokens(input_, dim: int, tp_group: ProcessGroup):
|
||||
if tp_group.size() == 1:
|
||||
# no tensor parallelism for non-experts
|
||||
return input_
|
||||
assert input_.requires_grad, "Input must require grad to assure that backward is executed, otherwise it might hang the program."
|
||||
assert (
|
||||
input_.requires_grad
|
||||
), "Input must require grad to assure that backward is executed, otherwise it might hang the program."
|
||||
return _DropTokens.apply(input_, dim, tp_group)
|
||||
|
||||
|
||||
# ===========================================================
|
||||
|
Reference in New Issue
Block a user