From aed46cd6f2ca856d699c5196ca1788de3288318a Mon Sep 17 00:00:00 2001 From: Smit Parmar Date: Fri, 8 Mar 2024 06:58:09 +0530 Subject: [PATCH] community[patch]: Added support for filter out AWS Kendra search by score confidence (#12920) **Description:** It will add support for filter out kendra search by score confidence which will make result more accurate. For example ``` retriever = AmazonKendraRetriever( index_id=kendra_index_id, top_k=5, region_name=region, score_confidence="HIGH" ) ``` Result will not include the records which has score confidence "LOW" or "MEDIUM". Relevant docs https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/kendra/client/query.html https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/kendra/client/retrieve.html **Issue:** the issue # it resolve #11801 **twitter:** [@SmitCode](https://twitter.com/SmitCode) --- .../langchain_community/retrievers/kendra.py | 62 +++++++++++++++++-- 1 file changed, 58 insertions(+), 4 deletions(-) diff --git a/libs/community/langchain_community/retrievers/kendra.py b/libs/community/langchain_community/retrievers/kendra.py index e93ef3ff569..b4480cae62c 100644 --- a/libs/community/langchain_community/retrievers/kendra.py +++ b/libs/community/langchain_community/retrievers/kendra.py @@ -1,11 +1,27 @@ import re from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Union +from typing import ( + Any, + Callable, + Dict, + List, + Literal, + Optional, + Sequence, + Union, +) from langchain_core.callbacks import CallbackManagerForRetrieverRun from langchain_core.documents import Document -from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator, validator +from langchain_core.pydantic_v1 import ( + BaseModel, + Extra, + Field, + root_validator, + validator, +) from langchain_core.retrievers import BaseRetriever +from typing_extensions import Annotated def clean_excerpt(excerpt: str) -> str: @@ -153,6 +169,8 @@ class ResultItem(BaseModel, ABC, extra=Extra.allow): # type: ignore[call-arg] """The document URI.""" DocumentAttributes: Optional[List[DocumentAttribute]] = [] """The document attributes.""" + ScoreAttributes: Optional[dict] + """The kendra score confidence""" @abstractmethod def get_title(self) -> str: @@ -178,6 +196,13 @@ class ResultItem(BaseModel, ABC, extra=Extra.allow): # type: ignore[call-arg] """Document attributes dict.""" return {attr.Key: attr.Value.value for attr in (self.DocumentAttributes or [])} + def get_score_attribute(self) -> str: + """Document Score Confidence""" + if self.ScoreAttributes is not None: + return self.ScoreAttributes["ScoreConfidence"] + else: + return "NOT_AVAILABLE" + def to_doc( self, page_content_formatter: Callable[["ResultItem"], str] = combined_text ) -> Document: @@ -192,9 +217,9 @@ class ResultItem(BaseModel, ABC, extra=Extra.allow): # type: ignore[call-arg] "title": self.get_title(), "excerpt": self.get_excerpt(), "document_attributes": self.get_document_attributes_dict(), + "score": self.get_score_attribute(), } ) - return Document(page_content=page_content, metadata=metadata) @@ -290,6 +315,15 @@ class RetrieveResult(BaseModel, extra=Extra.allow): # type: ignore[call-arg] """The result items.""" +KENDRA_CONFIDENCE_MAPPING = { + "NOT_AVAILABLE": 0.0, + "LOW": 0.25, + "MEDIUM": 0.50, + "HIGH": 0.75, + "VERY_HIGH": 1.0, +} + + class AmazonKendraRetriever(BaseRetriever): """`Amazon Kendra Index` retriever. @@ -336,6 +370,7 @@ class AmazonKendraRetriever(BaseRetriever): page_content_formatter: Callable[[ResultItem], str] = combined_text client: Any user_context: Optional[Dict] = None + min_score_confidence: Annotated[Optional[float], Field(ge=0.0, le=1.0)] @validator("top_k") def validate_top_k(cls, value: int) -> int: @@ -406,6 +441,25 @@ class AmazonKendraRetriever(BaseRetriever): ] return top_docs + def _filter_by_score_confidence(self, docs: List[Document]) -> List[Document]: + """ + Filter out the records that have a score confidence + greater than the required threshold. + """ + if not self.min_score_confidence: + return docs + filtered_docs = [ + item + for item in docs + if ( + item.metadata.get("score") is not None + and isinstance(item.metadata["score"], str) + and KENDRA_CONFIDENCE_MAPPING.get(item.metadata["score"], 0.0) + >= self.min_score_confidence + ) + ] + return filtered_docs + def _get_relevant_documents( self, query: str, @@ -422,4 +476,4 @@ class AmazonKendraRetriever(BaseRetriever): """ result_items = self._kendra_query(query) top_k_docs = self._get_top_k_docs(result_items) - return top_k_docs + return self._filter_by_score_confidence(top_k_docs)