fix(model): Fix device MPS load llama error (#1033)

This commit is contained in:
Fangyin Cheng
2024-01-05 14:19:13 +08:00
committed by GitHub
parent d8393a9b32
commit 186b6a5668
2 changed files with 16 additions and 4 deletions

View File

@@ -34,6 +34,7 @@ def forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()

View File

@@ -166,11 +166,22 @@ def huggingface_loader(llm_adapter: LLMModelAdapter, model_params: ModelParamete
elif device == "mps":
kwargs = {"torch_dtype": torch.float16}
from dbgpt.model.llm.monkey_patch import (
replace_llama_attn_with_non_inplace_operations,
)
replace_llama_attn_with_non_inplace_operations()
import transformers
version = tuple(int(v) for v in transformers.__version__.split("."))
if version < (4, 35, 0):
from dbgpt.model.llm.monkey_patch import (
replace_llama_attn_with_non_inplace_operations,
)
# NOTE: Recent transformers library seems to fix the mps issue, also
# it has made some changes causing compatibility issues with our
# original patch. So we only apply the patch for older versions.
# Avoid bugs in mps backend by not using in-place operations.
replace_llama_attn_with_non_inplace_operations()
else:
raise ValueError(f"Invalid device: {device}")