diff --git a/docs/docs/integrations/providers/pebblo/pebblo_retrieval_qa.ipynb b/docs/docs/integrations/providers/pebblo/pebblo_retrieval_qa.ipynb index 14cd3c1603f..2e64dc8956f 100644 --- a/docs/docs/integrations/providers/pebblo/pebblo_retrieval_qa.ipynb +++ b/docs/docs/integrations/providers/pebblo/pebblo_retrieval_qa.ipynb @@ -60,6 +60,7 @@ "**PebbloRetrievalQA chain supports the following vector databases:**\n", "- Qdrant\n", "- Pinecone\n", + "- Postgres(utilizing the pgvector extension)\n", "\n", "\n", "**Load vector database with authorization and semantic information in metadata:**\n", diff --git a/libs/community/langchain_community/chains/pebblo_retrieval/base.py b/libs/community/langchain_community/chains/pebblo_retrieval/base.py index 4fb76923148..97c939b4fce 100644 --- a/libs/community/langchain_community/chains/pebblo_retrieval/base.py +++ b/libs/community/langchain_community/chains/pebblo_retrieval/base.py @@ -272,14 +272,11 @@ class PebbloRetrievalQA(Chain): """ Validate that the vectorstore of the retriever is supported vectorstores. """ - if not any( - isinstance(retriever.vectorstore, supported_class) - for supported_class in SUPPORTED_VECTORSTORES - ): + if retriever.vectorstore.__class__.__name__ not in SUPPORTED_VECTORSTORES: raise ValueError( f"Vectorstore must be an instance of one of the supported " f"vectorstores: {SUPPORTED_VECTORSTORES}. " - f"Got {type(retriever.vectorstore).__name__} instead." + f"Got '{retriever.vectorstore.__class__.__name__}' instead." ) return retriever diff --git a/libs/community/langchain_community/chains/pebblo_retrieval/enforcement_filters.py b/libs/community/langchain_community/chains/pebblo_retrieval/enforcement_filters.py index 1761acb09fb..570cbdfa783 100644 --- a/libs/community/langchain_community/chains/pebblo_retrieval/enforcement_filters.py +++ b/libs/community/langchain_community/chains/pebblo_retrieval/enforcement_filters.py @@ -13,7 +13,7 @@ The methods in this module are designed to work with different types of vector s """ import logging -from typing import List, Optional, Union +from typing import Any, List, Optional, Union from langchain_core.vectorstores import VectorStoreRetriever @@ -21,11 +21,33 @@ from langchain_community.chains.pebblo_retrieval.models import ( AuthContext, SemanticContext, ) -from langchain_community.vectorstores import Pinecone, Qdrant logger = logging.getLogger(__name__) -SUPPORTED_VECTORSTORES = [Pinecone, Qdrant] +PINECONE = "Pinecone" +QDRANT = "Qdrant" +PGVECTOR = "PGVector" + +SUPPORTED_VECTORSTORES = {PINECONE, QDRANT, PGVECTOR} + + +def clear_enforcement_filters(retriever: VectorStoreRetriever) -> None: + """ + Clear the identity and semantic enforcement filters in the retriever search_kwargs. + """ + if retriever.vectorstore.__class__.__name__ == PGVECTOR: + search_kwargs = retriever.search_kwargs + if "filter" in search_kwargs: + filters = search_kwargs["filter"] + _pgvector_clear_pebblo_filters( + search_kwargs, filters, "authorized_identities" + ) + _pgvector_clear_pebblo_filters( + search_kwargs, filters, "pebblo_semantic_topics" + ) + _pgvector_clear_pebblo_filters( + search_kwargs, filters, "pebblo_semantic_entities" + ) def set_enforcement_filters( @@ -36,6 +58,8 @@ def set_enforcement_filters( """ Set identity and semantic enforcement filters in the retriever. """ + # Clear existing enforcement filters + clear_enforcement_filters(retriever) if auth_context is not None: _set_identity_enforcement_filter(retriever, auth_context) if semantic_context is not None: @@ -233,6 +257,244 @@ def _apply_pinecone_authorization_filter( } +def _apply_pgvector_filter( + search_kwargs: dict, filters: Optional[Any], pebblo_filter: dict +) -> None: + """ + Apply pebblo filters in the search_kwargs filters. + """ + if isinstance(filters, dict): + if len(filters) == 1: + # The only operators allowed at the top level are $and, $or, and $not + # First check if an operator or a field + key, value = list(filters.items())[0] + if key.startswith("$"): + # Then it's an operator + if key.lower() not in ["$and", "$or", "$not"]: + raise ValueError( + f"Invalid filter condition. Expected $and, $or or $not " + f"but got: {key}" + ) + if not isinstance(value, list): + raise ValueError( + f"Expected a list, but got {type(value)} for value: {value}" + ) + + # Here we handle the $and, $or, and $not operators(Semantic filters) + if key.lower() == "$and": + # Add pebblo_filter to the $and list as it is + value.append(pebblo_filter) + elif key.lower() == "$not": + # Check if pebblo_filter is an operator or a field + _key, _value = list(pebblo_filter.items())[0] + if _key.startswith("$"): + # Then it's a operator + if _key.lower() == "$not": + # It's Semantic filter, add it's value to filters + value.append(_value) + logger.warning( + "Adding $not operator to the existing $not operator" + ) + return + else: + # Only $not operator is supported in pebblo_filter + raise ValueError( + f"Invalid filter key. Expected '$not' but got: {_key}" + ) + else: + # Then it's a field(Auth filter), move filters into $and + search_kwargs["filter"] = {"$and": [filters, pebblo_filter]} + return + elif key.lower() == "$or": + search_kwargs["filter"] = {"$and": [filters, pebblo_filter]} + else: + # Then it's a field and we can check pebblo_filter now + # Check if pebblo_filter is an operator or a field + _key, _ = list(pebblo_filter.items())[0] + if _key.startswith("$"): + # Then it's a operator + if _key.lower() == "$not": + # It's a $not operator(Semantic filter), move filters into $and + search_kwargs["filter"] = {"$and": [filters, pebblo_filter]} + return + else: + # Only $not operator is allowed in pebblo_filter + raise ValueError( + f"Invalid filter key. Expected '$not' but got: {_key}" + ) + else: + # Then it's a field(This handles Auth filter) + filters.update(pebblo_filter) + return + elif len(filters) > 1: + # Then all keys have to be fields (they cannot be operators) + for key in filters.keys(): + if key.startswith("$"): + raise ValueError( + f"Invalid filter condition. Expected a field but got: {key}" + ) + # filters should all be fields and we can check pebblo_filter now + # Check if pebblo_filter is an operator or a field + _key, _ = list(pebblo_filter.items())[0] + if _key.startswith("$"): + # Then it's a operator + if _key.lower() == "$not": + # It's a $not operator(Semantic filter), move filters into '$and' + search_kwargs["filter"] = {"$and": [filters, pebblo_filter]} + return + else: + # Only $not operator is supported in pebblo_filter + raise ValueError( + f"Invalid filter key. Expected '$not' but got: {_key}" + ) + else: + # Then it's a field(This handles Auth filter) + filters.update(pebblo_filter) + return + else: + # Got an empty dictionary for filters, set pebblo_filter in filter + search_kwargs.setdefault("filter", {}).update(pebblo_filter) + elif filters is None: + # If filters is None, set pebblo_filter as a new filter + search_kwargs.setdefault("filter", {}).update(pebblo_filter) + else: + raise ValueError( + f"Invalid filter. Expected a dictionary/None but got type: {type(filters)}" + ) + + +def _pgvector_clear_pebblo_filters( + search_kwargs: dict, filters: dict, pebblo_filter_key: str +) -> None: + """ + Remove pebblo filters from the search_kwargs filters. + """ + if isinstance(filters, dict): + if len(filters) == 1: + # The only operators allowed at the top level are $and, $or, and $not + # First check if an operator or a field + key, value = list(filters.items())[0] + if key.startswith("$"): + # Then it's an operator + # Validate the operator's key and value type + if key.lower() not in ["$and", "$or", "$not"]: + raise ValueError( + f"Invalid filter condition. Expected $and, $or or $not " + f"but got: {key}" + ) + elif not isinstance(value, list): + raise ValueError( + f"Expected a list, but got {type(value)} for value: {value}" + ) + + # Here we handle the $and, $or, and $not operators + if key.lower() == "$and": + # Remove the pebblo filter from the $and list + for i, _filter in enumerate(value): + if pebblo_filter_key in _filter: + # This handles Auth filter + value.pop(i) + break + # Check for $not operator with Semantic filter + if "$not" in _filter: + sem_filter_found = False + # This handles Semantic filter + for j, nested_filter in enumerate(_filter["$not"]): + if pebblo_filter_key in nested_filter: + if len(_filter["$not"]) == 1: + # If only one filter is left, + # then remove the $not operator + value.pop(i) + else: + value[i]["$not"].pop(j) + sem_filter_found = True + break + if sem_filter_found: + break + if len(value) == 1: + # If only one filter is left, then remove the $and operator + search_kwargs["filter"] = value[0] + elif key.lower() == "$not": + # Remove the pebblo filter from the $not list + for i, _filter in enumerate(value): + if pebblo_filter_key in _filter: + # This removes Semantic filter + value.pop(i) + break + if len(value) == 0: + # If no filter is left, then unset the filter + search_kwargs["filter"] = {} + elif key.lower() == "$or": + # If $or, pebblo filter will not be present + return + else: + # Then it's a field, check if it's a pebblo filter + if key == pebblo_filter_key: + filters.pop(key) + return + elif len(filters) > 1: + # Then all keys have to be fields (they cannot be operators) + if pebblo_filter_key in filters: + # This handles Auth filter + filters.pop(pebblo_filter_key) + return + else: + # Got an empty dictionary for filters, ignore the filter + return + elif filters is None: + # If filters is None, ignore the filter + return + else: + raise ValueError( + f"Invalid filter. Expected a dictionary/None but got type: {type(filters)}" + ) + + +def _apply_pgvector_semantic_filter( + search_kwargs: dict, semantic_context: Optional[SemanticContext] +) -> None: + """ + Set semantic enforcement filter in search_kwargs for PGVector vectorstore. + """ + # Check if semantic_context is provided + if semantic_context is not None: + _semantic_filters = [] + filters = search_kwargs.get("filter") + if semantic_context.pebblo_semantic_topics is not None: + # Add pebblo_semantic_topics filter to search_kwargs + topic_filter: dict = { + "pebblo_semantic_topics": { + "$eq": semantic_context.pebblo_semantic_topics.deny + } + } + _semantic_filters.append(topic_filter) + + if semantic_context.pebblo_semantic_entities is not None: + # Add pebblo_semantic_entities filter to search_kwargs + entity_filter: dict = { + "pebblo_semantic_entities": { + "$eq": semantic_context.pebblo_semantic_entities.deny + } + } + _semantic_filters.append(entity_filter) + + if len(_semantic_filters) > 0: + semantic_filter: dict = {"$not": _semantic_filters} + _apply_pgvector_filter(search_kwargs, filters, semantic_filter) + + +def _apply_pgvector_authorization_filter( + search_kwargs: dict, auth_context: Optional[AuthContext] +) -> None: + """ + Set identity enforcement filter in search_kwargs for PGVector vectorstore. + """ + if auth_context is not None: + auth_filter: dict = {"authorized_identities": {"$eq": auth_context.user_auth}} + filters = search_kwargs.get("filter") + _apply_pgvector_filter(search_kwargs, filters, auth_filter) + + def _set_identity_enforcement_filter( retriever: VectorStoreRetriever, auth_context: Optional[AuthContext] ) -> None: @@ -243,10 +505,12 @@ def _set_identity_enforcement_filter( of the retriever based on the type of the vectorstore. """ search_kwargs = retriever.search_kwargs - if isinstance(retriever.vectorstore, Pinecone): + if retriever.vectorstore.__class__.__name__ == PINECONE: _apply_pinecone_authorization_filter(search_kwargs, auth_context) - elif isinstance(retriever.vectorstore, Qdrant): + elif retriever.vectorstore.__class__.__name__ == QDRANT: _apply_qdrant_authorization_filter(search_kwargs, auth_context) + elif retriever.vectorstore.__class__.__name__ == PGVECTOR: + _apply_pgvector_authorization_filter(search_kwargs, auth_context) def _set_semantic_enforcement_filter( @@ -259,7 +523,9 @@ def _set_semantic_enforcement_filter( of the retriever based on the type of the vectorstore. """ search_kwargs = retriever.search_kwargs - if isinstance(retriever.vectorstore, Pinecone): + if retriever.vectorstore.__class__.__name__ == PINECONE: _apply_pinecone_semantic_filter(search_kwargs, semantic_context) - elif isinstance(retriever.vectorstore, Qdrant): + elif retriever.vectorstore.__class__.__name__ == QDRANT: _apply_qdrant_semantic_filter(search_kwargs, semantic_context) + elif retriever.vectorstore.__class__.__name__ == PGVECTOR: + _apply_pgvector_semantic_filter(search_kwargs, semantic_context)