mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-13 14:50:00 +00:00
community[patch]: Added support for filter out AWS Kendra search by score confidence (#12920)
**Description:** It will add support for filter out kendra search by score confidence which will make result more accurate. For example ``` retriever = AmazonKendraRetriever( index_id=kendra_index_id, top_k=5, region_name=region, score_confidence="HIGH" ) ``` Result will not include the records which has score confidence "LOW" or "MEDIUM". Relevant docs https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/kendra/client/query.html https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/kendra/client/retrieve.html **Issue:** the issue # it resolve #11801 **twitter:** [@SmitCode](https://twitter.com/SmitCode)
This commit is contained in:
parent
390ef6abe3
commit
aed46cd6f2
@ -1,11 +1,27 @@
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator, validator
|
||||
from langchain_core.pydantic_v1 import (
|
||||
BaseModel,
|
||||
Extra,
|
||||
Field,
|
||||
root_validator,
|
||||
validator,
|
||||
)
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from typing_extensions import Annotated
|
||||
|
||||
|
||||
def clean_excerpt(excerpt: str) -> str:
|
||||
@ -153,6 +169,8 @@ class ResultItem(BaseModel, ABC, extra=Extra.allow): # type: ignore[call-arg]
|
||||
"""The document URI."""
|
||||
DocumentAttributes: Optional[List[DocumentAttribute]] = []
|
||||
"""The document attributes."""
|
||||
ScoreAttributes: Optional[dict]
|
||||
"""The kendra score confidence"""
|
||||
|
||||
@abstractmethod
|
||||
def get_title(self) -> str:
|
||||
@ -178,6 +196,13 @@ class ResultItem(BaseModel, ABC, extra=Extra.allow): # type: ignore[call-arg]
|
||||
"""Document attributes dict."""
|
||||
return {attr.Key: attr.Value.value for attr in (self.DocumentAttributes or [])}
|
||||
|
||||
def get_score_attribute(self) -> str:
|
||||
"""Document Score Confidence"""
|
||||
if self.ScoreAttributes is not None:
|
||||
return self.ScoreAttributes["ScoreConfidence"]
|
||||
else:
|
||||
return "NOT_AVAILABLE"
|
||||
|
||||
def to_doc(
|
||||
self, page_content_formatter: Callable[["ResultItem"], str] = combined_text
|
||||
) -> Document:
|
||||
@ -192,9 +217,9 @@ class ResultItem(BaseModel, ABC, extra=Extra.allow): # type: ignore[call-arg]
|
||||
"title": self.get_title(),
|
||||
"excerpt": self.get_excerpt(),
|
||||
"document_attributes": self.get_document_attributes_dict(),
|
||||
"score": self.get_score_attribute(),
|
||||
}
|
||||
)
|
||||
|
||||
return Document(page_content=page_content, metadata=metadata)
|
||||
|
||||
|
||||
@ -290,6 +315,15 @@ class RetrieveResult(BaseModel, extra=Extra.allow): # type: ignore[call-arg]
|
||||
"""The result items."""
|
||||
|
||||
|
||||
KENDRA_CONFIDENCE_MAPPING = {
|
||||
"NOT_AVAILABLE": 0.0,
|
||||
"LOW": 0.25,
|
||||
"MEDIUM": 0.50,
|
||||
"HIGH": 0.75,
|
||||
"VERY_HIGH": 1.0,
|
||||
}
|
||||
|
||||
|
||||
class AmazonKendraRetriever(BaseRetriever):
|
||||
"""`Amazon Kendra Index` retriever.
|
||||
|
||||
@ -336,6 +370,7 @@ class AmazonKendraRetriever(BaseRetriever):
|
||||
page_content_formatter: Callable[[ResultItem], str] = combined_text
|
||||
client: Any
|
||||
user_context: Optional[Dict] = None
|
||||
min_score_confidence: Annotated[Optional[float], Field(ge=0.0, le=1.0)]
|
||||
|
||||
@validator("top_k")
|
||||
def validate_top_k(cls, value: int) -> int:
|
||||
@ -406,6 +441,25 @@ class AmazonKendraRetriever(BaseRetriever):
|
||||
]
|
||||
return top_docs
|
||||
|
||||
def _filter_by_score_confidence(self, docs: List[Document]) -> List[Document]:
|
||||
"""
|
||||
Filter out the records that have a score confidence
|
||||
greater than the required threshold.
|
||||
"""
|
||||
if not self.min_score_confidence:
|
||||
return docs
|
||||
filtered_docs = [
|
||||
item
|
||||
for item in docs
|
||||
if (
|
||||
item.metadata.get("score") is not None
|
||||
and isinstance(item.metadata["score"], str)
|
||||
and KENDRA_CONFIDENCE_MAPPING.get(item.metadata["score"], 0.0)
|
||||
>= self.min_score_confidence
|
||||
)
|
||||
]
|
||||
return filtered_docs
|
||||
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
@ -422,4 +476,4 @@ class AmazonKendraRetriever(BaseRetriever):
|
||||
"""
|
||||
result_items = self._kendra_query(query)
|
||||
top_k_docs = self._get_top_k_docs(result_items)
|
||||
return top_k_docs
|
||||
return self._filter_by_score_confidence(top_k_docs)
|
||||
|
Loading…
Reference in New Issue
Block a user