mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[moe] support optimizer checkpoint (#5015)
* Refactor MoE Manager setup method * unshard optim ckpt * optim io * update transformer version * update requirements * update ckpt * update ckpt * update ckpt * fix engine * fix engine
This commit is contained in:
@@ -14,16 +14,16 @@ from tests.test_moe.moe_utils import MoeGradientHandler, sync_local_from_ep, syn
|
||||
def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size: int, dim: int, seed: int):
|
||||
assert batch_size % world_size == 0
|
||||
|
||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
|
||||
MOE_MANAGER.__init__()
|
||||
MOE_MANAGER.setup(seed, parallel=None)
|
||||
MOE_MANAGER.setup(parallel=None)
|
||||
local_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2)
|
||||
MOE_MANAGER.__init__()
|
||||
MOE_MANAGER.setup(seed, parallel="EP")
|
||||
MOE_MANAGER.setup(parallel="EP")
|
||||
ep_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2)
|
||||
MOE_MANAGER.__init__()
|
||||
MOE_MANAGER.setup(seed, parallel="TP")
|
||||
MOE_MANAGER.setup(parallel="TP")
|
||||
tp_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2)
|
||||
ep_model = ep_model.to(get_current_device())
|
||||
tp_model = tp_model.to(get_current_device())
|
||||
@@ -44,7 +44,7 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size
|
||||
torch.cuda.manual_seed(seed)
|
||||
tp_data = torch.randn(batch_size, dim, device=get_current_device())
|
||||
micro_batch_size = batch_size // world_size
|
||||
ep_data = tp_data.detach()[micro_batch_size * rank:micro_batch_size * (rank + 1)]
|
||||
ep_data = tp_data.detach()[micro_batch_size * rank : micro_batch_size * (rank + 1)]
|
||||
|
||||
out_local = local_model(tp_data)
|
||||
MOE_MANAGER.reset_loss()
|
||||
@@ -52,8 +52,8 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size
|
||||
MOE_MANAGER.reset_loss()
|
||||
out_ep = ep_model(ep_data)
|
||||
MOE_MANAGER.reset_loss()
|
||||
assert torch.allclose(out_ep, out_tp[micro_batch_size * rank:micro_batch_size * (rank + 1)])
|
||||
assert torch.allclose(out_ep, out_local[micro_batch_size * rank:micro_batch_size * (rank + 1)])
|
||||
assert torch.allclose(out_ep, out_tp[micro_batch_size * rank : micro_batch_size * (rank + 1)])
|
||||
assert torch.allclose(out_ep, out_local[micro_batch_size * rank : micro_batch_size * (rank + 1)])
|
||||
|
||||
out_local.mean().backward()
|
||||
out_tp.mean().backward()
|
||||
@@ -77,5 +77,5 @@ def test_moe_ep_tp(num_experts: int, batch_size: int, dim: int, seed: int):
|
||||
spawn(run_test, 2, num_experts=num_experts, batch_size=batch_size, dim=dim, seed=seed)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_moe_ep_tp(num_experts=8, batch_size=8, dim=256, seed=42)
|
||||
|
Reference in New Issue
Block a user