mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-07 13:40:46 +00:00
community[minor]: Revamp PGVector Filtering (#18992)
This PR makes the following updates in the pgvector database: 1. Use JSONB field for metadata instead of JSON 2. Update operator syntax to include required `$` prefix before the operators (otherwise there will be name collisions with fields) 3. The change is non-breaking, old functionality is still the default, but it will emit a deprecation warning 4. Previous functionality has bugs associated with comparisons due to casting to text (so lexical ordering is used incorrectly for numeric fields) 5. Adds an a GIN index on the JSONB field for more efficient querying
This commit is contained in:
parent
e276817e1d
commit
6cdca4355d
@ -2,6 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import enum
|
import enum
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from typing import (
|
from typing import (
|
||||||
@ -18,8 +19,9 @@ from typing import (
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
from sqlalchemy import delete
|
from langchain_core._api import warn_deprecated
|
||||||
from sqlalchemy.dialects.postgresql import JSON, UUID
|
from sqlalchemy import SQLColumnExpression, delete, func
|
||||||
|
from sqlalchemy.dialects.postgresql import JSON, JSONB, UUID
|
||||||
from sqlalchemy.orm import Session, relationship
|
from sqlalchemy.orm import Session, relationship
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -61,8 +63,39 @@ class BaseModel(Base):
|
|||||||
|
|
||||||
_classes: Any = None
|
_classes: Any = None
|
||||||
|
|
||||||
|
COMPARISONS_TO_NATIVE = {
|
||||||
|
"$eq": "==",
|
||||||
|
"$ne": "!=",
|
||||||
|
"$lt": "<",
|
||||||
|
"$lte": "<=",
|
||||||
|
"$gt": ">",
|
||||||
|
"$gte": ">=",
|
||||||
|
}
|
||||||
|
|
||||||
def _get_embedding_collection_store(vector_dimension: Optional[int] = None) -> Any:
|
SPECIAL_CASED_OPERATORS = {
|
||||||
|
"$in",
|
||||||
|
"$nin",
|
||||||
|
"$between",
|
||||||
|
}
|
||||||
|
|
||||||
|
TEXT_OPERATORS = {
|
||||||
|
"$like",
|
||||||
|
"$ilike",
|
||||||
|
}
|
||||||
|
|
||||||
|
LOGICAL_OPERATORS = {"$and", "$or"}
|
||||||
|
|
||||||
|
SUPPORTED_OPERATORS = (
|
||||||
|
set(COMPARISONS_TO_NATIVE)
|
||||||
|
.union(TEXT_OPERATORS)
|
||||||
|
.union(LOGICAL_OPERATORS)
|
||||||
|
.union(SPECIAL_CASED_OPERATORS)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_embedding_collection_store(
|
||||||
|
vector_dimension: Optional[int] = None, *, use_jsonb: bool = True
|
||||||
|
) -> Any:
|
||||||
global _classes
|
global _classes
|
||||||
if _classes is not None:
|
if _classes is not None:
|
||||||
return _classes
|
return _classes
|
||||||
@ -111,26 +144,60 @@ def _get_embedding_collection_store(vector_dimension: Optional[int] = None) -> A
|
|||||||
created = True
|
created = True
|
||||||
return collection, created
|
return collection, created
|
||||||
|
|
||||||
class EmbeddingStore(BaseModel):
|
if use_jsonb:
|
||||||
"""Embedding store."""
|
# TODO(PRIOR TO LANDING): Create a gin index on the cmetadata field
|
||||||
|
class EmbeddingStore(BaseModel):
|
||||||
|
"""Embedding store."""
|
||||||
|
|
||||||
__tablename__ = "langchain_pg_embedding"
|
__tablename__ = "langchain_pg_embedding"
|
||||||
|
|
||||||
collection_id = sqlalchemy.Column(
|
collection_id = sqlalchemy.Column(
|
||||||
UUID(as_uuid=True),
|
UUID(as_uuid=True),
|
||||||
sqlalchemy.ForeignKey(
|
sqlalchemy.ForeignKey(
|
||||||
f"{CollectionStore.__tablename__}.uuid",
|
f"{CollectionStore.__tablename__}.uuid",
|
||||||
ondelete="CASCADE",
|
ondelete="CASCADE",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
collection = relationship(CollectionStore, back_populates="embeddings")
|
collection = relationship(CollectionStore, back_populates="embeddings")
|
||||||
|
|
||||||
embedding: Vector = sqlalchemy.Column(Vector(vector_dimension))
|
embedding: Vector = sqlalchemy.Column(Vector(vector_dimension))
|
||||||
document = sqlalchemy.Column(sqlalchemy.String, nullable=True)
|
document = sqlalchemy.Column(sqlalchemy.String, nullable=True)
|
||||||
cmetadata = sqlalchemy.Column(JSON, nullable=True)
|
cmetadata = sqlalchemy.Column(JSONB, nullable=True)
|
||||||
|
|
||||||
# custom_id : any user defined id
|
# custom_id : any user defined id
|
||||||
custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True)
|
custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True)
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
sqlalchemy.Index(
|
||||||
|
"ix_cmetadata_gin",
|
||||||
|
"cmetadata",
|
||||||
|
postgresql_using="gin",
|
||||||
|
postgresql_ops={"cmetadata": "jsonb_path_ops"},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# For backwards comaptibilty with older versions of pgvector
|
||||||
|
# This should be removed in the future (remove during migration)
|
||||||
|
class EmbeddingStore(BaseModel): # type: ignore[no-redef]
|
||||||
|
"""Embedding store."""
|
||||||
|
|
||||||
|
__tablename__ = "langchain_pg_embedding"
|
||||||
|
|
||||||
|
collection_id = sqlalchemy.Column(
|
||||||
|
UUID(as_uuid=True),
|
||||||
|
sqlalchemy.ForeignKey(
|
||||||
|
f"{CollectionStore.__tablename__}.uuid",
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
collection = relationship(CollectionStore, back_populates="embeddings")
|
||||||
|
|
||||||
|
embedding: Vector = sqlalchemy.Column(Vector(vector_dimension))
|
||||||
|
document = sqlalchemy.Column(sqlalchemy.String, nullable=True)
|
||||||
|
cmetadata = sqlalchemy.Column(JSON, nullable=True)
|
||||||
|
|
||||||
|
# custom_id : any user defined id
|
||||||
|
custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True)
|
||||||
|
|
||||||
_classes = (EmbeddingStore, CollectionStore)
|
_classes = (EmbeddingStore, CollectionStore)
|
||||||
|
|
||||||
@ -163,6 +230,11 @@ class PGVector(VectorStore):
|
|||||||
pre_delete_collection: If True, will delete the collection if it exists.
|
pre_delete_collection: If True, will delete the collection if it exists.
|
||||||
(default: False). Useful for testing.
|
(default: False). Useful for testing.
|
||||||
engine_args: SQLAlchemy's create engine arguments.
|
engine_args: SQLAlchemy's create engine arguments.
|
||||||
|
use_jsonb: Use JSONB instead of JSON for metadata. (default: True)
|
||||||
|
Strongly discouraged from using JSON as it's not as efficient
|
||||||
|
for querying.
|
||||||
|
It's provided here for backwards compatibility with older versions,
|
||||||
|
and will be removed in the future.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
@ -178,9 +250,8 @@ class PGVector(VectorStore):
|
|||||||
documents=docs,
|
documents=docs,
|
||||||
collection_name=COLLECTION_NAME,
|
collection_name=COLLECTION_NAME,
|
||||||
connection_string=CONNECTION_STRING,
|
connection_string=CONNECTION_STRING,
|
||||||
|
use_jsonb=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -197,7 +268,9 @@ class PGVector(VectorStore):
|
|||||||
*,
|
*,
|
||||||
connection: Optional[sqlalchemy.engine.Connection] = None,
|
connection: Optional[sqlalchemy.engine.Connection] = None,
|
||||||
engine_args: Optional[dict[str, Any]] = None,
|
engine_args: Optional[dict[str, Any]] = None,
|
||||||
|
use_jsonb: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""Initialize the PGVector store."""
|
||||||
self.connection_string = connection_string
|
self.connection_string = connection_string
|
||||||
self.embedding_function = embedding_function
|
self.embedding_function = embedding_function
|
||||||
self._embedding_length = embedding_length
|
self._embedding_length = embedding_length
|
||||||
@ -209,6 +282,29 @@ class PGVector(VectorStore):
|
|||||||
self.override_relevance_score_fn = relevance_score_fn
|
self.override_relevance_score_fn = relevance_score_fn
|
||||||
self.engine_args = engine_args or {}
|
self.engine_args = engine_args or {}
|
||||||
self._bind = connection if connection else self._create_engine()
|
self._bind = connection if connection else self._create_engine()
|
||||||
|
self.use_jsonb = use_jsonb
|
||||||
|
|
||||||
|
if not use_jsonb:
|
||||||
|
# Replace with a deprecation warning.
|
||||||
|
warn_deprecated(
|
||||||
|
"0.0.29",
|
||||||
|
pending=True,
|
||||||
|
message=(
|
||||||
|
"Please use JSONB instead of JSON for metadata. "
|
||||||
|
"This change will allow for more efficient querying that "
|
||||||
|
"involves filtering based on metadata."
|
||||||
|
"Please note that filtering operators have been changed "
|
||||||
|
"when using JSOB metadata to be prefixed with a $ sign "
|
||||||
|
"to avoid name collisions with columns. "
|
||||||
|
"If you're using an existing database, you will need to create a"
|
||||||
|
"db migration for your metadata column to be JSONB and update your "
|
||||||
|
"queries to use the new operators. "
|
||||||
|
),
|
||||||
|
alternative=(
|
||||||
|
"Instantiate with use_jsonb=True to use JSONB instead "
|
||||||
|
"of JSON for metadata."
|
||||||
|
),
|
||||||
|
)
|
||||||
self.__post_init__()
|
self.__post_init__()
|
||||||
|
|
||||||
def __post_init__(
|
def __post_init__(
|
||||||
@ -218,7 +314,7 @@ class PGVector(VectorStore):
|
|||||||
self.create_vector_extension()
|
self.create_vector_extension()
|
||||||
|
|
||||||
EmbeddingStore, CollectionStore = _get_embedding_collection_store(
|
EmbeddingStore, CollectionStore = _get_embedding_collection_store(
|
||||||
self._embedding_length
|
self._embedding_length, use_jsonb=self.use_jsonb
|
||||||
)
|
)
|
||||||
self.CollectionStore = CollectionStore
|
self.CollectionStore = CollectionStore
|
||||||
self.EmbeddingStore = EmbeddingStore
|
self.EmbeddingStore = EmbeddingStore
|
||||||
@ -336,6 +432,8 @@ class PGVector(VectorStore):
|
|||||||
distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
|
distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
|
||||||
connection_string: Optional[str] = None,
|
connection_string: Optional[str] = None,
|
||||||
pre_delete_collection: bool = False,
|
pre_delete_collection: bool = False,
|
||||||
|
*,
|
||||||
|
use_jsonb: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> PGVector:
|
) -> PGVector:
|
||||||
if ids is None:
|
if ids is None:
|
||||||
@ -352,6 +450,7 @@ class PGVector(VectorStore):
|
|||||||
embedding_function=embedding,
|
embedding_function=embedding,
|
||||||
distance_strategy=distance_strategy,
|
distance_strategy=distance_strategy,
|
||||||
pre_delete_collection=pre_delete_collection,
|
pre_delete_collection=pre_delete_collection,
|
||||||
|
use_jsonb=use_jsonb,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -508,7 +607,117 @@ class PGVector(VectorStore):
|
|||||||
]
|
]
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
def _create_filter_clause(self, key, value): # type: ignore[no-untyped-def]
|
def _handle_field_filter(
|
||||||
|
self,
|
||||||
|
field: str,
|
||||||
|
value: Any,
|
||||||
|
) -> SQLColumnExpression:
|
||||||
|
"""Create a filter for a specific field.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
field: name of field
|
||||||
|
value: value to filter
|
||||||
|
If provided as is then this will be an equality filter
|
||||||
|
If provided as a dictionary then this will be a filter, the key
|
||||||
|
will be the operator and the value will be the value to filter by
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
sqlalchemy expression
|
||||||
|
"""
|
||||||
|
if not isinstance(field, str):
|
||||||
|
raise ValueError(
|
||||||
|
f"field should be a string but got: {type(field)} with value: {field}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if field.startswith("$"):
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid filter condition. Expected a field but got an operator: "
|
||||||
|
f"{field}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Allow [a-zA-Z0-9_], disallow $ for now until we support escape characters
|
||||||
|
if not field.isidentifier():
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid field name: {field}. Expected a valid identifier."
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(value, dict):
|
||||||
|
# This is a filter specification
|
||||||
|
if len(value) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid filter condition. Expected a value which "
|
||||||
|
"is a dictionary with a single key that corresponds to an operator "
|
||||||
|
f"but got a dictionary with {len(value)} keys. The first few "
|
||||||
|
f"keys are: {list(value.keys())[:3]}"
|
||||||
|
)
|
||||||
|
operator, filter_value = list(value.items())[0]
|
||||||
|
# Verify that that operator is an operator
|
||||||
|
if operator not in SUPPORTED_OPERATORS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid operator: {operator}. "
|
||||||
|
f"Expected one of {SUPPORTED_OPERATORS}"
|
||||||
|
)
|
||||||
|
else: # Then we assume an equality operator
|
||||||
|
operator = "$eq"
|
||||||
|
filter_value = value
|
||||||
|
|
||||||
|
if operator in COMPARISONS_TO_NATIVE:
|
||||||
|
# Then we implement an equality filter
|
||||||
|
# native is trusted input
|
||||||
|
native = COMPARISONS_TO_NATIVE[operator]
|
||||||
|
return func.jsonb_path_match(
|
||||||
|
self.EmbeddingStore.cmetadata,
|
||||||
|
f"$.{field} {native} $value",
|
||||||
|
json.dumps({"value": filter_value}),
|
||||||
|
)
|
||||||
|
elif operator == "$between":
|
||||||
|
# Use AND with two comparisons
|
||||||
|
low, high = filter_value
|
||||||
|
|
||||||
|
lower_bound = func.jsonb_path_match(
|
||||||
|
self.EmbeddingStore.cmetadata,
|
||||||
|
f"$.{field} >= $value",
|
||||||
|
json.dumps({"value": low}),
|
||||||
|
)
|
||||||
|
upper_bound = func.jsonb_path_match(
|
||||||
|
self.EmbeddingStore.cmetadata,
|
||||||
|
f"$.{field} <= $value",
|
||||||
|
json.dumps({"value": high}),
|
||||||
|
)
|
||||||
|
return sqlalchemy.and_(lower_bound, upper_bound)
|
||||||
|
elif operator in {"$in", "$nin", "$like", "$ilike"}:
|
||||||
|
# We'll do force coercion to text
|
||||||
|
if operator in {"$in", "$nin"}:
|
||||||
|
for val in filter_value:
|
||||||
|
if not isinstance(val, (str, int, float)):
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Unsupported type: {type(val)} for value: {val}"
|
||||||
|
)
|
||||||
|
|
||||||
|
queried_field = self.EmbeddingStore.cmetadata[field].astext
|
||||||
|
|
||||||
|
if operator in {"$in"}:
|
||||||
|
return queried_field.in_([str(val) for val in filter_value])
|
||||||
|
elif operator in {"$nin"}:
|
||||||
|
return queried_field.nin_([str(val) for val in filter_value])
|
||||||
|
elif operator in {"$like"}:
|
||||||
|
return queried_field.like(filter_value)
|
||||||
|
elif operator in {"$ilike"}:
|
||||||
|
return queried_field.ilike(filter_value)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def _create_filter_clause_deprecated(self, key, value): # type: ignore[no-untyped-def]
|
||||||
|
"""Deprecated functionality.
|
||||||
|
|
||||||
|
This is for backwards compatibility with the JSON based schema for metadata.
|
||||||
|
It uses incorrect operator syntax (operators are not prefixed with $).
|
||||||
|
|
||||||
|
This implementation is not efficient, and has bugs associated with
|
||||||
|
the way that it handles numeric filter clauses.
|
||||||
|
"""
|
||||||
IN, NIN, BETWEEN, GT, LT, NE = "in", "nin", "between", "gt", "lt", "ne"
|
IN, NIN, BETWEEN, GT, LT, NE = "in", "nin", "between", "gt", "lt", "ne"
|
||||||
EQ, LIKE, CONTAINS, OR, AND = "eq", "like", "contains", "or", "and"
|
EQ, LIKE, CONTAINS, OR, AND = "eq", "like", "contains", "or", "and"
|
||||||
|
|
||||||
@ -568,6 +777,117 @@ class PGVector(VectorStore):
|
|||||||
|
|
||||||
return filter_by_metadata
|
return filter_by_metadata
|
||||||
|
|
||||||
|
def _create_filter_clause_json_deprecated(
|
||||||
|
self, filter: Any
|
||||||
|
) -> List[SQLColumnExpression]:
|
||||||
|
"""Convert filters from IR to SQL clauses.
|
||||||
|
|
||||||
|
**DEPRECATED** This functionality will be deprecated in the future.
|
||||||
|
|
||||||
|
It implements translation of filters for a schema that uses JSON
|
||||||
|
for metadata rather than the JSONB field which is more efficient
|
||||||
|
for querying.
|
||||||
|
"""
|
||||||
|
filter_clauses = []
|
||||||
|
for key, value in filter.items():
|
||||||
|
if isinstance(value, dict):
|
||||||
|
filter_by_metadata = self._create_filter_clause_deprecated(key, value)
|
||||||
|
|
||||||
|
if filter_by_metadata is not None:
|
||||||
|
filter_clauses.append(filter_by_metadata)
|
||||||
|
else:
|
||||||
|
filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext == str(
|
||||||
|
value
|
||||||
|
)
|
||||||
|
filter_clauses.append(filter_by_metadata)
|
||||||
|
return filter_clauses
|
||||||
|
|
||||||
|
def _create_filter_clause(self, filters: Any) -> Any:
|
||||||
|
"""Convert LangChain IR filter representation to matching SQLAlchemy clauses.
|
||||||
|
|
||||||
|
At the top level, we still don't know if we're working with a field
|
||||||
|
or an operator for the keys. After we've determined that we can
|
||||||
|
call the appropriate logic to handle filter creation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filters: Dictionary of filters to apply to the query.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SQLAlchemy clause to apply to the query.
|
||||||
|
"""
|
||||||
|
if isinstance(filters, dict):
|
||||||
|
if len(filters) == 1:
|
||||||
|
# The only operators allowed at the top level are $AND and $OR
|
||||||
|
# First check if an operator or a field
|
||||||
|
key, value = list(filters.items())[0]
|
||||||
|
if key.startswith("$"):
|
||||||
|
# Then it's an operator
|
||||||
|
if key.lower() not in ["$and", "$or"]:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid filter condition. Expected $and or $or "
|
||||||
|
f"but got: {key}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Then it's a field
|
||||||
|
return self._handle_field_filter(key, filters[key])
|
||||||
|
|
||||||
|
# Here we handle the $and and $or operators
|
||||||
|
if not isinstance(value, list):
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected a list, but got {type(value)} for value: {value}"
|
||||||
|
)
|
||||||
|
if key.lower() == "$and":
|
||||||
|
and_ = [self._create_filter_clause(el) for el in value]
|
||||||
|
if len(and_) > 1:
|
||||||
|
return sqlalchemy.and_(*and_)
|
||||||
|
elif len(and_) == 1:
|
||||||
|
return and_[0]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid filter condition. Expected a dictionary "
|
||||||
|
"but got an empty dictionary"
|
||||||
|
)
|
||||||
|
elif key.lower() == "$or":
|
||||||
|
or_ = [self._create_filter_clause(el) for el in value]
|
||||||
|
if len(or_) > 1:
|
||||||
|
return sqlalchemy.or_(*or_)
|
||||||
|
elif len(or_) == 1:
|
||||||
|
return or_[0]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid filter condition. Expected a dictionary "
|
||||||
|
"but got an empty dictionary"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid filter condition. Expected $and or $or "
|
||||||
|
f"but got: {key}"
|
||||||
|
)
|
||||||
|
elif len(filters) > 1:
|
||||||
|
# Then all keys have to be fields (they cannot be operators)
|
||||||
|
for key in filters.keys():
|
||||||
|
if key.startswith("$"):
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid filter condition. Expected a field but got: {key}"
|
||||||
|
)
|
||||||
|
# These should all be fields and combined using an $and operator
|
||||||
|
and_ = [self._handle_field_filter(k, v) for k, v in filters.items()]
|
||||||
|
if len(and_) > 1:
|
||||||
|
return sqlalchemy.and_(*and_)
|
||||||
|
elif len(and_) == 1:
|
||||||
|
return and_[0]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid filter condition. Expected a dictionary "
|
||||||
|
"but got an empty dictionary"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Got an empty dictionary for filters.")
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid type: Expected a dictionary but got type: {type(filters)}"
|
||||||
|
)
|
||||||
|
|
||||||
def __query_collection(
|
def __query_collection(
|
||||||
self,
|
self,
|
||||||
embedding: List[float],
|
embedding: List[float],
|
||||||
@ -580,24 +900,16 @@ class PGVector(VectorStore):
|
|||||||
if not collection:
|
if not collection:
|
||||||
raise ValueError("Collection not found")
|
raise ValueError("Collection not found")
|
||||||
|
|
||||||
filter_by = self.EmbeddingStore.collection_id == collection.uuid
|
filter_by = [self.EmbeddingStore.collection_id == collection.uuid]
|
||||||
|
if filter:
|
||||||
if filter is not None:
|
if self.use_jsonb:
|
||||||
filter_clauses = []
|
filter_clauses = self._create_filter_clause(filter)
|
||||||
|
if filter_clauses is not None:
|
||||||
for key, value in filter.items():
|
filter_by.append(filter_clauses)
|
||||||
if isinstance(value, dict):
|
else:
|
||||||
filter_by_metadata = self._create_filter_clause(key, value)
|
# Old way of doing things
|
||||||
|
filter_clauses = self._create_filter_clause_json_deprecated(filter)
|
||||||
if filter_by_metadata is not None:
|
filter_by.extend(filter_clauses)
|
||||||
filter_clauses.append(filter_by_metadata)
|
|
||||||
else:
|
|
||||||
filter_by_metadata = self.EmbeddingStore.cmetadata[
|
|
||||||
key
|
|
||||||
].astext == str(value)
|
|
||||||
filter_clauses.append(filter_by_metadata)
|
|
||||||
|
|
||||||
filter_by = sqlalchemy.and_(filter_by, *filter_clauses)
|
|
||||||
|
|
||||||
_type = self.EmbeddingStore
|
_type = self.EmbeddingStore
|
||||||
|
|
||||||
@ -606,7 +918,7 @@ class PGVector(VectorStore):
|
|||||||
self.EmbeddingStore,
|
self.EmbeddingStore,
|
||||||
self.distance_strategy(embedding).label("distance"), # type: ignore
|
self.distance_strategy(embedding).label("distance"), # type: ignore
|
||||||
)
|
)
|
||||||
.filter(filter_by)
|
.filter(*filter_by)
|
||||||
.order_by(sqlalchemy.asc("distance"))
|
.order_by(sqlalchemy.asc("distance"))
|
||||||
.join(
|
.join(
|
||||||
self.CollectionStore,
|
self.CollectionStore,
|
||||||
@ -615,6 +927,7 @@ class PGVector(VectorStore):
|
|||||||
.limit(k)
|
.limit(k)
|
||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def similarity_search_by_vector(
|
def similarity_search_by_vector(
|
||||||
@ -649,6 +962,8 @@ class PGVector(VectorStore):
|
|||||||
distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
|
distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
|
||||||
ids: Optional[List[str]] = None,
|
ids: Optional[List[str]] = None,
|
||||||
pre_delete_collection: bool = False,
|
pre_delete_collection: bool = False,
|
||||||
|
*,
|
||||||
|
use_jsonb: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> PGVector:
|
) -> PGVector:
|
||||||
"""
|
"""
|
||||||
@ -668,6 +983,7 @@ class PGVector(VectorStore):
|
|||||||
collection_name=collection_name,
|
collection_name=collection_name,
|
||||||
distance_strategy=distance_strategy,
|
distance_strategy=distance_strategy,
|
||||||
pre_delete_collection=pre_delete_collection,
|
pre_delete_collection=pre_delete_collection,
|
||||||
|
use_jsonb=use_jsonb,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -769,6 +1085,8 @@ class PGVector(VectorStore):
|
|||||||
distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
|
distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
|
||||||
ids: Optional[List[str]] = None,
|
ids: Optional[List[str]] = None,
|
||||||
pre_delete_collection: bool = False,
|
pre_delete_collection: bool = False,
|
||||||
|
*,
|
||||||
|
use_jsonb: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> PGVector:
|
) -> PGVector:
|
||||||
"""
|
"""
|
||||||
@ -792,6 +1110,7 @@ class PGVector(VectorStore):
|
|||||||
metadatas=metadatas,
|
metadatas=metadatas,
|
||||||
ids=ids,
|
ids=ids,
|
||||||
collection_name=collection_name,
|
collection_name=collection_name,
|
||||||
|
use_jsonb=use_jsonb,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -0,0 +1,222 @@
|
|||||||
|
"""Module contains test cases for testing filtering of documents in vector stores.
|
||||||
|
"""
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
|
||||||
|
metadatas = [
|
||||||
|
{
|
||||||
|
"name": "adam",
|
||||||
|
"date": "2021-01-01",
|
||||||
|
"count": 1,
|
||||||
|
"is_active": True,
|
||||||
|
"tags": ["a", "b"],
|
||||||
|
"location": [1.0, 2.0],
|
||||||
|
"info": {"address": "123 main st", "phone": "123-456-7890"},
|
||||||
|
"id": 1,
|
||||||
|
"height": 10.0, # Float column
|
||||||
|
"happiness": 0.9, # Float column
|
||||||
|
"sadness": 0.1, # Float column
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "bob",
|
||||||
|
"date": "2021-01-02",
|
||||||
|
"count": 2,
|
||||||
|
"is_active": False,
|
||||||
|
"tags": ["b", "c"],
|
||||||
|
"location": [2.0, 3.0],
|
||||||
|
"info": {"address": "456 main st", "phone": "123-456-7890"},
|
||||||
|
"id": 2,
|
||||||
|
"height": 5.7, # Float column
|
||||||
|
"happiness": 0.8, # Float column
|
||||||
|
"sadness": 0.1, # Float column
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "jane",
|
||||||
|
"date": "2021-01-01",
|
||||||
|
"count": 3,
|
||||||
|
"is_active": True,
|
||||||
|
"tags": ["b", "d"],
|
||||||
|
"location": [3.0, 4.0],
|
||||||
|
"info": {"address": "789 main st", "phone": "123-456-7890"},
|
||||||
|
"id": 3,
|
||||||
|
"height": 2.4, # Float column
|
||||||
|
"happiness": None,
|
||||||
|
# Sadness missing intentionally
|
||||||
|
},
|
||||||
|
]
|
||||||
|
texts = ["id {id}".format(id=metadata["id"]) for metadata in metadatas]
|
||||||
|
|
||||||
|
DOCUMENTS = [
|
||||||
|
Document(page_content=text, metadata=metadata)
|
||||||
|
for text, metadata in zip(texts, metadatas)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
TYPE_1_FILTERING_TEST_CASES = [
|
||||||
|
# These tests only involve equality checks
|
||||||
|
(
|
||||||
|
{"id": 1},
|
||||||
|
[1],
|
||||||
|
),
|
||||||
|
# String field
|
||||||
|
(
|
||||||
|
# check name
|
||||||
|
{"name": "adam"},
|
||||||
|
[1],
|
||||||
|
),
|
||||||
|
# Boolean fields
|
||||||
|
(
|
||||||
|
{"is_active": True},
|
||||||
|
[1, 3],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"is_active": False},
|
||||||
|
[2],
|
||||||
|
),
|
||||||
|
# And semantics for top level filtering
|
||||||
|
(
|
||||||
|
{"id": 1, "is_active": True},
|
||||||
|
[1],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"id": 1, "is_active": False},
|
||||||
|
[],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
TYPE_2_FILTERING_TEST_CASES = [
|
||||||
|
# These involve equality checks and other operators
|
||||||
|
# like $ne, $gt, $gte, $lt, $lte, $not
|
||||||
|
(
|
||||||
|
{"id": 1},
|
||||||
|
[1],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"id": {"$ne": 1}},
|
||||||
|
[2, 3],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"id": {"$gt": 1}},
|
||||||
|
[2, 3],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"id": {"$gte": 1}},
|
||||||
|
[1, 2, 3],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"id": {"$lt": 1}},
|
||||||
|
[],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"id": {"$lte": 1}},
|
||||||
|
[1],
|
||||||
|
),
|
||||||
|
# Repeat all the same tests with name (string column)
|
||||||
|
(
|
||||||
|
{"name": "adam"},
|
||||||
|
[1],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"name": "bob"},
|
||||||
|
[2],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"name": {"$eq": "adam"}},
|
||||||
|
[1],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"name": {"$ne": "adam"}},
|
||||||
|
[2, 3],
|
||||||
|
),
|
||||||
|
# And also gt, gte, lt, lte relying on lexicographical ordering
|
||||||
|
(
|
||||||
|
{"name": {"$gt": "jane"}},
|
||||||
|
[],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"name": {"$gte": "jane"}},
|
||||||
|
[3],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"name": {"$lt": "jane"}},
|
||||||
|
[1, 2],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"name": {"$lte": "jane"}},
|
||||||
|
[1, 2, 3],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"is_active": {"$eq": True}},
|
||||||
|
[1, 3],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"is_active": {"$ne": True}},
|
||||||
|
[2],
|
||||||
|
),
|
||||||
|
# Test float column.
|
||||||
|
(
|
||||||
|
{"height": {"$gt": 5.0}},
|
||||||
|
[1, 2],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"height": {"$gte": 5.0}},
|
||||||
|
[1, 2],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"height": {"$lt": 5.0}},
|
||||||
|
[3],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"height": {"$lte": 5.8}},
|
||||||
|
[2, 3],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
TYPE_3_FILTERING_TEST_CASES = [
|
||||||
|
# These involve usage of AND and OR operators
|
||||||
|
(
|
||||||
|
{"$or": [{"id": 1}, {"id": 2}]},
|
||||||
|
[1, 2],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"$or": [{"id": 1}, {"name": "bob"}]},
|
||||||
|
[1, 2],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"$and": [{"id": 1}, {"id": 2}]},
|
||||||
|
[],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"$or": [{"id": 1}, {"id": 2}, {"id": 3}]},
|
||||||
|
[1, 2, 3],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
TYPE_4_FILTERING_TEST_CASES = [
|
||||||
|
# These involve special operators like $in, $nin, $between
|
||||||
|
# Test between
|
||||||
|
(
|
||||||
|
{"id": {"$between": (1, 2)}},
|
||||||
|
[1, 2],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"id": {"$between": (1, 1)}},
|
||||||
|
[1],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"name": {"$in": ["adam", "bob"]}},
|
||||||
|
[1, 2],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
TYPE_5_FILTERING_TEST_CASES = [
|
||||||
|
# These involve special operators like $like, $ilike that
|
||||||
|
# may be specified to certain databases.
|
||||||
|
(
|
||||||
|
{"name": {"$like": "a%"}},
|
||||||
|
[1],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"name": {"$like": "%a%"}}, # adam and jane
|
||||||
|
[1, 3],
|
||||||
|
),
|
||||||
|
]
|
@ -1,13 +1,26 @@
|
|||||||
"""Test PGVector functionality."""
|
"""Test PGVector functionality."""
|
||||||
import os
|
import os
|
||||||
from typing import List
|
from typing import Any, Dict, Generator, List, Type, Union
|
||||||
|
|
||||||
|
import pytest
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from langchain_community.vectorstores.pgvector import PGVector
|
from langchain_community.vectorstores.pgvector import (
|
||||||
|
SUPPORTED_OPERATORS,
|
||||||
|
PGVector,
|
||||||
|
)
|
||||||
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
||||||
|
from tests.integration_tests.vectorstores.fixtures.filtering_test_cases import (
|
||||||
|
DOCUMENTS,
|
||||||
|
TYPE_1_FILTERING_TEST_CASES,
|
||||||
|
TYPE_2_FILTERING_TEST_CASES,
|
||||||
|
TYPE_3_FILTERING_TEST_CASES,
|
||||||
|
TYPE_4_FILTERING_TEST_CASES,
|
||||||
|
TYPE_5_FILTERING_TEST_CASES,
|
||||||
|
)
|
||||||
|
|
||||||
# The connection string matches the default settings in the docker-compose file
|
# The connection string matches the default settings in the docker-compose file
|
||||||
# located in the root of the repository: [root]/docker/docker-compose.yml
|
# located in the root of the repository: [root]/docker/docker-compose.yml
|
||||||
@ -42,7 +55,7 @@ class FakeEmbeddingsWithAdaDimension(FakeEmbeddings):
|
|||||||
return [float(1.0)] * (ADA_TOKEN_COUNT - 1) + [float(0.0)]
|
return [float(1.0)] * (ADA_TOKEN_COUNT - 1) + [float(0.0)]
|
||||||
|
|
||||||
|
|
||||||
def test_pgvector() -> None:
|
def test_pgvector(pgvector: PGVector) -> None:
|
||||||
"""Test end to end construction and search."""
|
"""Test end to end construction and search."""
|
||||||
texts = ["foo", "bar", "baz"]
|
texts = ["foo", "bar", "baz"]
|
||||||
docsearch = PGVector.from_texts(
|
docsearch = PGVector.from_texts(
|
||||||
@ -375,3 +388,255 @@ def test_pgvector_with_custom_engine_args() -> None:
|
|||||||
)
|
)
|
||||||
output = docsearch.similarity_search("foo", k=1)
|
output = docsearch.similarity_search("foo", k=1)
|
||||||
assert output == [Document(page_content="foo")]
|
assert output == [Document(page_content="foo")]
|
||||||
|
|
||||||
|
|
||||||
|
# We should reuse this test-case across other integrations
|
||||||
|
# Add database fixture using pytest
|
||||||
|
@pytest.fixture
|
||||||
|
def pgvector() -> Generator[PGVector, None, None]:
|
||||||
|
"""Create a PGVector instance."""
|
||||||
|
store = PGVector.from_documents(
|
||||||
|
documents=DOCUMENTS,
|
||||||
|
collection_name="test_collection",
|
||||||
|
embedding=FakeEmbeddingsWithAdaDimension(),
|
||||||
|
connection_string=CONNECTION_STRING,
|
||||||
|
pre_delete_collection=True,
|
||||||
|
relevance_score_fn=lambda d: d * 0,
|
||||||
|
use_jsonb=True,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
yield store
|
||||||
|
# Do clean up
|
||||||
|
finally:
|
||||||
|
store.drop_tables()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("test_filter, expected_ids", TYPE_1_FILTERING_TEST_CASES)
|
||||||
|
def test_pgvector_with_with_metadata_filters_1(
|
||||||
|
pgvector: PGVector,
|
||||||
|
test_filter: Dict[str, Any],
|
||||||
|
expected_ids: List[int],
|
||||||
|
) -> None:
|
||||||
|
"""Test end to end construction and search."""
|
||||||
|
docs = pgvector.similarity_search("meow", k=5, filter=test_filter)
|
||||||
|
assert [doc.metadata["id"] for doc in docs] == expected_ids, test_filter
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("test_filter, expected_ids", TYPE_2_FILTERING_TEST_CASES)
|
||||||
|
def test_pgvector_with_with_metadata_filters_2(
|
||||||
|
pgvector: PGVector,
|
||||||
|
test_filter: Dict[str, Any],
|
||||||
|
expected_ids: List[int],
|
||||||
|
) -> None:
|
||||||
|
"""Test end to end construction and search."""
|
||||||
|
docs = pgvector.similarity_search("meow", k=5, filter=test_filter)
|
||||||
|
assert [doc.metadata["id"] for doc in docs] == expected_ids, test_filter
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("test_filter, expected_ids", TYPE_3_FILTERING_TEST_CASES)
|
||||||
|
def test_pgvector_with_with_metadata_filters_3(
|
||||||
|
pgvector: PGVector,
|
||||||
|
test_filter: Dict[str, Any],
|
||||||
|
expected_ids: List[int],
|
||||||
|
) -> None:
|
||||||
|
"""Test end to end construction and search."""
|
||||||
|
docs = pgvector.similarity_search("meow", k=5, filter=test_filter)
|
||||||
|
assert [doc.metadata["id"] for doc in docs] == expected_ids, test_filter
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("test_filter, expected_ids", TYPE_4_FILTERING_TEST_CASES)
|
||||||
|
def test_pgvector_with_with_metadata_filters_4(
|
||||||
|
pgvector: PGVector,
|
||||||
|
test_filter: Dict[str, Any],
|
||||||
|
expected_ids: List[int],
|
||||||
|
) -> None:
|
||||||
|
"""Test end to end construction and search."""
|
||||||
|
docs = pgvector.similarity_search("meow", k=5, filter=test_filter)
|
||||||
|
assert [doc.metadata["id"] for doc in docs] == expected_ids, test_filter
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("test_filter, expected_ids", TYPE_5_FILTERING_TEST_CASES)
|
||||||
|
def test_pgvector_with_with_metadata_filters_5(
|
||||||
|
pgvector: PGVector,
|
||||||
|
test_filter: Dict[str, Any],
|
||||||
|
expected_ids: List[int],
|
||||||
|
) -> None:
|
||||||
|
"""Test end to end construction and search."""
|
||||||
|
docs = pgvector.similarity_search("meow", k=5, filter=test_filter)
|
||||||
|
assert [doc.metadata["id"] for doc in docs] == expected_ids, test_filter
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"invalid_filter",
|
||||||
|
[
|
||||||
|
["hello"],
|
||||||
|
{
|
||||||
|
"id": 2,
|
||||||
|
"$name": "foo",
|
||||||
|
},
|
||||||
|
{"$or": {}},
|
||||||
|
{"$and": {}},
|
||||||
|
{"$between": {}},
|
||||||
|
{"$eq": {}},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_invalid_filters(pgvector: PGVector, invalid_filter: Any) -> None:
|
||||||
|
"""Verify that invalid filters raise an error."""
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
pgvector._create_filter_clause(invalid_filter)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"filter,compiled",
|
||||||
|
[
|
||||||
|
({"id 'evil code'": 2}, ValueError),
|
||||||
|
(
|
||||||
|
{"id": "'evil code' == 2"},
|
||||||
|
(
|
||||||
|
"jsonb_path_match(langchain_pg_embedding.cmetadata, "
|
||||||
|
"'$.id == $value', "
|
||||||
|
"'{\"value\": \"''evil code'' == 2\"}')"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"name": 'a"b'},
|
||||||
|
(
|
||||||
|
"jsonb_path_match(langchain_pg_embedding.cmetadata, "
|
||||||
|
"'$.name == $value', "
|
||||||
|
'\'{"value": "a\\\\"b"}\')'
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_evil_code(
|
||||||
|
pgvector: PGVector, filter: Any, compiled: Union[Type[Exception], str]
|
||||||
|
) -> None:
|
||||||
|
"""Test evil code."""
|
||||||
|
if isinstance(compiled, str):
|
||||||
|
clause = pgvector._create_filter_clause(filter)
|
||||||
|
compiled_stmt = str(
|
||||||
|
clause.compile(
|
||||||
|
dialect=postgresql.dialect(),
|
||||||
|
compile_kwargs={
|
||||||
|
# This substitutes the parameters with their actual values
|
||||||
|
"literal_binds": True
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert compiled_stmt == compiled
|
||||||
|
else:
|
||||||
|
with pytest.raises(compiled):
|
||||||
|
pgvector._create_filter_clause(filter)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"filter,compiled",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
{"id": 2},
|
||||||
|
"jsonb_path_match(langchain_pg_embedding.cmetadata, '$.id == $value', "
|
||||||
|
"'{\"value\": 2}')",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"id": {"$eq": 2}},
|
||||||
|
(
|
||||||
|
"jsonb_path_match(langchain_pg_embedding.cmetadata, '$.id == $value', "
|
||||||
|
"'{\"value\": 2}')"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"name": "foo"},
|
||||||
|
(
|
||||||
|
"jsonb_path_match(langchain_pg_embedding.cmetadata, "
|
||||||
|
"'$.name == $value', "
|
||||||
|
'\'{"value": "foo"}\')'
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"id": {"$ne": 2}},
|
||||||
|
(
|
||||||
|
"jsonb_path_match(langchain_pg_embedding.cmetadata, '$.id != $value', "
|
||||||
|
"'{\"value\": 2}')"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"id": {"$gt": 2}},
|
||||||
|
(
|
||||||
|
"jsonb_path_match(langchain_pg_embedding.cmetadata, '$.id > $value', "
|
||||||
|
"'{\"value\": 2}')"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"id": {"$gte": 2}},
|
||||||
|
(
|
||||||
|
"jsonb_path_match(langchain_pg_embedding.cmetadata, '$.id >= $value', "
|
||||||
|
"'{\"value\": 2}')"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"id": {"$lt": 2}},
|
||||||
|
(
|
||||||
|
"jsonb_path_match(langchain_pg_embedding.cmetadata, '$.id < $value', "
|
||||||
|
"'{\"value\": 2}')"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"id": {"$lte": 2}},
|
||||||
|
(
|
||||||
|
"jsonb_path_match(langchain_pg_embedding.cmetadata, '$.id <= $value', "
|
||||||
|
"'{\"value\": 2}')"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"name": {"$ilike": "foo"}},
|
||||||
|
"langchain_pg_embedding.cmetadata ->> 'name' ILIKE 'foo'",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"name": {"$like": "foo"}},
|
||||||
|
"langchain_pg_embedding.cmetadata ->> 'name' LIKE 'foo'",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"$or": [{"id": 1}, {"id": 2}]},
|
||||||
|
# Please note that this might not be super optimized
|
||||||
|
# Another way to phrase the query is as
|
||||||
|
# langchain_pg_embedding.cmetadata @@ '($.id == 1 || $.id == 2)'
|
||||||
|
"jsonb_path_match(langchain_pg_embedding.cmetadata, '$.id == $value', "
|
||||||
|
"'{\"value\": 1}') OR jsonb_path_match(langchain_pg_embedding.cmetadata, "
|
||||||
|
"'$.id == $value', '{\"value\": 2}')",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_pgvector_query_compilation(
|
||||||
|
pgvector: PGVector, filter: Any, compiled: str
|
||||||
|
) -> None:
|
||||||
|
"""Test translation from IR to SQL"""
|
||||||
|
clause = pgvector._create_filter_clause(filter)
|
||||||
|
compiled_stmt = str(
|
||||||
|
clause.compile(
|
||||||
|
dialect=postgresql.dialect(),
|
||||||
|
compile_kwargs={
|
||||||
|
# This substitutes the parameters with their actual values
|
||||||
|
"literal_binds": True
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert compiled_stmt == compiled
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_operators() -> None:
|
||||||
|
"""Verify that all operators have been categorized."""
|
||||||
|
assert sorted(SUPPORTED_OPERATORS) == [
|
||||||
|
"$and",
|
||||||
|
"$between",
|
||||||
|
"$eq",
|
||||||
|
"$gt",
|
||||||
|
"$gte",
|
||||||
|
"$ilike",
|
||||||
|
"$in",
|
||||||
|
"$like",
|
||||||
|
"$lt",
|
||||||
|
"$lte",
|
||||||
|
"$ne",
|
||||||
|
"$nin",
|
||||||
|
"$or",
|
||||||
|
]
|
||||||
|
Loading…
Reference in New Issue
Block a user