[pipeline] add pipeline support for T5Stack/T5EncoderModel (#4300)

* modify t5 policy & add test

* pipeline stage distribution for t5

* complete t5 base policy

* t5 stack: halfway

* modify gpt2 pipeline test

* complete pipeline forward for T5Stack/T5EncoderModel

* fix docstring

* move t5 util tests to test_pipeline
This commit is contained in:
Baizhou Zhang
2023-07-21 16:23:04 +08:00
committed by Hongxin Liu
parent 18ebcf406a
commit 36e546b2cc
6 changed files with 604 additions and 21 deletions

View File

@@ -62,17 +62,15 @@ output_transform_fn = lambda x: x
loss_fn_for_gpt2_model = lambda x: x.last_hidden_state.mean()
loss_fn = lambda x: x.loss
config = transformers.GPT2Config(
n_layer=2,
n_head=4,
#n_embd=128,
vocab_size=50258,
attn_pdrop=0,
embd_pdrop=0,
resid_pdrop=0,
summary_first_dropout=0,
hidden_dropout=0,
problem_type="single_label_classification")
config = transformers.GPT2Config(n_layer=2,
n_head=4,
vocab_size=50258,
attn_pdrop=0,
embd_pdrop=0,
resid_pdrop=0,
summary_first_dropout=0,
hidden_dropout=0,
problem_type="single_label_classification")
# register the following models
model_zoo.register(name='transformers_gpt',

View File

@@ -0,0 +1,39 @@
from colossalai.shardformer.policies.t5 import T5BasePolicy
def test_t5_pipeline_distribution():
num_test_cases = 8
test_dict = {
'num_encoder_layers': [2, 1, 3, 2, 3, 2, 10, 5],
'num_decoder_layers': [2, 8, 0, 2, 1, 5, 6, 22],
'num_stages': [2, 2, 2, 4, 4, 4, 8, 8],
'decoder_starting_stage': [1, 1, 2, 2, 3, 1, 5, 2]
}
for i in range(num_test_cases):
_, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(test_dict['num_encoder_layers'][i],
test_dict['num_decoder_layers'][i],
test_dict['num_stages'][i])
assert test_dict['decoder_starting_stage'][i] == decoder_starting_stage
def test_t5_pipeline_layers():
num_test_cases = 4
test_dict = {
'num_encoder_layers': [2, 3, 2, 4],
'num_decoder_layers': [2, 0, 2, 8],
'num_stages': [2, 2, 4, 4],
'layers_per_stage': [[[0, 2], [0, 2]], [[0, 1], [1, 3]], [[0, 1], [1, 2], [0, 1], [1, 2]],
[[0, 4], [0, 3], [3, 6], [6, 8]]]
}
for i in range(num_test_cases):
layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(
test_dict['num_encoder_layers'][i], test_dict['num_decoder_layers'][i], test_dict['num_stages'][i])
for stage in range(test_dict['num_stages'][i]):
start_idx, end_idx = test_dict['layers_per_stage'][i][stage]
predicted_start, predicted_end = T5BasePolicy.get_t5_stage_index(layers_per_stage, stage,
decoder_starting_stage)
assert start_idx == predicted_start
assert end_idx == predicted_end

View File

@@ -29,9 +29,11 @@ def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_laz
for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
inputs = data_gen_fn()
inputs = {k: v.cuda() for k, v in inputs.items()}
input_ids, _ = inputs['input_ids'], inputs['attention_mask']
_, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
enable_tensor_parallelism, use_lazy_init)
input_ids = inputs['input_ids']
batch_size, seq_len = input_ids.shape
hidden_size = 768
hidden_size = sharded_model.config.n_embd
hidden_state_shape = (batch_size, seq_len, hidden_size)
if not stage_manager.is_first_stage():
@@ -40,12 +42,12 @@ def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_laz
inputs['input_ids'] = None
inputs['hidden_states'] = hidden_states
_, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
enable_tensor_parallelism, use_lazy_init)
sharded_model.train()
output = sharded_model(**inputs)
if stage_manager.is_last_stage():
if name != 'transformers_gpt':
if name == 'transformers_gpt':
assert output[0].shape == hidden_state_shape
else:
assert output.loss is not None
else:
assert output['hidden_states'].shape == hidden_state_shape

View File

@@ -0,0 +1,96 @@
import pytest
import torch
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.policies.t5 import T5BasePolicy
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_pipeline_model
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
# TODO: add tests for forward/backward later
pass
@parameterize('enable_tensor_parallelism', [False])
@parameterize('enable_fused_normalization', [False])
@parameterize('use_lazy_init', [False])
#TODO: merge this into test_shard_t5.py
def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
DP_DIM, PP_DIM = 0, 1
DP_SIZE, PP_SIZE = 2, 2
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
if name != 'transformers_t5_encoder_model':
continue
inputs = data_gen_fn()
inputs = {k: v.cuda() for k, v in inputs.items()}
input_ids = inputs['input_ids']
_, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
enable_tensor_parallelism, use_lazy_init)
batch_size, seq_len = input_ids.shape
hidden_size = sharded_model.config.d_model
num_heads = sharded_model.config.num_heads
hidden_state_shape = (batch_size, seq_len, hidden_size)
position_bias_shape = (batch_size, num_heads, seq_len, seq_len)
num_encoder_layers = len(sharded_model.encoder.block)
decoder = sharded_model.__dict__.get('decoder', None)
num_decoder_layers = len(decoder.block) if decoder else 0
_, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(num_encoder_layers, num_decoder_layers, PP_SIZE)
stage = stage_manager.stage
at_first_stage = (stage == 0) or (stage == decoder_starting_stage)
at_last_stage = (stage == decoder_starting_stage - 1) or (stage == stage_manager.num_stages - 1)
if not at_first_stage:
# change inputs if not the first stage
hidden_states = torch.zeros(*hidden_state_shape).cuda()
position_bias = torch.zeros(*position_bias_shape).cuda()
encoder_decoder_position_bias = torch.zeros(*position_bias_shape).cuda()
inputs['input_ids'] = None
inputs['hidden_states'] = hidden_states
inputs['position_bias'] = position_bias
inputs['encoder_decoder_position_bias'] = encoder_decoder_position_bias
sharded_model.train()
output = sharded_model(**inputs)
if at_last_stage:
if name != 'transformers_t5_for_conditional_generation':
assert output[0].shape == hidden_state_shape
else:
assert output.loss is not None
else:
assert output['hidden_states'].shape == hidden_state_shape
# position_bias information should be passed in T5
assert 'position_bias' in output
assert 'encoder_decoder_position_bias' in output
torch.cuda.empty_cache()
def check_t5(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_t5_test()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_t5():
spawn(check_t5, 4)
if __name__ == "__main__":
test_t5()