mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[moe] fix moe bugs (#1633)
This commit is contained in:
@@ -5,6 +5,7 @@ import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.amp import convert_to_apex_amp
|
||||
from colossalai.nn import MoeLoss
|
||||
from colossalai.nn.optimizer import CPUAdam
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
@@ -60,7 +61,8 @@ def _run_test_sharded_optim_v2(cpu_offload,
|
||||
return
|
||||
MOE_CONTEXT.reset_loss()
|
||||
get_components_func = non_distributed_component_funcs.get_callable('no_leaf_module')
|
||||
_, train_dataloader, _, optimizer_class, criterion = get_components_func()
|
||||
_, train_dataloader, _, optimizer_class, _ = get_components_func()
|
||||
criterion = MoeLoss(aux_weight=0.01, loss_fn=torch.nn.CrossEntropyLoss)
|
||||
|
||||
with ZeroInitContext(target_device=torch.device('cpu') if cpu_offload else get_current_device(),
|
||||
shard_strategy=shard_strategy,
|
||||
|
Reference in New Issue
Block a user