Apply patch [skip ci]

This commit is contained in:
open-swe[bot] 2025-07-31 00:14:20 +00:00
parent 5b6e25a88f
commit 2e95b3fa71

View File

@ -61,41 +61,48 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
"""Whether to show a progress bar."""
def __init__(self, **kwargs: Any):
"""Initialize the sentence_transformer."""
"""Initialize the transformer model."""
super().__init__(**kwargs)
try:
import sentence_transformers # type: ignore[import]
from transformers import AutoModel, AutoTokenizer
except ImportError as exc:
msg = (
"Could not import sentence_transformers python package. "
"Please install it with `pip install sentence-transformers`."
"Could not import transformers python package. "
"Please install it with `pip install transformers`."
)
raise ImportError(msg) from exc
if self.model_kwargs.get("backend", "torch") == "ipex":
if not is_optimum_intel_available() or not is_ipex_available():
msg = f"Backend: ipex {IMPORT_ERROR.format('optimum[ipex]')}"
raise ImportError(msg)
# Extract device from model_kwargs
self.device = self.model_kwargs.get("device", "cpu")
if isinstance(self.device, str):
self.device = torch.device(self.device)
if is_optimum_intel_version("<", _MIN_OPTIMUM_VERSION):
msg = (
f"Backend: ipex requires optimum-intel>="
f"{_MIN_OPTIMUM_VERSION}. You can install it with pip: "
"`pip install --upgrade --upgrade-strategy eager "
"`optimum[ipex]`."
)
raise ImportError(msg)
# Remove device from model_kwargs as it's not a valid argument for from_pretrained
model_kwargs = {k: v for k, v in self.model_kwargs.items() if k != "device"}
from optimum.intel import IPEXSentenceTransformer # type: ignore[import]
model_cls = IPEXSentenceTransformer
else:
model_cls = sentence_transformers.SentenceTransformer
self._client = model_cls(
self.model_name, cache_folder=self.cache_folder, **self.model_kwargs
# Load tokenizer and model
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
cache_dir=self.cache_folder,
**model_kwargs
)
self.model = AutoModel.from_pretrained(
self.model_name,
cache_dir=self.cache_folder,
**model_kwargs
)
self.model.to(self.device)
self.model.eval()
# Warn about multi-process not being supported
if self.multi_process:
import warnings
warnings.warn(
"Multi-process encoding is not supported with the transformers "
"implementation. This parameter will be ignored.",
UserWarning,
stacklevel=2
)
model_config = ConfigDict(
extra="forbid",
@ -171,3 +178,4 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
return self._embed([text], embed_kwargs)[0]