[misc] resolve code factor issues (#4433)

This commit is contained in:
Hongxin Liu
2023-08-14 17:43:33 +08:00
parent 328a791d10
commit 172f7fa3cf
20 changed files with 31 additions and 205 deletions

View File

@@ -57,7 +57,7 @@ class BertPipelineForwards:
hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage
stage_index: Optional[List[int]] = None,
):
# TODO: add explaination of the output here.
# TODO(jianghai): add explaination of the output here.
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
@@ -113,7 +113,7 @@ class BertPipelineForwards:
batch_size, seq_length = input_shape
device = hidden_states.device
# TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future.
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
if output_attentions:
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
output_attentions = False
@@ -272,7 +272,7 @@ class BertPipelineForwards:
logger = logging.get_logger(__name__)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future.
# TODO(jianghai) left the recording kv-value tensors as () or None type, this feature may be added in the future.
if output_attentions:
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
output_attentions = False
@@ -534,7 +534,7 @@ class BertPipelineForwards:
stage_index: Optional[List[int]] = None,
**kwargs,
):
#-> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]:
# -> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair