diff --git a/libs/community/langchain_community/embeddings/bedrock.py b/libs/community/langchain_community/embeddings/bedrock.py index 529809fb911..7ab94df4dcb 100644 --- a/libs/community/langchain_community/embeddings/bedrock.py +++ b/libs/community/langchain_community/embeddings/bedrock.py @@ -3,6 +3,7 @@ import json import os from typing import Any, Dict, List, Optional +import numpy as np from langchain_core.embeddings import Embeddings from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator from langchain_core.runnables.config import run_in_executor @@ -64,6 +65,9 @@ class BedrockEmbeddings(BaseModel, Embeddings): endpoint_url: Optional[str] = None """Needed if you don't want to default to us-east-1 endpoint""" + normalize: bool = False + """Whether the embeddings should be normalized to unit vectors""" + class Config: """Configuration for this pydantic object.""" @@ -145,6 +149,12 @@ class BedrockEmbeddings(BaseModel, Embeddings): except Exception as e: raise ValueError(f"Error raised by inference endpoint: {e}") + def _normalize_vector(self, embeddings: List[float]) -> List[float]: + """Normalize the embedding to a unit vector.""" + emb = np.array(embeddings) + norm_emb = emb / np.linalg.norm(emb) + return norm_emb.tolist() + def embed_documents(self, texts: List[str]) -> List[List[float]]: """Compute doc embeddings using a Bedrock model. @@ -157,7 +167,12 @@ class BedrockEmbeddings(BaseModel, Embeddings): results = [] for text in texts: response = self._embedding_func(text) + + if self.normalize: + response = self._normalize_vector(response) + results.append(response) + return results def embed_query(self, text: str) -> List[float]: @@ -169,7 +184,12 @@ class BedrockEmbeddings(BaseModel, Embeddings): Returns: Embeddings for the text. """ - return self._embedding_func(text) + embedding = self._embedding_func(text) + + if self.normalize: + return self._normalize_vector(embedding) + + return embedding async def aembed_query(self, text: str) -> List[float]: """Asynchronous compute query embeddings using a Bedrock model.