[shardformer] fix modeling of bloom and falcon (#5796)

This commit is contained in:
Hongxin Liu
2024-06-11 17:43:50 +08:00
committed by GitHub
parent 587bbf4c6d
commit aa125bcc91
2 changed files with 12 additions and 7 deletions

View File

@@ -475,7 +475,10 @@ class BloomPipelineForwards:
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1
logger.warning(