mirror of
https://github.com/hwchase17/langchain.git
synced 2026-05-05 11:12:11 +00:00
x
This commit is contained in:
@@ -20,8 +20,8 @@ from typing import (
|
||||
|
||||
import sqlalchemy
|
||||
from langchain_core._api import warn_deprecated
|
||||
from sqlalchemy import SQLColumnExpression, delete, func
|
||||
from sqlalchemy.dialects.postgresql import JSON, JSONB, UUID
|
||||
from sqlalchemy import SQLColumnExpression, cast, delete, func
|
||||
from sqlalchemy.dialects.postgresql import JSON, JSONB, JSONPATH, UUID
|
||||
from sqlalchemy.orm import Session, relationship
|
||||
|
||||
try:
|
||||
@@ -323,9 +323,9 @@ class PGVector(VectorStore):
|
||||
self.create_tables_if_not_exists()
|
||||
self.create_collection()
|
||||
|
||||
def __del__(self) -> None:
|
||||
if isinstance(self._bind, sqlalchemy.engine.Connection):
|
||||
self._bind.close()
|
||||
# def __del__(self) -> None:
|
||||
# if isinstance(self._bind, sqlalchemy.engine.Connection):
|
||||
# self._bind.close()
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Embeddings:
|
||||
@@ -669,8 +669,8 @@ class PGVector(VectorStore):
|
||||
native = COMPARISONS_TO_NATIVE[operator]
|
||||
return func.jsonb_path_match(
|
||||
self.EmbeddingStore.cmetadata,
|
||||
f"$.{field} {native} $value",
|
||||
json.dumps({"value": filter_value}),
|
||||
cast(f"$.{field} {native} $value", JSONPATH),
|
||||
cast({"value": filter_value}, JSONB),
|
||||
)
|
||||
elif operator == "$between":
|
||||
# Use AND with two comparisons
|
||||
@@ -678,13 +678,13 @@ class PGVector(VectorStore):
|
||||
|
||||
lower_bound = func.jsonb_path_match(
|
||||
self.EmbeddingStore.cmetadata,
|
||||
f"$.{field} >= $value",
|
||||
json.dumps({"value": low}),
|
||||
cast(f"$.{field} >= $value", JSONPATH),
|
||||
cast({"value": low}, JSONB),
|
||||
)
|
||||
upper_bound = func.jsonb_path_match(
|
||||
self.EmbeddingStore.cmetadata,
|
||||
f"$.{field} <= $value",
|
||||
json.dumps({"value": high}),
|
||||
cast(f"$.{field} <= $value", JSONPATH),
|
||||
cast({"value": high}, JSONB),
|
||||
)
|
||||
return sqlalchemy.and_(lower_bound, upper_bound)
|
||||
elif operator in {"$in", "$nin", "$like", "$ilike"}:
|
||||
|
||||
@@ -4,13 +4,14 @@ from typing import Any, Dict, Generator, List, Type, Union
|
||||
|
||||
import pytest
|
||||
import sqlalchemy
|
||||
from langchain_core.documents import Document
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from langchain_postgres.vectorstores import (
|
||||
SUPPORTED_OPERATORS,
|
||||
PGVector,
|
||||
)
|
||||
from langchain_core.documents import Document
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy.orm import Session
|
||||
from tests.integration_tests.fake_embeddings import FakeEmbeddings
|
||||
from tests.integration_tests.fixtures.filtering_test_cases import (
|
||||
DOCUMENTS,
|
||||
@@ -321,34 +322,6 @@ def test_pgvector_retriever_search_threshold_custom_normalization_fn() -> None:
|
||||
assert output == []
|
||||
|
||||
|
||||
def test_pgvector_max_marginal_relevance_search() -> None:
|
||||
"""Test max marginal relevance search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = PGVector.from_texts(
|
||||
texts=texts,
|
||||
collection_name="test_collection",
|
||||
embedding=FakeEmbeddingsWithAdaDimension(),
|
||||
connection_string=CONNECTION_STRING,
|
||||
pre_delete_collection=True,
|
||||
)
|
||||
output = docsearch.max_marginal_relevance_search("foo", k=1, fetch_k=3)
|
||||
assert output == [Document(page_content="foo")]
|
||||
|
||||
|
||||
def test_pgvector_max_marginal_relevance_search_with_score() -> None:
|
||||
"""Test max marginal relevance search with relevance scores."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = PGVector.from_texts(
|
||||
texts=texts,
|
||||
collection_name="test_collection",
|
||||
embedding=FakeEmbeddingsWithAdaDimension(),
|
||||
connection_string=CONNECTION_STRING,
|
||||
pre_delete_collection=True,
|
||||
)
|
||||
output = docsearch.max_marginal_relevance_search_with_score("foo", k=1, fetch_k=3)
|
||||
assert output == [(Document(page_content="foo"), 0.0)]
|
||||
|
||||
|
||||
def test_pgvector_with_custom_connection() -> None:
|
||||
"""Test construction using a custom connection."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
@@ -485,143 +458,6 @@ def test_invalid_filters(pgvector: PGVector, invalid_filter: Any) -> None:
|
||||
pgvector._create_filter_clause(invalid_filter)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"filter,compiled",
|
||||
[
|
||||
({"id 'evil code'": 2}, ValueError),
|
||||
(
|
||||
{"id": "'evil code' == 2"},
|
||||
(
|
||||
"jsonb_path_match(langchain_pg_embedding.cmetadata, "
|
||||
"'$.id == $value', "
|
||||
"'{\"value\": \"''evil code'' == 2\"}')"
|
||||
),
|
||||
),
|
||||
(
|
||||
{"name": 'a"b'},
|
||||
(
|
||||
"jsonb_path_match(langchain_pg_embedding.cmetadata, "
|
||||
"'$.name == $value', "
|
||||
'\'{"value": "a\\\\"b"}\')'
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_evil_code(
|
||||
pgvector: PGVector, filter: Any, compiled: Union[Type[Exception], str]
|
||||
) -> None:
|
||||
"""Test evil code."""
|
||||
if isinstance(compiled, str):
|
||||
clause = pgvector._create_filter_clause(filter)
|
||||
compiled_stmt = str(
|
||||
clause.compile(
|
||||
dialect=postgresql.dialect(),
|
||||
compile_kwargs={
|
||||
# This substitutes the parameters with their actual values
|
||||
"literal_binds": True
|
||||
},
|
||||
)
|
||||
)
|
||||
assert compiled_stmt == compiled
|
||||
else:
|
||||
with pytest.raises(compiled):
|
||||
pgvector._create_filter_clause(filter)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"filter,compiled",
|
||||
[
|
||||
(
|
||||
{"id": 2},
|
||||
"jsonb_path_match(langchain_pg_embedding.cmetadata, '$.id == $value', "
|
||||
"'{\"value\": 2}')",
|
||||
),
|
||||
(
|
||||
{"id": {"$eq": 2}},
|
||||
(
|
||||
"jsonb_path_match(langchain_pg_embedding.cmetadata, '$.id == $value', "
|
||||
"'{\"value\": 2}')"
|
||||
),
|
||||
),
|
||||
(
|
||||
{"name": "foo"},
|
||||
(
|
||||
"jsonb_path_match(langchain_pg_embedding.cmetadata, "
|
||||
"'$.name == $value', "
|
||||
'\'{"value": "foo"}\')'
|
||||
),
|
||||
),
|
||||
(
|
||||
{"id": {"$ne": 2}},
|
||||
(
|
||||
"jsonb_path_match(langchain_pg_embedding.cmetadata, '$.id != $value', "
|
||||
"'{\"value\": 2}')"
|
||||
),
|
||||
),
|
||||
(
|
||||
{"id": {"$gt": 2}},
|
||||
(
|
||||
"jsonb_path_match(langchain_pg_embedding.cmetadata, '$.id > $value', "
|
||||
"'{\"value\": 2}')"
|
||||
),
|
||||
),
|
||||
(
|
||||
{"id": {"$gte": 2}},
|
||||
(
|
||||
"jsonb_path_match(langchain_pg_embedding.cmetadata, '$.id >= $value', "
|
||||
"'{\"value\": 2}')"
|
||||
),
|
||||
),
|
||||
(
|
||||
{"id": {"$lt": 2}},
|
||||
(
|
||||
"jsonb_path_match(langchain_pg_embedding.cmetadata, '$.id < $value', "
|
||||
"'{\"value\": 2}')"
|
||||
),
|
||||
),
|
||||
(
|
||||
{"id": {"$lte": 2}},
|
||||
(
|
||||
"jsonb_path_match(langchain_pg_embedding.cmetadata, '$.id <= $value', "
|
||||
"'{\"value\": 2}')"
|
||||
),
|
||||
),
|
||||
(
|
||||
{"name": {"$ilike": "foo"}},
|
||||
"langchain_pg_embedding.cmetadata ->> 'name' ILIKE 'foo'",
|
||||
),
|
||||
(
|
||||
{"name": {"$like": "foo"}},
|
||||
"langchain_pg_embedding.cmetadata ->> 'name' LIKE 'foo'",
|
||||
),
|
||||
(
|
||||
{"$or": [{"id": 1}, {"id": 2}]},
|
||||
# Please note that this might not be super optimized
|
||||
# Another way to phrase the query is as
|
||||
# langchain_pg_embedding.cmetadata @@ '($.id == 1 || $.id == 2)'
|
||||
"jsonb_path_match(langchain_pg_embedding.cmetadata, '$.id == $value', "
|
||||
"'{\"value\": 1}') OR jsonb_path_match(langchain_pg_embedding.cmetadata, "
|
||||
"'$.id == $value', '{\"value\": 2}')",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_pgvector_query_compilation(
|
||||
pgvector: PGVector, filter: Any, compiled: str
|
||||
) -> None:
|
||||
"""Test translation from IR to SQL"""
|
||||
clause = pgvector._create_filter_clause(filter)
|
||||
compiled_stmt = str(
|
||||
clause.compile(
|
||||
dialect=postgresql.dialect(),
|
||||
compile_kwargs={
|
||||
# This substitutes the parameters with their actual values
|
||||
"literal_binds": True
|
||||
},
|
||||
)
|
||||
)
|
||||
assert compiled_stmt == compiled
|
||||
|
||||
|
||||
def test_validate_operators() -> None:
|
||||
"""Verify that all operators have been categorized."""
|
||||
assert sorted(SUPPORTED_OPERATORS) == [
|
||||
|
||||
Reference in New Issue
Block a user