From 083d7da33d11d0bea17e4f05cdaf102ddb981ea0 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Tue, 25 Jul 2023 14:45:33 +0800 Subject: [PATCH] [pipeline] add pipeline support for all T5 models (#4310) * complete policy for T5Model & T5ForConditionalGeneration * modify function signature in forwards * add forward for T5model * add forward for T5ForConditionalGeneration * fix a bug * fix hidden_states transporting in decoder * fix the passing of encoder_outputs --- colossalai/shardformer/modeling/t5.py | 324 +++++++++++++++++- colossalai/shardformer/policies/t5.py | 64 +++- .../test_model/test_shard_t5_pipeline.py | 19 +- 3 files changed, 388 insertions(+), 19 deletions(-) diff --git a/colossalai/shardformer/modeling/t5.py b/colossalai/shardformer/modeling/t5.py index cc270d582..7eb4d1792 100644 --- a/colossalai/shardformer/modeling/t5.py +++ b/colossalai/shardformer/modeling/t5.py @@ -1,11 +1,15 @@ -from functools import partial -from types import MethodType -from typing import Callable, Dict, List, Optional, Tuple, Union +import warnings +from typing import Dict, List, Optional, Tuple, Union import torch -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from torch.nn import CrossEntropyLoss from torch.utils.checkpoint import checkpoint -from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions +from transformers.modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, +) from transformers.models.t5.modeling_t5 import T5EncoderModel, T5ForConditionalGeneration, T5Model, T5Stack from transformers.utils import logging @@ -198,14 +202,13 @@ class T5PipelineForwards: if use_cache is False or use_cache is None: layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] hidden_states, present_key_value_state = layer_outputs[:2] - # print(stage, len(layer_outputs), present_key_value_state.shape) # We share the position biases between the layers - the first layer store them # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), # (cross-attention position bias), (cross-attention weights) position_bias = layer_outputs[2] - if self.is_decoder and encoder_hidden_states is not None: + if in_decoder and encoder_hidden_states is not None: encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] # append next layer key value states if use_cache: @@ -238,6 +241,313 @@ class T5PipelineForwards: 'encoder_decoder_position_bias': encoder_decoder_position_bias } + @staticmethod + def t5_model_forward( + self: T5Model, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + decoder_inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + position_bias: Optional[torch.Tensor] = None, + encoder_decoder_position_bias: Optional[torch.Tensor] = None, + stage_index: Optional[List[int]] = None, + decoder_starting_stage: Optional[int] = None, + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: + + # This function is modified on the basis of transformers.models.t5.modeling_t5.T5Model.forward. + # Please refer to original code of transformers for more details. + + __HEAD_MASK_WARNING_MSG = """ + The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently, + `decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions. + If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers, + num_heads)`. + """ + + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + logger = logging.get_logger(__name__) + + # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + if past_key_values: + logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.') + past_key_values = None + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if use_cache: + logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + use_cache = False + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + in_decoder = stage_manager.stage >= decoder_starting_stage + + # Stage is in encoder, directly return the output of t5_stack_forward + if not in_decoder: + encoder_outputs = T5PipelineForwards.t5_stack_forward( + self.encoder, + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + position_bias=position_bias, + encoder_decoder_position_bias=encoder_decoder_position_bias, + stage_index=stage_index, + decoder_starting_stage=decoder_starting_stage) + if stage_manager.stage == decoder_starting_stage - 1: + # last stage of encoder + return {'encoder_outputs': encoder_outputs} + else: + return encoder_outputs + + at_last_decoder_stage = stage_manager.is_last_stage() + at_first_decoder_stage = stage_manager.stage == decoder_starting_stage + + if encoder_outputs is None: + raise ValueError("Non-empty encoder_outputs should be passed in at decoder stages.") + + encoder_hidden_states = encoder_outputs[0] + if return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # Stage is in decoder, we assume that the outputs of last stage of encoder will be passed in. + if not at_first_decoder_stage and hidden_states is None: + raise ValueError("If not at the first layer of decoder, non-empty hidden_states must be provided.") + + # Decode + decoder_outputs = T5PipelineForwards.t5_stack_forward( + self.decoder, + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + hidden_states=hidden_states, + position_bias=position_bias, + encoder_decoder_position_bias=encoder_decoder_position_bias, + stage_index=stage_index, + decoder_starting_stage=decoder_starting_stage) + + # Directly return outputs of overloaded T5Stack forward if not at last stage. + if not at_last_decoder_stage: + decoder_outputs['encoder_outputs'] = encoder_outputs # encoder_outputs should be passed to the next stage + return decoder_outputs + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + @staticmethod + def t5_for_conditional_generation_forward( + self: T5ForConditionalGeneration, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + position_bias: Optional[torch.Tensor] = None, + encoder_decoder_position_bias: Optional[torch.Tensor] = None, + stage_index: Optional[List[int]] = None, + decoder_starting_stage: Optional[int] = None, + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: + + # This function is modified on the basis of transformers.models.t5.modeling_t5.T5ForConditionalGeneration.forward. + # Please refer to original code of transformers for more details. + + __HEAD_MASK_WARNING_MSG = """ + The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently, + `decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions. + If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers, + num_heads)`. + """ + + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + logger = logging.get_logger(__name__) + + # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + if past_key_values: + logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.') + past_key_values = None + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if use_cache: + logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + use_cache = False + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + in_decoder = stage_manager.stage >= decoder_starting_stage + + # Stage is in encoder, directly return the output of t5_stack_forward + if not in_decoder: + encoder_outputs = T5PipelineForwards.t5_stack_forward( + self.encoder, + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + position_bias=position_bias, + encoder_decoder_position_bias=encoder_decoder_position_bias, + stage_index=stage_index, + decoder_starting_stage=decoder_starting_stage) + if stage_manager.stage == decoder_starting_stage - 1: + # last stage of encoder + return {'encoder_outputs': encoder_outputs} + else: + return encoder_outputs + + at_last_decoder_stage = stage_manager.is_last_stage() + at_first_decoder_stage = stage_manager.stage == decoder_starting_stage + + if encoder_outputs is None: + raise ValueError("Non-empty encoder_outputs should be passed in at decoder stages.") + + encoder_hidden_states = encoder_outputs[0] + if return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # Stage is in decoder, we assume that the outputs of last stage of encoder will be passed in. + if not at_first_decoder_stage and hidden_states is None: + raise ValueError("If not at the first layer of decoder, non-empty hidden_states must be provided.") + + # Decode + decoder_outputs = T5PipelineForwards.t5_stack_forward( + self.decoder, + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + hidden_states=hidden_states, + position_bias=position_bias, + encoder_decoder_position_bias=encoder_decoder_position_bias, + stage_index=stage_index, + decoder_starting_stage=decoder_starting_stage) + + # Directly return outputs of overloaded T5Stack forward if not at last stage. + if not at_last_decoder_stage: + decoder_outputs['encoder_outputs'] = encoder_outputs # encoder_outputs should be passed to the next stage + return decoder_outputs + + sequence_output = decoder_outputs[0] + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + sequence_output = sequence_output * (self.model_dim**-0.5) + + lm_logits = self.lm_head(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss(ignore_index=-100) + # move labels to correct device to enable PP + labels = labels.to(lm_logits.device) + loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs + return ((loss,) + output) if loss is not None else output + + return Seq2SeqLMOutput( + loss=loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + @staticmethod def t5_encoder_model_forward( self: T5EncoderModel, diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 1846c5873..0ee18d6c4 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -293,21 +293,42 @@ class T5BasePolicy(Policy): class T5ModelPolicy(T5BasePolicy): + def __init__(self) -> None: + super().__init__() + def module_policy(self): from transformers import T5Model - base_policy = super().module_policy() + policy = super().module_policy() if self.shard_config.enable_tensor_parallelism: self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( suffix="shared", target_module=VocabParallelEmbedding1D, ), - policy=base_policy, + policy=policy, target_key=T5Model) - return base_policy + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward(model_cls=T5Model, new_forward=T5PipelineForwards.t5_model_forward, policy=policy) + + return policy + + def get_held_layers(self) -> List[nn.Module]: + return super().get_held_layers() + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + module = self.model + stage_manager = self.pipeline_stage_manager + if stage_manager is not None and stage_manager.num_stages > 1: + _, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(len(module.encoder.block), + len(module.decoder.block), + stage_manager.num_stages) + + if id(module.decoder.embed_tokens.weight) == id(module.shared.weight): + return [{0: module.shared.weight, decoder_starting_stage: module.decoder.embed_tokens.weight}] + return [] def postprocess(self): - if self.shard_config.enable_tensor_parallelism: + if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None: binding_map = {"shared.weight": ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]} for k, v in binding_map.items(): src = getattr_(self.model, k) @@ -318,6 +339,9 @@ class T5ModelPolicy(T5BasePolicy): class T5ForConditionalGenerationPolicy(T5BasePolicy): + def __init__(self) -> None: + super().__init__() + def module_policy(self): from transformers import T5ForConditionalGeneration @@ -335,8 +359,38 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy): ], policy=policy, target_key=T5ForConditionalGeneration) + + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward(model_cls=T5ForConditionalGeneration, + new_forward=T5PipelineForwards.t5_for_conditional_generation_forward, + policy=policy) return policy + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.lm_head) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + module = self.model + stage_manager = self.pipeline_stage_manager + if stage_manager is not None and stage_manager.num_stages > 1: + _, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(len(module.encoder.block), + len(module.decoder.block), + stage_manager.num_stages) + + shared_params = [] + if id(module.decoder.embed_tokens.weight) == id(module.shared.weight): + shared_params.append({ + 0: module.shared.weight, + decoder_starting_stage: module.decoder.embed_tokens.weight + }) + if id(module.lm_head.weight) == id(module.shared.weight): + shared_params.append({0: module.shared.weight, stage_manager.num_stages - 1: module.lm_head.weight}) + return shared_params + return [] + def postprocess(self): super().postprocess() if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None: @@ -382,7 +436,7 @@ class T5EncoderPolicy(T5BasePolicy): return [] def postprocess(self): - if self.shard_config.enable_tensor_parallelism: + if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None: binding_map = {"shared.weight": ["encoder.embed_tokens.weight"]} for k, v in binding_map.items(): src = getattr_(self.model, k) diff --git a/tests/test_shardformer/test_model/test_shard_t5_pipeline.py b/tests/test_shardformer/test_model/test_shard_t5_pipeline.py index 3662aa8ac..7f3a5f2ea 100644 --- a/tests/test_shardformer/test_model/test_shard_t5_pipeline.py +++ b/tests/test_shardformer/test_model/test_shard_t5_pipeline.py @@ -28,8 +28,6 @@ def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_ sub_model_zoo = model_zoo.get_sub_registry('transformers_t5') for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items(): - if name != 'transformers_t5_encoder_model': - continue inputs = data_gen_fn() inputs = {k: v.cuda() for k, v in inputs.items()} @@ -52,6 +50,7 @@ def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_ stage = stage_manager.stage at_first_stage = (stage == 0) or (stage == decoder_starting_stage) at_last_stage = (stage == decoder_starting_stage - 1) or (stage == stage_manager.num_stages - 1) + in_decoder = stage >= decoder_starting_stage if not at_first_stage: # change inputs if not the first stage @@ -62,19 +61,25 @@ def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_ inputs['hidden_states'] = hidden_states inputs['position_bias'] = position_bias inputs['encoder_decoder_position_bias'] = encoder_decoder_position_bias + if in_decoder: + encoder_output_states = torch.zeros(*hidden_state_shape).cuda() + inputs['encoder_outputs'] = (encoder_output_states,) sharded_model.train() output = sharded_model(**inputs) if at_last_stage: - if name != 'transformers_t5_for_conditional_generation': - assert output[0].shape == hidden_state_shape - else: + if name == 'transformers_t5_for_conditional_generation' and in_decoder: assert output.loss is not None + else: + if name != 'transformers_t5_encoder_model' and not in_decoder: + output = output['encoder_outputs'] + assert output[0].shape == hidden_state_shape else: assert output['hidden_states'].shape == hidden_state_shape # position_bias information should be passed in T5 - assert 'position_bias' in output - assert 'encoder_decoder_position_bias' in output + assert output['position_bias'].shape == position_bias_shape + if in_decoder: + assert output['encoder_decoder_position_bias'].shape == position_bias_shape torch.cuda.empty_cache()