mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-03 18:17:45 +00:00
fix(model): Fix device MPS load llama error (#1033)
This commit is contained in:
@@ -34,6 +34,7 @@ def forward(
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
padding_mask: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
|
@@ -166,11 +166,22 @@ def huggingface_loader(llm_adapter: LLMModelAdapter, model_params: ModelParamete
|
||||
|
||||
elif device == "mps":
|
||||
kwargs = {"torch_dtype": torch.float16}
|
||||
from dbgpt.model.llm.monkey_patch import (
|
||||
replace_llama_attn_with_non_inplace_operations,
|
||||
)
|
||||
|
||||
replace_llama_attn_with_non_inplace_operations()
|
||||
import transformers
|
||||
|
||||
version = tuple(int(v) for v in transformers.__version__.split("."))
|
||||
if version < (4, 35, 0):
|
||||
from dbgpt.model.llm.monkey_patch import (
|
||||
replace_llama_attn_with_non_inplace_operations,
|
||||
)
|
||||
|
||||
# NOTE: Recent transformers library seems to fix the mps issue, also
|
||||
# it has made some changes causing compatibility issues with our
|
||||
# original patch. So we only apply the patch for older versions.
|
||||
|
||||
# Avoid bugs in mps backend by not using in-place operations.
|
||||
replace_llama_attn_with_non_inplace_operations()
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid device: {device}")
|
||||
|
||||
|
Reference in New Issue
Block a user