add moe context, moe utilities and refactor gradient handler (#455)

This commit is contained in:
HELSON
2022-03-18 16:38:32 +08:00
committed by GitHub
parent af185b5519
commit 84fd7c1d4d
11 changed files with 255 additions and 125 deletions

View File

@@ -23,13 +23,13 @@ def check_equal(A, B, atol=1e-06):
def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.float32):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
moe_set_seed(42)
# torch.set_printoptions(precision=30)
torch.backends.cuda.matmul.allow_tf32 = False
local_rank = gpc.get_local_rank(ParallelMode.GLOBAL)
torch.manual_seed(rs + local_rank)
moe_env.reset_loss()
tokens = torch.randn(BATCH_SIZE, hidden_size, dtype=data_type, device=get_current_device(), requires_grad=True)
# print(f"tokens:\n{tokens}")
router = Top2Router(1)
expert = Experts(nn.Identity, 4)
layer = MoeLayer(hidden_size, NUM_EXPERTS, router, expert)
@@ -38,7 +38,6 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
layer.cuda_mode = False
old_out = layer(tokens)
# print(f"old output:\n{old_out}")
ech = old_out.shape
grad = torch.randn(ech, device=get_current_device())
@@ -53,33 +52,27 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
layer.cuda_mode = True
new_out = layer(tokens)
# print(torch.max(torch.abs(old_out - new_out)))
if data_type == torch.float32:
check_equal(old_out, new_out)
else:
check_equal(old_out, new_out, 1e-2)
# print(f"forward functions passed")
# print(f"new output:\n{new_out}")
new_out.backward(grad)
n_tk_grad = tokens.grad.data.clone()
n_gt_grad = layer.gate.weight.grad.data.clone()
# print(torch.max(torch.abs(o_tk_grad - n_tk_grad)))
if data_type == torch.float32:
check_equal(o_tk_grad, n_tk_grad)
else:
check_equal(o_tk_grad, o_tk_grad, 1e-2)
# print(f"tokens gradient passed")
# print(torch.max(torch.abs(o_gt_grad - n_gt_grad)))
if data_type == torch.float32:
check_equal(o_gt_grad, n_gt_grad, 5e-05)
else:
check_equal(o_gt_grad, n_gt_grad, 2e-01)
# print(f"linear weight gradient passed")
@pytest.mark.skip(reason="MoE refactoring has not finished yet")
@pytest.mark.dist
@pytest.mark.parametrize("rs", [131])
@pytest.mark.parametrize("hidden_size", [32, 144])