mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-01 09:04:03 +00:00
community: normalize bedrock embeddings (#15103)
In this PR I added a post-processing function to normalize the embeddings. This happens only if the new `normalize` flag is `True`. --------- Co-authored-by: taamedag <Davide.Menini@swisscom.com>
This commit is contained in:
parent
20fcd49348
commit
9ce177580a
@ -3,6 +3,7 @@ import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
|
||||
from langchain_core.runnables.config import run_in_executor
|
||||
@ -64,6 +65,9 @@ class BedrockEmbeddings(BaseModel, Embeddings):
|
||||
endpoint_url: Optional[str] = None
|
||||
"""Needed if you don't want to default to us-east-1 endpoint"""
|
||||
|
||||
normalize: bool = False
|
||||
"""Whether the embeddings should be normalized to unit vectors"""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
@ -145,6 +149,12 @@ class BedrockEmbeddings(BaseModel, Embeddings):
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error raised by inference endpoint: {e}")
|
||||
|
||||
def _normalize_vector(self, embeddings: List[float]) -> List[float]:
|
||||
"""Normalize the embedding to a unit vector."""
|
||||
emb = np.array(embeddings)
|
||||
norm_emb = emb / np.linalg.norm(emb)
|
||||
return norm_emb.tolist()
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Compute doc embeddings using a Bedrock model.
|
||||
|
||||
@ -157,7 +167,12 @@ class BedrockEmbeddings(BaseModel, Embeddings):
|
||||
results = []
|
||||
for text in texts:
|
||||
response = self._embedding_func(text)
|
||||
|
||||
if self.normalize:
|
||||
response = self._normalize_vector(response)
|
||||
|
||||
results.append(response)
|
||||
|
||||
return results
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
@ -169,7 +184,12 @@ class BedrockEmbeddings(BaseModel, Embeddings):
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
return self._embedding_func(text)
|
||||
embedding = self._embedding_func(text)
|
||||
|
||||
if self.normalize:
|
||||
return self._normalize_vector(embedding)
|
||||
|
||||
return embedding
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
"""Asynchronous compute query embeddings using a Bedrock model.
|
||||
|
Loading…
Reference in New Issue
Block a user