community[minor]: Support PGVector in PebbloRetrievalQA (#23874)

- **Description:** Support PGVector in PebbloRetrievalQA
  - Identity and Semantic Enforcement support for PGVector
  - Refactor Vectorstore validation and name check
  - Clear the overridden identity and semantic enforcement filters
- **Issue:** NA
- **Dependencies:** NA
- **Tests**: NA(already added)
-  **Docs**: Updated
- **Twitter handle:** [@Raj__725](https://twitter.com/Raj__725)
This commit is contained in:
Rajendra Kadam 2024-07-06 01:32:25 +05:30 committed by GitHub
parent e0186df56b
commit 8b84457b17
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 276 additions and 12 deletions

View File

@ -60,6 +60,7 @@
"**PebbloRetrievalQA chain supports the following vector databases:**\n",
"- Qdrant\n",
"- Pinecone\n",
"- Postgres(utilizing the pgvector extension)\n",
"\n",
"\n",
"**Load vector database with authorization and semantic information in metadata:**\n",

View File

@ -272,14 +272,11 @@ class PebbloRetrievalQA(Chain):
"""
Validate that the vectorstore of the retriever is supported vectorstores.
"""
if not any(
isinstance(retriever.vectorstore, supported_class)
for supported_class in SUPPORTED_VECTORSTORES
):
if retriever.vectorstore.__class__.__name__ not in SUPPORTED_VECTORSTORES:
raise ValueError(
f"Vectorstore must be an instance of one of the supported "
f"vectorstores: {SUPPORTED_VECTORSTORES}. "
f"Got {type(retriever.vectorstore).__name__} instead."
f"Got '{retriever.vectorstore.__class__.__name__}' instead."
)
return retriever

View File

@ -13,7 +13,7 @@ The methods in this module are designed to work with different types of vector s
"""
import logging
from typing import List, Optional, Union
from typing import Any, List, Optional, Union
from langchain_core.vectorstores import VectorStoreRetriever
@ -21,11 +21,33 @@ from langchain_community.chains.pebblo_retrieval.models import (
AuthContext,
SemanticContext,
)
from langchain_community.vectorstores import Pinecone, Qdrant
logger = logging.getLogger(__name__)
SUPPORTED_VECTORSTORES = [Pinecone, Qdrant]
PINECONE = "Pinecone"
QDRANT = "Qdrant"
PGVECTOR = "PGVector"
SUPPORTED_VECTORSTORES = {PINECONE, QDRANT, PGVECTOR}
def clear_enforcement_filters(retriever: VectorStoreRetriever) -> None:
"""
Clear the identity and semantic enforcement filters in the retriever search_kwargs.
"""
if retriever.vectorstore.__class__.__name__ == PGVECTOR:
search_kwargs = retriever.search_kwargs
if "filter" in search_kwargs:
filters = search_kwargs["filter"]
_pgvector_clear_pebblo_filters(
search_kwargs, filters, "authorized_identities"
)
_pgvector_clear_pebblo_filters(
search_kwargs, filters, "pebblo_semantic_topics"
)
_pgvector_clear_pebblo_filters(
search_kwargs, filters, "pebblo_semantic_entities"
)
def set_enforcement_filters(
@ -36,6 +58,8 @@ def set_enforcement_filters(
"""
Set identity and semantic enforcement filters in the retriever.
"""
# Clear existing enforcement filters
clear_enforcement_filters(retriever)
if auth_context is not None:
_set_identity_enforcement_filter(retriever, auth_context)
if semantic_context is not None:
@ -233,6 +257,244 @@ def _apply_pinecone_authorization_filter(
}
def _apply_pgvector_filter(
search_kwargs: dict, filters: Optional[Any], pebblo_filter: dict
) -> None:
"""
Apply pebblo filters in the search_kwargs filters.
"""
if isinstance(filters, dict):
if len(filters) == 1:
# The only operators allowed at the top level are $and, $or, and $not
# First check if an operator or a field
key, value = list(filters.items())[0]
if key.startswith("$"):
# Then it's an operator
if key.lower() not in ["$and", "$or", "$not"]:
raise ValueError(
f"Invalid filter condition. Expected $and, $or or $not "
f"but got: {key}"
)
if not isinstance(value, list):
raise ValueError(
f"Expected a list, but got {type(value)} for value: {value}"
)
# Here we handle the $and, $or, and $not operators(Semantic filters)
if key.lower() == "$and":
# Add pebblo_filter to the $and list as it is
value.append(pebblo_filter)
elif key.lower() == "$not":
# Check if pebblo_filter is an operator or a field
_key, _value = list(pebblo_filter.items())[0]
if _key.startswith("$"):
# Then it's a operator
if _key.lower() == "$not":
# It's Semantic filter, add it's value to filters
value.append(_value)
logger.warning(
"Adding $not operator to the existing $not operator"
)
return
else:
# Only $not operator is supported in pebblo_filter
raise ValueError(
f"Invalid filter key. Expected '$not' but got: {_key}"
)
else:
# Then it's a field(Auth filter), move filters into $and
search_kwargs["filter"] = {"$and": [filters, pebblo_filter]}
return
elif key.lower() == "$or":
search_kwargs["filter"] = {"$and": [filters, pebblo_filter]}
else:
# Then it's a field and we can check pebblo_filter now
# Check if pebblo_filter is an operator or a field
_key, _ = list(pebblo_filter.items())[0]
if _key.startswith("$"):
# Then it's a operator
if _key.lower() == "$not":
# It's a $not operator(Semantic filter), move filters into $and
search_kwargs["filter"] = {"$and": [filters, pebblo_filter]}
return
else:
# Only $not operator is allowed in pebblo_filter
raise ValueError(
f"Invalid filter key. Expected '$not' but got: {_key}"
)
else:
# Then it's a field(This handles Auth filter)
filters.update(pebblo_filter)
return
elif len(filters) > 1:
# Then all keys have to be fields (they cannot be operators)
for key in filters.keys():
if key.startswith("$"):
raise ValueError(
f"Invalid filter condition. Expected a field but got: {key}"
)
# filters should all be fields and we can check pebblo_filter now
# Check if pebblo_filter is an operator or a field
_key, _ = list(pebblo_filter.items())[0]
if _key.startswith("$"):
# Then it's a operator
if _key.lower() == "$not":
# It's a $not operator(Semantic filter), move filters into '$and'
search_kwargs["filter"] = {"$and": [filters, pebblo_filter]}
return
else:
# Only $not operator is supported in pebblo_filter
raise ValueError(
f"Invalid filter key. Expected '$not' but got: {_key}"
)
else:
# Then it's a field(This handles Auth filter)
filters.update(pebblo_filter)
return
else:
# Got an empty dictionary for filters, set pebblo_filter in filter
search_kwargs.setdefault("filter", {}).update(pebblo_filter)
elif filters is None:
# If filters is None, set pebblo_filter as a new filter
search_kwargs.setdefault("filter", {}).update(pebblo_filter)
else:
raise ValueError(
f"Invalid filter. Expected a dictionary/None but got type: {type(filters)}"
)
def _pgvector_clear_pebblo_filters(
search_kwargs: dict, filters: dict, pebblo_filter_key: str
) -> None:
"""
Remove pebblo filters from the search_kwargs filters.
"""
if isinstance(filters, dict):
if len(filters) == 1:
# The only operators allowed at the top level are $and, $or, and $not
# First check if an operator or a field
key, value = list(filters.items())[0]
if key.startswith("$"):
# Then it's an operator
# Validate the operator's key and value type
if key.lower() not in ["$and", "$or", "$not"]:
raise ValueError(
f"Invalid filter condition. Expected $and, $or or $not "
f"but got: {key}"
)
elif not isinstance(value, list):
raise ValueError(
f"Expected a list, but got {type(value)} for value: {value}"
)
# Here we handle the $and, $or, and $not operators
if key.lower() == "$and":
# Remove the pebblo filter from the $and list
for i, _filter in enumerate(value):
if pebblo_filter_key in _filter:
# This handles Auth filter
value.pop(i)
break
# Check for $not operator with Semantic filter
if "$not" in _filter:
sem_filter_found = False
# This handles Semantic filter
for j, nested_filter in enumerate(_filter["$not"]):
if pebblo_filter_key in nested_filter:
if len(_filter["$not"]) == 1:
# If only one filter is left,
# then remove the $not operator
value.pop(i)
else:
value[i]["$not"].pop(j)
sem_filter_found = True
break
if sem_filter_found:
break
if len(value) == 1:
# If only one filter is left, then remove the $and operator
search_kwargs["filter"] = value[0]
elif key.lower() == "$not":
# Remove the pebblo filter from the $not list
for i, _filter in enumerate(value):
if pebblo_filter_key in _filter:
# This removes Semantic filter
value.pop(i)
break
if len(value) == 0:
# If no filter is left, then unset the filter
search_kwargs["filter"] = {}
elif key.lower() == "$or":
# If $or, pebblo filter will not be present
return
else:
# Then it's a field, check if it's a pebblo filter
if key == pebblo_filter_key:
filters.pop(key)
return
elif len(filters) > 1:
# Then all keys have to be fields (they cannot be operators)
if pebblo_filter_key in filters:
# This handles Auth filter
filters.pop(pebblo_filter_key)
return
else:
# Got an empty dictionary for filters, ignore the filter
return
elif filters is None:
# If filters is None, ignore the filter
return
else:
raise ValueError(
f"Invalid filter. Expected a dictionary/None but got type: {type(filters)}"
)
def _apply_pgvector_semantic_filter(
search_kwargs: dict, semantic_context: Optional[SemanticContext]
) -> None:
"""
Set semantic enforcement filter in search_kwargs for PGVector vectorstore.
"""
# Check if semantic_context is provided
if semantic_context is not None:
_semantic_filters = []
filters = search_kwargs.get("filter")
if semantic_context.pebblo_semantic_topics is not None:
# Add pebblo_semantic_topics filter to search_kwargs
topic_filter: dict = {
"pebblo_semantic_topics": {
"$eq": semantic_context.pebblo_semantic_topics.deny
}
}
_semantic_filters.append(topic_filter)
if semantic_context.pebblo_semantic_entities is not None:
# Add pebblo_semantic_entities filter to search_kwargs
entity_filter: dict = {
"pebblo_semantic_entities": {
"$eq": semantic_context.pebblo_semantic_entities.deny
}
}
_semantic_filters.append(entity_filter)
if len(_semantic_filters) > 0:
semantic_filter: dict = {"$not": _semantic_filters}
_apply_pgvector_filter(search_kwargs, filters, semantic_filter)
def _apply_pgvector_authorization_filter(
search_kwargs: dict, auth_context: Optional[AuthContext]
) -> None:
"""
Set identity enforcement filter in search_kwargs for PGVector vectorstore.
"""
if auth_context is not None:
auth_filter: dict = {"authorized_identities": {"$eq": auth_context.user_auth}}
filters = search_kwargs.get("filter")
_apply_pgvector_filter(search_kwargs, filters, auth_filter)
def _set_identity_enforcement_filter(
retriever: VectorStoreRetriever, auth_context: Optional[AuthContext]
) -> None:
@ -243,10 +505,12 @@ def _set_identity_enforcement_filter(
of the retriever based on the type of the vectorstore.
"""
search_kwargs = retriever.search_kwargs
if isinstance(retriever.vectorstore, Pinecone):
if retriever.vectorstore.__class__.__name__ == PINECONE:
_apply_pinecone_authorization_filter(search_kwargs, auth_context)
elif isinstance(retriever.vectorstore, Qdrant):
elif retriever.vectorstore.__class__.__name__ == QDRANT:
_apply_qdrant_authorization_filter(search_kwargs, auth_context)
elif retriever.vectorstore.__class__.__name__ == PGVECTOR:
_apply_pgvector_authorization_filter(search_kwargs, auth_context)
def _set_semantic_enforcement_filter(
@ -259,7 +523,9 @@ def _set_semantic_enforcement_filter(
of the retriever based on the type of the vectorstore.
"""
search_kwargs = retriever.search_kwargs
if isinstance(retriever.vectorstore, Pinecone):
if retriever.vectorstore.__class__.__name__ == PINECONE:
_apply_pinecone_semantic_filter(search_kwargs, semantic_context)
elif isinstance(retriever.vectorstore, Qdrant):
elif retriever.vectorstore.__class__.__name__ == QDRANT:
_apply_qdrant_semantic_filter(search_kwargs, semantic_context)
elif retriever.vectorstore.__class__.__name__ == PGVECTOR:
_apply_pgvector_semantic_filter(search_kwargs, semantic_context)