mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
[shardformer] support SAM (#4231)
* 1.support sam 2.add fused qkv for nn.Linear * update utils support set element in list * overtwrite SamVisionAttention foward to use DropoutForParallelInput * remove unused code
This commit is contained in:
@@ -4,5 +4,6 @@ from .bloom import *
|
||||
from .gpt import *
|
||||
from .llama import *
|
||||
from .opt import *
|
||||
from .sam import *
|
||||
from .t5 import *
|
||||
from .vit import *
|
||||
|
52
tests/kit/model_zoo/transformers/sam.py
Normal file
52
tests/kit/model_zoo/transformers/sam.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
from ..registry import ModelAttribute, model_zoo
|
||||
|
||||
# ===============================
|
||||
# Register single-image SAM
|
||||
# ===============================
|
||||
|
||||
|
||||
# define data gen function
|
||||
def data_gen():
|
||||
# Generated from following code snippet
|
||||
#
|
||||
# from PIL import Image
|
||||
# import requests
|
||||
# from transformers import SamModel, SamProcessor
|
||||
#
|
||||
# model = SamModel.from_pretrained("facebook/sam-vit-base")
|
||||
# processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
|
||||
#
|
||||
# img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
|
||||
# raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
|
||||
# input_points = [[[450, 600]]] # 2D localization of a window
|
||||
# inputs = processor(raw_image, input_points=input_points, return_tensors="pt")
|
||||
|
||||
pixel_values = torch.rand(1, 3, 1024, 1024, dtype=torch.float32)
|
||||
original_sizes = torch.tensor([[1764, 2646]], dtype=torch.int64)
|
||||
reshaped_input_sizes = torch.tensor([[683, 1024]], dtype=torch.int64)
|
||||
input_points = torch.tensor([[[[174.1497, 232.3129]]]], dtype=torch.float64)
|
||||
return dict(pixel_values=pixel_values,
|
||||
original_sizes=original_sizes,
|
||||
reshaped_input_sizes=reshaped_input_sizes,
|
||||
input_points=input_points)
|
||||
|
||||
|
||||
# define output transform function
|
||||
output_transform_fn = lambda x: x
|
||||
|
||||
# define loss funciton
|
||||
loss_fn = lambda x: x.iou_scores.mean()
|
||||
|
||||
config = transformers.SamConfig()
|
||||
config.vision_config.num_hidden_layers = 2
|
||||
|
||||
# register the BERT variants
|
||||
model_zoo.register(name='transformers_sam',
|
||||
model_fn=lambda: transformers.SamModel(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
@@ -0,0 +1,120 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
|
||||
from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
# This code is copied from https://github.com/huggingface/transformers
|
||||
class Conv1D(nn.Module):
|
||||
"""
|
||||
1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
|
||||
|
||||
Basically works like a linear layer but the weights are transposed.
|
||||
|
||||
Args:
|
||||
nf (`int`): The number of output features.
|
||||
nx (`int`): The number of input features.
|
||||
"""
|
||||
|
||||
def __init__(self, nf, nx):
|
||||
super().__init__()
|
||||
self.nf = nf
|
||||
self.weight = nn.Parameter(torch.empty(nx, nf))
|
||||
self.bias = nn.Parameter(torch.zeros(nf))
|
||||
nn.init.normal_(self.weight, std=0.02)
|
||||
|
||||
def forward(self, x):
|
||||
size_out = x.size()[:-1] + (self.nf,)
|
||||
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
|
||||
x = x.view(size_out)
|
||||
return x
|
||||
|
||||
|
||||
def rearrange(tensor: torch.Tensor, dim: int):
|
||||
tensor = tensor.clone()
|
||||
world_size = 2
|
||||
order = torch.arange(world_size * 3)
|
||||
new_order = []
|
||||
for i in range(world_size):
|
||||
new_order.append(order[i::world_size])
|
||||
new_order = torch.cat(new_order)
|
||||
|
||||
tensor_chunks = torch.chunk(tensor, world_size * 3, dim=dim)
|
||||
rearanged_tensor_chunks = [tensor_chunks[i] for i in new_order]
|
||||
rearanged_tensor = torch.cat(rearanged_tensor_chunks, dim=dim)
|
||||
return rearanged_tensor
|
||||
|
||||
|
||||
def check_gpt2_linear_conv_1d_col():
|
||||
linear = Conv1D(192, 48).cuda()
|
||||
linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(linear,
|
||||
process_group=None,
|
||||
gather_output=True,
|
||||
n_fused=3)
|
||||
|
||||
assert linear.weight.shape == torch.Size([48, 192])
|
||||
assert linear.bias.shape == torch.Size([192])
|
||||
assert linear_conv_col.weight.shape == torch.Size([48, 96])
|
||||
assert linear_conv_col.bias.shape == torch.Size([96])
|
||||
|
||||
# ensure weights are reversibly loadable
|
||||
linear_conv_col.load_state_dict(linear.state_dict())
|
||||
linear.load_state_dict(linear_conv_col.state_dict())
|
||||
|
||||
# check computation correctness
|
||||
x = torch.rand(4, 48).cuda()
|
||||
out = linear(x)
|
||||
gather_out = linear_conv_col(x)
|
||||
assert_close(rearrange(out, 1), gather_out)
|
||||
|
||||
# check backward correctness
|
||||
out.sum().backward()
|
||||
gather_out.sum().backward()
|
||||
|
||||
target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, 3, None, True)
|
||||
assert_close(target_grad, linear_conv_col.weight.grad)
|
||||
|
||||
|
||||
def check_gpt2_linear_conv_1d_row():
|
||||
linear = Conv1D(192, 48).cuda()
|
||||
linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear, process_group=None, parallel_input=False)
|
||||
|
||||
assert linear.weight.shape == torch.Size([48, 192])
|
||||
assert linear_row.weight.shape == torch.Size([24, 192])
|
||||
assert linear_row.bias.shape == torch.Size([192])
|
||||
|
||||
# check computation correctness
|
||||
x = torch.rand(4, 48).cuda()
|
||||
out = linear(x)
|
||||
gather_out = linear_row(x)
|
||||
assert_close(out, gather_out)
|
||||
|
||||
# check backward correctness
|
||||
out.sum().backward()
|
||||
gather_out.sum().backward()
|
||||
|
||||
rank = dist.get_rank()
|
||||
target_grad = torch.chunk(linear.weight.grad, 2, dim=0)[rank]
|
||||
assert_close(target_grad, linear_row.weight.grad)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
# test for linear conv
|
||||
check_gpt2_linear_conv_1d_col()
|
||||
check_gpt2_linear_conv_1d_row()
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_gpt2_linearconv():
|
||||
spawn(run_dist, nprocs=2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_gpt2_linearconv()
|
92
tests/test_shardformer/test_model/test_shard_sam.py
Normal file
92
tests/test_shardformer/test_model/test_shard_sam.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
|
||||
from colossalai.testing import (
|
||||
assert_hf_output_close,
|
||||
clear_cache_before_run,
|
||||
parameterize,
|
||||
rerun_if_address_is_in_use,
|
||||
spawn,
|
||||
)
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_model, run_forward
|
||||
|
||||
|
||||
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||
# check forward
|
||||
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
|
||||
output_transform_fn, loss_fn)
|
||||
assert_hf_output_close(org_output, shard_output, ignore_keys=['pred_masks'])
|
||||
|
||||
# do backward
|
||||
org_loss.backward()
|
||||
shard_loss.backward()
|
||||
|
||||
assert torch.allclose(org_loss, shard_loss,
|
||||
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
|
||||
|
||||
# check grad
|
||||
|
||||
sam = org_model
|
||||
sharded_sam = sharded_model
|
||||
|
||||
# compare mask decoder grad
|
||||
|
||||
org_grad = sam.mask_decoder.transformer.layers[0].self_attn.q_proj.weight.grad
|
||||
shard_grad = sharded_sam.mask_decoder.transformer.layers[0].self_attn.q_proj.weight.grad
|
||||
shard_weight = sharded_sam.mask_decoder.transformer.layers[0].self_attn.q_proj.weight
|
||||
|
||||
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||
else:
|
||||
all_shard_grad = shard_grad
|
||||
assert torch.allclose(org_grad, all_shard_grad,
|
||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
||||
|
||||
# compare vision_encoder grad
|
||||
org_grad = sam.vision_encoder.layers[0].mlp.lin1.weight.grad
|
||||
shard_grad = sharded_sam.vision_encoder.layers[0].mlp.lin1.weight.grad
|
||||
shard_weight = sharded_sam.vision_encoder.layers[0].mlp.lin1.weight
|
||||
|
||||
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||
else:
|
||||
all_shard_grad = shard_grad
|
||||
|
||||
assert torch.allclose(org_grad, all_shard_grad,
|
||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
||||
|
||||
|
||||
@parameterize('enable_fused_normalization', [True, False])
|
||||
@parameterize('enable_tensor_parallelism', [True, False])
|
||||
def run_sam_test(enable_fused_normalization, enable_tensor_parallelism):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_sam')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def check_sam(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_sam_test()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_sam():
|
||||
spawn(check_sam, 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_sam()
|
Reference in New Issue
Block a user