[pipeline] fix return_dict/fix pure_pipeline_test (#4331)

This commit is contained in:
Baizhou Zhang
2023-07-27 14:53:20 +08:00
committed by Hongxin Liu
parent 411cf1d2db
commit da3cef27ad
5 changed files with 29 additions and 53 deletions

View File

@@ -8,6 +8,18 @@ import torch
import torch.nn as nn
from torch import Tensor, nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
)
from transformers.models.opt.modeling_opt import (
OPTForCausalLM,
OPTForQuestionAnswering,
OPTForSequenceClassification,
OPTModel,
)
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
@@ -317,7 +329,7 @@ class OPTPipelineForwards:
@staticmethod
def opt_model_forward(
self: 'OPTModel',
self: OPTModel,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
@@ -330,7 +342,7 @@ class OPTPipelineForwards:
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
) -> Union[Tuple, 'BaseModelOutputWithPast']:
) -> Union[Tuple, BaseModelOutputWithPast]:
'''
This forward method is modified based on transformers.models.opt.modeling_opt.OPTModel.forward
'''
@@ -506,7 +518,7 @@ class OPTPipelineForwards:
@staticmethod
def opt_for_causal_lm_forward(
self: 'OPTForCausalLM',
self: OPTForCausalLM,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
@@ -520,7 +532,7 @@ class OPTPipelineForwards:
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
) -> Union[Tuple, 'CausalLMOutputWithPast']:
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -646,7 +658,7 @@ class OPTPipelineForwards:
@staticmethod
def opt_for_sequence_classification_forward(
self: 'OPTForSequenceClassification',
self: OPTForSequenceClassification,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
@@ -660,7 +672,7 @@ class OPTPipelineForwards:
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
) -> Union[Tuple, 'SequenceClassifierOutputWithPast']:
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
@@ -746,7 +758,7 @@ class OPTPipelineForwards:
@staticmethod
def opt_for_question_answering_forward(
self: 'OPTForQuestionAnswering',
self: OPTForQuestionAnswering,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
@@ -761,7 +773,7 @@ class OPTPipelineForwards:
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
) -> Union[Tuple, 'QuestionAnsweringModelOutput']:
) -> Union[Tuple, QuestionAnsweringModelOutput]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.