[moe] support optimizer checkpoint (#5015)

* Refactor MoE Manager setup method

* unshard optim ckpt

* optim io

* update transformer version

* update requirements

* update ckpt

* update ckpt

* update ckpt

* fix engine

* fix engine
This commit is contained in:
Xuanlei Zhao
2023-11-08 23:07:03 +08:00
committed by GitHub
parent 67f5331754
commit f71e63b0f3
20 changed files with 738 additions and 150 deletions

View File

@@ -14,16 +14,16 @@ from tests.test_moe.moe_utils import MoeGradientHandler, sync_local_from_ep, syn
def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size: int, dim: int, seed: int):
assert batch_size % world_size == 0
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
MOE_MANAGER.__init__()
MOE_MANAGER.setup(seed, parallel=None)
MOE_MANAGER.setup(parallel=None)
local_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2)
MOE_MANAGER.__init__()
MOE_MANAGER.setup(seed, parallel="EP")
MOE_MANAGER.setup(parallel="EP")
ep_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2)
MOE_MANAGER.__init__()
MOE_MANAGER.setup(seed, parallel="TP")
MOE_MANAGER.setup(parallel="TP")
tp_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2)
ep_model = ep_model.to(get_current_device())
tp_model = tp_model.to(get_current_device())
@@ -44,7 +44,7 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size
torch.cuda.manual_seed(seed)
tp_data = torch.randn(batch_size, dim, device=get_current_device())
micro_batch_size = batch_size // world_size
ep_data = tp_data.detach()[micro_batch_size * rank:micro_batch_size * (rank + 1)]
ep_data = tp_data.detach()[micro_batch_size * rank : micro_batch_size * (rank + 1)]
out_local = local_model(tp_data)
MOE_MANAGER.reset_loss()
@@ -52,8 +52,8 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size
MOE_MANAGER.reset_loss()
out_ep = ep_model(ep_data)
MOE_MANAGER.reset_loss()
assert torch.allclose(out_ep, out_tp[micro_batch_size * rank:micro_batch_size * (rank + 1)])
assert torch.allclose(out_ep, out_local[micro_batch_size * rank:micro_batch_size * (rank + 1)])
assert torch.allclose(out_ep, out_tp[micro_batch_size * rank : micro_batch_size * (rank + 1)])
assert torch.allclose(out_ep, out_local[micro_batch_size * rank : micro_batch_size * (rank + 1)])
out_local.mean().backward()
out_tp.mean().backward()
@@ -77,5 +77,5 @@ def test_moe_ep_tp(num_experts: int, batch_size: int, dim: int, seed: int):
spawn(run_test, 2, num_experts=num_experts, batch_size=batch_size, dim=dim, seed=seed)
if __name__ == '__main__':
if __name__ == "__main__":
test_moe_ep_tp(num_experts=8, batch_size=8, dim=256, seed=42)