fix(model): Fix device MPS load llama error (#1033)

This commit is contained in:
Fangyin Cheng
2024-01-05 14:19:13 +08:00
committed by GitHub
parent d8393a9b32
commit 186b6a5668
2 changed files with 16 additions and 4 deletions

View File

@@ -34,6 +34,7 @@ def forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()