[autoparallel] fix param hook issue in transform pass (#1755)

This commit is contained in:
YuliangLiu0306
2022-10-24 13:13:38 +08:00
committed by GitHub
parent 262652c8bc
commit d2fc067231
2 changed files with 85 additions and 11 deletions

View File

@@ -171,22 +171,92 @@ def check_apply_bottleneck(rank, world_size, port):
torch.cuda.set_rng_state(cuda_rng_state)
origin_output.sum().backward()
if rank == 0:
print((gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 0, 4)).abs().sum())
print((gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 0, 8)).abs().sum())
print(
f"bn3 diff sum in rank {rank}: {(gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 0, 4)).abs().sum()}"
)
print(
f"conv3 diff sum in rank {rank}: {(gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 0, 8)).abs().sum()}"
)
print(
f"bn2 diff sum in rank {rank}: {(gm.bn2.weight.grad - test_model.bn2.weight.grad.narrow(0, 0, 2)).abs().sum()}"
)
print(
f"conv2 diff sum in rank {rank}: {(gm.conv2.weight.grad - test_model.conv2.weight.grad.narrow(0, 0, 2)).abs().sum()}"
)
print(
f"bn1 diff sum in rank {rank}: {(gm.bn1.weight.grad - test_model.bn1.weight.grad.narrow(0, 0, 1)).abs().sum()}"
)
print(f"conv1 diff sum in rank {rank}: {(gm.conv1.weight.grad - test_model.conv1.weight.grad).sum()}")
assert_close_loose(gm.conv3.weight.grad.sum(), test_model.conv3.weight.grad.narrow(0, 0, 8).sum())
assert_close_loose(gm.conv2.weight.grad.sum(), test_model.conv2.weight.grad.narrow(0, 0, 2).sum())
assert_close_loose(gm.conv1.weight.grad.sum(), test_model.conv1.weight.grad.sum())
if rank == 1:
print((gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 4, 4)).abs().sum())
print((gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 0, 8)).abs().sum())
print(
f"bn3 diff sum in rank {rank}: {(gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 4, 4)).abs().sum()}"
)
print(
f"conv3 diff sum in rank {rank}: {(gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 0, 8)).abs().sum()}"
)
print(
f"bn2 diff sum in rank {rank}: {(gm.bn2.weight.grad - test_model.bn2.weight.grad.narrow(0, 2, 2)).abs().sum()}"
)
print(
f"conv2 diff sum in rank {rank}: {(gm.conv2.weight.grad - test_model.conv2.weight.grad.narrow(0, 2, 2)).abs().sum()}"
)
print(
f"bn1 diff sum in rank {rank}: {(gm.bn1.weight.grad - test_model.bn1.weight.grad.narrow(0, 1, 1)).abs().sum()}"
)
print(f"conv1 diff sum in rank {rank}: {(gm.conv1.weight.grad - test_model.conv1.weight.grad).sum()}")
assert_close_loose(gm.conv3.weight.grad.sum(), test_model.conv3.weight.grad.narrow(0, 0, 8).sum())
assert_close_loose(gm.conv2.weight.grad.sum(), test_model.conv2.weight.grad.narrow(0, 2, 2).sum())
assert_close_loose(gm.conv1.weight.grad.sum(), test_model.conv1.weight.grad.sum())
if rank == 2:
print((gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 8, 4)).abs().sum())
print((gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 8, 8)).abs().sum())
print(
f"bn3 diff sum in rank {rank}: {(gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 8, 4)).abs().sum()}"
)
print(
f"conv3 diff sum in rank {rank}: {(gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 8, 8)).abs().sum()}"
)
print(
f"bn2 diff sum in rank {rank}: {(gm.bn2.weight.grad - test_model.bn2.weight.grad.narrow(0, 0, 2)).abs().sum()}"
)
print(
f"conv2 diff sum in rank {rank}: {(gm.conv2.weight.grad - test_model.conv2.weight.grad.narrow(0, 0, 2)).abs().sum()}"
)
print(
f"bn1 diff sum in rank {rank}: {(gm.bn1.weight.grad - test_model.bn1.weight.grad.narrow(0, 2, 1)).abs().sum()}"
)
print(f"conv1 diff sum in rank {rank}: {(gm.conv1.weight.grad - test_model.conv1.weight.grad).sum()}")
assert_close_loose(gm.conv3.weight.grad.sum(), test_model.conv3.weight.grad.narrow(0, 8, 8).sum())
assert_close_loose(gm.conv2.weight.grad.sum(), test_model.conv2.weight.grad.narrow(0, 0, 2).sum())
assert_close_loose(gm.conv1.weight.grad.sum(), test_model.conv1.weight.grad.sum())
if rank == 3:
print((gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 12, 4)).abs().sum())
print((gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 8, 8)).abs().sum())
print(
f"bn3 diff sum in rank {rank}: {(gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 12, 4)).abs().sum()}"
)
print(
f"conv3 diff sum in rank {rank}: {(gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 8, 8)).abs().sum()}"
)
print(
f"bn2 diff sum in rank {rank}: {(gm.bn2.weight.grad - test_model.bn2.weight.grad.narrow(0, 2, 2)).abs().sum()}"
)
print(
f"conv2 diff sum in rank {rank}: {(gm.conv2.weight.grad - test_model.conv2.weight.grad.narrow(0, 2, 2)).abs().sum()}"
)
print(
f"bn1 diff sum in rank {rank}: {(gm.bn1.weight.grad - test_model.bn1.weight.grad.narrow(0, 3, 1)).abs().sum()}"
)
print(f"conv1 diff sum in rank {rank}: {(gm.conv1.weight.grad - test_model.conv1.weight.grad).sum()}")
assert_close_loose(gm.conv3.weight.grad.sum(), test_model.conv3.weight.grad.narrow(0, 8, 8).sum())
assert_close_loose(gm.conv2.weight.grad.sum(), test_model.conv2.weight.grad.narrow(0, 2, 2).sum())
assert_close_loose(gm.conv1.weight.grad.sum(), test_model.conv1.weight.grad.sum())
@run_on_environment_flag(name='AUTO_PARALLEL')