mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 12:47:21 +00:00
[moe] fix moe bugs (#1633)
This commit is contained in:
@@ -32,7 +32,7 @@ def run_test(rank, world_size, port):
|
||||
moe_layer = MoeLayer(DIM, num_experts, router, exp)
|
||||
layer_list.append(moe_layer)
|
||||
|
||||
model = nn.Sequential(*layer_list)
|
||||
model = nn.ModuleList(layer_list)
|
||||
model = model.to(get_current_device())
|
||||
sync_moe_model_param(model)
|
||||
|
||||
@@ -49,8 +49,9 @@ def run_test(rank, world_size, port):
|
||||
grad = torch.randn_like(data)
|
||||
|
||||
MOE_CONTEXT.reset_loss()
|
||||
outputs = model(data)
|
||||
outputs.backward(grad)
|
||||
for layer in layer_list:
|
||||
data, _ = layer(data)
|
||||
data.backward(grad)
|
||||
grad_handler.handle_gradient()
|
||||
|
||||
assert_equal_in_group(layer_list[0].experts.experts[0].weight.grad, dist_dict[1].dp_group)
|
||||
|
@@ -44,7 +44,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
|
||||
|
||||
# use matrix multiplication instead of COL_MOE_KERNL in MOE dispatch and combine
|
||||
layer.use_kernel = False
|
||||
old_out = layer(tokens)
|
||||
old_out, _ = layer(tokens)
|
||||
ech = old_out.shape
|
||||
grad = torch.randn(ech, device=get_current_device())
|
||||
old_out.backward(grad) # get gradient
|
||||
@@ -58,7 +58,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
|
||||
layer.gate_weight.grad.zero_()
|
||||
|
||||
layer.use_kernel = True
|
||||
new_out = layer(tokens) # get ouputs through colossal kernel
|
||||
new_out, _ = layer(tokens) # get ouputs through colossal kernel
|
||||
|
||||
if data_type == torch.float32:
|
||||
check_equal(old_out, new_out)
|
||||
|
@@ -19,20 +19,39 @@ from colossalai.utils import get_current_device
|
||||
from tests.test_zero.common import CONFIG
|
||||
|
||||
|
||||
class MoeModel(CheckpointModule):
|
||||
class MoeModel(nn.Module):
|
||||
|
||||
def __init__(self, checkpoint: bool = False):
|
||||
super().__init__(checkpoint)
|
||||
self.proj1 = nn.Linear(4, 16)
|
||||
expert_cls = nn.Linear
|
||||
expert_args_dict = dict(in_features=16, out_features=16)
|
||||
self.moe = MoeModule(dim_model=16, num_experts=8, use_residual=True, expert_cls=expert_cls, **expert_args_dict)
|
||||
self.proj2 = nn.Linear(16, 4)
|
||||
|
||||
class TestSubModule(CheckpointModule):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(checkpoint)
|
||||
expert_cls = nn.Linear
|
||||
expert_args_dict = dict(in_features=16, out_features=16)
|
||||
self.moe = MoeModule(dim_model=16,
|
||||
num_experts=8,
|
||||
use_residual=True,
|
||||
expert_cls=expert_cls,
|
||||
**expert_args_dict)
|
||||
self.proj = nn.Linear(16, 4)
|
||||
|
||||
def _forward(self, x):
|
||||
x, y = self.moe(x)
|
||||
x = self.proj(x)
|
||||
return x, y
|
||||
|
||||
super().__init__()
|
||||
self.test_embed = nn.Linear(4, 16)
|
||||
self.test_transform = TestSubModule()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj1(x)
|
||||
x = self.moe(x)
|
||||
x = self.proj2(x)
|
||||
MOE_CONTEXT.reset_loss()
|
||||
|
||||
x = self.test_embed(x)
|
||||
x, y = self.test_transform(x)
|
||||
|
||||
MOE_CONTEXT.add_loss(y)
|
||||
return x
|
||||
|
||||
|
||||
|
@@ -4,6 +4,8 @@ import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from colossalai.nn import MoeLoss
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
@@ -26,7 +28,8 @@ def run_model_test(enable_autocast, shard_strategy_class):
|
||||
shard_strategy = shard_strategy_class()
|
||||
|
||||
get_components_func = non_distributed_component_funcs.get_callable('no_leaf_module')
|
||||
_, train_dataloader, _, _, criterion = get_components_func()
|
||||
_, train_dataloader, _, optimizer_class, _ = get_components_func()
|
||||
criterion = MoeLoss(aux_weight=0.01, loss_fn=torch.nn.CrossEntropyLoss)
|
||||
|
||||
with ZeroInitContext(target_device=torch.device('cuda', torch.cuda.current_device()),
|
||||
shard_strategy=shard_strategy,
|
||||
@@ -59,7 +62,6 @@ def run_model_test(enable_autocast, shard_strategy_class):
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
MOE_CONTEXT.setup(seed=42)
|
||||
MOE_CONTEXT.reset_loss()
|
||||
run_model_test()
|
||||
|
||||
|
||||
|
@@ -5,6 +5,7 @@ import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.amp import convert_to_apex_amp
|
||||
from colossalai.nn import MoeLoss
|
||||
from colossalai.nn.optimizer import CPUAdam
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
@@ -60,7 +61,8 @@ def _run_test_sharded_optim_v2(cpu_offload,
|
||||
return
|
||||
MOE_CONTEXT.reset_loss()
|
||||
get_components_func = non_distributed_component_funcs.get_callable('no_leaf_module')
|
||||
_, train_dataloader, _, optimizer_class, criterion = get_components_func()
|
||||
_, train_dataloader, _, optimizer_class, _ = get_components_func()
|
||||
criterion = MoeLoss(aux_weight=0.01, loss_fn=torch.nn.CrossEntropyLoss)
|
||||
|
||||
with ZeroInitContext(target_device=torch.device('cpu') if cpu_offload else get_current_device(),
|
||||
shard_strategy=shard_strategy,
|
||||
|
Reference in New Issue
Block a user