mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-19 19:11:33 +00:00
parent
53b14de636
commit
b2564a6391
@ -13,7 +13,10 @@ def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
|
|||||||
X = np.array(X)
|
X = np.array(X)
|
||||||
Y = np.array(Y)
|
Y = np.array(Y)
|
||||||
if X.shape[1] != Y.shape[1]:
|
if X.shape[1] != Y.shape[1]:
|
||||||
raise ValueError("Number of columns in X and Y must be the same.")
|
raise ValueError(
|
||||||
|
f"Number of columns in X and Y must be the same. X has shape {X.shape} "
|
||||||
|
f"and Y has shape {Y.shape}."
|
||||||
|
)
|
||||||
|
|
||||||
X_norm = np.linalg.norm(X, axis=1)
|
X_norm = np.linalg.norm(X, axis=1)
|
||||||
Y_norm = np.linalg.norm(Y, axis=1)
|
Y_norm = np.linalg.norm(Y, axis=1)
|
||||||
|
@ -16,7 +16,9 @@ def maximal_marginal_relevance(
|
|||||||
"""Calculate maximal marginal relevance."""
|
"""Calculate maximal marginal relevance."""
|
||||||
if min(k, len(embedding_list)) <= 0:
|
if min(k, len(embedding_list)) <= 0:
|
||||||
return []
|
return []
|
||||||
similarity_to_query = cosine_similarity([query_embedding], embedding_list)[0]
|
if query_embedding.ndim == 1:
|
||||||
|
query_embedding = np.expand_dims(query_embedding, axis=0)
|
||||||
|
similarity_to_query = cosine_similarity(query_embedding, embedding_list)[0]
|
||||||
most_similar = int(np.argmax(similarity_to_query))
|
most_similar = int(np.argmax(similarity_to_query))
|
||||||
idxs = [most_similar]
|
idxs = [most_similar]
|
||||||
selected = np.array([embedding_list[most_similar]])
|
selected = np.array([embedding_list[most_similar]])
|
||||||
|
0
tests/unit_tests/vectorstores/__init__.py
Normal file
0
tests/unit_tests/vectorstores/__init__.py
Normal file
54
tests/unit_tests/vectorstores/test_utils.py
Normal file
54
tests/unit_tests/vectorstores/test_utils.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
"""Test vector store utility functions."""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from langchain.vectorstores.utils import maximal_marginal_relevance
|
||||||
|
|
||||||
|
|
||||||
|
def test_maximal_marginal_relevance_lambda_zero() -> None:
|
||||||
|
query_embedding = np.random.random(size=5)
|
||||||
|
embedding_list = [query_embedding, query_embedding, np.zeros(5)]
|
||||||
|
expected = [0, 2]
|
||||||
|
actual = maximal_marginal_relevance(
|
||||||
|
query_embedding, embedding_list, lambda_mult=0, k=2
|
||||||
|
)
|
||||||
|
assert expected == actual
|
||||||
|
|
||||||
|
|
||||||
|
def test_maximal_marginal_relevance_lambda_one() -> None:
|
||||||
|
query_embedding = np.random.random(size=5)
|
||||||
|
embedding_list = [query_embedding, query_embedding, np.zeros(5)]
|
||||||
|
expected = [0, 1]
|
||||||
|
actual = maximal_marginal_relevance(
|
||||||
|
query_embedding, embedding_list, lambda_mult=1, k=2
|
||||||
|
)
|
||||||
|
assert expected == actual
|
||||||
|
|
||||||
|
|
||||||
|
def test_maximal_marginal_relevance() -> None:
|
||||||
|
query_embedding = np.array([1, 0])
|
||||||
|
# Vectors that are 30, 45 and 75 degrees from query vector (cosine similarity of
|
||||||
|
# 0.87, 0.71, 0.26) and the latter two are 15 and 60 degree from the first
|
||||||
|
# (cosine similarity 0.97 and 0.71). So for 3rd vector be chosen, must be case that
|
||||||
|
# 0.71lambda - 0.97(1 - lambda) < 0.26lambda - 0.71(1-lambda)
|
||||||
|
# -> lambda ~< .26 / .71
|
||||||
|
embedding_list = [[3**0.5, 1], [1, 1], [1, 2 + (3**0.5)]]
|
||||||
|
expected = [0, 2]
|
||||||
|
actual = maximal_marginal_relevance(
|
||||||
|
query_embedding, embedding_list, lambda_mult=(25 / 71), k=2
|
||||||
|
)
|
||||||
|
assert expected == actual
|
||||||
|
|
||||||
|
expected = [0, 1]
|
||||||
|
actual = maximal_marginal_relevance(
|
||||||
|
query_embedding, embedding_list, lambda_mult=(27 / 71), k=2
|
||||||
|
)
|
||||||
|
assert expected == actual
|
||||||
|
|
||||||
|
|
||||||
|
def test_maximal_marginal_relevance_query_dim() -> None:
|
||||||
|
query_embedding = np.random.random(size=5)
|
||||||
|
query_embedding_2d = query_embedding.reshape((1, 5))
|
||||||
|
embedding_list = np.random.random(size=(4, 5)).tolist()
|
||||||
|
first = maximal_marginal_relevance(query_embedding, embedding_list)
|
||||||
|
second = maximal_marginal_relevance(query_embedding_2d, embedding_list)
|
||||||
|
assert first == second
|
Loading…
Reference in New Issue
Block a user