community: normalize bedrock embeddings (#15103)

In this PR I added a post-processing function to normalize the
embeddings. This happens only if the new `normalize` flag is `True`.

---------

Co-authored-by: taamedag <Davide.Menini@swisscom.com>
This commit is contained in:
Davide Menini 2024-01-24 02:05:24 +01:00 committed by GitHub
parent 20fcd49348
commit 9ce177580a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3,6 +3,7 @@ import json
import os
from typing import Any, Dict, List, Optional
import numpy as np
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
from langchain_core.runnables.config import run_in_executor
@ -64,6 +65,9 @@ class BedrockEmbeddings(BaseModel, Embeddings):
endpoint_url: Optional[str] = None
"""Needed if you don't want to default to us-east-1 endpoint"""
normalize: bool = False
"""Whether the embeddings should be normalized to unit vectors"""
class Config:
"""Configuration for this pydantic object."""
@ -145,6 +149,12 @@ class BedrockEmbeddings(BaseModel, Embeddings):
except Exception as e:
raise ValueError(f"Error raised by inference endpoint: {e}")
def _normalize_vector(self, embeddings: List[float]) -> List[float]:
"""Normalize the embedding to a unit vector."""
emb = np.array(embeddings)
norm_emb = emb / np.linalg.norm(emb)
return norm_emb.tolist()
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Compute doc embeddings using a Bedrock model.
@ -157,7 +167,12 @@ class BedrockEmbeddings(BaseModel, Embeddings):
results = []
for text in texts:
response = self._embedding_func(text)
if self.normalize:
response = self._normalize_vector(response)
results.append(response)
return results
def embed_query(self, text: str) -> List[float]:
@ -169,7 +184,12 @@ class BedrockEmbeddings(BaseModel, Embeddings):
Returns:
Embeddings for the text.
"""
return self._embedding_func(text)
embedding = self._embedding_func(text)
if self.normalize:
return self._normalize_vector(embedding)
return embedding
async def aembed_query(self, text: str) -> List[float]:
"""Asynchronous compute query embeddings using a Bedrock model.