[shardformer] Support the T5ForTokenClassification model (#5816)

* t5 token, still pytest fail

* Resolve T5 Pytest Failure

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix typos

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Guangyao Zhang
2024-06-27 16:40:38 +08:00
committed by GitHub
parent 5dfbcd7746
commit d9d5e7ea1f
5 changed files with 166 additions and 11 deletions

View File

@@ -8,8 +8,15 @@ from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
Seq2SeqLMOutput,
Seq2SeqModelOutput,
TokenClassifierOutput,
)
from transformers.models.t5.modeling_t5 import (
T5EncoderModel,
T5ForConditionalGeneration,
T5ForTokenClassification,
T5Model,
T5Stack,
)
from transformers.models.t5.modeling_t5 import T5EncoderModel, T5ForConditionalGeneration, T5Model, T5Stack
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
@@ -582,6 +589,71 @@ class T5PipelineForwards:
return outputs
@staticmethod
def t5_for_token_classification_forward(
self: T5ForTokenClassification,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
position_bias: Optional[torch.Tensor] = None,
encoder_decoder_position_bias: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
backward_tensor_keys: Optional[List[str]] = None,
stage_index: Optional[List[int]] = None,
decoder_starting_stage: Optional[int] = None,
) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
r"""
This function is modified on the basis of transformers.models.t5.modeling_t5.T5ForTokenClassification.forward.
Please refer to original code of transformers for more details.
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = T5PipelineForwards.t5_stack_forward(
self.transformer.encoder,
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
position_bias=position_bias,
encoder_decoder_position_bias=encoder_decoder_position_bias,
stage_index=stage_index,
decoder_starting_stage=decoder_starting_stage,
)
if stage_manager.is_last_stage():
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
return outputs
def get_t5_flash_attention_forward():
from transformers.models.t5.modeling_t5 import T5Attention