[shardformer] supported fused qkv checkpoint (#4073)

This commit is contained in:
Frank Lee
2023-06-23 16:07:09 +08:00
parent 0803a61412
commit 70c58cfd4f
10 changed files with 420 additions and 88 deletions

View File

@@ -27,8 +27,13 @@ def check_linear_1d_col():
# check computation correctness
x = torch.rand(4, 32).cuda()
out = linear(x)
gather_out = linear_col(x)
x_for_unshard = x.expand_as(x.clone())
x_for_unshard.requires_grad_(True)
x_for_shard = x.expand_as(x.clone())
x_for_shard.requires_grad_(True)
out = linear(x_for_unshard)
gather_out = linear_col(x_for_shard)
assert_close(out, gather_out)
# check backward correctness
@@ -39,6 +44,11 @@ def check_linear_1d_col():
target_grad = torch.chunk(linear.weight.grad, 2, dim=0)[rank]
assert_close(target_grad, linear_col.weight.grad)
# check the input gradients
assert x_for_shard.grad is not None
assert x_for_unshard.grad is not None
assert_close(x_for_unshard.grad, x_for_shard.grad)
def check_linear_1d_row():
linear = nn.Linear(32, 128).cuda()
@@ -49,8 +59,14 @@ def check_linear_1d_row():
# check computation correctness
x = torch.rand(4, 32).cuda()
out = linear(x)
gather_out = linear_row(x)
x_for_unshard = x.expand_as(x.clone())
x_for_unshard.requires_grad_(True)
x_for_shard = x.expand_as(x.clone())
x_for_shard.requires_grad_(True)
# run forward
out = linear(x_for_unshard)
gather_out = linear_row(x_for_shard)
assert_close(out, gather_out)
# check backward correctness
@@ -61,11 +77,49 @@ def check_linear_1d_row():
target_grad = torch.chunk(linear.weight.grad, 2, dim=1)[rank]
assert_close(target_grad, linear_row.weight.grad)
# check the input gradients
assert x_for_shard.grad is not None
assert x_for_unshard.grad is not None
assert_close(x_for_unshard.grad, x_for_shard.grad)
def check_linear_col_plus_row():
linear_1 = nn.Linear(32, 128).cuda()
linear_2 = nn.Linear(128, 32).cuda()
linear_col = Linear1D_Col.from_native_module(linear_1, process_group=None, gather_output=False)
linear_row = Linear1D_Row.from_native_module(linear_2, process_group=None, parallel_input=True)
# check computation correctness
x = torch.rand(4, 32).cuda()
x_for_unshard = x.expand_as(x.clone())
x_for_unshard.requires_grad_(True)
x_for_shard = x.expand_as(x.clone())
x_for_shard.requires_grad_(True)
# run forward
unshard_out = linear_2(linear_1(x_for_unshard))
shard_out = linear_row(linear_col(x_for_shard))
assert_close(unshard_out, shard_out)
# check backward correctness
unshard_out.sum().backward()
shard_out.sum().backward()
rank = dist.get_rank()
target_1_grad = torch.chunk(linear_1.weight.grad, 2, dim=0)[rank]
assert_close(target_1_grad, linear_col.weight.grad)
# check the input gradients
assert x_for_shard.grad is not None
assert x_for_unshard.grad is not None
assert_close(x_for_unshard.grad, x_for_shard.grad)
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
check_linear_1d_col()
# check_linear_1d_row()
check_linear_1d_row()
check_linear_col_plus_row()
@rerun_if_address_is_in_use()