mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
[pipeline] add pipeline support for all T5 models (#4310)
* complete policy for T5Model & T5ForConditionalGeneration * modify function signature in forwards * add forward for T5model * add forward for T5ForConditionalGeneration * fix a bug * fix hidden_states transporting in decoder * fix the passing of encoder_outputs
This commit is contained in:
committed by
Hongxin Liu
parent
d0807122e2
commit
083d7da33d
@@ -28,8 +28,6 @@ def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
|
||||
for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
|
||||
if name != 'transformers_t5_encoder_model':
|
||||
continue
|
||||
|
||||
inputs = data_gen_fn()
|
||||
inputs = {k: v.cuda() for k, v in inputs.items()}
|
||||
@@ -52,6 +50,7 @@ def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_
|
||||
stage = stage_manager.stage
|
||||
at_first_stage = (stage == 0) or (stage == decoder_starting_stage)
|
||||
at_last_stage = (stage == decoder_starting_stage - 1) or (stage == stage_manager.num_stages - 1)
|
||||
in_decoder = stage >= decoder_starting_stage
|
||||
|
||||
if not at_first_stage:
|
||||
# change inputs if not the first stage
|
||||
@@ -62,19 +61,25 @@ def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_
|
||||
inputs['hidden_states'] = hidden_states
|
||||
inputs['position_bias'] = position_bias
|
||||
inputs['encoder_decoder_position_bias'] = encoder_decoder_position_bias
|
||||
if in_decoder:
|
||||
encoder_output_states = torch.zeros(*hidden_state_shape).cuda()
|
||||
inputs['encoder_outputs'] = (encoder_output_states,)
|
||||
|
||||
sharded_model.train()
|
||||
output = sharded_model(**inputs)
|
||||
if at_last_stage:
|
||||
if name != 'transformers_t5_for_conditional_generation':
|
||||
assert output[0].shape == hidden_state_shape
|
||||
else:
|
||||
if name == 'transformers_t5_for_conditional_generation' and in_decoder:
|
||||
assert output.loss is not None
|
||||
else:
|
||||
if name != 'transformers_t5_encoder_model' and not in_decoder:
|
||||
output = output['encoder_outputs']
|
||||
assert output[0].shape == hidden_state_shape
|
||||
else:
|
||||
assert output['hidden_states'].shape == hidden_state_shape
|
||||
# position_bias information should be passed in T5
|
||||
assert 'position_bias' in output
|
||||
assert 'encoder_decoder_position_bias' in output
|
||||
assert output['position_bias'].shape == position_bias_shape
|
||||
if in_decoder:
|
||||
assert output['encoder_decoder_position_bias'].shape == position_bias_shape
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
Reference in New Issue
Block a user