mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
[pipeline] fix return_dict/fix pure_pipeline_test (#4331)
This commit is contained in:
committed by
Hongxin Liu
parent
411cf1d2db
commit
da3cef27ad
@@ -1,3 +1,4 @@
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@@ -277,9 +278,6 @@ class BertPipelineForwards:
|
||||
if output_hidden_states:
|
||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||
output_hidden_states = False
|
||||
if return_dict:
|
||||
logger.warning_once('return_dict is not supported for pipeline models at the moment')
|
||||
return_dict = False
|
||||
|
||||
outputs = BertPipelineForwards.bert_model_forward(
|
||||
self.bert,
|
||||
@@ -387,9 +385,6 @@ class BertPipelineForwards:
|
||||
if output_hidden_states:
|
||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||
output_hidden_states = False
|
||||
if return_dict:
|
||||
logger.warning_once('return_dict is not supported for pipeline models at the moment')
|
||||
return_dict = False
|
||||
|
||||
outputs = BertPipelineForwards.bert_model_forward(
|
||||
self.bert,
|
||||
@@ -478,9 +473,6 @@ class BertPipelineForwards:
|
||||
if output_hidden_states:
|
||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||
output_hidden_states = False
|
||||
if return_dict:
|
||||
logger.warning_once('return_dict is not supported for pipeline models at the moment')
|
||||
return_dict = False
|
||||
|
||||
outputs = BertPipelineForwards.bert_model_forward(
|
||||
self.bert,
|
||||
@@ -579,16 +571,15 @@ class BertPipelineForwards:
|
||||
FutureWarning,
|
||||
)
|
||||
labels = kwargs.pop("next_sentence_label")
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if output_attentions:
|
||||
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
|
||||
output_attentions = False
|
||||
if output_hidden_states:
|
||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||
output_hidden_states = False
|
||||
if return_dict:
|
||||
logger.warning_once('return_dict is not supported for pipeline models at the moment')
|
||||
return_dict = False
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = BertPipelineForwards.bert_model_forward(self.bert,
|
||||
input_ids,
|
||||
@@ -661,10 +652,6 @@ class BertPipelineForwards:
|
||||
if output_hidden_states:
|
||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||
output_hidden_states = False
|
||||
if return_dict:
|
||||
logger.warning_once('return_dict is not supported for pipeline models at the moment')
|
||||
return_dict = False
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = BertPipelineForwards.bert_model_forward(self.bert,
|
||||
input_ids,
|
||||
@@ -753,10 +740,6 @@ class BertPipelineForwards:
|
||||
if output_hidden_states:
|
||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||
output_hidden_states = False
|
||||
if return_dict:
|
||||
logger.warning_once('return_dict is not supported for pipeline models at the moment')
|
||||
return_dict = False
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = BertPipelineForwards.bert_model_forward(
|
||||
self.bert,
|
||||
@@ -832,10 +815,6 @@ class BertPipelineForwards:
|
||||
if output_hidden_states:
|
||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||
output_hidden_states = False
|
||||
if return_dict:
|
||||
logger.warning_once('return_dict is not supported for pipeline models at the moment')
|
||||
return_dict = False
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# in our pipeline design,input ids are copied for every stage and shouldn't be none
|
||||
# the input_ids for multiple choice model is [batch_size, num_choices, sequence_length]
|
||||
@@ -928,10 +907,6 @@ class BertPipelineForwards:
|
||||
if output_hidden_states:
|
||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||
output_hidden_states = False
|
||||
if return_dict:
|
||||
logger.warning_once('return_dict is not supported for pipeline models at the moment')
|
||||
return_dict = False
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = BertPipelineForwards.bert_model_forward(
|
||||
self.bert,
|
||||
|
Reference in New Issue
Block a user