mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-16 08:06:14 +00:00
Apply patch [skip ci]
This commit is contained in:
parent
5b6e25a88f
commit
2e95b3fa71
@ -61,41 +61,48 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
"""Whether to show a progress bar."""
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
"""Initialize the sentence_transformer."""
|
||||
"""Initialize the transformer model."""
|
||||
super().__init__(**kwargs)
|
||||
try:
|
||||
import sentence_transformers # type: ignore[import]
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
except ImportError as exc:
|
||||
msg = (
|
||||
"Could not import sentence_transformers python package. "
|
||||
"Please install it with `pip install sentence-transformers`."
|
||||
"Could not import transformers python package. "
|
||||
"Please install it with `pip install transformers`."
|
||||
)
|
||||
raise ImportError(msg) from exc
|
||||
|
||||
if self.model_kwargs.get("backend", "torch") == "ipex":
|
||||
if not is_optimum_intel_available() or not is_ipex_available():
|
||||
msg = f"Backend: ipex {IMPORT_ERROR.format('optimum[ipex]')}"
|
||||
raise ImportError(msg)
|
||||
# Extract device from model_kwargs
|
||||
self.device = self.model_kwargs.get("device", "cpu")
|
||||
if isinstance(self.device, str):
|
||||
self.device = torch.device(self.device)
|
||||
|
||||
if is_optimum_intel_version("<", _MIN_OPTIMUM_VERSION):
|
||||
msg = (
|
||||
f"Backend: ipex requires optimum-intel>="
|
||||
f"{_MIN_OPTIMUM_VERSION}. You can install it with pip: "
|
||||
"`pip install --upgrade --upgrade-strategy eager "
|
||||
"`optimum[ipex]`."
|
||||
)
|
||||
raise ImportError(msg)
|
||||
# Remove device from model_kwargs as it's not a valid argument for from_pretrained
|
||||
model_kwargs = {k: v for k, v in self.model_kwargs.items() if k != "device"}
|
||||
|
||||
from optimum.intel import IPEXSentenceTransformer # type: ignore[import]
|
||||
|
||||
model_cls = IPEXSentenceTransformer
|
||||
|
||||
else:
|
||||
model_cls = sentence_transformers.SentenceTransformer
|
||||
|
||||
self._client = model_cls(
|
||||
self.model_name, cache_folder=self.cache_folder, **self.model_kwargs
|
||||
# Load tokenizer and model
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.model_name,
|
||||
cache_dir=self.cache_folder,
|
||||
**model_kwargs
|
||||
)
|
||||
self.model = AutoModel.from_pretrained(
|
||||
self.model_name,
|
||||
cache_dir=self.cache_folder,
|
||||
**model_kwargs
|
||||
)
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
# Warn about multi-process not being supported
|
||||
if self.multi_process:
|
||||
import warnings
|
||||
warnings.warn(
|
||||
"Multi-process encoding is not supported with the transformers "
|
||||
"implementation. This parameter will be ignored.",
|
||||
UserWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="forbid",
|
||||
@ -171,3 +178,4 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
return self._embed([text], embed_kwargs)[0]
|
||||
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user