mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 03:20:52 +00:00
[moe] merge moe into main (#4978)
* update moe module * support openmoe
This commit is contained in:
@@ -2,120 +2,91 @@ import pytest
|
||||
import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.context import MOE_CONTEXT
|
||||
from colossalai.legacy.amp import convert_to_apex_amp
|
||||
from colossalai.legacy.engine.gradient_handler import MoeGradientHandler
|
||||
from colossalai.nn import MoeLoss
|
||||
from colossalai.nn.optimizer import CPUAdam
|
||||
from colossalai.testing import assert_equal_in_group, parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero.legacy.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy
|
||||
from colossalai.zero.legacy.sharded_model import ShardedModelV2
|
||||
from colossalai.zero.legacy.sharded_model.utils import col_model_deepcopy
|
||||
from colossalai.zero.legacy.sharded_optim import ShardedOptimizerV2
|
||||
from colossalai.zero.low_level._utils import has_inf_or_nan
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from tests.test_moe.test_moe_zero_init import MoeModel
|
||||
from tests.test_zero.test_legacy.common import CONFIG, check_sharded_model_params
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import LowLevelZeroPlugin
|
||||
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from tests.test_moe.moe_utils import MoeGradientHandler, MoeModel
|
||||
|
||||
|
||||
def _run_step(model, optimizer, data, label, criterion, grad_handler):
|
||||
def split_ddp_grad(grad, world_size):
|
||||
with torch.no_grad():
|
||||
grad = grad.clone().detach().flatten()
|
||||
padding_size = (world_size - grad.numel() % world_size) % world_size
|
||||
if padding_size > 0:
|
||||
grad = torch.nn.functional.pad(grad, [0, padding_size])
|
||||
splited_grad = grad.split(grad.numel() // world_size)
|
||||
return splited_grad
|
||||
|
||||
|
||||
def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False):
|
||||
model.train()
|
||||
optimizer.zero_grad()
|
||||
with torch.cuda.amp.autocast(enabled=enable_autocast):
|
||||
if criterion:
|
||||
y = model(data)
|
||||
loss = criterion(y, label)
|
||||
else:
|
||||
loss = model(data, label)
|
||||
loss = loss.float()
|
||||
|
||||
if criterion:
|
||||
y = model(data)
|
||||
loss = criterion(y, label)
|
||||
else:
|
||||
loss = model(data, label)
|
||||
|
||||
loss = loss.float()
|
||||
if isinstance(model, ShardedModelV2):
|
||||
if isinstance(model, LowLevelZeroModel):
|
||||
optimizer.backward(loss)
|
||||
else:
|
||||
loss.backward()
|
||||
return y
|
||||
|
||||
if grad_handler is not None:
|
||||
|
||||
def run_zero_optim_test(local_rank, world_size, stage=1):
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
|
||||
zero_model = MoeModel()
|
||||
zero_optimizer = torch.optim.Adam(zero_model.parameters())
|
||||
plugin = LowLevelZeroPlugin(stage=stage, precision="fp32")
|
||||
booster = Booster(plugin=plugin)
|
||||
zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer)
|
||||
|
||||
torch_model = MoeModel()
|
||||
for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()):
|
||||
torch_param.data.copy_(zero_param.data)
|
||||
torch_optimizer = torch.optim.Adam(torch_model.parameters())
|
||||
torch_model = torch_model.cuda()
|
||||
grad_handler = MoeGradientHandler(torch_model)
|
||||
|
||||
for _ in range(2):
|
||||
data = torch.randn(16, 4).cuda() / (local_rank + 1)
|
||||
label = torch.randint(0, 4, (16,)).cuda()
|
||||
run_fwd_bwd(torch_model, data, label, criterion, None)
|
||||
run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
|
||||
grad_handler.handle_gradient()
|
||||
|
||||
optimizer.step()
|
||||
torch_optimizer.step()
|
||||
zero_optimizer.step()
|
||||
|
||||
for (torch_name, torch_param), (zero_name, zero_param) in zip(
|
||||
torch_model.named_parameters(), zero_model.named_parameters()
|
||||
):
|
||||
assert torch.allclose(
|
||||
torch_param.data, zero_param.data
|
||||
), f"{torch_name}\ntorch_param {torch_param.data}\nzero_param {zero_param.data}"
|
||||
|
||||
torch_optimizer.zero_grad()
|
||||
zero_optimizer.zero_grad()
|
||||
|
||||
|
||||
@parameterize("cpu_offload", [True])
|
||||
@parameterize("use_cpuadam", [True]) # We do not use Hybrid Adam right now, since it has a little bug
|
||||
@parameterize("reuse_fp16_shard", [True, False])
|
||||
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||
def _run_test_sharded_optim_v2(
|
||||
cpu_offload, shard_strategy_class, use_cpuadam, reuse_fp16_shard, gpu_margin_mem_ratio=0.0
|
||||
):
|
||||
shard_strategy = shard_strategy_class()
|
||||
if use_cpuadam and cpu_offload is False:
|
||||
return
|
||||
MOE_CONTEXT.reset_loss()
|
||||
get_components_func = non_distributed_component_funcs.get_callable("hanging_param_model")
|
||||
_, train_dataloader, _, optimizer_class, _ = get_components_func()
|
||||
criterion = MoeLoss(aux_weight=0.01, loss_fn=torch.nn.CrossEntropyLoss)
|
||||
|
||||
with ZeroInitContext(
|
||||
target_device=torch.device("cpu") if cpu_offload else get_current_device(),
|
||||
shard_strategy=shard_strategy,
|
||||
shard_param=True,
|
||||
):
|
||||
zero_model = MoeModel(checkpoint=True)
|
||||
|
||||
zero_model = ShardedModelV2(
|
||||
zero_model,
|
||||
shard_strategy,
|
||||
tensor_placement_policy="cpu" if cpu_offload else "cuda",
|
||||
reuse_fp16_shard=reuse_fp16_shard,
|
||||
)
|
||||
|
||||
# check whether parameters are identical in ddp
|
||||
for name, p in zero_model.named_parameters():
|
||||
if not p.colo_attr.param_is_sharded and p.colo_attr.is_replicated:
|
||||
assert_equal_in_group(p.colo_attr.data_payload.to(get_current_device()))
|
||||
|
||||
model = MoeModel(checkpoint=True).half()
|
||||
col_model_deepcopy(zero_model, model)
|
||||
model = model.cuda().float()
|
||||
|
||||
if use_cpuadam:
|
||||
optimizer_class = CPUAdam
|
||||
optim = optimizer_class(model.parameters(), lr=1e-3)
|
||||
sharded_optim = optimizer_class(zero_model.parameters(), lr=1e-3)
|
||||
sharded_optim = ShardedOptimizerV2(
|
||||
zero_model, sharded_optim, initial_scale=2**5, gpu_margin_mem_ratio=gpu_margin_mem_ratio
|
||||
)
|
||||
|
||||
amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False)
|
||||
apex_model, apex_optimizer = convert_to_apex_amp(model, optim, amp_config)
|
||||
apex_grad_handler = MoeGradientHandler(model)
|
||||
|
||||
for i, (data, label) in enumerate(train_dataloader):
|
||||
if i > 5:
|
||||
break
|
||||
data, label = data.cuda(), label.cuda()
|
||||
_run_step(apex_model, apex_optimizer, data, label, criterion, apex_grad_handler)
|
||||
_run_step(zero_model, sharded_optim, data, label, criterion, None)
|
||||
check_sharded_model_params(model, zero_model, loose=True, reuse_fp16_shard=use_cpuadam)
|
||||
for param in model.parameters():
|
||||
assert not has_inf_or_nan(param)
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
MOE_MANAGER.setup(seed=42, parallel="EP")
|
||||
run_zero_optim_test(rank, world_size, stage=1)
|
||||
run_zero_optim_test(rank, world_size, stage=2)
|
||||
|
||||
|
||||
def _run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
MOE_CONTEXT.setup(seed=42)
|
||||
_run_test_sharded_optim_v2()
|
||||
|
||||
|
||||
# use_cpuadam = True can be used with cpu_offload = False
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_moe_zero_optim(world_size):
|
||||
spawn(_run_dist, world_size)
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_moe_zero_optim(world_size=4)
|
||||
test_moe_zero_optim(world_size=2)
|
||||
|
Reference in New Issue
Block a user