mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-16 16:11:02 +00:00
Apply patch [skip ci]
This commit is contained in:
parent
5b6e25a88f
commit
2e95b3fa71
@ -61,41 +61,48 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
|||||||
"""Whether to show a progress bar."""
|
"""Whether to show a progress bar."""
|
||||||
|
|
||||||
def __init__(self, **kwargs: Any):
|
def __init__(self, **kwargs: Any):
|
||||||
"""Initialize the sentence_transformer."""
|
"""Initialize the transformer model."""
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
try:
|
try:
|
||||||
import sentence_transformers # type: ignore[import]
|
from transformers import AutoModel, AutoTokenizer
|
||||||
except ImportError as exc:
|
except ImportError as exc:
|
||||||
msg = (
|
msg = (
|
||||||
"Could not import sentence_transformers python package. "
|
"Could not import transformers python package. "
|
||||||
"Please install it with `pip install sentence-transformers`."
|
"Please install it with `pip install transformers`."
|
||||||
)
|
)
|
||||||
raise ImportError(msg) from exc
|
raise ImportError(msg) from exc
|
||||||
|
|
||||||
if self.model_kwargs.get("backend", "torch") == "ipex":
|
# Extract device from model_kwargs
|
||||||
if not is_optimum_intel_available() or not is_ipex_available():
|
self.device = self.model_kwargs.get("device", "cpu")
|
||||||
msg = f"Backend: ipex {IMPORT_ERROR.format('optimum[ipex]')}"
|
if isinstance(self.device, str):
|
||||||
raise ImportError(msg)
|
self.device = torch.device(self.device)
|
||||||
|
|
||||||
if is_optimum_intel_version("<", _MIN_OPTIMUM_VERSION):
|
# Remove device from model_kwargs as it's not a valid argument for from_pretrained
|
||||||
msg = (
|
model_kwargs = {k: v for k, v in self.model_kwargs.items() if k != "device"}
|
||||||
f"Backend: ipex requires optimum-intel>="
|
|
||||||
f"{_MIN_OPTIMUM_VERSION}. You can install it with pip: "
|
|
||||||
"`pip install --upgrade --upgrade-strategy eager "
|
|
||||||
"`optimum[ipex]`."
|
|
||||||
)
|
|
||||||
raise ImportError(msg)
|
|
||||||
|
|
||||||
from optimum.intel import IPEXSentenceTransformer # type: ignore[import]
|
# Load tokenizer and model
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_cls = IPEXSentenceTransformer
|
self.model_name,
|
||||||
|
cache_dir=self.cache_folder,
|
||||||
else:
|
**model_kwargs
|
||||||
model_cls = sentence_transformers.SentenceTransformer
|
|
||||||
|
|
||||||
self._client = model_cls(
|
|
||||||
self.model_name, cache_folder=self.cache_folder, **self.model_kwargs
|
|
||||||
)
|
)
|
||||||
|
self.model = AutoModel.from_pretrained(
|
||||||
|
self.model_name,
|
||||||
|
cache_dir=self.cache_folder,
|
||||||
|
**model_kwargs
|
||||||
|
)
|
||||||
|
self.model.to(self.device)
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
# Warn about multi-process not being supported
|
||||||
|
if self.multi_process:
|
||||||
|
import warnings
|
||||||
|
warnings.warn(
|
||||||
|
"Multi-process encoding is not supported with the transformers "
|
||||||
|
"implementation. This parameter will be ignored.",
|
||||||
|
UserWarning,
|
||||||
|
stacklevel=2
|
||||||
|
)
|
||||||
|
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(
|
||||||
extra="forbid",
|
extra="forbid",
|
||||||
@ -171,3 +178,4 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
|||||||
return self._embed([text], embed_kwargs)[0]
|
return self._embed([text], embed_kwargs)[0]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user