[autoparallel] complete gpt related module search (#2097)

This commit is contained in:
YuliangLiu0306
2022-12-08 10:04:09 +08:00
committed by GitHub
parent 85efb7ac2e
commit 3af7e65dea
3 changed files with 173 additions and 53 deletions

View File

@@ -26,18 +26,21 @@ from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
def check_linear_module_handler(rank, bias, world_size, port):
def check_linear_module_handler(rank, bias, input_shape, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = nn.Sequential(nn.Linear(16, 32, bias=bias)).cuda()
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
input = torch.rand(4, 4, 4, 16).cuda()
input = torch.rand(input_shape).cuda()
# the index of linear node in computation graph
node_index = 1
# strategy number of linear node
strategy_number = 24
if input_shape == (1, 4, 4, 16):
strategy_number = 19
else:
strategy_number = 24
# construct input args
input_args = [input]
# construct meta arg names
@@ -50,7 +53,7 @@ def check_linear_module_handler(rank, bias, world_size, port):
meta_arg_names=meta_arg_names)
tracer = ColoTracer()
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 4, 4, 16).to('meta')})
graph = tracer.trace(model, meta_args={"input": torch.rand(input_shape).to('meta')})
gm = ColoGraphModule(model, graph)
linear_mod_node = list(graph.nodes)[1]
@@ -69,9 +72,10 @@ def check_linear_module_handler(rank, bias, world_size, port):
assert op_data.data is not None
assert mapping['input'].name == "input_1"
assert mapping['input'].data.shape == torch.Size([4, 4, 4, 16])
assert mapping['input'].data.shape == torch.Size(input_shape)
assert mapping['input'].type == OperationDataType.ARG
assert mapping['input'].logical_shape == torch.Size([64, 16])
input_logical_shape = mapping['input'].data.view(-1, 16).shape
assert mapping['input'].logical_shape == input_logical_shape
assert mapping['other'].name == "weight"
assert mapping['other'].data.shape == torch.Size([32, 16])
@@ -85,28 +89,32 @@ def check_linear_module_handler(rank, bias, world_size, port):
assert mapping['bias'].logical_shape == torch.Size([32])
assert mapping['output'].name == "_0"
assert mapping['output'].data.shape == torch.Size([4, 4, 4, 32])
output_shape = input_shape[:-1] + (32,)
assert mapping['output'].data.shape == torch.Size(output_shape)
assert mapping['output'].type == OperationDataType.OUTPUT
assert mapping['output'].logical_shape == torch.Size([64, 32])
output_logical_shape = mapping['output'].data.view(-1, 32).shape
assert mapping['output'].logical_shape == torch.Size(output_logical_shape)
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
strategy_name_list = [val.name for val in strategies_vector]
# one strategy will be converted to different physical sharding spec
assert len(strategy_name_list) > 8
# First dimension cannot be shard if input shape is (1, 4, 4, 16)
if input_shape != (1, 4, 4, 16):
assert 'S1S0 = S1R x RS0_0' in strategy_name_list
assert 'S0S1 = S0R x RS1_0' in strategy_name_list
assert 'S1R = S1S0 x S0R_0' in strategy_name_list
assert 'S0R = S0S1 x S1R_0' in strategy_name_list
assert 'S01R = S01R x RR_0' in strategy_name_list
# SS = SR x RS
assert 'S0S1 = S0R x RS1_0' in strategy_name_list
assert 'S0S1 = S0R x RS1_1' in strategy_name_list
assert 'S0S1 = S0R x RS1_2' in strategy_name_list
assert 'S1S0 = S1R x RS0_0' in strategy_name_list
assert 'S1S0 = S1R x RS0_1' in strategy_name_list
assert 'S1S0 = S1R x RS0_2' in strategy_name_list
# SR = SS x SR
assert 'S0R = S0S1 x S1R_0' in strategy_name_list
assert 'S0R = S0S1 x S1R_1' in strategy_name_list
assert 'S0R = S0S1 x S1R_2' in strategy_name_list
assert 'S1R = S1S0 x S0R_0' in strategy_name_list
assert 'S1R = S1S0 x S0R_1' in strategy_name_list
assert 'S1R = S1S0 x S0R_2' in strategy_name_list
@@ -123,7 +131,6 @@ def check_linear_module_handler(rank, bias, world_size, port):
assert 'RS1 = RR x RS1' in strategy_name_list
# S01R = S01R x RR
assert 'S01R = S01R x RR_0' in strategy_name_list
assert 'S01R = S01R x RR_1' in strategy_name_list
assert 'S01R = S01R x RR_2' in strategy_name_list
@@ -164,7 +171,7 @@ class LinearModel(nn.Module):
return x
def check_linear_function_handler(rank, bias, world_size, port):
def check_linear_function_handler(rank, bias, input_shape, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = LinearModel().cuda()
@@ -172,12 +179,15 @@ def check_linear_function_handler(rank, bias, world_size, port):
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
input = torch.rand(4, 4, 4, 16).cuda()
input = torch.rand(input_shape).cuda()
other = torch.rand(32, 16).cuda()
# the index of linear node in computation graph
node_index = 2
# strategy number of linear node
strategy_number = 24
if input_shape == (1, 4, 4, 16):
strategy_number = 19
else:
strategy_number = 24
# construct input args
input_args = [input, other]
# construct meta arg names
@@ -192,7 +202,7 @@ def check_linear_function_handler(rank, bias, world_size, port):
tracer = ColoTracer()
graph = tracer.trace(model,
meta_args={
"input": torch.rand(4, 4, 4, 16).to('meta'),
"input": torch.rand(input_shape).to('meta'),
'others': torch.rand(32, 16).to('meta')
})
gm = ColoGraphModule(model, graph)
@@ -209,9 +219,10 @@ def check_linear_function_handler(rank, bias, world_size, port):
mapping = handler.get_operation_data_mapping()
assert mapping['input'].name == "input_1"
assert mapping['input'].data.shape == torch.Size([4, 4, 4, 16])
assert mapping['input'].data.shape == torch.Size(input_shape)
assert mapping['input'].type == OperationDataType.ARG
assert mapping['input'].logical_shape == torch.Size([64, 16])
input_logical_shape = mapping['input'].data.view(-1, 16).shape
assert mapping['input'].logical_shape == torch.Size(input_logical_shape)
assert mapping['other'].name == "others"
assert mapping['other'].data.shape == torch.Size([32, 16])
@@ -225,27 +236,32 @@ def check_linear_function_handler(rank, bias, world_size, port):
assert mapping['other'].logical_shape == torch.Size([16, 32])
assert mapping['output'].name == "linear"
assert mapping['output'].data.shape == torch.Size([4, 4, 4, 32])
output_shape = input_shape[:-1] + (32,)
assert mapping['output'].data.shape == torch.Size(output_shape)
assert mapping['output'].type == OperationDataType.OUTPUT
output_logical_shape = mapping['output'].data.view(-1, 32).shape
assert mapping['output'].logical_shape == torch.Size(output_logical_shape)
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
strategy_name_list = [val.name for val in strategies_vector]
# one strategy will be converted to different physical sharding spec
assert len(strategy_name_list) > 8
# First dimension cannot be shard if input shape is (1, 4, 4, 16)
if input_shape != (1, 4, 4, 16):
assert 'S1S0 = S1R x RS0_0' in strategy_name_list
assert 'S0S1 = S0R x RS1_0' in strategy_name_list
assert 'S1R = S1S0 x S0R_0' in strategy_name_list
assert 'S0R = S0S1 x S1R_0' in strategy_name_list
assert 'S01R = S01R x RR_0' in strategy_name_list
# SS = SR x RS
assert 'S0S1 = S0R x RS1_0' in strategy_name_list
assert 'S0S1 = S0R x RS1_1' in strategy_name_list
assert 'S0S1 = S0R x RS1_2' in strategy_name_list
assert 'S1S0 = S1R x RS0_0' in strategy_name_list
assert 'S1S0 = S1R x RS0_1' in strategy_name_list
assert 'S1S0 = S1R x RS0_2' in strategy_name_list
# SR = SS x SR
assert 'S0R = S0S1 x S1R_0' in strategy_name_list
assert 'S0R = S0S1 x S1R_1' in strategy_name_list
assert 'S0R = S0S1 x S1R_2' in strategy_name_list
assert 'S1R = S1S0 x S0R_0' in strategy_name_list
assert 'S1R = S1S0 x S0R_1' in strategy_name_list
assert 'S1R = S1S0 x S0R_2' in strategy_name_list
@@ -262,7 +278,6 @@ def check_linear_function_handler(rank, bias, world_size, port):
assert 'RS1 = RR x RS1' in strategy_name_list
# S01R = S01R x RR
assert 'S01R = S01R x RR_0' in strategy_name_list
assert 'S01R = S01R x RR_1' in strategy_name_list
assert 'S01R = S01R x RR_2' in strategy_name_list
@@ -293,15 +308,23 @@ def check_linear_function_handler(rank, bias, world_size, port):
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
# @parameterize('bias', [True, False])
@parameterize('input_shape', [(1, 4, 4, 16), (4, 4, 4, 16)])
@run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_linear_handler(bias=False):
def test_linear_handler(input_shape, bias=False):
world_size = 4
run_func_module = partial(check_linear_module_handler, bias=bias, world_size=world_size, port=free_port())
run_func_module = partial(check_linear_module_handler,
bias=bias,
input_shape=input_shape,
world_size=world_size,
port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
run_func_function = partial(check_linear_function_handler, bias=bias, world_size=world_size, port=free_port())
run_func_function = partial(check_linear_function_handler,
bias=bias,
input_shape=input_shape,
world_size=world_size,
port=free_port())
mp.spawn(run_func_function, nprocs=world_size)

View File

@@ -1,11 +1,14 @@
from typing import Optional, Tuple, Union
import torch
# from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
import torch.nn as nn
import transformers
from torch.fx import GraphModule
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP
from transformers.models.gpt2.modeling_gpt2 import (
GPT2MLP,
BaseModelOutputWithPastAndCrossAttentions,
GPT2PreTrainedModel,
)
from transformers.pytorch_utils import Conv1D
from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
@@ -173,8 +176,91 @@ class GPT2Block(nn.Module):
return outputs # hidden_states, present, (attentions, cross_attentions)
class GPT2Model(GPT2PreTrainedModel):
_keys_to_ignore_on_load_missing = ["attn.masked_bias"]
def __init__(self, config):
super().__init__(config)
self.embed_dim = config.hidden_size
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.drop = nn.Dropout(config.embd_pdrop)
self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
batch_size = input_ids.shape[0]
device = input_ids.device
token_type_ids = token_type_ids.view(-1, input_shape[-1])
past_length = 0
past_key_values = tuple([None] * len(self.h))
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
# GPT2Attention mask.
attention_mask = attention_mask.view(batch_size, -1)
attention_mask = attention_mask[:, None, None, :]
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * -10000.0
encoder_attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# head_mask has shape n_layer x batch x n_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
# add_2
hidden_states = inputs_embeds + position_embeds
token_type_embeds = self.wte(token_type_ids)
hidden_states = hidden_states + token_type_embeds
# transformer_drop
hidden_states = self.drop(hidden_states)
# comment to run pipeline
# add_3
output_shape = input_shape + (hidden_states.size(-1),)
presents = None
all_self_attentions = None
all_cross_attentions = None
all_hidden_states = None
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
outputs = block(hidden_states, attention_mask=attention_mask, head_mask=head_mask[i])
hidden_states = outputs[0]
hidden_states = self.ln_f(hidden_states)
# comment to run pipeline
hidden_states = hidden_states.view(output_shape)
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
if v is not None)
@run_on_environment_flag(name='AUTO_PARALLEL')
@parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP])
@parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP, GPT2Model])
def test_self_attention_block(model_cls):
config = transformers.GPT2Config(n_position=64, n_layer=4, n_head=16, n_embd=HIDDEN_DIM)
if model_cls == GPT2MLP:
@@ -193,11 +279,17 @@ def test_self_attention_block(model_cls):
input_sample = {
'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'),
}
else:
elif model_cls in (GPT2Attention, GPT2Block):
input_sample = {
'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'),
'attention_mask': torch.rand(1, SEQ_LENGTH).to('meta'),
}
else:
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
kwargs = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
input_sample = {k: v.to('meta') for k, v in kwargs.items()}
graph = tracer.trace(root=model, meta_args=input_sample)