langchain/libs/community/langchain_community/embeddings/ascend.py
cold-eye 7c175e3fda
Update ascend.py (#30060)
add batch_size to fix oom when embed large amount texts

Thank you for contributing to LangChain!

- [ ] **PR title**: "package: description"
- Where "package" is whichever of langchain, community, core, etc. is
being modified. Use "docs: ..." for purely docs changes, "infra: ..."
for CI changes.
  - Example: "community: add foobar LLM"


- [ ] **PR message**: ***Delete this entire checklist*** and replace
with
    - **Description:** a description of the change
    - **Issue:** the issue # it fixes, if applicable
    - **Dependencies:** any dependencies required for this change
- **Twitter handle:** if your PR gets announced, and you'd like a
mention, we'll gladly shout you out!


- [ ] **Add tests and docs**: If you're adding a new integration, please
include
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.


- [ ] **Lint and test**: Run `make format`, `make lint` and `make test`
from the root of the package(s) you've modified. See contribution
guidelines for more: https://python.langchain.com/docs/contributing/

Additional guidelines:
- Make sure optional dependencies are imported within a function.
- Please do not add dependencies to pyproject.toml files (even optional
ones) unless they are required for unit tests.
- Most PRs should not touch more than one package.
- Changes should be backwards compatible.
- If you are adding something to community, do not re-import it in
langchain.

If no one reviews your PR within a few days, please @-mention one of
baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17.
2025-03-01 14:10:41 -05:00

138 lines
4.8 KiB
Python

import os
from typing import Any, Dict, List, Optional
from langchain_core.embeddings import Embeddings
from pydantic import BaseModel, ConfigDict, model_validator
class AscendEmbeddings(Embeddings, BaseModel):
"""
Ascend NPU accelerate Embedding model
Please ensure that you have installed CANN and torch_npu.
Example:
from langchain_community.embeddings import AscendEmbeddings
model = AscendEmbeddings(model_path=<path_to_model>,
device_id=0,
query_instruction="Represent this sentence for searching relevant passages: "
)
"""
"""model path"""
model_path: str
"""Ascend NPU device id."""
device_id: int = 0
"""Unstruntion to used for embedding query."""
query_instruction: str = ""
"""Unstruntion to used for embedding document."""
document_instruction: str = ""
use_fp16: bool = True
pooling_method: Optional[str] = "cls"
batch_size: int = 32
model: Any
tokenizer: Any
model_config = ConfigDict(protected_namespaces=())
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
try:
from transformers import AutoModel, AutoTokenizer
except ImportError as e:
raise ImportError(
"Unable to import transformers, please install with "
"`pip install -U transformers`."
) from e
try:
self.model = AutoModel.from_pretrained(self.model_path).npu().eval()
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
except Exception as e:
raise Exception(
f"Failed to load model [self.model_path], due to following error:{e}"
)
if self.use_fp16:
self.model.half()
self.encode([f"warmup {i} times" for i in range(10)])
@model_validator(mode="before")
@classmethod
def validate_environment(cls, values: Dict) -> Any:
if "model_path" not in values:
raise ValueError("model_path is required")
if not os.access(values["model_path"], os.F_OK):
raise FileNotFoundError(
f"Unable to find valid model path in [{values['model_path']}]"
)
try:
import torch_npu
except ImportError:
raise ModuleNotFoundError("torch_npu not found, please install torch_npu")
except Exception as e:
raise e
try:
torch_npu.npu.set_device(values["device_id"])
except Exception as e:
raise Exception(f"set device failed due to {e}")
return values
def encode(self, sentences: Any) -> Any:
inputs = self.tokenizer(
sentences,
padding=True,
truncation=True,
return_tensors="pt",
max_length=512,
)
try:
import torch
except ImportError as e:
raise ImportError(
"Unable to import torch, please install with `pip install -U torch`."
) from e
last_hidden_state = self.model(
inputs.input_ids.npu(), inputs.attention_mask.npu(), return_dict=True
).last_hidden_state
tmp = self.pooling(last_hidden_state, inputs["attention_mask"].npu())
embeddings = torch.nn.functional.normalize(tmp, dim=-1)
return embeddings.cpu().detach().numpy()
def pooling(self, last_hidden_state: Any, attention_mask: Any = None) -> Any:
try:
import torch
except ImportError as e:
raise ImportError(
"Unable to import torch, please install with `pip install -U torch`."
) from e
if self.pooling_method == "cls":
return last_hidden_state[:, 0]
elif self.pooling_method == "mean":
s = torch.sum(
last_hidden_state * attention_mask.unsqueeze(-1).float(), dim=-1
)
d = attention_mask.sum(dim=1, keepdim=True).float()
return s / d
else:
raise NotImplementedError(
f"Pooling method [{self.pooling_method}] not implemented"
)
def embed_documents(self, texts: List[str]) -> List[List[float]]:
try:
import numpy as np
except ImportError as e:
raise ImportError(
"Unable to import numpy, please install with `pip install -U numpy`."
) from e
embedding_list = []
for i in range(0, len(texts), self.batch_size):
texts_ = texts[i : i + self.batch_size]
emb = self.encode([self.document_instruction + text for text in texts_])
embedding_list.append(emb)
return np.concatenate(embedding_list)
def embed_query(self, text: str) -> List[float]:
return self.encode([self.query_instruction + text])[0]