mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-19 20:23:41 +00:00
[autoparallel] complete gpt related module search (#2097)
This commit is contained in:
parent
85efb7ac2e
commit
3af7e65dea
@ -64,20 +64,14 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
|
|||||||
last_physical_output_dims = output_op_data.data.dim() - 1
|
last_physical_output_dims = output_op_data.data.dim() - 1
|
||||||
|
|
||||||
if last_logical_input_dims in input_sharding_spec.dim_partition_dict:
|
if last_logical_input_dims in input_sharding_spec.dim_partition_dict:
|
||||||
update_partition_dim(
|
input_last_dim_mapping = {last_logical_input_dims: last_physical_input_dims}
|
||||||
sharding_spec=input_sharding_spec,
|
else:
|
||||||
dim_mapping={last_logical_input_dims: last_physical_input_dims},
|
input_last_dim_mapping = {}
|
||||||
physical_shape=input_op_data.data.shape,
|
|
||||||
inplace=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if last_logical_output_dims in output_sharding_spec.dim_partition_dict:
|
if last_logical_output_dims in output_sharding_spec.dim_partition_dict:
|
||||||
update_partition_dim(
|
output_last_dim_mapping = {last_logical_output_dims: last_physical_output_dims}
|
||||||
sharding_spec=output_sharding_spec,
|
else:
|
||||||
dim_mapping={last_logical_output_dims: last_physical_output_dims},
|
output_last_dim_mapping = {}
|
||||||
physical_shape=output_op_data.data.shape,
|
|
||||||
inplace=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# get logger for debug message
|
# get logger for debug message
|
||||||
logger = get_dist_logger()
|
logger = get_dist_logger()
|
||||||
@ -97,12 +91,18 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
|
|||||||
output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name)
|
output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name)
|
||||||
try:
|
try:
|
||||||
# replace the 0th dimension in the logical sharding with ith dimension in the physical sharding
|
# replace the 0th dimension in the logical sharding with ith dimension in the physical sharding
|
||||||
|
input_dim_mapping = {0: i}
|
||||||
|
input_dim_mapping.update(input_last_dim_mapping)
|
||||||
|
|
||||||
update_partition_dim(sharding_spec=input_sharding_spec,
|
update_partition_dim(sharding_spec=input_sharding_spec,
|
||||||
dim_mapping={0: i},
|
dim_mapping=input_dim_mapping,
|
||||||
physical_shape=input_op_data.data.shape,
|
physical_shape=input_op_data.data.shape,
|
||||||
inplace=True)
|
inplace=True)
|
||||||
|
output_dim_mapping = {0: i}
|
||||||
|
output_dim_mapping.update(output_last_dim_mapping)
|
||||||
|
|
||||||
update_partition_dim(sharding_spec=output_sharding_spec,
|
update_partition_dim(sharding_spec=output_sharding_spec,
|
||||||
dim_mapping={0: i},
|
dim_mapping=output_dim_mapping,
|
||||||
physical_shape=output_op_data.data.shape,
|
physical_shape=output_op_data.data.shape,
|
||||||
inplace=True)
|
inplace=True)
|
||||||
strategy_copy.name = f'{strategy.name}_{i}'
|
strategy_copy.name = f'{strategy.name}_{i}'
|
||||||
@ -120,12 +120,17 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
|
|||||||
output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name)
|
output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name)
|
||||||
|
|
||||||
# after updating, the logical shape will be replaced by the physical shape
|
# after updating, the logical shape will be replaced by the physical shape
|
||||||
|
input_dim_mapping = {}
|
||||||
|
input_dim_mapping.update(input_last_dim_mapping)
|
||||||
update_partition_dim(sharding_spec=input_sharding_spec,
|
update_partition_dim(sharding_spec=input_sharding_spec,
|
||||||
dim_mapping={},
|
dim_mapping=input_dim_mapping,
|
||||||
physical_shape=input_op_data.data.shape,
|
physical_shape=input_op_data.data.shape,
|
||||||
inplace=True)
|
inplace=True)
|
||||||
|
|
||||||
|
output_dim_mapping = {}
|
||||||
|
output_dim_mapping.update(output_last_dim_mapping)
|
||||||
update_partition_dim(sharding_spec=output_sharding_spec,
|
update_partition_dim(sharding_spec=output_sharding_spec,
|
||||||
dim_mapping={},
|
dim_mapping=output_dim_mapping,
|
||||||
physical_shape=output_op_data.data.shape,
|
physical_shape=output_op_data.data.shape,
|
||||||
inplace=True)
|
inplace=True)
|
||||||
sharding_strategies.append(strategy_copy)
|
sharding_strategies.append(strategy_copy)
|
||||||
|
@ -26,17 +26,20 @@ from colossalai.utils import free_port
|
|||||||
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
|
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
|
||||||
|
|
||||||
|
|
||||||
def check_linear_module_handler(rank, bias, world_size, port):
|
def check_linear_module_handler(rank, bias, input_shape, world_size, port):
|
||||||
disable_existing_loggers()
|
disable_existing_loggers()
|
||||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
model = nn.Sequential(nn.Linear(16, 32, bias=bias)).cuda()
|
model = nn.Sequential(nn.Linear(16, 32, bias=bias)).cuda()
|
||||||
physical_mesh_id = torch.arange(0, 4)
|
physical_mesh_id = torch.arange(0, 4)
|
||||||
mesh_shape = (2, 2)
|
mesh_shape = (2, 2)
|
||||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||||
input = torch.rand(4, 4, 4, 16).cuda()
|
input = torch.rand(input_shape).cuda()
|
||||||
# the index of linear node in computation graph
|
# the index of linear node in computation graph
|
||||||
node_index = 1
|
node_index = 1
|
||||||
# strategy number of linear node
|
# strategy number of linear node
|
||||||
|
if input_shape == (1, 4, 4, 16):
|
||||||
|
strategy_number = 19
|
||||||
|
else:
|
||||||
strategy_number = 24
|
strategy_number = 24
|
||||||
# construct input args
|
# construct input args
|
||||||
input_args = [input]
|
input_args = [input]
|
||||||
@ -50,7 +53,7 @@ def check_linear_module_handler(rank, bias, world_size, port):
|
|||||||
meta_arg_names=meta_arg_names)
|
meta_arg_names=meta_arg_names)
|
||||||
|
|
||||||
tracer = ColoTracer()
|
tracer = ColoTracer()
|
||||||
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 4, 4, 16).to('meta')})
|
graph = tracer.trace(model, meta_args={"input": torch.rand(input_shape).to('meta')})
|
||||||
gm = ColoGraphModule(model, graph)
|
gm = ColoGraphModule(model, graph)
|
||||||
|
|
||||||
linear_mod_node = list(graph.nodes)[1]
|
linear_mod_node = list(graph.nodes)[1]
|
||||||
@ -69,9 +72,10 @@ def check_linear_module_handler(rank, bias, world_size, port):
|
|||||||
assert op_data.data is not None
|
assert op_data.data is not None
|
||||||
|
|
||||||
assert mapping['input'].name == "input_1"
|
assert mapping['input'].name == "input_1"
|
||||||
assert mapping['input'].data.shape == torch.Size([4, 4, 4, 16])
|
assert mapping['input'].data.shape == torch.Size(input_shape)
|
||||||
assert mapping['input'].type == OperationDataType.ARG
|
assert mapping['input'].type == OperationDataType.ARG
|
||||||
assert mapping['input'].logical_shape == torch.Size([64, 16])
|
input_logical_shape = mapping['input'].data.view(-1, 16).shape
|
||||||
|
assert mapping['input'].logical_shape == input_logical_shape
|
||||||
|
|
||||||
assert mapping['other'].name == "weight"
|
assert mapping['other'].name == "weight"
|
||||||
assert mapping['other'].data.shape == torch.Size([32, 16])
|
assert mapping['other'].data.shape == torch.Size([32, 16])
|
||||||
@ -85,28 +89,32 @@ def check_linear_module_handler(rank, bias, world_size, port):
|
|||||||
assert mapping['bias'].logical_shape == torch.Size([32])
|
assert mapping['bias'].logical_shape == torch.Size([32])
|
||||||
|
|
||||||
assert mapping['output'].name == "_0"
|
assert mapping['output'].name == "_0"
|
||||||
assert mapping['output'].data.shape == torch.Size([4, 4, 4, 32])
|
output_shape = input_shape[:-1] + (32,)
|
||||||
|
assert mapping['output'].data.shape == torch.Size(output_shape)
|
||||||
assert mapping['output'].type == OperationDataType.OUTPUT
|
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||||
assert mapping['output'].logical_shape == torch.Size([64, 32])
|
output_logical_shape = mapping['output'].data.view(-1, 32).shape
|
||||||
|
assert mapping['output'].logical_shape == torch.Size(output_logical_shape)
|
||||||
|
|
||||||
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
|
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
|
||||||
strategy_name_list = [val.name for val in strategies_vector]
|
strategy_name_list = [val.name for val in strategies_vector]
|
||||||
# one strategy will be converted to different physical sharding spec
|
|
||||||
assert len(strategy_name_list) > 8
|
# First dimension cannot be shard if input shape is (1, 4, 4, 16)
|
||||||
|
if input_shape != (1, 4, 4, 16):
|
||||||
|
assert 'S1S0 = S1R x RS0_0' in strategy_name_list
|
||||||
|
assert 'S0S1 = S0R x RS1_0' in strategy_name_list
|
||||||
|
assert 'S1R = S1S0 x S0R_0' in strategy_name_list
|
||||||
|
assert 'S0R = S0S1 x S1R_0' in strategy_name_list
|
||||||
|
assert 'S01R = S01R x RR_0' in strategy_name_list
|
||||||
|
|
||||||
# SS = SR x RS
|
# SS = SR x RS
|
||||||
assert 'S0S1 = S0R x RS1_0' in strategy_name_list
|
|
||||||
assert 'S0S1 = S0R x RS1_1' in strategy_name_list
|
assert 'S0S1 = S0R x RS1_1' in strategy_name_list
|
||||||
assert 'S0S1 = S0R x RS1_2' in strategy_name_list
|
assert 'S0S1 = S0R x RS1_2' in strategy_name_list
|
||||||
assert 'S1S0 = S1R x RS0_0' in strategy_name_list
|
|
||||||
assert 'S1S0 = S1R x RS0_1' in strategy_name_list
|
assert 'S1S0 = S1R x RS0_1' in strategy_name_list
|
||||||
assert 'S1S0 = S1R x RS0_2' in strategy_name_list
|
assert 'S1S0 = S1R x RS0_2' in strategy_name_list
|
||||||
|
|
||||||
# SR = SS x SR
|
# SR = SS x SR
|
||||||
assert 'S0R = S0S1 x S1R_0' in strategy_name_list
|
|
||||||
assert 'S0R = S0S1 x S1R_1' in strategy_name_list
|
assert 'S0R = S0S1 x S1R_1' in strategy_name_list
|
||||||
assert 'S0R = S0S1 x S1R_2' in strategy_name_list
|
assert 'S0R = S0S1 x S1R_2' in strategy_name_list
|
||||||
assert 'S1R = S1S0 x S0R_0' in strategy_name_list
|
|
||||||
assert 'S1R = S1S0 x S0R_1' in strategy_name_list
|
assert 'S1R = S1S0 x S0R_1' in strategy_name_list
|
||||||
assert 'S1R = S1S0 x S0R_2' in strategy_name_list
|
assert 'S1R = S1S0 x S0R_2' in strategy_name_list
|
||||||
|
|
||||||
@ -123,7 +131,6 @@ def check_linear_module_handler(rank, bias, world_size, port):
|
|||||||
assert 'RS1 = RR x RS1' in strategy_name_list
|
assert 'RS1 = RR x RS1' in strategy_name_list
|
||||||
|
|
||||||
# S01R = S01R x RR
|
# S01R = S01R x RR
|
||||||
assert 'S01R = S01R x RR_0' in strategy_name_list
|
|
||||||
assert 'S01R = S01R x RR_1' in strategy_name_list
|
assert 'S01R = S01R x RR_1' in strategy_name_list
|
||||||
assert 'S01R = S01R x RR_2' in strategy_name_list
|
assert 'S01R = S01R x RR_2' in strategy_name_list
|
||||||
|
|
||||||
@ -164,7 +171,7 @@ class LinearModel(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def check_linear_function_handler(rank, bias, world_size, port):
|
def check_linear_function_handler(rank, bias, input_shape, world_size, port):
|
||||||
disable_existing_loggers()
|
disable_existing_loggers()
|
||||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
model = LinearModel().cuda()
|
model = LinearModel().cuda()
|
||||||
@ -172,11 +179,14 @@ def check_linear_function_handler(rank, bias, world_size, port):
|
|||||||
mesh_shape = (2, 2)
|
mesh_shape = (2, 2)
|
||||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||||
|
|
||||||
input = torch.rand(4, 4, 4, 16).cuda()
|
input = torch.rand(input_shape).cuda()
|
||||||
other = torch.rand(32, 16).cuda()
|
other = torch.rand(32, 16).cuda()
|
||||||
# the index of linear node in computation graph
|
# the index of linear node in computation graph
|
||||||
node_index = 2
|
node_index = 2
|
||||||
# strategy number of linear node
|
# strategy number of linear node
|
||||||
|
if input_shape == (1, 4, 4, 16):
|
||||||
|
strategy_number = 19
|
||||||
|
else:
|
||||||
strategy_number = 24
|
strategy_number = 24
|
||||||
# construct input args
|
# construct input args
|
||||||
input_args = [input, other]
|
input_args = [input, other]
|
||||||
@ -192,7 +202,7 @@ def check_linear_function_handler(rank, bias, world_size, port):
|
|||||||
tracer = ColoTracer()
|
tracer = ColoTracer()
|
||||||
graph = tracer.trace(model,
|
graph = tracer.trace(model,
|
||||||
meta_args={
|
meta_args={
|
||||||
"input": torch.rand(4, 4, 4, 16).to('meta'),
|
"input": torch.rand(input_shape).to('meta'),
|
||||||
'others': torch.rand(32, 16).to('meta')
|
'others': torch.rand(32, 16).to('meta')
|
||||||
})
|
})
|
||||||
gm = ColoGraphModule(model, graph)
|
gm = ColoGraphModule(model, graph)
|
||||||
@ -209,9 +219,10 @@ def check_linear_function_handler(rank, bias, world_size, port):
|
|||||||
mapping = handler.get_operation_data_mapping()
|
mapping = handler.get_operation_data_mapping()
|
||||||
|
|
||||||
assert mapping['input'].name == "input_1"
|
assert mapping['input'].name == "input_1"
|
||||||
assert mapping['input'].data.shape == torch.Size([4, 4, 4, 16])
|
assert mapping['input'].data.shape == torch.Size(input_shape)
|
||||||
assert mapping['input'].type == OperationDataType.ARG
|
assert mapping['input'].type == OperationDataType.ARG
|
||||||
assert mapping['input'].logical_shape == torch.Size([64, 16])
|
input_logical_shape = mapping['input'].data.view(-1, 16).shape
|
||||||
|
assert mapping['input'].logical_shape == torch.Size(input_logical_shape)
|
||||||
|
|
||||||
assert mapping['other'].name == "others"
|
assert mapping['other'].name == "others"
|
||||||
assert mapping['other'].data.shape == torch.Size([32, 16])
|
assert mapping['other'].data.shape == torch.Size([32, 16])
|
||||||
@ -225,27 +236,32 @@ def check_linear_function_handler(rank, bias, world_size, port):
|
|||||||
assert mapping['other'].logical_shape == torch.Size([16, 32])
|
assert mapping['other'].logical_shape == torch.Size([16, 32])
|
||||||
|
|
||||||
assert mapping['output'].name == "linear"
|
assert mapping['output'].name == "linear"
|
||||||
assert mapping['output'].data.shape == torch.Size([4, 4, 4, 32])
|
output_shape = input_shape[:-1] + (32,)
|
||||||
|
assert mapping['output'].data.shape == torch.Size(output_shape)
|
||||||
assert mapping['output'].type == OperationDataType.OUTPUT
|
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||||
|
output_logical_shape = mapping['output'].data.view(-1, 32).shape
|
||||||
|
assert mapping['output'].logical_shape == torch.Size(output_logical_shape)
|
||||||
|
|
||||||
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
|
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
|
||||||
strategy_name_list = [val.name for val in strategies_vector]
|
strategy_name_list = [val.name for val in strategies_vector]
|
||||||
# one strategy will be converted to different physical sharding spec
|
|
||||||
assert len(strategy_name_list) > 8
|
# First dimension cannot be shard if input shape is (1, 4, 4, 16)
|
||||||
|
if input_shape != (1, 4, 4, 16):
|
||||||
|
assert 'S1S0 = S1R x RS0_0' in strategy_name_list
|
||||||
|
assert 'S0S1 = S0R x RS1_0' in strategy_name_list
|
||||||
|
assert 'S1R = S1S0 x S0R_0' in strategy_name_list
|
||||||
|
assert 'S0R = S0S1 x S1R_0' in strategy_name_list
|
||||||
|
assert 'S01R = S01R x RR_0' in strategy_name_list
|
||||||
|
|
||||||
# SS = SR x RS
|
# SS = SR x RS
|
||||||
assert 'S0S1 = S0R x RS1_0' in strategy_name_list
|
|
||||||
assert 'S0S1 = S0R x RS1_1' in strategy_name_list
|
assert 'S0S1 = S0R x RS1_1' in strategy_name_list
|
||||||
assert 'S0S1 = S0R x RS1_2' in strategy_name_list
|
assert 'S0S1 = S0R x RS1_2' in strategy_name_list
|
||||||
assert 'S1S0 = S1R x RS0_0' in strategy_name_list
|
|
||||||
assert 'S1S0 = S1R x RS0_1' in strategy_name_list
|
assert 'S1S0 = S1R x RS0_1' in strategy_name_list
|
||||||
assert 'S1S0 = S1R x RS0_2' in strategy_name_list
|
assert 'S1S0 = S1R x RS0_2' in strategy_name_list
|
||||||
|
|
||||||
# SR = SS x SR
|
# SR = SS x SR
|
||||||
assert 'S0R = S0S1 x S1R_0' in strategy_name_list
|
|
||||||
assert 'S0R = S0S1 x S1R_1' in strategy_name_list
|
assert 'S0R = S0S1 x S1R_1' in strategy_name_list
|
||||||
assert 'S0R = S0S1 x S1R_2' in strategy_name_list
|
assert 'S0R = S0S1 x S1R_2' in strategy_name_list
|
||||||
assert 'S1R = S1S0 x S0R_0' in strategy_name_list
|
|
||||||
assert 'S1R = S1S0 x S0R_1' in strategy_name_list
|
assert 'S1R = S1S0 x S0R_1' in strategy_name_list
|
||||||
assert 'S1R = S1S0 x S0R_2' in strategy_name_list
|
assert 'S1R = S1S0 x S0R_2' in strategy_name_list
|
||||||
|
|
||||||
@ -262,7 +278,6 @@ def check_linear_function_handler(rank, bias, world_size, port):
|
|||||||
assert 'RS1 = RR x RS1' in strategy_name_list
|
assert 'RS1 = RR x RS1' in strategy_name_list
|
||||||
|
|
||||||
# S01R = S01R x RR
|
# S01R = S01R x RR
|
||||||
assert 'S01R = S01R x RR_0' in strategy_name_list
|
|
||||||
assert 'S01R = S01R x RR_1' in strategy_name_list
|
assert 'S01R = S01R x RR_1' in strategy_name_list
|
||||||
assert 'S01R = S01R x RR_2' in strategy_name_list
|
assert 'S01R = S01R x RR_2' in strategy_name_list
|
||||||
|
|
||||||
@ -293,15 +308,23 @@ def check_linear_function_handler(rank, bias, world_size, port):
|
|||||||
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
|
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
|
||||||
|
|
||||||
|
|
||||||
# @parameterize('bias', [True, False])
|
@parameterize('input_shape', [(1, 4, 4, 16), (4, 4, 4, 16)])
|
||||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
def test_linear_handler(bias=False):
|
def test_linear_handler(input_shape, bias=False):
|
||||||
world_size = 4
|
world_size = 4
|
||||||
run_func_module = partial(check_linear_module_handler, bias=bias, world_size=world_size, port=free_port())
|
run_func_module = partial(check_linear_module_handler,
|
||||||
|
bias=bias,
|
||||||
|
input_shape=input_shape,
|
||||||
|
world_size=world_size,
|
||||||
|
port=free_port())
|
||||||
mp.spawn(run_func_module, nprocs=world_size)
|
mp.spawn(run_func_module, nprocs=world_size)
|
||||||
run_func_function = partial(check_linear_function_handler, bias=bias, world_size=world_size, port=free_port())
|
run_func_function = partial(check_linear_function_handler,
|
||||||
|
bias=bias,
|
||||||
|
input_shape=input_shape,
|
||||||
|
world_size=world_size,
|
||||||
|
port=free_port())
|
||||||
mp.spawn(run_func_function, nprocs=world_size)
|
mp.spawn(run_func_function, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,11 +1,14 @@
|
|||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
# from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import transformers
|
import transformers
|
||||||
from torch.fx import GraphModule
|
from torch.fx import GraphModule
|
||||||
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP
|
from transformers.models.gpt2.modeling_gpt2 import (
|
||||||
|
GPT2MLP,
|
||||||
|
BaseModelOutputWithPastAndCrossAttentions,
|
||||||
|
GPT2PreTrainedModel,
|
||||||
|
)
|
||||||
from transformers.pytorch_utils import Conv1D
|
from transformers.pytorch_utils import Conv1D
|
||||||
|
|
||||||
from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
|
from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
|
||||||
@ -173,8 +176,91 @@ class GPT2Block(nn.Module):
|
|||||||
return outputs # hidden_states, present, (attentions, cross_attentions)
|
return outputs # hidden_states, present, (attentions, cross_attentions)
|
||||||
|
|
||||||
|
|
||||||
|
class GPT2Model(GPT2PreTrainedModel):
|
||||||
|
_keys_to_ignore_on_load_missing = ["attn.masked_bias"]
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
self.embed_dim = config.hidden_size
|
||||||
|
|
||||||
|
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
|
||||||
|
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
||||||
|
|
||||||
|
self.drop = nn.Dropout(config.embd_pdrop)
|
||||||
|
self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)])
|
||||||
|
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
token_type_ids: Optional[torch.LongTensor] = None,
|
||||||
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
||||||
|
input_shape = input_ids.size()
|
||||||
|
input_ids = input_ids.view(-1, input_shape[-1])
|
||||||
|
batch_size = input_ids.shape[0]
|
||||||
|
|
||||||
|
device = input_ids.device
|
||||||
|
|
||||||
|
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
||||||
|
|
||||||
|
past_length = 0
|
||||||
|
past_key_values = tuple([None] * len(self.h))
|
||||||
|
|
||||||
|
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
||||||
|
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
||||||
|
|
||||||
|
# GPT2Attention mask.
|
||||||
|
attention_mask = attention_mask.view(batch_size, -1)
|
||||||
|
attention_mask = attention_mask[:, None, None, :]
|
||||||
|
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||||
|
attention_mask = (1.0 - attention_mask) * -10000.0
|
||||||
|
|
||||||
|
encoder_attention_mask = None
|
||||||
|
|
||||||
|
# Prepare head mask if needed
|
||||||
|
# 1.0 in head_mask indicate we keep the head
|
||||||
|
# attention_probs has shape bsz x n_heads x N x N
|
||||||
|
# head_mask has shape n_layer x batch x n_heads x N x N
|
||||||
|
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
||||||
|
|
||||||
|
inputs_embeds = self.wte(input_ids)
|
||||||
|
position_embeds = self.wpe(position_ids)
|
||||||
|
# add_2
|
||||||
|
hidden_states = inputs_embeds + position_embeds
|
||||||
|
|
||||||
|
token_type_embeds = self.wte(token_type_ids)
|
||||||
|
hidden_states = hidden_states + token_type_embeds
|
||||||
|
|
||||||
|
# transformer_drop
|
||||||
|
hidden_states = self.drop(hidden_states)
|
||||||
|
# comment to run pipeline
|
||||||
|
# add_3
|
||||||
|
output_shape = input_shape + (hidden_states.size(-1),)
|
||||||
|
|
||||||
|
presents = None
|
||||||
|
all_self_attentions = None
|
||||||
|
all_cross_attentions = None
|
||||||
|
all_hidden_states = None
|
||||||
|
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||||
|
outputs = block(hidden_states, attention_mask=attention_mask, head_mask=head_mask[i])
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
|
||||||
|
hidden_states = self.ln_f(hidden_states)
|
||||||
|
# comment to run pipeline
|
||||||
|
hidden_states = hidden_states.view(output_shape)
|
||||||
|
|
||||||
|
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
|
||||||
|
if v is not None)
|
||||||
|
|
||||||
|
|
||||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||||
@parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP])
|
@parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP, GPT2Model])
|
||||||
def test_self_attention_block(model_cls):
|
def test_self_attention_block(model_cls):
|
||||||
config = transformers.GPT2Config(n_position=64, n_layer=4, n_head=16, n_embd=HIDDEN_DIM)
|
config = transformers.GPT2Config(n_position=64, n_layer=4, n_head=16, n_embd=HIDDEN_DIM)
|
||||||
if model_cls == GPT2MLP:
|
if model_cls == GPT2MLP:
|
||||||
@ -193,11 +279,17 @@ def test_self_attention_block(model_cls):
|
|||||||
input_sample = {
|
input_sample = {
|
||||||
'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'),
|
'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'),
|
||||||
}
|
}
|
||||||
else:
|
elif model_cls in (GPT2Attention, GPT2Block):
|
||||||
input_sample = {
|
input_sample = {
|
||||||
'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'),
|
'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'),
|
||||||
'attention_mask': torch.rand(1, SEQ_LENGTH).to('meta'),
|
'attention_mask': torch.rand(1, SEQ_LENGTH).to('meta'),
|
||||||
}
|
}
|
||||||
|
else:
|
||||||
|
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
||||||
|
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
||||||
|
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
||||||
|
kwargs = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
||||||
|
input_sample = {k: v.to('meta') for k, v in kwargs.items()}
|
||||||
|
|
||||||
graph = tracer.trace(root=model, meta_args=input_sample)
|
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user