mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-05 02:51:07 +00:00
fix(model): Fix device MPS load llama error (#1033)
This commit is contained in:
@@ -34,6 +34,7 @@ def forward(
|
|||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
|
padding_mask: Optional[torch.LongTensor] = None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
@@ -166,11 +166,22 @@ def huggingface_loader(llm_adapter: LLMModelAdapter, model_params: ModelParamete
|
|||||||
|
|
||||||
elif device == "mps":
|
elif device == "mps":
|
||||||
kwargs = {"torch_dtype": torch.float16}
|
kwargs = {"torch_dtype": torch.float16}
|
||||||
from dbgpt.model.llm.monkey_patch import (
|
|
||||||
replace_llama_attn_with_non_inplace_operations,
|
|
||||||
)
|
|
||||||
|
|
||||||
replace_llama_attn_with_non_inplace_operations()
|
import transformers
|
||||||
|
|
||||||
|
version = tuple(int(v) for v in transformers.__version__.split("."))
|
||||||
|
if version < (4, 35, 0):
|
||||||
|
from dbgpt.model.llm.monkey_patch import (
|
||||||
|
replace_llama_attn_with_non_inplace_operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTE: Recent transformers library seems to fix the mps issue, also
|
||||||
|
# it has made some changes causing compatibility issues with our
|
||||||
|
# original patch. So we only apply the patch for older versions.
|
||||||
|
|
||||||
|
# Avoid bugs in mps backend by not using in-place operations.
|
||||||
|
replace_llama_attn_with_non_inplace_operations()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid device: {device}")
|
raise ValueError(f"Invalid device: {device}")
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user