mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-03 19:57:51 +00:00
langchain[minor]: Add PebbloRetrievalQA chain with Identity & Semantic Enforcement support (#20641)
- **Description:** PebbloRetrievalQA chain introduces identity enforcement using vector-db metadata filtering - **Dependencies:** None - **Issue:** None - **Documentation:** Adding documentation for PebbloRetrievalQA chain in a separate PR(https://github.com/langchain-ai/langchain/pull/20746) - **Unit tests:** New unit-tests added --------- Co-authored-by: Eugene Yurtsev <eugene@langchain.dev>
This commit is contained in:
parent
f2f970f93d
commit
54e003268e
@ -0,0 +1,24 @@
|
|||||||
|
"""
|
||||||
|
Chains module for langchain_community
|
||||||
|
|
||||||
|
This module contains the community chains.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from langchain_community.chains.pebblo_retrieval.base import PebbloRetrievalQA
|
||||||
|
|
||||||
|
__all__ = ["PebbloRetrievalQA"]
|
||||||
|
|
||||||
|
_module_lookup = {
|
||||||
|
"PebbloRetrievalQA": "langchain_community.chains.pebblo_retrieval.base"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def __getattr__(name: str) -> Any:
|
||||||
|
if name in _module_lookup:
|
||||||
|
module = importlib.import_module(_module_lookup[name])
|
||||||
|
return getattr(module, name)
|
||||||
|
raise AttributeError(f"module {__name__} has no attribute {name}")
|
@ -0,0 +1,218 @@
|
|||||||
|
"""
|
||||||
|
Pebblo Retrieval Chain with Identity & Semantic Enforcement for question-answering
|
||||||
|
against a vector database.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from langchain.chains.base import Chain
|
||||||
|
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||||
|
from langchain_core.callbacks import (
|
||||||
|
AsyncCallbackManagerForChainRun,
|
||||||
|
CallbackManagerForChainRun,
|
||||||
|
)
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
|
from langchain_core.pydantic_v1 import Extra, Field, validator
|
||||||
|
from langchain_core.vectorstores import VectorStoreRetriever
|
||||||
|
|
||||||
|
from langchain_community.chains.pebblo_retrieval.enforcement_filters import (
|
||||||
|
SUPPORTED_VECTORSTORES,
|
||||||
|
set_enforcement_filters,
|
||||||
|
)
|
||||||
|
from langchain_community.chains.pebblo_retrieval.models import (
|
||||||
|
AuthContext,
|
||||||
|
SemanticContext,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PebbloRetrievalQA(Chain):
|
||||||
|
"""
|
||||||
|
Retrieval Chain with Identity & Semantic Enforcement for question-answering
|
||||||
|
against a vector database.
|
||||||
|
"""
|
||||||
|
|
||||||
|
combine_documents_chain: BaseCombineDocumentsChain
|
||||||
|
"""Chain to use to combine the documents."""
|
||||||
|
input_key: str = "query" #: :meta private:
|
||||||
|
output_key: str = "result" #: :meta private:
|
||||||
|
return_source_documents: bool = False
|
||||||
|
"""Return the source documents or not."""
|
||||||
|
|
||||||
|
retriever: VectorStoreRetriever = Field(exclude=True)
|
||||||
|
"""VectorStore to use for retrieval."""
|
||||||
|
auth_context_key: str = "auth_context" #: :meta private:
|
||||||
|
"""Authentication context for identity enforcement."""
|
||||||
|
semantic_context_key: str = "semantic_context" #: :meta private:
|
||||||
|
"""Semantic context for semantic enforcement."""
|
||||||
|
|
||||||
|
def _call(
|
||||||
|
self,
|
||||||
|
inputs: Dict[str, Any],
|
||||||
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Run get_relevant_text and llm on input query.
|
||||||
|
|
||||||
|
If chain has 'return_source_documents' as 'True', returns
|
||||||
|
the retrieved documents as well under the key 'source_documents'.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
res = indexqa({'query': 'This is my query'})
|
||||||
|
answer, docs = res['result'], res['source_documents']
|
||||||
|
"""
|
||||||
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||||
|
question = inputs[self.input_key]
|
||||||
|
auth_context = inputs.get(self.auth_context_key)
|
||||||
|
semantic_context = inputs.get(self.semantic_context_key)
|
||||||
|
accepts_run_manager = (
|
||||||
|
"run_manager" in inspect.signature(self._get_docs).parameters
|
||||||
|
)
|
||||||
|
if accepts_run_manager:
|
||||||
|
docs = self._get_docs(
|
||||||
|
question, auth_context, semantic_context, run_manager=_run_manager
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
docs = self._get_docs(question, auth_context, semantic_context) # type: ignore[call-arg]
|
||||||
|
answer = self.combine_documents_chain.run(
|
||||||
|
input_documents=docs, question=question, callbacks=_run_manager.get_child()
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.return_source_documents:
|
||||||
|
return {self.output_key: answer, "source_documents": docs}
|
||||||
|
else:
|
||||||
|
return {self.output_key: answer}
|
||||||
|
|
||||||
|
async def _acall(
|
||||||
|
self,
|
||||||
|
inputs: Dict[str, Any],
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Run get_relevant_text and llm on input query.
|
||||||
|
|
||||||
|
If chain has 'return_source_documents' as 'True', returns
|
||||||
|
the retrieved documents as well under the key 'source_documents'.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
res = indexqa({'query': 'This is my query'})
|
||||||
|
answer, docs = res['result'], res['source_documents']
|
||||||
|
"""
|
||||||
|
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||||
|
question = inputs[self.input_key]
|
||||||
|
auth_context = inputs.get(self.auth_context_key)
|
||||||
|
semantic_context = inputs.get(self.semantic_context_key)
|
||||||
|
accepts_run_manager = (
|
||||||
|
"run_manager" in inspect.signature(self._aget_docs).parameters
|
||||||
|
)
|
||||||
|
if accepts_run_manager:
|
||||||
|
docs = await self._aget_docs(
|
||||||
|
question, auth_context, semantic_context, run_manager=_run_manager
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
docs = await self._aget_docs(question, auth_context, semantic_context) # type: ignore[call-arg]
|
||||||
|
answer = await self.combine_documents_chain.arun(
|
||||||
|
input_documents=docs, question=question, callbacks=_run_manager.get_child()
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.return_source_documents:
|
||||||
|
return {self.output_key: answer, "source_documents": docs}
|
||||||
|
else:
|
||||||
|
return {self.output_key: answer}
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
allow_population_by_field_name = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_keys(self) -> List[str]:
|
||||||
|
"""Input keys.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
return [self.input_key, self.auth_context_key, self.semantic_context_key]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_keys(self) -> List[str]:
|
||||||
|
"""Output keys.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
_output_keys = [self.output_key]
|
||||||
|
if self.return_source_documents:
|
||||||
|
_output_keys += ["source_documents"]
|
||||||
|
return _output_keys
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _chain_type(self) -> str:
|
||||||
|
"""Return the chain type."""
|
||||||
|
return "pebblo_retrieval_qa"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_chain_type(
|
||||||
|
cls,
|
||||||
|
llm: BaseLanguageModel,
|
||||||
|
chain_type: str = "stuff",
|
||||||
|
chain_type_kwargs: Optional[dict] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> "PebbloRetrievalQA":
|
||||||
|
"""Load chain from chain type."""
|
||||||
|
from langchain.chains.question_answering import load_qa_chain
|
||||||
|
|
||||||
|
_chain_type_kwargs = chain_type_kwargs or {}
|
||||||
|
combine_documents_chain = load_qa_chain(
|
||||||
|
llm, chain_type=chain_type, **_chain_type_kwargs
|
||||||
|
)
|
||||||
|
return cls(combine_documents_chain=combine_documents_chain, **kwargs)
|
||||||
|
|
||||||
|
@validator("retriever", pre=True, always=True)
|
||||||
|
def validate_vectorstore(
|
||||||
|
cls, retriever: VectorStoreRetriever
|
||||||
|
) -> VectorStoreRetriever:
|
||||||
|
"""
|
||||||
|
Validate that the vectorstore of the retriever is supported vectorstores.
|
||||||
|
"""
|
||||||
|
if not any(
|
||||||
|
isinstance(retriever.vectorstore, supported_class)
|
||||||
|
for supported_class in SUPPORTED_VECTORSTORES
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Vectorstore must be an instance of one of the supported "
|
||||||
|
f"vectorstores: {SUPPORTED_VECTORSTORES}. "
|
||||||
|
f"Got {type(retriever.vectorstore).__name__} instead."
|
||||||
|
)
|
||||||
|
return retriever
|
||||||
|
|
||||||
|
def _get_docs(
|
||||||
|
self,
|
||||||
|
question: str,
|
||||||
|
auth_context: Optional[AuthContext],
|
||||||
|
semantic_context: Optional[SemanticContext],
|
||||||
|
*,
|
||||||
|
run_manager: CallbackManagerForChainRun,
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Get docs."""
|
||||||
|
set_enforcement_filters(self.retriever, auth_context, semantic_context)
|
||||||
|
return self.retriever.get_relevant_documents(
|
||||||
|
question, callbacks=run_manager.get_child()
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _aget_docs(
|
||||||
|
self,
|
||||||
|
question: str,
|
||||||
|
auth_context: Optional[AuthContext],
|
||||||
|
semantic_context: Optional[SemanticContext],
|
||||||
|
*,
|
||||||
|
run_manager: AsyncCallbackManagerForChainRun,
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Get docs."""
|
||||||
|
set_enforcement_filters(self.retriever, auth_context, semantic_context)
|
||||||
|
return await self.retriever.aget_relevant_documents(
|
||||||
|
question, callbacks=run_manager.get_child()
|
||||||
|
)
|
@ -0,0 +1,265 @@
|
|||||||
|
"""
|
||||||
|
Identity & Semantic Enforcement filters for PebbloRetrievalQA chain:
|
||||||
|
|
||||||
|
This module contains methods for applying Identity and Semantic Enforcement filters
|
||||||
|
in the PebbloRetrievalQA chain.
|
||||||
|
These filters are used to control the retrieval of documents based on authorization and
|
||||||
|
semantic context.
|
||||||
|
The Identity Enforcement filter ensures that only authorized identities can access
|
||||||
|
certain documents, while the Semantic Enforcement filter controls document retrieval
|
||||||
|
based on semantic context.
|
||||||
|
|
||||||
|
The methods in this module are designed to work with different types of vector stores.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
from langchain_core.vectorstores import VectorStoreRetriever
|
||||||
|
|
||||||
|
from langchain_community.chains.pebblo_retrieval.models import (
|
||||||
|
AuthContext,
|
||||||
|
SemanticContext,
|
||||||
|
)
|
||||||
|
from langchain_community.vectorstores import Pinecone, Qdrant
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
SUPPORTED_VECTORSTORES = [Pinecone, Qdrant]
|
||||||
|
|
||||||
|
|
||||||
|
def set_enforcement_filters(
|
||||||
|
retriever: VectorStoreRetriever,
|
||||||
|
auth_context: Optional[AuthContext],
|
||||||
|
semantic_context: Optional[SemanticContext],
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Set identity and semantic enforcement filters in the retriever.
|
||||||
|
"""
|
||||||
|
if auth_context is not None:
|
||||||
|
_set_identity_enforcement_filter(retriever, auth_context)
|
||||||
|
if semantic_context is not None:
|
||||||
|
_set_semantic_enforcement_filter(retriever, semantic_context)
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_qdrant_semantic_filter(
|
||||||
|
search_kwargs: dict, semantic_context: Optional[SemanticContext]
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Set semantic enforcement filter in search_kwargs for Qdrant vectorstore.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from qdrant_client.http import models as rest
|
||||||
|
except ImportError as e:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import `qdrant-client.http` python package. "
|
||||||
|
"Please install it with `pip install qdrant-client`."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
# Create a semantic enforcement filter condition
|
||||||
|
semantic_filters: List[
|
||||||
|
Union[
|
||||||
|
rest.FieldCondition,
|
||||||
|
rest.IsEmptyCondition,
|
||||||
|
rest.IsNullCondition,
|
||||||
|
rest.HasIdCondition,
|
||||||
|
rest.NestedCondition,
|
||||||
|
rest.Filter,
|
||||||
|
]
|
||||||
|
] = []
|
||||||
|
|
||||||
|
if (
|
||||||
|
semantic_context is not None
|
||||||
|
and semantic_context.pebblo_semantic_topics is not None
|
||||||
|
):
|
||||||
|
semantic_topics_filter = rest.FieldCondition(
|
||||||
|
key="metadata.pebblo_semantic_topics",
|
||||||
|
match=rest.MatchAny(any=semantic_context.pebblo_semantic_topics.deny),
|
||||||
|
)
|
||||||
|
semantic_filters.append(semantic_topics_filter)
|
||||||
|
if (
|
||||||
|
semantic_context is not None
|
||||||
|
and semantic_context.pebblo_semantic_entities is not None
|
||||||
|
):
|
||||||
|
semantic_entities_filter = rest.FieldCondition(
|
||||||
|
key="metadata.pebblo_semantic_entities",
|
||||||
|
match=rest.MatchAny(any=semantic_context.pebblo_semantic_entities.deny),
|
||||||
|
)
|
||||||
|
semantic_filters.append(semantic_entities_filter)
|
||||||
|
|
||||||
|
# If 'filter' already exists in search_kwargs
|
||||||
|
if "filter" in search_kwargs:
|
||||||
|
existing_filter: rest.Filter = search_kwargs["filter"]
|
||||||
|
|
||||||
|
# Check if existing_filter is a qdrant-client filter
|
||||||
|
if isinstance(existing_filter, rest.Filter):
|
||||||
|
# If 'must_not' condition exists in the existing filter
|
||||||
|
if isinstance(existing_filter.must_not, list):
|
||||||
|
# Warn if 'pebblo_semantic_topics' or 'pebblo_semantic_entities'
|
||||||
|
# filter is overridden
|
||||||
|
new_must_not_conditions: List[
|
||||||
|
Union[
|
||||||
|
rest.FieldCondition,
|
||||||
|
rest.IsEmptyCondition,
|
||||||
|
rest.IsNullCondition,
|
||||||
|
rest.HasIdCondition,
|
||||||
|
rest.NestedCondition,
|
||||||
|
rest.Filter,
|
||||||
|
]
|
||||||
|
] = []
|
||||||
|
# Drop semantic filter conditions if already present
|
||||||
|
for condition in existing_filter.must_not:
|
||||||
|
if hasattr(condition, "key"):
|
||||||
|
if condition.key == "metadata.pebblo_semantic_topics":
|
||||||
|
continue
|
||||||
|
if condition.key == "metadata.pebblo_semantic_entities":
|
||||||
|
continue
|
||||||
|
new_must_not_conditions.append(condition)
|
||||||
|
# Add semantic enforcement filters to 'must_not' conditions
|
||||||
|
existing_filter.must_not = new_must_not_conditions
|
||||||
|
existing_filter.must_not.extend(semantic_filters)
|
||||||
|
else:
|
||||||
|
# Set 'must_not' condition with semantic enforcement filters
|
||||||
|
existing_filter.must_not = semantic_filters
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
"Using dict as a `filter` is deprecated. "
|
||||||
|
"Please use qdrant-client filters directly: "
|
||||||
|
"https://qdrant.tech/documentation/concepts/filtering/"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# If 'filter' does not exist in search_kwargs, create it
|
||||||
|
search_kwargs["filter"] = rest.Filter(must_not=semantic_filters)
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_qdrant_authorization_filter(
|
||||||
|
search_kwargs: dict, auth_context: Optional[AuthContext]
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Set identity enforcement filter in search_kwargs for Qdrant vectorstore.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from qdrant_client.http import models as rest
|
||||||
|
except ImportError as e:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import `qdrant-client.http` python package. "
|
||||||
|
"Please install it with `pip install qdrant-client`."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
if auth_context is not None:
|
||||||
|
# Create a identity enforcement filter condition
|
||||||
|
identity_enforcement_filter = rest.FieldCondition(
|
||||||
|
key="metadata.authorized_identities",
|
||||||
|
match=rest.MatchAny(any=auth_context.user_auth),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return
|
||||||
|
|
||||||
|
# If 'filter' already exists in search_kwargs
|
||||||
|
if "filter" in search_kwargs:
|
||||||
|
existing_filter: rest.Filter = search_kwargs["filter"]
|
||||||
|
|
||||||
|
# Check if existing_filter is a qdrant-client filter
|
||||||
|
if isinstance(existing_filter, rest.Filter):
|
||||||
|
# If 'must' exists in the existing filter
|
||||||
|
if existing_filter.must:
|
||||||
|
new_must_conditions: List[
|
||||||
|
Union[
|
||||||
|
rest.FieldCondition,
|
||||||
|
rest.IsEmptyCondition,
|
||||||
|
rest.IsNullCondition,
|
||||||
|
rest.HasIdCondition,
|
||||||
|
rest.NestedCondition,
|
||||||
|
rest.Filter,
|
||||||
|
]
|
||||||
|
] = []
|
||||||
|
# Drop 'authorized_identities' filter condition if already present
|
||||||
|
for condition in existing_filter.must:
|
||||||
|
if (
|
||||||
|
hasattr(condition, "key")
|
||||||
|
and condition.key == "metadata.authorized_identities"
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
new_must_conditions.append(condition)
|
||||||
|
|
||||||
|
# Add identity enforcement filter to 'must' conditions
|
||||||
|
existing_filter.must = new_must_conditions
|
||||||
|
existing_filter.must.append(identity_enforcement_filter)
|
||||||
|
else:
|
||||||
|
# Set 'must' condition with identity enforcement filter
|
||||||
|
existing_filter.must = [identity_enforcement_filter]
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
"Using dict as a `filter` is deprecated. "
|
||||||
|
"Please use qdrant-client filters directly: "
|
||||||
|
"https://qdrant.tech/documentation/concepts/filtering/"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# If 'filter' does not exist in search_kwargs, create it
|
||||||
|
search_kwargs["filter"] = rest.Filter(must=[identity_enforcement_filter])
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_pinecone_semantic_filter(
|
||||||
|
search_kwargs: dict, semantic_context: Optional[SemanticContext]
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Set semantic enforcement filter in search_kwargs for Pinecone vectorstore.
|
||||||
|
"""
|
||||||
|
# Check if semantic_context is provided
|
||||||
|
semantic_context = semantic_context
|
||||||
|
if semantic_context is not None:
|
||||||
|
if semantic_context.pebblo_semantic_topics is not None:
|
||||||
|
# Add pebblo_semantic_topics filter to search_kwargs
|
||||||
|
search_kwargs.setdefault("filter", {})["pebblo_semantic_topics"] = {
|
||||||
|
"$nin": semantic_context.pebblo_semantic_topics.deny
|
||||||
|
}
|
||||||
|
|
||||||
|
if semantic_context.pebblo_semantic_entities is not None:
|
||||||
|
# Add pebblo_semantic_entities filter to search_kwargs
|
||||||
|
search_kwargs.setdefault("filter", {})["pebblo_semantic_entities"] = {
|
||||||
|
"$nin": semantic_context.pebblo_semantic_entities.deny
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_pinecone_authorization_filter(
|
||||||
|
search_kwargs: dict, auth_context: Optional[AuthContext]
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Set identity enforcement filter in search_kwargs for Pinecone vectorstore.
|
||||||
|
"""
|
||||||
|
if auth_context is not None:
|
||||||
|
search_kwargs.setdefault("filter", {})["authorized_identities"] = {
|
||||||
|
"$in": auth_context.user_auth
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _set_identity_enforcement_filter(
|
||||||
|
retriever: VectorStoreRetriever, auth_context: Optional[AuthContext]
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Set identity enforcement filter in search_kwargs.
|
||||||
|
|
||||||
|
This method sets the identity enforcement filter in the search_kwargs
|
||||||
|
of the retriever based on the type of the vectorstore.
|
||||||
|
"""
|
||||||
|
search_kwargs = retriever.search_kwargs
|
||||||
|
if isinstance(retriever.vectorstore, Pinecone):
|
||||||
|
_apply_pinecone_authorization_filter(search_kwargs, auth_context)
|
||||||
|
elif isinstance(retriever.vectorstore, Qdrant):
|
||||||
|
_apply_qdrant_authorization_filter(search_kwargs, auth_context)
|
||||||
|
|
||||||
|
|
||||||
|
def _set_semantic_enforcement_filter(
|
||||||
|
retriever: VectorStoreRetriever, semantic_context: Optional[SemanticContext]
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Set semantic enforcement filter in search_kwargs.
|
||||||
|
|
||||||
|
This method sets the semantic enforcement filter in the search_kwargs
|
||||||
|
of the retriever based on the type of the vectorstore.
|
||||||
|
"""
|
||||||
|
search_kwargs = retriever.search_kwargs
|
||||||
|
if isinstance(retriever.vectorstore, Pinecone):
|
||||||
|
_apply_pinecone_semantic_filter(search_kwargs, semantic_context)
|
||||||
|
elif isinstance(retriever.vectorstore, Qdrant):
|
||||||
|
_apply_qdrant_semantic_filter(search_kwargs, semantic_context)
|
@ -0,0 +1,62 @@
|
|||||||
|
"""Models for the PebbloRetrievalQA chain."""
|
||||||
|
|
||||||
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class AuthContext(BaseModel):
|
||||||
|
"""Class for an authorization context."""
|
||||||
|
|
||||||
|
name: Optional[str] = None
|
||||||
|
user_id: str
|
||||||
|
user_auth: List[str]
|
||||||
|
"""List of user authorizations, which may include their User ID and
|
||||||
|
the groups they are part of"""
|
||||||
|
|
||||||
|
|
||||||
|
class SemanticEntities(BaseModel):
|
||||||
|
"""Class for a semantic entity filter."""
|
||||||
|
|
||||||
|
deny: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
class SemanticTopics(BaseModel):
|
||||||
|
"""Class for a semantic topic filter."""
|
||||||
|
|
||||||
|
deny: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
class SemanticContext(BaseModel):
|
||||||
|
"""Class for a semantic context."""
|
||||||
|
|
||||||
|
pebblo_semantic_entities: Optional[SemanticEntities] = None
|
||||||
|
pebblo_semantic_topics: Optional[SemanticTopics] = None
|
||||||
|
|
||||||
|
def __init__(self, **data: Any) -> None:
|
||||||
|
super().__init__(**data)
|
||||||
|
|
||||||
|
# Validate semantic_context
|
||||||
|
if (
|
||||||
|
self.pebblo_semantic_entities is None
|
||||||
|
and self.pebblo_semantic_topics is None
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"semantic_context must contain 'pebblo_semantic_entities' or "
|
||||||
|
"'pebblo_semantic_topics'"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ChainInput(BaseModel):
|
||||||
|
"""Input for PebbloRetrievalQA chain."""
|
||||||
|
|
||||||
|
query: str
|
||||||
|
auth_context: Optional[AuthContext] = None
|
||||||
|
semantic_context: Optional[SemanticContext] = None
|
||||||
|
|
||||||
|
def dict(self, **kwargs: Any) -> dict:
|
||||||
|
base_dict = super().dict(**kwargs)
|
||||||
|
# Keep auth_context and semantic_context as it is(Pydantic models)
|
||||||
|
base_dict["auth_context"] = self.auth_context
|
||||||
|
base_dict["semantic_context"] = self.semantic_context
|
||||||
|
return base_dict
|
129
libs/community/tests/unit_tests/chains/test_pebblo_retrieval.py
Normal file
129
libs/community/tests/unit_tests/chains/test_pebblo_retrieval.py
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
"""
|
||||||
|
Unit tests for the PebbloRetrievalQA chain
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain_core.callbacks import (
|
||||||
|
AsyncCallbackManagerForRetrieverRun,
|
||||||
|
CallbackManagerForRetrieverRun,
|
||||||
|
)
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
from langchain_core.vectorstores import VectorStore, VectorStoreRetriever
|
||||||
|
|
||||||
|
from langchain_community.chains import PebbloRetrievalQA
|
||||||
|
from langchain_community.chains.pebblo_retrieval.models import (
|
||||||
|
AuthContext,
|
||||||
|
ChainInput,
|
||||||
|
SemanticContext,
|
||||||
|
)
|
||||||
|
from langchain_community.vectorstores.chroma import Chroma
|
||||||
|
from langchain_community.vectorstores.pinecone import Pinecone
|
||||||
|
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||||
|
|
||||||
|
|
||||||
|
class FakeRetriever(VectorStoreRetriever):
|
||||||
|
"""
|
||||||
|
Test util that parrots the query back as documents
|
||||||
|
"""
|
||||||
|
|
||||||
|
vectorstore: VectorStore = Mock()
|
||||||
|
|
||||||
|
def _get_relevant_documents(
|
||||||
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||||
|
) -> List[Document]:
|
||||||
|
return [Document(page_content=query)]
|
||||||
|
|
||||||
|
async def _aget_relevant_documents(
|
||||||
|
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||||
|
) -> List[Document]:
|
||||||
|
return [Document(page_content=query)]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def unsupported_retriever() -> FakeRetriever:
|
||||||
|
"""
|
||||||
|
Create a FakeRetriever instance
|
||||||
|
"""
|
||||||
|
retriever = FakeRetriever()
|
||||||
|
retriever.search_kwargs = {}
|
||||||
|
# Set the class of vectorstore to Chroma
|
||||||
|
retriever.vectorstore.__class__ = Chroma
|
||||||
|
return retriever
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def retriever() -> FakeRetriever:
|
||||||
|
"""
|
||||||
|
Create a FakeRetriever instance
|
||||||
|
"""
|
||||||
|
retriever = FakeRetriever()
|
||||||
|
retriever.search_kwargs = {}
|
||||||
|
# Set the class of vectorstore to Pinecone
|
||||||
|
retriever.vectorstore.__class__ = Pinecone
|
||||||
|
return retriever
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def pebblo_retrieval_qa(retriever: FakeRetriever) -> PebbloRetrievalQA:
|
||||||
|
"""
|
||||||
|
Create a PebbloRetrievalQA instance
|
||||||
|
"""
|
||||||
|
pebblo_retrieval_qa = PebbloRetrievalQA.from_chain_type(
|
||||||
|
llm=FakeLLM(), chain_type="stuff", retriever=retriever
|
||||||
|
)
|
||||||
|
|
||||||
|
return pebblo_retrieval_qa
|
||||||
|
|
||||||
|
|
||||||
|
def test_invoke(pebblo_retrieval_qa: PebbloRetrievalQA) -> None:
|
||||||
|
"""
|
||||||
|
Test that the invoke method returns a non-None result
|
||||||
|
"""
|
||||||
|
# Create a fake auth context and semantic context
|
||||||
|
auth_context = AuthContext(
|
||||||
|
user_id="fake_user@email.com",
|
||||||
|
user_auth=["fake-group", "fake-group2"],
|
||||||
|
)
|
||||||
|
semantic_context_dict = {
|
||||||
|
"pebblo_semantic_topics": {"deny": ["harmful-advice"]},
|
||||||
|
"pebblo_semantic_entities": {"deny": ["credit-card"]},
|
||||||
|
}
|
||||||
|
semantic_context = SemanticContext(**semantic_context_dict)
|
||||||
|
|
||||||
|
question = "What is the meaning of life?"
|
||||||
|
|
||||||
|
chain_input_obj = ChainInput(
|
||||||
|
query=question, auth_context=auth_context, semantic_context=semantic_context
|
||||||
|
)
|
||||||
|
response = pebblo_retrieval_qa.invoke(chain_input_obj.dict())
|
||||||
|
assert response is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_vectorstore(
|
||||||
|
retriever: FakeRetriever, unsupported_retriever: FakeRetriever
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Test vectorstore validation
|
||||||
|
"""
|
||||||
|
|
||||||
|
# No exception should be raised for supported vectorstores (Pinecone)
|
||||||
|
_ = PebbloRetrievalQA.from_chain_type(
|
||||||
|
llm=FakeLLM(),
|
||||||
|
chain_type="stuff",
|
||||||
|
retriever=retriever,
|
||||||
|
)
|
||||||
|
|
||||||
|
# validate_vectorstore method should raise a ValueError for unsupported vectorstores
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
_ = PebbloRetrievalQA.from_chain_type(
|
||||||
|
llm=FakeLLM(),
|
||||||
|
chain_type="stuff",
|
||||||
|
retriever=unsupported_retriever,
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
"Vectorstore must be an instance of one of the supported vectorstores"
|
||||||
|
in str(exc_info.value)
|
||||||
|
)
|
Loading…
Reference in New Issue
Block a user