diff --git a/libs/partners/huggingface/langchain_huggingface/embeddings/huggingface.py b/libs/partners/huggingface/langchain_huggingface/embeddings/huggingface.py index 299e5a52137..596ea36cf0c 100644 --- a/libs/partners/huggingface/langchain_huggingface/embeddings/huggingface.py +++ b/libs/partners/huggingface/langchain_huggingface/embeddings/huggingface.py @@ -110,6 +110,82 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings): populate_by_name=True, ) + def _mean_pooling( + self, model_output: Any, attention_mask: torch.Tensor + ) -> torch.Tensor: + """Apply mean pooling to model output.""" + # Extract token embeddings from model output + token_embeddings = model_output[0] + + # Expand attention mask for broadcasting + input_mask_expanded = ( + attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + ) + + # Apply mean pooling + sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) + sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) + return sum_embeddings / sum_mask + + def encode( + self, + texts: list[str], + batch_size: int = 32, + show_progress_bar: Optional[bool] = None, + normalize_embeddings: bool = False, + **kwargs: Any, + ) -> np.ndarray: + """Encode texts into embeddings. + + Args: + texts: The list of texts to embed. + batch_size: Batch size for encoding. + show_progress_bar: Whether to show progress bar. + normalize_embeddings: Whether to normalize embeddings. + **kwargs: Additional keyword arguments. + + Returns: + Array of embeddings. + """ + if show_progress_bar is None: + show_progress_bar = self.show_progress + + all_embeddings = [] + + # Process in batches + for i in range(0, len(texts), batch_size): + batch_texts = texts[i : i + batch_size] + + # Tokenize texts + encoded_input = self.tokenizer( + batch_texts, + padding=True, + truncation=True, + return_tensors="pt", + **kwargs, + ) + + # Move to device + encoded_input = {k: v.to(self.device) for k, v in encoded_input.items()} + + # Generate embeddings + with torch.no_grad(): + model_output = self.model(**encoded_input) + + # Apply mean pooling + embeddings = self._mean_pooling( + model_output, encoded_input["attention_mask"] + ) + + # Normalize if requested + if normalize_embeddings: + embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) + + all_embeddings.append(embeddings.cpu().numpy()) + + # Concatenate all embeddings + return np.vstack(all_embeddings) + def _embed( self, texts: list[str], encode_kwargs: dict[str, Any] ) -> list[list[float]]: @@ -118,35 +194,19 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings): Args: texts: The list of texts to embed. encode_kwargs: Keyword arguments to pass when calling the - `encode` method for the documents of the SentenceTransformer encode method. Returns: List of embeddings, one for each text. """ - import sentence_transformers # type: ignore[import] - texts = [x.replace("\n", " ") for x in texts] - if self.multi_process: - pool = self._client.start_multi_process_pool() - embeddings = self._client.encode_multi_process(texts, pool) - sentence_transformers.SentenceTransformer.stop_multi_process_pool(pool) - else: - embeddings = self._client.encode( - texts, - show_progress_bar=self.show_progress, - **encode_kwargs, - ) - - if isinstance(embeddings, list): - msg = ( - "Expected embeddings to be a Tensor or a numpy array, " - "got a list instead." - ) - raise TypeError(msg) - - return embeddings.tolist() # type: ignore[return-type] + embeddings = self.encode( + texts, + show_progress_bar=self.show_progress, + **encode_kwargs, + ) + return embeddings.tolist() def embed_documents(self, texts: list[str]) -> list[list[float]]: """Compute doc embeddings using a HuggingFace transformer model. @@ -179,3 +239,4 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings): +