update some module with new api version

This commit is contained in:
FoolPlayer
2023-08-01 18:02:49 +08:00
committed by Hongxin Liu
parent 879301d0da
commit 726541afe2
7 changed files with 88 additions and 48 deletions

View File

@@ -1,12 +1,15 @@
from contextlib import nullcontext
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.testing import assert_close
import colossalai
from colossalai.lazy import LazyInitContext
from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
# This code is copied from https://github.com/huggingface/transformers
@@ -50,9 +53,13 @@ def rearrange(tensor: torch.Tensor, dim: int):
return rearanged_tensor
def check_gpt2_linear_conv_1d_col():
@parameterize('lazy_init', [False, True])
def check_linear_conv_1d_col(lazy_init: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = Conv1D(192, 48).cuda()
linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(linear,
with ctx:
linear_copy = Conv1D(192, 48).cuda()
linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(linear_copy,
process_group=None,
gather_output=True,
n_fused=3)
@@ -61,6 +68,8 @@ def check_gpt2_linear_conv_1d_col():
assert linear.bias.shape == torch.Size([192])
assert linear_conv_col.weight.shape == torch.Size([48, 96])
assert linear_conv_col.bias.shape == torch.Size([96])
assert linear_copy.weight is linear_conv_col.weight
assert linear_copy.bias is linear_conv_col.bias
# ensure weights are reversibly loadable
linear_conv_col.load_state_dict(linear.state_dict())
@@ -80,13 +89,24 @@ def check_gpt2_linear_conv_1d_col():
assert_close(target_grad, linear_conv_col.weight.grad)
def check_gpt2_linear_conv_1d_row():
@parameterize('lazy_init', [False, True])
def check_linear_conv_1d_row(lazy_init: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = Conv1D(192, 48).cuda()
linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear, process_group=None, parallel_input=False)
with ctx:
linear_copy = Conv1D(192, 48).cuda()
linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear_copy, process_group=None, parallel_input=False)
assert linear.weight.shape == torch.Size([48, 192])
assert linear_row.weight.shape == torch.Size([24, 192])
assert linear_row.bias.shape == torch.Size([192])
assert linear_copy.weight is linear_row.weight
assert linear_copy.bias is linear_row.bias
# ensure weights are reversibly loadable
linear_row.load_state_dict(linear.state_dict())
linear.load_state_dict(linear_row.state_dict())
# check computation correctness
x = torch.rand(4, 48).cuda()
@@ -107,14 +127,14 @@ def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
# test for linear conv
check_gpt2_linear_conv_1d_col()
check_gpt2_linear_conv_1d_row()
check_linear_conv_1d_col()
check_linear_conv_1d_row()
@rerun_if_address_is_in_use()
def test_gpt2_linearconv():
def test_linearconv():
spawn(run_dist, nprocs=2)
if __name__ == '__main__':
test_gpt2_linearconv()
test_linearconv()

View File

@@ -84,9 +84,10 @@ def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism):
model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config)
if name == "transformers_chatglm":
sharded_model = shard_former.optimize(model_copy, ChatGLMModelPolicy()).cuda()
sharded_model, _ = shard_former.optimize(model_copy, ChatGLMModelPolicy())
else:
sharded_model = shard_former.optimize(model_copy, ChatGLMForConditionalGenerationPolicy()).cuda()
sharded_model, _ = shard_former.optimize(model_copy, ChatGLMForConditionalGenerationPolicy())
sharded_model = sharded_model.cuda()
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()