fix beam_width

This commit is contained in:
yuehuayingxueluo
2024-01-04 16:48:53 +08:00
committed by FrankLeeeee
parent b2eb9cd186
commit 3ad1f3b78b
2 changed files with 6 additions and 3 deletions

View File

@@ -176,8 +176,12 @@ def llama_attn_forward(
def generate_padding_position_id(input_ids: torch.Tensor) -> torch.Tensor:
# Replace this code and use a more flexible method to obtain padding_id, avoiding directly setting padding_id like this.
padding_id = 2
attention_mask = input_ids.ne(padding_id).long()
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
return position_ids
# def unpad_inputs(input_ids: torch.Tensor):