mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-11 16:01:33 +00:00
parameterized distance metrics; lint; format; tests (#4375)
# Parameterize Redis vectorstore index Redis vectorstore allows for three different distance metrics: `L2` (flat L2), `COSINE`, and `IP` (inner product). Currently, the `Redis._create_index` method hard codes the distance metric to COSINE. I've parameterized this as an argument in the `Redis.from_texts` method -- pretty simple. Fixes #4368 ## Before submitting I've added an integration test showing indexes can be instantiated with all three values in the `REDIS_DISTANCE_METRICS` literal. An example notebook seemed overkill here. Normal API documentation would be more appropriate, but no standards are in place for that yet. ## Who can review? Not sure who's responsible for the vectorstore module... Maybe @eyurtsev / @hwchase17 / @agola11 ?
This commit is contained in:
@@ -1,4 +1,6 @@
|
||||
"""Test Redis functionality."""
|
||||
import pytest
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.vectorstores.redis import Redis
|
||||
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
||||
@@ -7,6 +9,9 @@ TEST_INDEX_NAME = "test"
|
||||
TEST_REDIS_URL = "redis://localhost:6379"
|
||||
TEST_SINGLE_RESULT = [Document(page_content="foo")]
|
||||
TEST_RESULT = [Document(page_content="foo"), Document(page_content="foo")]
|
||||
COSINE_SCORE = pytest.approx(0.05, abs=0.002)
|
||||
IP_SCORE = -8.0
|
||||
EUCLIDEAN_SCORE = 1.0
|
||||
|
||||
|
||||
def drop(index_name: str) -> bool:
|
||||
@@ -58,3 +63,42 @@ def test_redis_add_texts_to_existing() -> None:
|
||||
output = docsearch.similarity_search("foo", k=2)
|
||||
assert output == TEST_RESULT
|
||||
assert drop(TEST_INDEX_NAME)
|
||||
|
||||
|
||||
def test_cosine() -> None:
|
||||
"""Test cosine distance."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = Redis.from_texts(
|
||||
texts,
|
||||
FakeEmbeddings(),
|
||||
redis_url=TEST_REDIS_URL,
|
||||
distance_metric="COSINE",
|
||||
)
|
||||
output = docsearch.similarity_search_with_score("far", k=2)
|
||||
_, score = output[1]
|
||||
assert score == COSINE_SCORE
|
||||
assert drop(docsearch.index_name)
|
||||
|
||||
|
||||
def test_l2() -> None:
|
||||
"""Test Flat L2 distance."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = Redis.from_texts(
|
||||
texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL, distance_metric="L2"
|
||||
)
|
||||
output = docsearch.similarity_search_with_score("far", k=2)
|
||||
_, score = output[1]
|
||||
assert score == EUCLIDEAN_SCORE
|
||||
assert drop(docsearch.index_name)
|
||||
|
||||
|
||||
def test_ip() -> None:
|
||||
"""Test inner product distance."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = Redis.from_texts(
|
||||
texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL, distance_metric="IP"
|
||||
)
|
||||
output = docsearch.similarity_search_with_score("far", k=2)
|
||||
_, score = output[1]
|
||||
assert score == IP_SCORE
|
||||
assert drop(docsearch.index_name)
|
||||
|
Reference in New Issue
Block a user