mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-03 02:06:33 +00:00
community: normalize bedrock embeddings (#15103)
In this PR I added a post-processing function to normalize the embeddings. This happens only if the new `normalize` flag is `True`. --------- Co-authored-by: taamedag <Davide.Menini@swisscom.com>
This commit is contained in:
parent
20fcd49348
commit
9ce177580a
@ -3,6 +3,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
|
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
|
||||||
from langchain_core.runnables.config import run_in_executor
|
from langchain_core.runnables.config import run_in_executor
|
||||||
@ -64,6 +65,9 @@ class BedrockEmbeddings(BaseModel, Embeddings):
|
|||||||
endpoint_url: Optional[str] = None
|
endpoint_url: Optional[str] = None
|
||||||
"""Needed if you don't want to default to us-east-1 endpoint"""
|
"""Needed if you don't want to default to us-east-1 endpoint"""
|
||||||
|
|
||||||
|
normalize: bool = False
|
||||||
|
"""Whether the embeddings should be normalized to unit vectors"""
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
@ -145,6 +149,12 @@ class BedrockEmbeddings(BaseModel, Embeddings):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Error raised by inference endpoint: {e}")
|
raise ValueError(f"Error raised by inference endpoint: {e}")
|
||||||
|
|
||||||
|
def _normalize_vector(self, embeddings: List[float]) -> List[float]:
|
||||||
|
"""Normalize the embedding to a unit vector."""
|
||||||
|
emb = np.array(embeddings)
|
||||||
|
norm_emb = emb / np.linalg.norm(emb)
|
||||||
|
return norm_emb.tolist()
|
||||||
|
|
||||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
"""Compute doc embeddings using a Bedrock model.
|
"""Compute doc embeddings using a Bedrock model.
|
||||||
|
|
||||||
@ -157,7 +167,12 @@ class BedrockEmbeddings(BaseModel, Embeddings):
|
|||||||
results = []
|
results = []
|
||||||
for text in texts:
|
for text in texts:
|
||||||
response = self._embedding_func(text)
|
response = self._embedding_func(text)
|
||||||
|
|
||||||
|
if self.normalize:
|
||||||
|
response = self._normalize_vector(response)
|
||||||
|
|
||||||
results.append(response)
|
results.append(response)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
def embed_query(self, text: str) -> List[float]:
|
||||||
@ -169,7 +184,12 @@ class BedrockEmbeddings(BaseModel, Embeddings):
|
|||||||
Returns:
|
Returns:
|
||||||
Embeddings for the text.
|
Embeddings for the text.
|
||||||
"""
|
"""
|
||||||
return self._embedding_func(text)
|
embedding = self._embedding_func(text)
|
||||||
|
|
||||||
|
if self.normalize:
|
||||||
|
return self._normalize_vector(embedding)
|
||||||
|
|
||||||
|
return embedding
|
||||||
|
|
||||||
async def aembed_query(self, text: str) -> List[float]:
|
async def aembed_query(self, text: str) -> List[float]:
|
||||||
"""Asynchronous compute query embeddings using a Bedrock model.
|
"""Asynchronous compute query embeddings using a Bedrock model.
|
||||||
|
Loading…
Reference in New Issue
Block a user