[shardformer] fix linear 1d row and support uneven splits for fused qkv linear (#6084)

* [tp] hotfix linear row

* [tp] support uneven split for fused linear

* [tp] support sp for fused linear

* [tp] fix gpt2 mlp policy

* [tp] fix gather fused and add fused linear row
This commit is contained in:
Hongxin Liu
2024-10-10 14:34:45 +08:00
committed by GitHub
parent f4daf04270
commit 646b3c5a90
10 changed files with 399 additions and 157 deletions

View File

@@ -41,21 +41,6 @@ class Conv1D(nn.Module):
return x
def rearrange(tensor: torch.Tensor, dim: int):
tensor = tensor.clone()
world_size = 2
order = torch.arange(world_size * 3)
new_order = []
for i in range(world_size):
new_order.append(order[i::world_size])
new_order = torch.cat(new_order)
tensor_chunks = torch.chunk(tensor, world_size * 3, dim=dim)
rearanged_tensor_chunks = [tensor_chunks[i] for i in new_order]
rearanged_tensor = torch.cat(rearanged_tensor_chunks, dim=dim)
return rearanged_tensor
def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str, overlap: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = Conv1D(192, 48).cuda()
@@ -66,7 +51,7 @@ def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str, overlap: b
process_group=None,
gather_output=True,
seq_parallel_mode=seq_parallel_mode,
n_fused=3,
split_sizes=[64] * 3,
overlap=overlap,
)
@@ -88,13 +73,13 @@ def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str, overlap: b
x.expand_as(x.clone()) if seq_parallel_mode is None else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
)
gather_out = linear_conv_col(x_for_shard)
assert_close(rearrange(out, -1), gather_out)
assert_close(out, gather_out)
# check backward correctness
out.sum().backward()
gather_out.sum().backward()
target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, 3, None, True)
target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, [64] * 3, None, True)
assert_close(target_grad, linear_conv_col.weight.grad)

View File

