mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
add moe context, moe utilities and refactor gradient handler (#455)
This commit is contained in:
@@ -23,13 +23,13 @@ def check_equal(A, B, atol=1e-06):
|
||||
def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.float32):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
moe_set_seed(42)
|
||||
# torch.set_printoptions(precision=30)
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
local_rank = gpc.get_local_rank(ParallelMode.GLOBAL)
|
||||
torch.manual_seed(rs + local_rank)
|
||||
moe_env.reset_loss()
|
||||
tokens = torch.randn(BATCH_SIZE, hidden_size, dtype=data_type, device=get_current_device(), requires_grad=True)
|
||||
# print(f"tokens:\n{tokens}")
|
||||
|
||||
router = Top2Router(1)
|
||||
expert = Experts(nn.Identity, 4)
|
||||
layer = MoeLayer(hidden_size, NUM_EXPERTS, router, expert)
|
||||
@@ -38,7 +38,6 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
|
||||
layer.cuda_mode = False
|
||||
|
||||
old_out = layer(tokens)
|
||||
# print(f"old output:\n{old_out}")
|
||||
|
||||
ech = old_out.shape
|
||||
grad = torch.randn(ech, device=get_current_device())
|
||||
@@ -53,33 +52,27 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
|
||||
layer.cuda_mode = True
|
||||
new_out = layer(tokens)
|
||||
|
||||
# print(torch.max(torch.abs(old_out - new_out)))
|
||||
if data_type == torch.float32:
|
||||
check_equal(old_out, new_out)
|
||||
else:
|
||||
check_equal(old_out, new_out, 1e-2)
|
||||
# print(f"forward functions passed")
|
||||
|
||||
# print(f"new output:\n{new_out}")
|
||||
new_out.backward(grad)
|
||||
n_tk_grad = tokens.grad.data.clone()
|
||||
n_gt_grad = layer.gate.weight.grad.data.clone()
|
||||
|
||||
# print(torch.max(torch.abs(o_tk_grad - n_tk_grad)))
|
||||
if data_type == torch.float32:
|
||||
check_equal(o_tk_grad, n_tk_grad)
|
||||
else:
|
||||
check_equal(o_tk_grad, o_tk_grad, 1e-2)
|
||||
# print(f"tokens gradient passed")
|
||||
|
||||
# print(torch.max(torch.abs(o_gt_grad - n_gt_grad)))
|
||||
if data_type == torch.float32:
|
||||
check_equal(o_gt_grad, n_gt_grad, 5e-05)
|
||||
else:
|
||||
check_equal(o_gt_grad, n_gt_grad, 2e-01)
|
||||
# print(f"linear weight gradient passed")
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="MoE refactoring has not finished yet")
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("rs", [131])
|
||||
@pytest.mark.parametrize("hidden_size", [32, 144])
|
||||
|
Reference in New Issue
Block a user