[moe] fix moe bugs (#1633)

This commit is contained in:
HELSON
2022-09-23 15:33:57 +08:00
committed by GitHub
parent 702dbc5288
commit a088022efc
8 changed files with 287 additions and 249 deletions

View File

@@ -32,7 +32,7 @@ def run_test(rank, world_size, port):
moe_layer = MoeLayer(DIM, num_experts, router, exp)
layer_list.append(moe_layer)
model = nn.Sequential(*layer_list)
model = nn.ModuleList(layer_list)
model = model.to(get_current_device())
sync_moe_model_param(model)
@@ -49,8 +49,9 @@ def run_test(rank, world_size, port):
grad = torch.randn_like(data)
MOE_CONTEXT.reset_loss()
outputs = model(data)
outputs.backward(grad)
for layer in layer_list:
data, _ = layer(data)
data.backward(grad)
grad_handler.handle_gradient()
assert_equal_in_group(layer_list[0].experts.experts[0].weight.grad, dist_dict[1].dp_group)

View File

@@ -44,7 +44,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
# use matrix multiplication instead of COL_MOE_KERNL in MOE dispatch and combine
layer.use_kernel = False
old_out = layer(tokens)
old_out, _ = layer(tokens)
ech = old_out.shape
grad = torch.randn(ech, device=get_current_device())
old_out.backward(grad) # get gradient
@@ -58,7 +58,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
layer.gate_weight.grad.zero_()
layer.use_kernel = True
new_out = layer(tokens) # get ouputs through colossal kernel
new_out, _ = layer(tokens) # get ouputs through colossal kernel
if data_type == torch.float32:
check_equal(old_out, new_out)

View File

@@ -19,20 +19,39 @@ from colossalai.utils import get_current_device
from tests.test_zero.common import CONFIG
class MoeModel(CheckpointModule):
class MoeModel(nn.Module):
def __init__(self, checkpoint: bool = False):
super().__init__(checkpoint)
self.proj1 = nn.Linear(4, 16)
expert_cls = nn.Linear
expert_args_dict = dict(in_features=16, out_features=16)
self.moe = MoeModule(dim_model=16, num_experts=8, use_residual=True, expert_cls=expert_cls, **expert_args_dict)
self.proj2 = nn.Linear(16, 4)
class TestSubModule(CheckpointModule):
def __init__(self):
super().__init__(checkpoint)
expert_cls = nn.Linear
expert_args_dict = dict(in_features=16, out_features=16)
self.moe = MoeModule(dim_model=16,
num_experts=8,
use_residual=True,
expert_cls=expert_cls,
**expert_args_dict)
self.proj = nn.Linear(16, 4)
def _forward(self, x):
x, y = self.moe(x)
x = self.proj(x)
return x, y
super().__init__()
self.test_embed = nn.Linear(4, 16)
self.test_transform = TestSubModule()
def forward(self, x):
x = self.proj1(x)
x = self.moe(x)
x = self.proj2(x)
MOE_CONTEXT.reset_loss()
x = self.test_embed(x)
x, y = self.test_transform(x)
MOE_CONTEXT.add_loss(y)
return x

View File

@@ -4,6 +4,8 @@ import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.nn import MoeLoss
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext
@@ -26,7 +28,8 @@ def run_model_test(enable_autocast, shard_strategy_class):
shard_strategy = shard_strategy_class()
get_components_func = non_distributed_component_funcs.get_callable('no_leaf_module')
_, train_dataloader, _, _, criterion = get_components_func()
_, train_dataloader, _, optimizer_class, _ = get_components_func()
criterion = MoeLoss(aux_weight=0.01, loss_fn=torch.nn.CrossEntropyLoss)
with ZeroInitContext(target_device=torch.device('cuda', torch.cuda.current_device()),
shard_strategy=shard_strategy,
@@ -59,7 +62,6 @@ def run_model_test(enable_autocast, shard_strategy_class):
def run_dist(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
MOE_CONTEXT.setup(seed=42)
MOE_CONTEXT.reset_loss()
run_model_test()

View File

@@ -5,6 +5,7 @@ import pytest
import torch
import torch.multiprocessing as mp
from colossalai.amp import convert_to_apex_amp
from colossalai.nn import MoeLoss
from colossalai.nn.optimizer import CPUAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
@@ -60,7 +61,8 @@ def _run_test_sharded_optim_v2(cpu_offload,
return
MOE_CONTEXT.reset_loss()
get_components_func = non_distributed_component_funcs.get_callable('no_leaf_module')
_, train_dataloader, _, optimizer_class, criterion = get_components_func()
_, train_dataloader, _, optimizer_class, _ = get_components_func()
criterion = MoeLoss(aux_weight=0.01, loss_fn=torch.nn.CrossEntropyLoss)
with ZeroInitContext(target_device=torch.device('cpu') if cpu_offload else get_current_device(),
shard_strategy=shard_strategy,