mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 13:00:52 +00:00
[shardformer] supported fused qkv checkpoint (#4073)
This commit is contained in:
@@ -27,8 +27,13 @@ def check_linear_1d_col():
|
||||
|
||||
# check computation correctness
|
||||
x = torch.rand(4, 32).cuda()
|
||||
out = linear(x)
|
||||
gather_out = linear_col(x)
|
||||
x_for_unshard = x.expand_as(x.clone())
|
||||
x_for_unshard.requires_grad_(True)
|
||||
x_for_shard = x.expand_as(x.clone())
|
||||
x_for_shard.requires_grad_(True)
|
||||
|
||||
out = linear(x_for_unshard)
|
||||
gather_out = linear_col(x_for_shard)
|
||||
assert_close(out, gather_out)
|
||||
|
||||
# check backward correctness
|
||||
@@ -39,6 +44,11 @@ def check_linear_1d_col():
|
||||
target_grad = torch.chunk(linear.weight.grad, 2, dim=0)[rank]
|
||||
assert_close(target_grad, linear_col.weight.grad)
|
||||
|
||||
# check the input gradients
|
||||
assert x_for_shard.grad is not None
|
||||
assert x_for_unshard.grad is not None
|
||||
assert_close(x_for_unshard.grad, x_for_shard.grad)
|
||||
|
||||
|
||||
def check_linear_1d_row():
|
||||
linear = nn.Linear(32, 128).cuda()
|
||||
@@ -49,8 +59,14 @@ def check_linear_1d_row():
|
||||
|
||||
# check computation correctness
|
||||
x = torch.rand(4, 32).cuda()
|
||||
out = linear(x)
|
||||
gather_out = linear_row(x)
|
||||
x_for_unshard = x.expand_as(x.clone())
|
||||
x_for_unshard.requires_grad_(True)
|
||||
x_for_shard = x.expand_as(x.clone())
|
||||
x_for_shard.requires_grad_(True)
|
||||
|
||||
# run forward
|
||||
out = linear(x_for_unshard)
|
||||
gather_out = linear_row(x_for_shard)
|
||||
assert_close(out, gather_out)
|
||||
|
||||
# check backward correctness
|
||||
@@ -61,11 +77,49 @@ def check_linear_1d_row():
|
||||
target_grad = torch.chunk(linear.weight.grad, 2, dim=1)[rank]
|
||||
assert_close(target_grad, linear_row.weight.grad)
|
||||
|
||||
# check the input gradients
|
||||
assert x_for_shard.grad is not None
|
||||
assert x_for_unshard.grad is not None
|
||||
assert_close(x_for_unshard.grad, x_for_shard.grad)
|
||||
|
||||
|
||||
def check_linear_col_plus_row():
|
||||
linear_1 = nn.Linear(32, 128).cuda()
|
||||
linear_2 = nn.Linear(128, 32).cuda()
|
||||
linear_col = Linear1D_Col.from_native_module(linear_1, process_group=None, gather_output=False)
|
||||
linear_row = Linear1D_Row.from_native_module(linear_2, process_group=None, parallel_input=True)
|
||||
|
||||
# check computation correctness
|
||||
x = torch.rand(4, 32).cuda()
|
||||
x_for_unshard = x.expand_as(x.clone())
|
||||
x_for_unshard.requires_grad_(True)
|
||||
x_for_shard = x.expand_as(x.clone())
|
||||
x_for_shard.requires_grad_(True)
|
||||
|
||||
# run forward
|
||||
unshard_out = linear_2(linear_1(x_for_unshard))
|
||||
shard_out = linear_row(linear_col(x_for_shard))
|
||||
assert_close(unshard_out, shard_out)
|
||||
|
||||
# check backward correctness
|
||||
unshard_out.sum().backward()
|
||||
shard_out.sum().backward()
|
||||
|
||||
rank = dist.get_rank()
|
||||
target_1_grad = torch.chunk(linear_1.weight.grad, 2, dim=0)[rank]
|
||||
assert_close(target_1_grad, linear_col.weight.grad)
|
||||
|
||||
# check the input gradients
|
||||
assert x_for_shard.grad is not None
|
||||
assert x_for_unshard.grad is not None
|
||||
assert_close(x_for_unshard.grad, x_for_shard.grad)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
check_linear_1d_col()
|
||||
# check_linear_1d_row()
|
||||
check_linear_1d_row()
|
||||
check_linear_col_plus_row()
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
|
@@ -5,6 +5,7 @@ from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.shardformer.layer import LinearConv1D_Col, LinearConv1D_Row
|
||||
from colossalai.shardformer.layer.linear_conv import split_fused_qkv
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
@@ -53,9 +54,15 @@ def check_linear_conv_1d_col():
|
||||
linear = Conv1D(192, 48).cuda()
|
||||
linear_conv_col = LinearConv1D_Col.from_native_module(linear, process_group=None, gather_output=True, n_fused=3)
|
||||
|
||||
assert linear_conv_col.weight.shape == torch.Size([96, 48])
|
||||
assert linear.weight.shape == torch.Size([48, 192])
|
||||
assert linear.bias.shape == torch.Size([192])
|
||||
assert linear_conv_col.weight.shape == torch.Size([48, 96])
|
||||
assert linear_conv_col.bias.shape == torch.Size([96])
|
||||
|
||||
# ensure weights are reversibly loadable
|
||||
linear_conv_col.load_state_dict(linear.state_dict())
|
||||
linear.load_state_dict(linear_conv_col.state_dict())
|
||||
|
||||
# check computation correctness
|
||||
x = torch.rand(4, 48).cuda()
|
||||
out = linear(x)
|
||||
@@ -66,16 +73,16 @@ def check_linear_conv_1d_col():
|
||||
out.sum().backward()
|
||||
gather_out.sum().backward()
|
||||
|
||||
rank = dist.get_rank()
|
||||
target_grad = torch.chunk(linear.weight.grad, 2, dim=1)[rank]
|
||||
assert_close(target_grad.transpose(0, 1).contiguous(), linear_conv_col.weight.grad)
|
||||
target_grad = split_fused_qkv(linear.weight.grad, 3, None)
|
||||
assert_close(target_grad, linear_conv_col.weight.grad)
|
||||
|
||||
|
||||
def check_linear_1d_row():
|
||||
linear = Conv1D(192, 48).cuda()
|
||||
linear_row = LinearConv1D_Row.from_native_module(linear, process_group=None, parallel_input=False)
|
||||
|
||||
assert linear_row.weight.shape == torch.Size([192, 24])
|
||||
assert linear.weight.shape == torch.Size([48, 192])
|
||||
assert linear_row.weight.shape == torch.Size([24, 192])
|
||||
assert linear_row.bias.shape == torch.Size([192])
|
||||
|
||||
# check computation correctness
|
||||
@@ -89,13 +96,14 @@ def check_linear_1d_row():
|
||||
gather_out.sum().backward()
|
||||
|
||||
rank = dist.get_rank()
|
||||
target_grad = torch.chunk(linear.weight.grad, 2, dim=1)[rank]
|
||||
target_grad = torch.chunk(linear.weight.grad, 2, dim=0)[rank]
|
||||
assert_close(target_grad, linear_row.weight.grad)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
check_linear_conv_1d_col()
|
||||
check_linear_1d_row()
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
|
@@ -20,20 +20,21 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
||||
|
||||
# check grad equality
|
||||
if org_model.__class__.__name__ == 'GPT2Model':
|
||||
org_grad = org_model.h[0].attn.c_attn.weight.grad
|
||||
shard_grad = sharded_model.h[0].attn.c_attn.weight.grad.transpose(0, 1).contiguous()
|
||||
org_grad = org_model.h[0].mlp.c_fc.weight.grad
|
||||
shard_grad = sharded_model.h[0].mlp.c_fc.weight.grad
|
||||
else:
|
||||
org_grad = org_model.transformer.h[0].mlp.c_fc.weight.grad
|
||||
shard_grad = sharded_model.transformer.h[0].mlp.c_fc.weight.grad.transpose(0, 1).contiguous()
|
||||
shard_grad = sharded_model.transformer.h[0].mlp.c_fc.weight.grad
|
||||
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=1)
|
||||
|
||||
assert torch.allclose(org_loss, shard_loss,
|
||||
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
|
||||
assert torch.allclose(org_grad, all_shard_grad,
|
||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
||||
atol=1e-5), f"shard model loss is not equal to origin model loss\n{org_loss}\n{shard_loss}"
|
||||
assert torch.allclose(
|
||||
org_grad, all_shard_grad,
|
||||
atol=1e-5), f"shard model grad is not equal to origin model grad\n{org_grad}\n{all_shard_grad}"
|
||||
|
||||
|
||||
def check_gpt2(rank, world_size, port):
|
||||
|
Reference in New Issue
Block a user