[moe] support optimizer checkpoint (#5015)

* Refactor MoE Manager setup method

* unshard optim ckpt

* optim io

* update transformer version

* update requirements

* update ckpt

* update ckpt

* update ckpt

* fix engine

* fix engine
This commit is contained in:
Xuanlei Zhao
2023-11-08 23:07:03 +08:00
committed by GitHub
parent 67f5331754
commit f71e63b0f3
20 changed files with 738 additions and 150 deletions

View File

@@ -20,21 +20,23 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
# Here we do not need TF32, since it brings absolute error on results
torch.backends.cuda.matmul.allow_tf32 = False
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
local_rank = dist.get_rank()
MOE_MANAGER.setup(42, parallel="EP") # MOE environment initialization
MOE_MANAGER.setup(parallel="EP") # MOE environment initialization
MOE_MANAGER.reset_loss()
torch.manual_seed(rs + local_rank) # set each process has different random seed
torch.manual_seed(rs + local_rank) # set each process has different random seed
# get randomized data
tokens = torch.randn(BATCH_SIZE, hidden_size, dtype=data_type, device=get_current_device(), requires_grad=True)
layer = SparseMLP(hidden_size=hidden_size,
intermediate_size=hidden_size * 2,
num_experts=NUM_EXPERTS,
router_top_k=topk,
router_capacity_factor_train=1.0)
layer = SparseMLP(
hidden_size=hidden_size,
intermediate_size=hidden_size * 2,
num_experts=NUM_EXPERTS,
router_top_k=topk,
router_capacity_factor_train=1.0,
)
layer = layer.to(get_current_device())
if data_type == torch.float16:
layer = layer.half()
@@ -55,7 +57,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
layer.gate_weight.grad.zero_()
layer.enable_kernel = True
new_out = layer(tokens) # get outputs through colossal kernel
new_out = layer(tokens) # get outputs through colossal kernel
if data_type == torch.float32:
check_equal(old_out, new_out)
@@ -90,5 +92,5 @@ def test_moe_kernel(rs, hidden_size, data_type, topk):
spawn(run_routing, 4, rs=rs, hidden_size=hidden_size, data_type=data_type, topk=topk)
if __name__ == '__main__':
if __name__ == "__main__":
test_moe_kernel(2, 256, torch.float16, 2)