This commit is contained in:
Eugene Yurtsev
2024-04-05 15:09:58 -04:00
parent 0486469cf3
commit 4e2350ae7d
2 changed files with 15 additions and 179 deletions

View File

@@ -20,8 +20,8 @@ from typing import (
import sqlalchemy
from langchain_core._api import warn_deprecated
from sqlalchemy import SQLColumnExpression, delete, func
from sqlalchemy.dialects.postgresql import JSON, JSONB, UUID
from sqlalchemy import SQLColumnExpression, cast, delete, func
from sqlalchemy.dialects.postgresql import JSON, JSONB, JSONPATH, UUID
from sqlalchemy.orm import Session, relationship
try:
@@ -323,9 +323,9 @@ class PGVector(VectorStore):
self.create_tables_if_not_exists()
self.create_collection()
def __del__(self) -> None:
if isinstance(self._bind, sqlalchemy.engine.Connection):
self._bind.close()
# def __del__(self) -> None:
# if isinstance(self._bind, sqlalchemy.engine.Connection):
# self._bind.close()
@property
def embeddings(self) -> Embeddings:
@@ -669,8 +669,8 @@ class PGVector(VectorStore):
native = COMPARISONS_TO_NATIVE[operator]
return func.jsonb_path_match(
self.EmbeddingStore.cmetadata,
f"$.{field} {native} $value",
json.dumps({"value": filter_value}),
cast(f"$.{field} {native} $value", JSONPATH),
cast({"value": filter_value}, JSONB),
)
elif operator == "$between":
# Use AND with two comparisons
@@ -678,13 +678,13 @@ class PGVector(VectorStore):
lower_bound = func.jsonb_path_match(
self.EmbeddingStore.cmetadata,
f"$.{field} >= $value",
json.dumps({"value": low}),
cast(f"$.{field} >= $value", JSONPATH),
cast({"value": low}, JSONB),
)
upper_bound = func.jsonb_path_match(
self.EmbeddingStore.cmetadata,
f"$.{field} <= $value",
json.dumps({"value": high}),
cast(f"$.{field} <= $value", JSONPATH),
cast({"value": high}, JSONB),
)
return sqlalchemy.and_(lower_bound, upper_bound)
elif operator in {"$in", "$nin", "$like", "$ilike"}:

View File

@@ -4,13 +4,14 @@ from typing import Any, Dict, Generator, List, Type, Union
import pytest
import sqlalchemy
from langchain_core.documents import Document
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import Session
from langchain_postgres.vectorstores import (
SUPPORTED_OPERATORS,
PGVector,
)
from langchain_core.documents import Document
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import Session
from tests.integration_tests.fake_embeddings import FakeEmbeddings
from tests.integration_tests.fixtures.filtering_test_cases import (
DOCUMENTS,
@@ -321,34 +322,6 @@ def test_pgvector_retriever_search_threshold_custom_normalization_fn() -> None:
assert output == []
def test_pgvector_max_marginal_relevance_search() -> None:
"""Test max marginal relevance search."""
texts = ["foo", "bar", "baz"]
docsearch = PGVector.from_texts(
texts=texts,
collection_name="test_collection",
embedding=FakeEmbeddingsWithAdaDimension(),
connection_string=CONNECTION_STRING,
pre_delete_collection=True,
)
output = docsearch.max_marginal_relevance_search("foo", k=1, fetch_k=3)
assert output == [Document(page_content="foo")]
def test_pgvector_max_marginal_relevance_search_with_score() -> None:
"""Test max marginal relevance search with relevance scores."""
texts = ["foo", "bar", "baz"]
docsearch = PGVector.from_texts(
texts=texts,
collection_name="test_collection",
embedding=FakeEmbeddingsWithAdaDimension(),
connection_string=CONNECTION_STRING,
pre_delete_collection=True,
)
output = docsearch.max_marginal_relevance_search_with_score("foo", k=1, fetch_k=3)
assert output == [(Document(page_content="foo"), 0.0)]
def test_pgvector_with_custom_connection() -> None:
"""Test construction using a custom connection."""
texts = ["foo", "bar", "baz"]
@@ -485,143 +458,6 @@ def test_invalid_filters(pgvector: PGVector, invalid_filter: Any) -> None:
pgvector._create_filter_clause(invalid_filter)
@pytest.mark.parametrize(
"filter,compiled",
[
({"id 'evil code'": 2}, ValueError),
(
{"id": "'evil code' == 2"},
(
"jsonb_path_match(langchain_pg_embedding.cmetadata, "
"'$.id == $value', "
"'{\"value\": \"''evil code'' == 2\"}')"
),
),
(
{"name": 'a"b'},
(
"jsonb_path_match(langchain_pg_embedding.cmetadata, "
"'$.name == $value', "
'\'{"value": "a\\\\"b"}\')'
),
),
],
)
def test_evil_code(
pgvector: PGVector, filter: Any, compiled: Union[Type[Exception], str]
) -> None:
"""Test evil code."""
if isinstance(compiled, str):
clause = pgvector._create_filter_clause(filter)
compiled_stmt = str(
clause.compile(
dialect=postgresql.dialect(),
compile_kwargs={
# This substitutes the parameters with their actual values
"literal_binds": True
},
)
)
assert compiled_stmt == compiled
else:
with pytest.raises(compiled):
pgvector._create_filter_clause(filter)
@pytest.mark.parametrize(
"filter,compiled",
[
(
{"id": 2},
"jsonb_path_match(langchain_pg_embedding.cmetadata, '$.id == $value', "
"'{\"value\": 2}')",
),
(
{"id": {"$eq": 2}},
(
"jsonb_path_match(langchain_pg_embedding.cmetadata, '$.id == $value', "
"'{\"value\": 2}')"
),
),
(
{"name": "foo"},
(
"jsonb_path_match(langchain_pg_embedding.cmetadata, "
"'$.name == $value', "
'\'{"value": "foo"}\')'
),
),
(
{"id": {"$ne": 2}},
(
"jsonb_path_match(langchain_pg_embedding.cmetadata, '$.id != $value', "
"'{\"value\": 2}')"
),
),
(
{"id": {"$gt": 2}},
(
"jsonb_path_match(langchain_pg_embedding.cmetadata, '$.id > $value', "
"'{\"value\": 2}')"
),
),
(
{"id": {"$gte": 2}},
(
"jsonb_path_match(langchain_pg_embedding.cmetadata, '$.id >= $value', "
"'{\"value\": 2}')"
),
),
(
{"id": {"$lt": 2}},
(
"jsonb_path_match(langchain_pg_embedding.cmetadata, '$.id < $value', "
"'{\"value\": 2}')"
),
),
(
{"id": {"$lte": 2}},
(
"jsonb_path_match(langchain_pg_embedding.cmetadata, '$.id <= $value', "
"'{\"value\": 2}')"
),
),
(
{"name": {"$ilike": "foo"}},
"langchain_pg_embedding.cmetadata ->> 'name' ILIKE 'foo'",
),
(
{"name": {"$like": "foo"}},
"langchain_pg_embedding.cmetadata ->> 'name' LIKE 'foo'",
),
(
{"$or": [{"id": 1}, {"id": 2}]},
# Please note that this might not be super optimized
# Another way to phrase the query is as
# langchain_pg_embedding.cmetadata @@ '($.id == 1 || $.id == 2)'
"jsonb_path_match(langchain_pg_embedding.cmetadata, '$.id == $value', "
"'{\"value\": 1}') OR jsonb_path_match(langchain_pg_embedding.cmetadata, "
"'$.id == $value', '{\"value\": 2}')",
),
],
)
def test_pgvector_query_compilation(
pgvector: PGVector, filter: Any, compiled: str
) -> None:
"""Test translation from IR to SQL"""
clause = pgvector._create_filter_clause(filter)
compiled_stmt = str(
clause.compile(
dialect=postgresql.dialect(),
compile_kwargs={
# This substitutes the parameters with their actual values
"literal_binds": True
},
)
)
assert compiled_stmt == compiled
def test_validate_operators() -> None:
"""Verify that all operators have been categorized."""
assert sorted(SUPPORTED_OPERATORS) == [