mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 21:22:49 +00:00
[pipeline] add pipeline support for T5Stack/T5EncoderModel (#4300)
* modify t5 policy & add test * pipeline stage distribution for t5 * complete t5 base policy * t5 stack: halfway * modify gpt2 pipeline test * complete pipeline forward for T5Stack/T5EncoderModel * fix docstring * move t5 util tests to test_pipeline
This commit is contained in:
committed by
Hongxin Liu
parent
18ebcf406a
commit
36e546b2cc
@@ -62,17 +62,15 @@ output_transform_fn = lambda x: x
|
||||
loss_fn_for_gpt2_model = lambda x: x.last_hidden_state.mean()
|
||||
loss_fn = lambda x: x.loss
|
||||
|
||||
config = transformers.GPT2Config(
|
||||
n_layer=2,
|
||||
n_head=4,
|
||||
#n_embd=128,
|
||||
vocab_size=50258,
|
||||
attn_pdrop=0,
|
||||
embd_pdrop=0,
|
||||
resid_pdrop=0,
|
||||
summary_first_dropout=0,
|
||||
hidden_dropout=0,
|
||||
problem_type="single_label_classification")
|
||||
config = transformers.GPT2Config(n_layer=2,
|
||||
n_head=4,
|
||||
vocab_size=50258,
|
||||
attn_pdrop=0,
|
||||
embd_pdrop=0,
|
||||
resid_pdrop=0,
|
||||
summary_first_dropout=0,
|
||||
hidden_dropout=0,
|
||||
problem_type="single_label_classification")
|
||||
|
||||
# register the following models
|
||||
model_zoo.register(name='transformers_gpt',
|
||||
|
39
tests/test_pipeline/test_policy/test_t5_pipeline_utils.py
Normal file
39
tests/test_pipeline/test_policy/test_t5_pipeline_utils.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from colossalai.shardformer.policies.t5 import T5BasePolicy
|
||||
|
||||
|
||||
def test_t5_pipeline_distribution():
|
||||
num_test_cases = 8
|
||||
test_dict = {
|
||||
'num_encoder_layers': [2, 1, 3, 2, 3, 2, 10, 5],
|
||||
'num_decoder_layers': [2, 8, 0, 2, 1, 5, 6, 22],
|
||||
'num_stages': [2, 2, 2, 4, 4, 4, 8, 8],
|
||||
'decoder_starting_stage': [1, 1, 2, 2, 3, 1, 5, 2]
|
||||
}
|
||||
|
||||
for i in range(num_test_cases):
|
||||
_, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(test_dict['num_encoder_layers'][i],
|
||||
test_dict['num_decoder_layers'][i],
|
||||
test_dict['num_stages'][i])
|
||||
assert test_dict['decoder_starting_stage'][i] == decoder_starting_stage
|
||||
|
||||
|
||||
def test_t5_pipeline_layers():
|
||||
num_test_cases = 4
|
||||
test_dict = {
|
||||
'num_encoder_layers': [2, 3, 2, 4],
|
||||
'num_decoder_layers': [2, 0, 2, 8],
|
||||
'num_stages': [2, 2, 4, 4],
|
||||
'layers_per_stage': [[[0, 2], [0, 2]], [[0, 1], [1, 3]], [[0, 1], [1, 2], [0, 1], [1, 2]],
|
||||
[[0, 4], [0, 3], [3, 6], [6, 8]]]
|
||||
}
|
||||
|
||||
for i in range(num_test_cases):
|
||||
layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(
|
||||
test_dict['num_encoder_layers'][i], test_dict['num_decoder_layers'][i], test_dict['num_stages'][i])
|
||||
|
||||
for stage in range(test_dict['num_stages'][i]):
|
||||
start_idx, end_idx = test_dict['layers_per_stage'][i][stage]
|
||||
predicted_start, predicted_end = T5BasePolicy.get_t5_stage_index(layers_per_stage, stage,
|
||||
decoder_starting_stage)
|
||||
assert start_idx == predicted_start
|
||||
assert end_idx == predicted_end
|
@@ -29,9 +29,11 @@ def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_laz
|
||||
for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
|
||||
inputs = data_gen_fn()
|
||||
inputs = {k: v.cuda() for k, v in inputs.items()}
|
||||
input_ids, _ = inputs['input_ids'], inputs['attention_mask']
|
||||
_, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
|
||||
enable_tensor_parallelism, use_lazy_init)
|
||||
input_ids = inputs['input_ids']
|
||||
batch_size, seq_len = input_ids.shape
|
||||
hidden_size = 768
|
||||
hidden_size = sharded_model.config.n_embd
|
||||
hidden_state_shape = (batch_size, seq_len, hidden_size)
|
||||
|
||||
if not stage_manager.is_first_stage():
|
||||
@@ -40,12 +42,12 @@ def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_laz
|
||||
inputs['input_ids'] = None
|
||||
inputs['hidden_states'] = hidden_states
|
||||
|
||||
_, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
|
||||
enable_tensor_parallelism, use_lazy_init)
|
||||
sharded_model.train()
|
||||
output = sharded_model(**inputs)
|
||||
if stage_manager.is_last_stage():
|
||||
if name != 'transformers_gpt':
|
||||
if name == 'transformers_gpt':
|
||||
assert output[0].shape == hidden_state_shape
|
||||
else:
|
||||
assert output.loss is not None
|
||||
else:
|
||||
assert output['hidden_states'].shape == hidden_state_shape
|
||||
|
96
tests/test_shardformer/test_model/test_shard_t5_pipeline.py
Normal file
96
tests/test_shardformer/test_model/test_shard_t5_pipeline.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.policies.t5 import T5BasePolicy
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_pipeline_model
|
||||
|
||||
|
||||
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||
# TODO: add tests for forward/backward later
|
||||
pass
|
||||
|
||||
|
||||
@parameterize('enable_tensor_parallelism', [False])
|
||||
@parameterize('enable_fused_normalization', [False])
|
||||
@parameterize('use_lazy_init', [False])
|
||||
#TODO: merge this into test_shard_t5.py
|
||||
def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
|
||||
DP_DIM, PP_DIM = 0, 1
|
||||
DP_SIZE, PP_SIZE = 2, 2
|
||||
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
|
||||
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
|
||||
for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
|
||||
if name != 'transformers_t5_encoder_model':
|
||||
continue
|
||||
|
||||
inputs = data_gen_fn()
|
||||
inputs = {k: v.cuda() for k, v in inputs.items()}
|
||||
input_ids = inputs['input_ids']
|
||||
|
||||
_, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
|
||||
enable_tensor_parallelism, use_lazy_init)
|
||||
|
||||
batch_size, seq_len = input_ids.shape
|
||||
hidden_size = sharded_model.config.d_model
|
||||
num_heads = sharded_model.config.num_heads
|
||||
hidden_state_shape = (batch_size, seq_len, hidden_size)
|
||||
position_bias_shape = (batch_size, num_heads, seq_len, seq_len)
|
||||
|
||||
num_encoder_layers = len(sharded_model.encoder.block)
|
||||
decoder = sharded_model.__dict__.get('decoder', None)
|
||||
num_decoder_layers = len(decoder.block) if decoder else 0
|
||||
|
||||
_, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(num_encoder_layers, num_decoder_layers, PP_SIZE)
|
||||
stage = stage_manager.stage
|
||||
at_first_stage = (stage == 0) or (stage == decoder_starting_stage)
|
||||
at_last_stage = (stage == decoder_starting_stage - 1) or (stage == stage_manager.num_stages - 1)
|
||||
|
||||
if not at_first_stage:
|
||||
# change inputs if not the first stage
|
||||
hidden_states = torch.zeros(*hidden_state_shape).cuda()
|
||||
position_bias = torch.zeros(*position_bias_shape).cuda()
|
||||
encoder_decoder_position_bias = torch.zeros(*position_bias_shape).cuda()
|
||||
inputs['input_ids'] = None
|
||||
inputs['hidden_states'] = hidden_states
|
||||
inputs['position_bias'] = position_bias
|
||||
inputs['encoder_decoder_position_bias'] = encoder_decoder_position_bias
|
||||
|
||||
sharded_model.train()
|
||||
output = sharded_model(**inputs)
|
||||
if at_last_stage:
|
||||
if name != 'transformers_t5_for_conditional_generation':
|
||||
assert output[0].shape == hidden_state_shape
|
||||
else:
|
||||
assert output.loss is not None
|
||||
else:
|
||||
assert output['hidden_states'].shape == hidden_state_shape
|
||||
# position_bias information should be passed in T5
|
||||
assert 'position_bias' in output
|
||||
assert 'encoder_decoder_position_bias' in output
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def check_t5(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_t5_test()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_t5():
|
||||
spawn(check_t5, 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_t5()
|
Reference in New Issue
Block a user