[pipeline] add pipeline support for all T5 models (#4310)

* complete policy for T5Model & T5ForConditionalGeneration

* modify function signature in forwards

* add forward for T5model

* add forward for T5ForConditionalGeneration

* fix a bug

* fix hidden_states transporting in decoder

* fix the passing of encoder_outputs
This commit is contained in:
Baizhou Zhang
2023-07-25 14:45:33 +08:00
committed by Hongxin Liu
parent d0807122e2
commit 083d7da33d
3 changed files with 388 additions and 19 deletions

View File

@@ -28,8 +28,6 @@ def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_
sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
if name != 'transformers_t5_encoder_model':
continue
inputs = data_gen_fn()
inputs = {k: v.cuda() for k, v in inputs.items()}
@@ -52,6 +50,7 @@ def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_
stage = stage_manager.stage
at_first_stage = (stage == 0) or (stage == decoder_starting_stage)
at_last_stage = (stage == decoder_starting_stage - 1) or (stage == stage_manager.num_stages - 1)
in_decoder = stage >= decoder_starting_stage
if not at_first_stage:
# change inputs if not the first stage
@@ -62,19 +61,25 @@ def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_
inputs['hidden_states'] = hidden_states
inputs['position_bias'] = position_bias
inputs['encoder_decoder_position_bias'] = encoder_decoder_position_bias
if in_decoder:
encoder_output_states = torch.zeros(*hidden_state_shape).cuda()
inputs['encoder_outputs'] = (encoder_output_states,)
sharded_model.train()
output = sharded_model(**inputs)
if at_last_stage:
if name != 'transformers_t5_for_conditional_generation':
assert output[0].shape == hidden_state_shape
else:
if name == 'transformers_t5_for_conditional_generation' and in_decoder:
assert output.loss is not None
else:
if name != 'transformers_t5_encoder_model' and not in_decoder:
output = output['encoder_outputs']
assert output[0].shape == hidden_state_shape
else:
assert output['hidden_states'].shape == hidden_state_shape
# position_bias information should be passed in T5
assert 'position_bias' in output
assert 'encoder_decoder_position_bias' in output
assert output['position_bias'].shape == position_bias_shape
if in_decoder:
assert output['encoder_decoder_position_bias'].shape == position_bias_shape
torch.cuda.empty_cache()