From d2fc0672313969c1fb864c96c44b4b31cf6bf3fe Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Mon, 24 Oct 2022 13:13:38 +0800 Subject: [PATCH] [autoparallel] fix param hook issue in transform pass (#1755) --- .../adding_shape_consistency_pass_v2.py | 10 ++- .../test_resnet_block_runtime.py | 86 +++++++++++++++++-- 2 files changed, 85 insertions(+), 11 deletions(-) diff --git a/colossalai/fx/passes/experimental/adding_shape_consistency_pass_v2.py b/colossalai/fx/passes/experimental/adding_shape_consistency_pass_v2.py index 2776b38b8..2e735a25d 100644 --- a/colossalai/fx/passes/experimental/adding_shape_consistency_pass_v2.py +++ b/colossalai/fx/passes/experimental/adding_shape_consistency_pass_v2.py @@ -70,10 +70,14 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], de comm_spec_to_use = comm_action.comm_spec if operation_data.type == OperationDataType.PARAM and operation_data.name == name and comm_action.comm_type == CommType.HOOK: - def hook_fn(grad): - _all_reduce(grad, comm_spec_to_use) + def wrapper(param, comm_spec): - param_sharded.register_hook(hook_fn) + def hook_fn(grad): + _all_reduce(grad, comm_spec) + + param.register_hook(hook_fn) + + wrapper(param_sharded, comm_spec_to_use) sharded_buffer_dict = {} for name, buffer in target_module.named_buffers(): diff --git a/tests/test_auto_parallel/test_tensor_shard/test_resnet_block_runtime.py b/tests/test_auto_parallel/test_tensor_shard/test_resnet_block_runtime.py index bb109840b..1f753522c 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_resnet_block_runtime.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_resnet_block_runtime.py @@ -171,22 +171,92 @@ def check_apply_bottleneck(rank, world_size, port): torch.cuda.set_rng_state(cuda_rng_state) origin_output.sum().backward() if rank == 0: - print((gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 0, 4)).abs().sum()) - print((gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 0, 8)).abs().sum()) + print( + f"bn3 diff sum in rank {rank}: {(gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 0, 4)).abs().sum()}" + ) + print( + f"conv3 diff sum in rank {rank}: {(gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 0, 8)).abs().sum()}" + ) + print( + f"bn2 diff sum in rank {rank}: {(gm.bn2.weight.grad - test_model.bn2.weight.grad.narrow(0, 0, 2)).abs().sum()}" + ) + print( + f"conv2 diff sum in rank {rank}: {(gm.conv2.weight.grad - test_model.conv2.weight.grad.narrow(0, 0, 2)).abs().sum()}" + ) + print( + f"bn1 diff sum in rank {rank}: {(gm.bn1.weight.grad - test_model.bn1.weight.grad.narrow(0, 0, 1)).abs().sum()}" + ) + print(f"conv1 diff sum in rank {rank}: {(gm.conv1.weight.grad - test_model.conv1.weight.grad).sum()}") + assert_close_loose(gm.conv3.weight.grad.sum(), test_model.conv3.weight.grad.narrow(0, 0, 8).sum()) + assert_close_loose(gm.conv2.weight.grad.sum(), test_model.conv2.weight.grad.narrow(0, 0, 2).sum()) + assert_close_loose(gm.conv1.weight.grad.sum(), test_model.conv1.weight.grad.sum()) + if rank == 1: - print((gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 4, 4)).abs().sum()) - print((gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 0, 8)).abs().sum()) + print( + f"bn3 diff sum in rank {rank}: {(gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 4, 4)).abs().sum()}" + ) + print( + f"conv3 diff sum in rank {rank}: {(gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 0, 8)).abs().sum()}" + ) + print( + f"bn2 diff sum in rank {rank}: {(gm.bn2.weight.grad - test_model.bn2.weight.grad.narrow(0, 2, 2)).abs().sum()}" + ) + print( + f"conv2 diff sum in rank {rank}: {(gm.conv2.weight.grad - test_model.conv2.weight.grad.narrow(0, 2, 2)).abs().sum()}" + ) + print( + f"bn1 diff sum in rank {rank}: {(gm.bn1.weight.grad - test_model.bn1.weight.grad.narrow(0, 1, 1)).abs().sum()}" + ) + print(f"conv1 diff sum in rank {rank}: {(gm.conv1.weight.grad - test_model.conv1.weight.grad).sum()}") + assert_close_loose(gm.conv3.weight.grad.sum(), test_model.conv3.weight.grad.narrow(0, 0, 8).sum()) + assert_close_loose(gm.conv2.weight.grad.sum(), test_model.conv2.weight.grad.narrow(0, 2, 2).sum()) + assert_close_loose(gm.conv1.weight.grad.sum(), test_model.conv1.weight.grad.sum()) + if rank == 2: - print((gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 8, 4)).abs().sum()) - print((gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 8, 8)).abs().sum()) + print( + f"bn3 diff sum in rank {rank}: {(gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 8, 4)).abs().sum()}" + ) + print( + f"conv3 diff sum in rank {rank}: {(gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 8, 8)).abs().sum()}" + ) + print( + f"bn2 diff sum in rank {rank}: {(gm.bn2.weight.grad - test_model.bn2.weight.grad.narrow(0, 0, 2)).abs().sum()}" + ) + print( + f"conv2 diff sum in rank {rank}: {(gm.conv2.weight.grad - test_model.conv2.weight.grad.narrow(0, 0, 2)).abs().sum()}" + ) + print( + f"bn1 diff sum in rank {rank}: {(gm.bn1.weight.grad - test_model.bn1.weight.grad.narrow(0, 2, 1)).abs().sum()}" + ) + print(f"conv1 diff sum in rank {rank}: {(gm.conv1.weight.grad - test_model.conv1.weight.grad).sum()}") + assert_close_loose(gm.conv3.weight.grad.sum(), test_model.conv3.weight.grad.narrow(0, 8, 8).sum()) + assert_close_loose(gm.conv2.weight.grad.sum(), test_model.conv2.weight.grad.narrow(0, 0, 2).sum()) + assert_close_loose(gm.conv1.weight.grad.sum(), test_model.conv1.weight.grad.sum()) if rank == 3: - print((gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 12, 4)).abs().sum()) - print((gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 8, 8)).abs().sum()) + print( + f"bn3 diff sum in rank {rank}: {(gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 12, 4)).abs().sum()}" + ) + print( + f"conv3 diff sum in rank {rank}: {(gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 8, 8)).abs().sum()}" + ) + print( + f"bn2 diff sum in rank {rank}: {(gm.bn2.weight.grad - test_model.bn2.weight.grad.narrow(0, 2, 2)).abs().sum()}" + ) + print( + f"conv2 diff sum in rank {rank}: {(gm.conv2.weight.grad - test_model.conv2.weight.grad.narrow(0, 2, 2)).abs().sum()}" + ) + print( + f"bn1 diff sum in rank {rank}: {(gm.bn1.weight.grad - test_model.bn1.weight.grad.narrow(0, 3, 1)).abs().sum()}" + ) + print(f"conv1 diff sum in rank {rank}: {(gm.conv1.weight.grad - test_model.conv1.weight.grad).sum()}") + assert_close_loose(gm.conv3.weight.grad.sum(), test_model.conv3.weight.grad.narrow(0, 8, 8).sum()) + assert_close_loose(gm.conv2.weight.grad.sum(), test_model.conv2.weight.grad.narrow(0, 2, 2).sum()) + assert_close_loose(gm.conv1.weight.grad.sum(), test_model.conv1.weight.grad.sum()) @run_on_environment_flag(name='AUTO_PARALLEL')