mirror of
				https://github.com/hwchase17/langchain.git
				synced 2025-10-31 07:41:40 +00:00 
			
		
		
		
	I made a change to how was implemented the support for GPU in `FastEmbedEmbeddings` to be more consistent with the existing implementation `langchain-qdrant` sparse embeddings implementation It is directly enabling to provide the list of ONNX execution providers: https://github.com/langchain-ai/langchain/blob/master/libs/partners/qdrant/langchain_qdrant/fastembed_sparse.py#L15 It is a bit less clear to a user that just wants to enable GPU, but gives more capabilities to work with other execution providers that are not the `CUDAExecutionProvider`, and is more future proof Sorry for the disturbance @ccurme > Nice to see you just moved to `uv`! It is so much nicer to run format/lint/test! No need to manually rerun the `poetry install` with all required extras now
		
			
				
	
	
		
			153 lines
		
	
	
		
			4.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			153 lines
		
	
	
		
			4.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import importlib
 | |
| import importlib.metadata
 | |
| from typing import Any, Dict, List, Literal, Optional, Sequence, cast
 | |
| 
 | |
| import numpy as np
 | |
| from langchain_core.embeddings import Embeddings
 | |
| from langchain_core.utils import pre_init
 | |
| from pydantic import BaseModel, ConfigDict
 | |
| 
 | |
| MIN_VERSION = "0.2.0"
 | |
| 
 | |
| 
 | |
| class FastEmbedEmbeddings(BaseModel, Embeddings):
 | |
|     """Qdrant FastEmbedding models.
 | |
| 
 | |
|     FastEmbed is a lightweight, fast, Python library built for embedding generation.
 | |
|     See more documentation at:
 | |
|     * https://github.com/qdrant/fastembed/
 | |
|     * https://qdrant.github.io/fastembed/
 | |
| 
 | |
|     To use this class, you must install the `fastembed` Python package.
 | |
| 
 | |
|     `pip install fastembed`
 | |
|     Example:
 | |
|         from langchain_community.embeddings import FastEmbedEmbeddings
 | |
|         fastembed = FastEmbedEmbeddings()
 | |
|     """
 | |
| 
 | |
|     model_name: str = "BAAI/bge-small-en-v1.5"
 | |
|     """Name of the FastEmbedding model to use
 | |
|     Defaults to "BAAI/bge-small-en-v1.5"
 | |
|     Find the list of supported models at
 | |
|     https://qdrant.github.io/fastembed/examples/Supported_Models/
 | |
|     """
 | |
| 
 | |
|     max_length: int = 512
 | |
|     """The maximum number of tokens. Defaults to 512.
 | |
|     Unknown behavior for values > 512.
 | |
|     """
 | |
| 
 | |
|     cache_dir: Optional[str] = None
 | |
|     """The path to the cache directory.
 | |
|     Defaults to `local_cache` in the parent directory
 | |
|     """
 | |
| 
 | |
|     threads: Optional[int] = None
 | |
|     """The number of threads single onnxruntime session can use.
 | |
|     Defaults to None
 | |
|     """
 | |
| 
 | |
|     doc_embed_type: Literal["default", "passage"] = "default"
 | |
|     """Type of embedding to use for documents
 | |
|     The available options are: "default" and "passage"
 | |
|     """
 | |
| 
 | |
|     batch_size: int = 256
 | |
|     """Batch size for encoding. Higher values will use more memory, but be faster.
 | |
|     Defaults to 256.
 | |
|     """
 | |
| 
 | |
|     parallel: Optional[int] = None
 | |
|     """If `>1`, parallel encoding is used, recommended for encoding of large datasets.
 | |
|     If `0`, use all available cores.
 | |
|     If `None`, don't use data-parallel processing, use default onnxruntime threading.
 | |
|     Defaults to `None`.
 | |
|     """
 | |
| 
 | |
|     providers: Optional[Sequence[Any]] = None
 | |
|     """List of ONNX execution providers. Use `["CUDAExecutionProvider"]` to enable the
 | |
|     use of GPU when generating embeddings. This requires to install `fastembed-gpu`
 | |
|     instead of `fastembed`. See https://qdrant.github.io/fastembed/examples/FastEmbed_GPU
 | |
|     for more details.
 | |
|     Defaults to `None`.
 | |
|     """
 | |
| 
 | |
|     model: Any = None  # : :meta private:
 | |
| 
 | |
|     model_config = ConfigDict(extra="allow", protected_namespaces=())
 | |
| 
 | |
|     @pre_init
 | |
|     def validate_environment(cls, values: Dict) -> Dict:
 | |
|         """Validate that FastEmbed has been installed."""
 | |
|         model_name = values.get("model_name")
 | |
|         max_length = values.get("max_length")
 | |
|         cache_dir = values.get("cache_dir")
 | |
|         threads = values.get("threads")
 | |
|         providers = values.get("providers")
 | |
|         pkg_to_install = (
 | |
|             "fastembed-gpu"
 | |
|             if providers and "CUDAExecutionProvider" in providers
 | |
|             else "fastembed"
 | |
|         )
 | |
| 
 | |
|         try:
 | |
|             fastembed = importlib.import_module("fastembed")
 | |
| 
 | |
|         except ModuleNotFoundError:
 | |
|             raise ImportError(
 | |
|                 "Could not import 'fastembed' Python package. "
 | |
|                 f"Please install it with `pip install {pkg_to_install}`."
 | |
|             )
 | |
| 
 | |
|         if importlib.metadata.version(pkg_to_install) < MIN_VERSION:
 | |
|             raise ImportError(
 | |
|                 f"FastEmbedEmbeddings requires "
 | |
|                 f'`pip install -U "{pkg_to_install}>={MIN_VERSION}"`.'
 | |
|             )
 | |
| 
 | |
|         values["model"] = fastembed.TextEmbedding(
 | |
|             model_name=model_name,
 | |
|             max_length=max_length,
 | |
|             cache_dir=cache_dir,
 | |
|             threads=threads,
 | |
|             providers=providers,
 | |
|         )
 | |
|         return values
 | |
| 
 | |
|     def embed_documents(self, texts: List[str]) -> List[List[float]]:
 | |
|         """Generate embeddings for documents using FastEmbed.
 | |
| 
 | |
|         Args:
 | |
|             texts: The list of texts to embed.
 | |
| 
 | |
|         Returns:
 | |
|             List of embeddings, one for each text.
 | |
|         """
 | |
|         embeddings: List[np.ndarray]
 | |
|         if self.doc_embed_type == "passage":
 | |
|             embeddings = self.model.passage_embed(
 | |
|                 texts, batch_size=self.batch_size, parallel=self.parallel
 | |
|             )
 | |
|         else:
 | |
|             embeddings = self.model.embed(
 | |
|                 texts, batch_size=self.batch_size, parallel=self.parallel
 | |
|             )
 | |
|         return [cast(List[float], e.tolist()) for e in embeddings]
 | |
| 
 | |
|     def embed_query(self, text: str) -> List[float]:
 | |
|         """Generate query embeddings using FastEmbed.
 | |
| 
 | |
|         Args:
 | |
|             text: The text to embed.
 | |
| 
 | |
|         Returns:
 | |
|             Embeddings for the text.
 | |
|         """
 | |
|         query_embeddings: np.ndarray = next(
 | |
|             self.model.query_embed(
 | |
|                 text, batch_size=self.batch_size, parallel=self.parallel
 | |
|             )
 | |
|         )
 | |
|         return cast(List[float], query_embeddings.tolist())
 |