[moe] fix moe bugs (#1633)

This commit is contained in:
HELSON
2022-09-23 15:33:57 +08:00
committed by GitHub
parent 702dbc5288
commit a088022efc
8 changed files with 287 additions and 249 deletions

View File

@@ -5,6 +5,7 @@ import pytest
import torch
import torch.multiprocessing as mp
from colossalai.amp import convert_to_apex_amp
from colossalai.nn import MoeLoss
from colossalai.nn.optimizer import CPUAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
@@ -60,7 +61,8 @@ def _run_test_sharded_optim_v2(cpu_offload,
return
MOE_CONTEXT.reset_loss()
get_components_func = non_distributed_component_funcs.get_callable('no_leaf_module')
_, train_dataloader, _, optimizer_class, criterion = get_components_func()
_, train_dataloader, _, optimizer_class, _ = get_components_func()
criterion = MoeLoss(aux_weight=0.01, loss_fn=torch.nn.CrossEntropyLoss)
with ZeroInitContext(target_device=torch.device('cpu') if cpu_offload else get_current_device(),
shard_strategy=shard_strategy,