Added TPExpert for special situation

This commit is contained in:
1SAA
2022-02-27 22:28:39 +08:00
committed by Frank Lee
parent 36b8477228
commit 82023779bb
7 changed files with 192 additions and 41 deletions

View File

@@ -1,6 +1,4 @@
import os
from functools import partial
from pathlib import Path
import pytest
import torch
import torch.nn as nn
@@ -9,10 +7,10 @@ import colossalai
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import free_port, get_current_device
from colossalai.nn.layer.moe import Top2Router, MoeLayer
from colossalai.nn.layer.moe import Top2Router, MoeLayer, Experts
from colossalai.context.random import moe_set_seed
from colossalai.global_variables import moe_env
BATCH_SIZE = 32
NUM_EXPERTS = 4
CONFIG = dict(parallel=dict(moe=dict(size=4)))
@@ -24,17 +22,17 @@ def check_equal(A, B, atol=1e-06):
def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.float32):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
moe_set_seed(42)
# torch.set_printoptions(precision=30)
torch.backends.cuda.matmul.allow_tf32 = False
local_rank = gpc.get_local_rank(ParallelMode.GLOBAL)
torch.manual_seed(rs + local_rank)
moe_env.reset_loss()
tokens = torch.randn(BATCH_SIZE, hidden_size,
dtype=data_type, device=get_current_device(), requires_grad=True)
tokens = torch.randn(BATCH_SIZE, hidden_size, dtype=data_type, device=get_current_device(), requires_grad=True)
# print(f"tokens:\n{tokens}")
router = Top2Router(1)
layer = MoeLayer(hidden_size, NUM_EXPERTS, router, nn.Identity())
expert = Experts(nn.Identity, 4)
layer = MoeLayer(hidden_size, NUM_EXPERTS, router, expert)
if data_type == torch.float16:
layer = layer.half()
layer.cuda_mode = False
@@ -88,8 +86,12 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
@pytest.mark.parametrize("data_type", [torch.float32, torch.float16])
def test_moe_top2(rs, hidden_size, data_type):
world_size = 4
run_func = partial(run_routing, world_size=world_size, port=free_port(),
rs=rs, hidden_size=hidden_size, data_type=data_type)
run_func = partial(run_routing,
world_size=world_size,
port=free_port(),
rs=rs,
hidden_size=hidden_size,
data_type=data_type)
mp.spawn(run_func, nprocs=world_size)