mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-02 13:08:57 +00:00
- **Description:** PebbloRetrievalQA chain introduces identity enforcement using vector-db metadata filtering - **Dependencies:** None - **Issue:** None - **Documentation:** Adding documentation for PebbloRetrievalQA chain in a separate PR(https://github.com/langchain-ai/langchain/pull/20746) - **Unit tests:** New unit-tests added --------- Co-authored-by: Eugene Yurtsev <eugene@langchain.dev>
219 lines
7.5 KiB
Python
219 lines
7.5 KiB
Python
"""
|
|
Pebblo Retrieval Chain with Identity & Semantic Enforcement for question-answering
|
|
against a vector database.
|
|
"""
|
|
|
|
import inspect
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from langchain.chains.base import Chain
|
|
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
|
from langchain_core.callbacks import (
|
|
AsyncCallbackManagerForChainRun,
|
|
CallbackManagerForChainRun,
|
|
)
|
|
from langchain_core.documents import Document
|
|
from langchain_core.language_models import BaseLanguageModel
|
|
from langchain_core.pydantic_v1 import Extra, Field, validator
|
|
from langchain_core.vectorstores import VectorStoreRetriever
|
|
|
|
from langchain_community.chains.pebblo_retrieval.enforcement_filters import (
|
|
SUPPORTED_VECTORSTORES,
|
|
set_enforcement_filters,
|
|
)
|
|
from langchain_community.chains.pebblo_retrieval.models import (
|
|
AuthContext,
|
|
SemanticContext,
|
|
)
|
|
|
|
|
|
class PebbloRetrievalQA(Chain):
|
|
"""
|
|
Retrieval Chain with Identity & Semantic Enforcement for question-answering
|
|
against a vector database.
|
|
"""
|
|
|
|
combine_documents_chain: BaseCombineDocumentsChain
|
|
"""Chain to use to combine the documents."""
|
|
input_key: str = "query" #: :meta private:
|
|
output_key: str = "result" #: :meta private:
|
|
return_source_documents: bool = False
|
|
"""Return the source documents or not."""
|
|
|
|
retriever: VectorStoreRetriever = Field(exclude=True)
|
|
"""VectorStore to use for retrieval."""
|
|
auth_context_key: str = "auth_context" #: :meta private:
|
|
"""Authentication context for identity enforcement."""
|
|
semantic_context_key: str = "semantic_context" #: :meta private:
|
|
"""Semantic context for semantic enforcement."""
|
|
|
|
def _call(
|
|
self,
|
|
inputs: Dict[str, Any],
|
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
|
) -> Dict[str, Any]:
|
|
"""Run get_relevant_text and llm on input query.
|
|
|
|
If chain has 'return_source_documents' as 'True', returns
|
|
the retrieved documents as well under the key 'source_documents'.
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
res = indexqa({'query': 'This is my query'})
|
|
answer, docs = res['result'], res['source_documents']
|
|
"""
|
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
|
question = inputs[self.input_key]
|
|
auth_context = inputs.get(self.auth_context_key)
|
|
semantic_context = inputs.get(self.semantic_context_key)
|
|
accepts_run_manager = (
|
|
"run_manager" in inspect.signature(self._get_docs).parameters
|
|
)
|
|
if accepts_run_manager:
|
|
docs = self._get_docs(
|
|
question, auth_context, semantic_context, run_manager=_run_manager
|
|
)
|
|
else:
|
|
docs = self._get_docs(question, auth_context, semantic_context) # type: ignore[call-arg]
|
|
answer = self.combine_documents_chain.run(
|
|
input_documents=docs, question=question, callbacks=_run_manager.get_child()
|
|
)
|
|
|
|
if self.return_source_documents:
|
|
return {self.output_key: answer, "source_documents": docs}
|
|
else:
|
|
return {self.output_key: answer}
|
|
|
|
async def _acall(
|
|
self,
|
|
inputs: Dict[str, Any],
|
|
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
|
) -> Dict[str, Any]:
|
|
"""Run get_relevant_text and llm on input query.
|
|
|
|
If chain has 'return_source_documents' as 'True', returns
|
|
the retrieved documents as well under the key 'source_documents'.
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
res = indexqa({'query': 'This is my query'})
|
|
answer, docs = res['result'], res['source_documents']
|
|
"""
|
|
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
|
question = inputs[self.input_key]
|
|
auth_context = inputs.get(self.auth_context_key)
|
|
semantic_context = inputs.get(self.semantic_context_key)
|
|
accepts_run_manager = (
|
|
"run_manager" in inspect.signature(self._aget_docs).parameters
|
|
)
|
|
if accepts_run_manager:
|
|
docs = await self._aget_docs(
|
|
question, auth_context, semantic_context, run_manager=_run_manager
|
|
)
|
|
else:
|
|
docs = await self._aget_docs(question, auth_context, semantic_context) # type: ignore[call-arg]
|
|
answer = await self.combine_documents_chain.arun(
|
|
input_documents=docs, question=question, callbacks=_run_manager.get_child()
|
|
)
|
|
|
|
if self.return_source_documents:
|
|
return {self.output_key: answer, "source_documents": docs}
|
|
else:
|
|
return {self.output_key: answer}
|
|
|
|
class Config:
|
|
"""Configuration for this pydantic object."""
|
|
|
|
extra = Extra.forbid
|
|
arbitrary_types_allowed = True
|
|
allow_population_by_field_name = True
|
|
|
|
@property
|
|
def input_keys(self) -> List[str]:
|
|
"""Input keys.
|
|
|
|
:meta private:
|
|
"""
|
|
return [self.input_key, self.auth_context_key, self.semantic_context_key]
|
|
|
|
@property
|
|
def output_keys(self) -> List[str]:
|
|
"""Output keys.
|
|
|
|
:meta private:
|
|
"""
|
|
_output_keys = [self.output_key]
|
|
if self.return_source_documents:
|
|
_output_keys += ["source_documents"]
|
|
return _output_keys
|
|
|
|
@property
|
|
def _chain_type(self) -> str:
|
|
"""Return the chain type."""
|
|
return "pebblo_retrieval_qa"
|
|
|
|
@classmethod
|
|
def from_chain_type(
|
|
cls,
|
|
llm: BaseLanguageModel,
|
|
chain_type: str = "stuff",
|
|
chain_type_kwargs: Optional[dict] = None,
|
|
**kwargs: Any,
|
|
) -> "PebbloRetrievalQA":
|
|
"""Load chain from chain type."""
|
|
from langchain.chains.question_answering import load_qa_chain
|
|
|
|
_chain_type_kwargs = chain_type_kwargs or {}
|
|
combine_documents_chain = load_qa_chain(
|
|
llm, chain_type=chain_type, **_chain_type_kwargs
|
|
)
|
|
return cls(combine_documents_chain=combine_documents_chain, **kwargs)
|
|
|
|
@validator("retriever", pre=True, always=True)
|
|
def validate_vectorstore(
|
|
cls, retriever: VectorStoreRetriever
|
|
) -> VectorStoreRetriever:
|
|
"""
|
|
Validate that the vectorstore of the retriever is supported vectorstores.
|
|
"""
|
|
if not any(
|
|
isinstance(retriever.vectorstore, supported_class)
|
|
for supported_class in SUPPORTED_VECTORSTORES
|
|
):
|
|
raise ValueError(
|
|
f"Vectorstore must be an instance of one of the supported "
|
|
f"vectorstores: {SUPPORTED_VECTORSTORES}. "
|
|
f"Got {type(retriever.vectorstore).__name__} instead."
|
|
)
|
|
return retriever
|
|
|
|
def _get_docs(
|
|
self,
|
|
question: str,
|
|
auth_context: Optional[AuthContext],
|
|
semantic_context: Optional[SemanticContext],
|
|
*,
|
|
run_manager: CallbackManagerForChainRun,
|
|
) -> List[Document]:
|
|
"""Get docs."""
|
|
set_enforcement_filters(self.retriever, auth_context, semantic_context)
|
|
return self.retriever.get_relevant_documents(
|
|
question, callbacks=run_manager.get_child()
|
|
)
|
|
|
|
async def _aget_docs(
|
|
self,
|
|
question: str,
|
|
auth_context: Optional[AuthContext],
|
|
semantic_context: Optional[SemanticContext],
|
|
*,
|
|
run_manager: AsyncCallbackManagerForChainRun,
|
|
) -> List[Document]:
|
|
"""Get docs."""
|
|
set_enforcement_filters(self.retriever, auth_context, semantic_context)
|
|
return await self.retriever.aget_relevant_documents(
|
|
question, callbacks=run_manager.get_child()
|
|
)
|