community[minor]: Revamp PGVector Filtering (#18992)

This PR makes the following updates in the pgvector database:

1. Use JSONB field for metadata instead of JSON
2. Update operator syntax to include required `$` prefix before the
operators (otherwise there will be name collisions with fields)
3. The change is non-breaking, old functionality is still the default,
but it will emit a deprecation warning
4. Previous functionality has bugs associated with comparisons due to
casting to text (so lexical ordering is used incorrectly for numeric
fields)
5. Adds an a GIN index on the JSONB field for more efficient querying
This commit is contained in:
Eugene Yurtsev
2024-03-14 16:56:00 -04:00
committed by GitHub
parent e276817e1d
commit 6cdca4355d
3 changed files with 851 additions and 45 deletions

View File

@@ -0,0 +1,222 @@
"""Module contains test cases for testing filtering of documents in vector stores.
"""
from langchain_core.documents import Document
metadatas = [
{
"name": "adam",
"date": "2021-01-01",
"count": 1,
"is_active": True,
"tags": ["a", "b"],
"location": [1.0, 2.0],
"info": {"address": "123 main st", "phone": "123-456-7890"},
"id": 1,
"height": 10.0, # Float column
"happiness": 0.9, # Float column
"sadness": 0.1, # Float column
},
{
"name": "bob",
"date": "2021-01-02",
"count": 2,
"is_active": False,
"tags": ["b", "c"],
"location": [2.0, 3.0],
"info": {"address": "456 main st", "phone": "123-456-7890"},
"id": 2,
"height": 5.7, # Float column
"happiness": 0.8, # Float column
"sadness": 0.1, # Float column
},
{
"name": "jane",
"date": "2021-01-01",
"count": 3,
"is_active": True,
"tags": ["b", "d"],
"location": [3.0, 4.0],
"info": {"address": "789 main st", "phone": "123-456-7890"},
"id": 3,
"height": 2.4, # Float column
"happiness": None,
# Sadness missing intentionally
},
]
texts = ["id {id}".format(id=metadata["id"]) for metadata in metadatas]
DOCUMENTS = [
Document(page_content=text, metadata=metadata)
for text, metadata in zip(texts, metadatas)
]
TYPE_1_FILTERING_TEST_CASES = [
# These tests only involve equality checks
(
{"id": 1},
[1],
),
# String field
(
# check name
{"name": "adam"},
[1],
),
# Boolean fields
(
{"is_active": True},
[1, 3],
),
(
{"is_active": False},
[2],
),
# And semantics for top level filtering
(
{"id": 1, "is_active": True},
[1],
),
(
{"id": 1, "is_active": False},
[],
),
]
TYPE_2_FILTERING_TEST_CASES = [
# These involve equality checks and other operators
# like $ne, $gt, $gte, $lt, $lte, $not
(
{"id": 1},
[1],
),
(
{"id": {"$ne": 1}},
[2, 3],
),
(
{"id": {"$gt": 1}},
[2, 3],
),
(
{"id": {"$gte": 1}},
[1, 2, 3],
),
(
{"id": {"$lt": 1}},
[],
),
(
{"id": {"$lte": 1}},
[1],
),
# Repeat all the same tests with name (string column)
(
{"name": "adam"},
[1],
),
(
{"name": "bob"},
[2],
),
(
{"name": {"$eq": "adam"}},
[1],
),
(
{"name": {"$ne": "adam"}},
[2, 3],
),
# And also gt, gte, lt, lte relying on lexicographical ordering
(
{"name": {"$gt": "jane"}},
[],
),
(
{"name": {"$gte": "jane"}},
[3],
),
(
{"name": {"$lt": "jane"}},
[1, 2],
),
(
{"name": {"$lte": "jane"}},
[1, 2, 3],
),
(
{"is_active": {"$eq": True}},
[1, 3],
),
(
{"is_active": {"$ne": True}},
[2],
),
# Test float column.
(
{"height": {"$gt": 5.0}},
[1, 2],
),
(
{"height": {"$gte": 5.0}},
[1, 2],
),
(
{"height": {"$lt": 5.0}},
[3],
),
(
{"height": {"$lte": 5.8}},
[2, 3],
),
]
TYPE_3_FILTERING_TEST_CASES = [
# These involve usage of AND and OR operators
(
{"$or": [{"id": 1}, {"id": 2}]},
[1, 2],
),
(
{"$or": [{"id": 1}, {"name": "bob"}]},
[1, 2],
),
(
{"$and": [{"id": 1}, {"id": 2}]},
[],
),
(
{"$or": [{"id": 1}, {"id": 2}, {"id": 3}]},
[1, 2, 3],
),
]
TYPE_4_FILTERING_TEST_CASES = [
# These involve special operators like $in, $nin, $between
# Test between
(
{"id": {"$between": (1, 2)}},
[1, 2],
),
(
{"id": {"$between": (1, 1)}},
[1],
),
(
{"name": {"$in": ["adam", "bob"]}},
[1, 2],
),
]
TYPE_5_FILTERING_TEST_CASES = [
# These involve special operators like $like, $ilike that
# may be specified to certain databases.
(
{"name": {"$like": "a%"}},
[1],
),
(
{"name": {"$like": "%a%"}}, # adam and jane
[1, 3],
),
]

View File

@@ -1,13 +1,26 @@
"""Test PGVector functionality."""
import os
from typing import List
from typing import Any, Dict, Generator, List, Type, Union
import pytest
import sqlalchemy
from langchain_core.documents import Document
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import Session
from langchain_community.vectorstores.pgvector import PGVector
from langchain_community.vectorstores.pgvector import (
SUPPORTED_OPERATORS,
PGVector,
)
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
from tests.integration_tests.vectorstores.fixtures.filtering_test_cases import (
DOCUMENTS,
TYPE_1_FILTERING_TEST_CASES,
TYPE_2_FILTERING_TEST_CASES,
TYPE_3_FILTERING_TEST_CASES,
TYPE_4_FILTERING_TEST_CASES,
TYPE_5_FILTERING_TEST_CASES,
)
# The connection string matches the default settings in the docker-compose file
# located in the root of the repository: [root]/docker/docker-compose.yml
@@ -42,7 +55,7 @@ class FakeEmbeddingsWithAdaDimension(FakeEmbeddings):
return [float(1.0)] * (ADA_TOKEN_COUNT - 1) + [float(0.0)]
def test_pgvector() -> None:
def test_pgvector(pgvector: PGVector) -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
docsearch = PGVector.from_texts(
@@ -375,3 +388,255 @@ def test_pgvector_with_custom_engine_args() -> None:
)
output = docsearch.similarity_search("foo", k=1)
assert output == [Document(page_content="foo")]
# We should reuse this test-case across other integrations
# Add database fixture using pytest
@pytest.fixture
def pgvector() -> Generator[PGVector, None, None]:
"""Create a PGVector instance."""
store = PGVector.from_documents(
documents=DOCUMENTS,
collection_name="test_collection",
embedding=FakeEmbeddingsWithAdaDimension(),
connection_string=CONNECTION_STRING,
pre_delete_collection=True,
relevance_score_fn=lambda d: d * 0,
use_jsonb=True,
)
try:
yield store
# Do clean up
finally:
store.drop_tables()
@pytest.mark.parametrize("test_filter, expected_ids", TYPE_1_FILTERING_TEST_CASES)
def test_pgvector_with_with_metadata_filters_1(
pgvector: PGVector,
test_filter: Dict[str, Any],
expected_ids: List[int],
) -> None:
"""Test end to end construction and search."""
docs = pgvector.similarity_search("meow", k=5, filter=test_filter)
assert [doc.metadata["id"] for doc in docs] == expected_ids, test_filter
@pytest.mark.parametrize("test_filter, expected_ids", TYPE_2_FILTERING_TEST_CASES)
def test_pgvector_with_with_metadata_filters_2(
pgvector: PGVector,
test_filter: Dict[str, Any],
expected_ids: List[int],
) -> None:
"""Test end to end construction and search."""
docs = pgvector.similarity_search("meow", k=5, filter=test_filter)
assert [doc.metadata["id"] for doc in docs] == expected_ids, test_filter
@pytest.mark.parametrize("test_filter, expected_ids", TYPE_3_FILTERING_TEST_CASES)
def test_pgvector_with_with_metadata_filters_3(
pgvector: PGVector,
test_filter: Dict[str, Any],
expected_ids: List[int],
) -> None:
"""Test end to end construction and search."""
docs = pgvector.similarity_search("meow", k=5, filter=test_filter)
assert [doc.metadata["id"] for doc in docs] == expected_ids, test_filter
@pytest.mark.parametrize("test_filter, expected_ids", TYPE_4_FILTERING_TEST_CASES)
def test_pgvector_with_with_metadata_filters_4(
pgvector: PGVector,
test_filter: Dict[str, Any],
expected_ids: List[int],
) -> None:
"""Test end to end construction and search."""
docs = pgvector.similarity_search("meow", k=5, filter=test_filter)
assert [doc.metadata["id"] for doc in docs] == expected_ids, test_filter
@pytest.mark.parametrize("test_filter, expected_ids", TYPE_5_FILTERING_TEST_CASES)
def test_pgvector_with_with_metadata_filters_5(
pgvector: PGVector,
test_filter: Dict[str, Any],
expected_ids: List[int],
) -> None:
"""Test end to end construction and search."""
docs = pgvector.similarity_search("meow", k=5, filter=test_filter)
assert [doc.metadata["id"] for doc in docs] == expected_ids, test_filter
@pytest.mark.parametrize(
"invalid_filter",
[
["hello"],
{
"id": 2,
"$name": "foo",
},
{"$or": {}},
{"$and": {}},
{"$between": {}},
{"$eq": {}},
],
)
def test_invalid_filters(pgvector: PGVector, invalid_filter: Any) -> None:
"""Verify that invalid filters raise an error."""
with pytest.raises(ValueError):
pgvector._create_filter_clause(invalid_filter)
@pytest.mark.parametrize(
"filter,compiled",
[
({"id 'evil code'": 2}, ValueError),
(
{"id": "'evil code' == 2"},
(
"jsonb_path_match(langchain_pg_embedding.cmetadata, "
"'$.id == $value', "
"'{\"value\": \"''evil code'' == 2\"}')"
),
),
(
{"name": 'a"b'},
(
"jsonb_path_match(langchain_pg_embedding.cmetadata, "
"'$.name == $value', "
'\'{"value": "a\\\\"b"}\')'
),
),
],
)
def test_evil_code(
pgvector: PGVector, filter: Any, compiled: Union[Type[Exception], str]
) -> None:
"""Test evil code."""
if isinstance(compiled, str):
clause = pgvector._create_filter_clause(filter)
compiled_stmt = str(
clause.compile(
dialect=postgresql.dialect(),
compile_kwargs={
# This substitutes the parameters with their actual values
"literal_binds": True
},
)
)
assert compiled_stmt == compiled
else:
with pytest.raises(compiled):
pgvector._create_filter_clause(filter)
@pytest.mark.parametrize(
"filter,compiled",
[
(
{"id": 2},
"jsonb_path_match(langchain_pg_embedding.cmetadata, '$.id == $value', "
"'{\"value\": 2}')",
),
(
{"id": {"$eq": 2}},
(
"jsonb_path_match(langchain_pg_embedding.cmetadata, '$.id == $value', "
"'{\"value\": 2}')"
),
),
(
{"name": "foo"},
(
"jsonb_path_match(langchain_pg_embedding.cmetadata, "
"'$.name == $value', "
'\'{"value": "foo"}\')'
),
),
(
{"id": {"$ne": 2}},
(
"jsonb_path_match(langchain_pg_embedding.cmetadata, '$.id != $value', "
"'{\"value\": 2}')"
),
),
(
{"id": {"$gt": 2}},
(
"jsonb_path_match(langchain_pg_embedding.cmetadata, '$.id > $value', "
"'{\"value\": 2}')"
),
),
(
{"id": {"$gte": 2}},
(
"jsonb_path_match(langchain_pg_embedding.cmetadata, '$.id >= $value', "
"'{\"value\": 2}')"
),
),
(
{"id": {"$lt": 2}},
(
"jsonb_path_match(langchain_pg_embedding.cmetadata, '$.id < $value', "
"'{\"value\": 2}')"
),
),
(
{"id": {"$lte": 2}},
(
"jsonb_path_match(langchain_pg_embedding.cmetadata, '$.id <= $value', "
"'{\"value\": 2}')"
),
),
(
{"name": {"$ilike": "foo"}},
"langchain_pg_embedding.cmetadata ->> 'name' ILIKE 'foo'",
),
(
{"name": {"$like": "foo"}},
"langchain_pg_embedding.cmetadata ->> 'name' LIKE 'foo'",
),
(
{"$or": [{"id": 1}, {"id": 2}]},
# Please note that this might not be super optimized
# Another way to phrase the query is as
# langchain_pg_embedding.cmetadata @@ '($.id == 1 || $.id == 2)'
"jsonb_path_match(langchain_pg_embedding.cmetadata, '$.id == $value', "
"'{\"value\": 1}') OR jsonb_path_match(langchain_pg_embedding.cmetadata, "
"'$.id == $value', '{\"value\": 2}')",
),
],
)
def test_pgvector_query_compilation(
pgvector: PGVector, filter: Any, compiled: str
) -> None:
"""Test translation from IR to SQL"""
clause = pgvector._create_filter_clause(filter)
compiled_stmt = str(
clause.compile(
dialect=postgresql.dialect(),
compile_kwargs={
# This substitutes the parameters with their actual values
"literal_binds": True
},
)
)
assert compiled_stmt == compiled
def test_validate_operators() -> None:
"""Verify that all operators have been categorized."""
assert sorted(SUPPORTED_OPERATORS) == [
"$and",
"$between",
"$eq",
"$gt",
"$gte",
"$ilike",
"$in",
"$like",
"$lt",
"$lte",
"$ne",
"$nin",
"$or",
]