mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-13 14:50:00 +00:00
community[patch]: Added support for filter out AWS Kendra search by score confidence (#12920)
**Description:** It will add support for filter out kendra search by score confidence which will make result more accurate. For example ``` retriever = AmazonKendraRetriever( index_id=kendra_index_id, top_k=5, region_name=region, score_confidence="HIGH" ) ``` Result will not include the records which has score confidence "LOW" or "MEDIUM". Relevant docs https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/kendra/client/query.html https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/kendra/client/retrieve.html **Issue:** the issue # it resolve #11801 **twitter:** [@SmitCode](https://twitter.com/SmitCode)
This commit is contained in:
parent
390ef6abe3
commit
aed46cd6f2
@ -1,11 +1,27 @@
|
|||||||
import re
|
import re
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Union
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator, validator
|
from langchain_core.pydantic_v1 import (
|
||||||
|
BaseModel,
|
||||||
|
Extra,
|
||||||
|
Field,
|
||||||
|
root_validator,
|
||||||
|
validator,
|
||||||
|
)
|
||||||
from langchain_core.retrievers import BaseRetriever
|
from langchain_core.retrievers import BaseRetriever
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
|
||||||
def clean_excerpt(excerpt: str) -> str:
|
def clean_excerpt(excerpt: str) -> str:
|
||||||
@ -153,6 +169,8 @@ class ResultItem(BaseModel, ABC, extra=Extra.allow): # type: ignore[call-arg]
|
|||||||
"""The document URI."""
|
"""The document URI."""
|
||||||
DocumentAttributes: Optional[List[DocumentAttribute]] = []
|
DocumentAttributes: Optional[List[DocumentAttribute]] = []
|
||||||
"""The document attributes."""
|
"""The document attributes."""
|
||||||
|
ScoreAttributes: Optional[dict]
|
||||||
|
"""The kendra score confidence"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_title(self) -> str:
|
def get_title(self) -> str:
|
||||||
@ -178,6 +196,13 @@ class ResultItem(BaseModel, ABC, extra=Extra.allow): # type: ignore[call-arg]
|
|||||||
"""Document attributes dict."""
|
"""Document attributes dict."""
|
||||||
return {attr.Key: attr.Value.value for attr in (self.DocumentAttributes or [])}
|
return {attr.Key: attr.Value.value for attr in (self.DocumentAttributes or [])}
|
||||||
|
|
||||||
|
def get_score_attribute(self) -> str:
|
||||||
|
"""Document Score Confidence"""
|
||||||
|
if self.ScoreAttributes is not None:
|
||||||
|
return self.ScoreAttributes["ScoreConfidence"]
|
||||||
|
else:
|
||||||
|
return "NOT_AVAILABLE"
|
||||||
|
|
||||||
def to_doc(
|
def to_doc(
|
||||||
self, page_content_formatter: Callable[["ResultItem"], str] = combined_text
|
self, page_content_formatter: Callable[["ResultItem"], str] = combined_text
|
||||||
) -> Document:
|
) -> Document:
|
||||||
@ -192,9 +217,9 @@ class ResultItem(BaseModel, ABC, extra=Extra.allow): # type: ignore[call-arg]
|
|||||||
"title": self.get_title(),
|
"title": self.get_title(),
|
||||||
"excerpt": self.get_excerpt(),
|
"excerpt": self.get_excerpt(),
|
||||||
"document_attributes": self.get_document_attributes_dict(),
|
"document_attributes": self.get_document_attributes_dict(),
|
||||||
|
"score": self.get_score_attribute(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
return Document(page_content=page_content, metadata=metadata)
|
return Document(page_content=page_content, metadata=metadata)
|
||||||
|
|
||||||
|
|
||||||
@ -290,6 +315,15 @@ class RetrieveResult(BaseModel, extra=Extra.allow): # type: ignore[call-arg]
|
|||||||
"""The result items."""
|
"""The result items."""
|
||||||
|
|
||||||
|
|
||||||
|
KENDRA_CONFIDENCE_MAPPING = {
|
||||||
|
"NOT_AVAILABLE": 0.0,
|
||||||
|
"LOW": 0.25,
|
||||||
|
"MEDIUM": 0.50,
|
||||||
|
"HIGH": 0.75,
|
||||||
|
"VERY_HIGH": 1.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class AmazonKendraRetriever(BaseRetriever):
|
class AmazonKendraRetriever(BaseRetriever):
|
||||||
"""`Amazon Kendra Index` retriever.
|
"""`Amazon Kendra Index` retriever.
|
||||||
|
|
||||||
@ -336,6 +370,7 @@ class AmazonKendraRetriever(BaseRetriever):
|
|||||||
page_content_formatter: Callable[[ResultItem], str] = combined_text
|
page_content_formatter: Callable[[ResultItem], str] = combined_text
|
||||||
client: Any
|
client: Any
|
||||||
user_context: Optional[Dict] = None
|
user_context: Optional[Dict] = None
|
||||||
|
min_score_confidence: Annotated[Optional[float], Field(ge=0.0, le=1.0)]
|
||||||
|
|
||||||
@validator("top_k")
|
@validator("top_k")
|
||||||
def validate_top_k(cls, value: int) -> int:
|
def validate_top_k(cls, value: int) -> int:
|
||||||
@ -406,6 +441,25 @@ class AmazonKendraRetriever(BaseRetriever):
|
|||||||
]
|
]
|
||||||
return top_docs
|
return top_docs
|
||||||
|
|
||||||
|
def _filter_by_score_confidence(self, docs: List[Document]) -> List[Document]:
|
||||||
|
"""
|
||||||
|
Filter out the records that have a score confidence
|
||||||
|
greater than the required threshold.
|
||||||
|
"""
|
||||||
|
if not self.min_score_confidence:
|
||||||
|
return docs
|
||||||
|
filtered_docs = [
|
||||||
|
item
|
||||||
|
for item in docs
|
||||||
|
if (
|
||||||
|
item.metadata.get("score") is not None
|
||||||
|
and isinstance(item.metadata["score"], str)
|
||||||
|
and KENDRA_CONFIDENCE_MAPPING.get(item.metadata["score"], 0.0)
|
||||||
|
>= self.min_score_confidence
|
||||||
|
)
|
||||||
|
]
|
||||||
|
return filtered_docs
|
||||||
|
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
@ -422,4 +476,4 @@ class AmazonKendraRetriever(BaseRetriever):
|
|||||||
"""
|
"""
|
||||||
result_items = self._kendra_query(query)
|
result_items = self._kendra_query(query)
|
||||||
top_k_docs = self._get_top_k_docs(result_items)
|
top_k_docs = self._get_top_k_docs(result_items)
|
||||||
return top_k_docs
|
return self._filter_by_score_confidence(top_k_docs)
|
||||||
|
Loading…
Reference in New Issue
Block a user