@@ -2,13 +2,12 @@ import os
from contextlib import nullcontext
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.testing import assert_close
import colossalai
from colossalai.lazy import LazyInitContext
from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
from colossalai.shardformer.layer import FusedLinear1D_Col, FusedLinear1D_Row
from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
@@ -16,93 +15,55 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
class Conv1D(nn.Module):
"""
1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
Basically works like a linear layer but the weights are transposed.
Args:
nf (`int`): The number of output features.
nx (`int`): The number of input features.
"""
def __init__(self, nf, nx):
super().__init__()
self.nf = nf
self.weight = nn.Parameter(torch.empty(nx, nf))
self.bias = nn.Parameter(torch.zeros(nf))
nn.init.normal_(self.weight, std=0.02)
def forward(self, x):
size_out = x.size()[:-1] + (self.nf,)
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
x = x.view(size_out)
return x
def rearrange(tensor: torch.Tensor, dim: int):
tensor = tensor.clone()
world_size = 2
order = torch.arange(world_size * 3)
new_order = []
for i in range(world_size):
new_order.append(order[i::world_size])
new_order = torch.cat(new_order)
tensor_chunks = torch.chunk(tensor, world_size * 3, dim=dim)
rearanged_tensor_chunks = [tensor_chunks[i] for i in new_order]
rearanged_tensor = torch.cat(rearanged_tensor_chunks, dim=dim)
return rearanged_tensor
@parameterize("lazy_init", [False, True])
def check_linear_conv_1d_col(lazy_init: bool):
def check_linear_1d_col(lazy_init: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = Conv1D(192, 48).cuda()
linear = nn.Linear(8, 80).cuda()
with ctx:
linear_copy = Conv1D(192, 48).cuda()
linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(
linear_copy, process_group=None, gather_output=True, n_fused=3
linear_copy = nn.Linear(8, 80).cuda()
linear_col = FusedLinear1D_Col.from_native_module(
linear_copy, process_group=None, gather_output=True, split_sizes=[32, 32, 16]
)
assert linear.weight.shape == torch.Size([48, 192])
assert linear.bias.shape == torch.Size([192])
assert linear_conv_col.weight.shape == torch.Size([48, 96])
assert linear_conv_col.bias.shape == torch.Size([96])
assert linear_copy.weight is linear_conv_col.weight
assert linear_copy.bias is linear_conv_col.bias
assert linear.weight.shape == torch.Size([80, 8])
assert linear.bias.shape == torch.Size([80])
assert linear_col.weight.shape == torch.Size([40, 8])
assert linear_col.bias.shape == torch.Size([40])
assert linear_copy.weight is linear_col.weight
assert linear_copy.bias is linear_col.bias
# ensure weights are reversibly loadable
linear_conv_col.load_state_dict(linear.state_dict())
linear.load_state_dict(linear_conv_col.state_dict())
linear_col.load_state_dict(linear.state_dict())
linear.load_state_dict(linear_col.state_dict())
# check computation correctness
x = torch.rand(4, 48).cuda()
x = torch.rand(4, 8).cuda()
out = linear(x)
gather_out = linear_conv_col(x)
assert_close(rearrange(out, 1), gather_out)
gather_out = linear_col(x)
assert_close(out, gather_out)
# check backward correctness
out.sum().backward()
gather_out.sum().backward()
target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, 3, None, True)
assert_close(target_grad, linear_conv_col.weight.grad)
target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, [32, 32, 16], None, False)
assert_close(target_grad, linear_col.weight.grad)
@parameterize("lazy_init", [False, True])
def check_linear_conv_1d_row(lazy_init: bool):
def check_linear_1d_row(lazy_init: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = Conv1D(192, 48).cuda()
linear = nn.Linear(80, 8).cuda()
with ctx:
linear_copy = Conv1D(192, 48).cuda()
linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear_copy, process_group=None, parallel_input=False)
linear_copy = nn.Linear(80, 8).cuda()
linear_row = FusedLinear1D_Row.from_native_module(
linear_copy, process_group=None, split_sizes=[32, 32, 16], parallel_input=False
)
assert linear.weight.shape == torch.Size([48, 192])
assert linear_row.weight.shape == torch.Size([24, 192])
assert linear_row.bias.shape == torch.Size([192])
assert linear.weight.shape == torch.Size([8, 80])
assert linear_row.weight.shape == torch.Size([8, 40])
assert linear_row.bias.shape == torch.Size([8])
assert linear_copy.weight is linear_row.weight
assert linear_copy.bias is linear_row.bias
@@ -111,7 +72,7 @@ def check_linear_conv_1d_row(lazy_init: bool):
linear.load_state_dict(linear_row.state_dict())
# check computation correctness
x = torch.rand(4, 48).cuda()
x = torch.rand(4, 80).cuda()
out = linear(x)
gather_out = linear_row(x)
assert_close(out, gather_out)
@@ -120,17 +81,51 @@ def check_linear_conv_1d_row(lazy_init: bool):
out.sum().backward()
gather_out.sum().backward()
rank = dist.get_rank()
target_grad = torch.chunk(linear.weight.grad, 2, dim=0)[rank]
target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, [32, 32, 16], None, True)
assert_close(target_grad, linear_row.weight.grad)
@parameterize("lazy_init", [False, True])
def check_linear_1d_col_row(lazy_init: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear1 = nn.Linear(8, 80).cuda()
linear2 = nn.Linear(80, 8).cuda()
with ctx:
linear1_copy = nn.Linear(8, 80).cuda()
linear2_copy = nn.Linear(80, 8).cuda()
linear_col = FusedLinear1D_Col.from_native_module(linear1_copy, process_group=None, split_sizes=[32, 32, 16])
linear_row = FusedLinear1D_Row.from_native_module(
linear2_copy,
process_group=None,
split_sizes=[32, 32, 16],
)
# ensure weights are reversibly loadable
linear_col.load_state_dict(linear1.state_dict())
linear_row.load_state_dict(linear2.state_dict())
# check computation correctness
x = torch.rand(4, 8).cuda()
target_out = linear2(linear1(x))
out = linear_row(linear_col(x))
assert_close(out, target_out)
# check backward correctness
target_out.sum().backward()
out.sum().backward()
target_grad1 = split_fused_qkv_in_gpt2_style(linear1.weight.grad, [32, 32, 16], None, False)
assert_close(target_grad1, linear_col.weight.grad)
target_grad2 = split_fused_qkv_in_gpt2_style(linear2.weight.grad, [32, 32, 16], None, True)
assert_close(target_grad2, linear_row.weight.grad)
def run_dist(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
# test for linear conv
check_linear_conv_1d_col()
check_linear_conv_1d_row()
check_linear_1d_col()
check_linear_1d_row()
check_linear_1d_col_row()
@rerun_if_address_is_in_use()