Kendra retriever api (#6616)

## Description
Replaces [Kendra
Retriever](https://github.com/hwchase17/langchain/blob/master/langchain/retrievers/aws_kendra_index_retriever.py)
with an updated version that uses the new [retriever
API](https://docs.aws.amazon.com/kendra/latest/dg/searching-retrieve.html)
which is better suited for retrieval augmented generation (RAG) systems.

**Note**: This change requires the latest version (1.26.159) of boto3 to
work. `pip install -U boto3` to upgrade the boto3 version.

cc @hupe1980
cc @dev2049
This commit is contained in:
Piyush Jain 2023-06-23 14:59:35 -07:00 committed by GitHub
parent 4e5d78579b
commit b1de927f1b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 279 additions and 105 deletions

View File

@ -26,7 +26,7 @@
"metadata": {},
"outputs": [],
"source": [
"#!pip install boto3"
"%pip install boto3"
]
},
{
@ -36,7 +36,7 @@
"outputs": [],
"source": [
"import boto3\n",
"from langchain.retrievers import AwsKendraIndexRetriever"
"from langchain.retrievers import AmazonKendraRetriever"
]
},
{
@ -53,11 +53,8 @@
"metadata": {},
"outputs": [],
"source": [
"kclient = boto3.client(\"kendra\", region_name=\"us-east-1\")\n",
"\n",
"retriever = AwsKendraIndexRetriever(\n",
" kclient=kclient,\n",
" kendraindex=\"kendraindex\",\n",
"retriever = AmazonKendraRetriever(\n",
" index_id=\"c0806df7-e76b-4bce-9b5c-d5582f6b1a03\"\n",
")"
]
},
@ -66,7 +63,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Now you can use retrieved documents from AWS Kendra Index"
"Now you can use retrieved documents from Kendra index"
]
},
{

View File

@ -1,11 +1,11 @@
from langchain.retrievers.arxiv import ArxivRetriever
from langchain.retrievers.aws_kendra_index_retriever import AwsKendraIndexRetriever
from langchain.retrievers.azure_cognitive_search import AzureCognitiveSearchRetriever
from langchain.retrievers.chatgpt_plugin_retriever import ChatGPTPluginRetriever
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
from langchain.retrievers.databerry import DataberryRetriever
from langchain.retrievers.docarray import DocArrayRetriever
from langchain.retrievers.elastic_search_bm25 import ElasticSearchBM25Retriever
from langchain.retrievers.kendra import AmazonKendraRetriever
from langchain.retrievers.knn import KNNRetriever
from langchain.retrievers.llama_index import (
LlamaIndexGraphRetriever,
@ -30,8 +30,8 @@ from langchain.retrievers.zep import ZepRetriever
from langchain.retrievers.zilliz import ZillizRetriever
__all__ = [
"AmazonKendraRetriever",
"ArxivRetriever",
"AwsKendraIndexRetriever",
"AzureCognitiveSearchRetriever",
"ChatGPTPluginRetriever",
"ContextualCompressionRetriever",

View File

@ -1,95 +0,0 @@
"""Retriever wrapper for AWS Kendra."""
import re
from typing import Any, Dict, List
from langchain.schema import BaseRetriever, Document
class AwsKendraIndexRetriever(BaseRetriever):
"""Wrapper around AWS Kendra."""
kendraindex: str
"""Kendra index id"""
k: int
"""Number of documents to query for."""
languagecode: str
"""Languagecode used for querying."""
kclient: Any
""" boto3 client for Kendra. """
def __init__(
self, kclient: Any, kendraindex: str, k: int = 3, languagecode: str = "en"
):
self.kendraindex = kendraindex
self.k = k
self.languagecode = languagecode
self.kclient = kclient
def _clean_result(self, res_text: str) -> str:
return re.sub("\s+", " ", res_text).replace("...", "")
def _get_top_n_results(self, resp: Dict, count: int) -> Document:
r = resp["ResultItems"][count]
doc_title = r["DocumentTitle"]["Text"]
doc_uri = r["DocumentURI"]
r_type = r["Type"]
if (
r["AdditionalAttributes"]
and r["AdditionalAttributes"][0]["Key"] == "AnswerText"
):
res_text = r["AdditionalAttributes"][0]["Value"]["TextWithHighlightsValue"][
"Text"
]
else:
res_text = r["DocumentExcerpt"]["Text"]
doc_excerpt = self._clean_result(res_text)
combined_text = f"""Document Title: {doc_title}
Document Excerpt: {doc_excerpt}
"""
return Document(
page_content=combined_text,
metadata={
"source": doc_uri,
"title": doc_title,
"excerpt": doc_excerpt,
"type": r_type,
},
)
def _kendra_query(self, kquery: str) -> List[Document]:
response = self.kclient.query(
IndexId=self.kendraindex,
QueryText=kquery.strip(),
AttributeFilter={
"AndAllFilters": [
{
"EqualsTo": {
"Key": "_language_code",
"Value": {
"StringValue": self.languagecode,
},
}
}
]
},
)
if len(response["ResultItems"]) > self.k:
r_count = self.k
else:
r_count = len(response["ResultItems"])
return [self._get_top_n_results(response, i) for i in range(0, r_count)]
def get_relevant_documents(self, query: str) -> List[Document]:
"""Run search on Kendra index and get top k documents
docs = get_relevant_documents('This is my query')
"""
return self._kendra_query(query)
async def aget_relevant_documents(self, query: str) -> List[Document]:
raise NotImplementedError("AwsKendraIndexRetriever does not support async")

View File

@ -0,0 +1,272 @@
import re
from typing import Any, Dict, List, Literal, Optional
from pydantic import BaseModel, Extra
from langchain.docstore.document import Document
from langchain.schema import BaseRetriever
def clean_excerpt(excerpt: str) -> str:
if not excerpt:
return excerpt
res = re.sub("\s+", " ", excerpt).replace("...", "")
return res
def combined_text(title: str, excerpt: str) -> str:
if not title or not excerpt:
return ""
return f"Document Title: {title} \nDocument Excerpt: \n{excerpt}\n"
class Highlight(BaseModel, extra=Extra.allow):
BeginOffset: int
EndOffset: int
TopAnswer: Optional[bool]
Type: Optional[str]
class TextWithHighLights(BaseModel, extra=Extra.allow):
Text: str
Highlights: Optional[Any]
class AdditionalResultAttribute(BaseModel, extra=Extra.allow):
Key: str
ValueType: Literal["TEXT_WITH_HIGHLIGHTS_VALUE"]
Value: Optional[TextWithHighLights]
def get_value_text(self) -> str:
if not self.Value:
return ""
else:
return self.Value.Text
class QueryResultItem(BaseModel, extra=Extra.allow):
DocumentId: str
DocumentTitle: TextWithHighLights
DocumentURI: Optional[str]
FeedbackToken: Optional[str]
Format: Optional[str]
Id: Optional[str]
Type: Optional[str]
AdditionalAttributes: Optional[List[AdditionalResultAttribute]] = []
DocumentExcerpt: Optional[TextWithHighLights]
def get_attribute_value(self) -> str:
if not self.AdditionalAttributes:
return ""
if not self.AdditionalAttributes[0]:
return ""
else:
return self.AdditionalAttributes[0].get_value_text()
def get_excerpt(self) -> str:
if (
self.AdditionalAttributes
and self.AdditionalAttributes[0].Key == "AnswerText"
):
excerpt = self.get_attribute_value()
elif self.DocumentExcerpt:
excerpt = self.DocumentExcerpt.Text
else:
excerpt = ""
return clean_excerpt(excerpt)
def to_doc(self) -> Document:
title = self.DocumentTitle.Text
source = self.DocumentURI
excerpt = self.get_excerpt()
type = self.Type
page_content = combined_text(title, excerpt)
metadata = {"source": source, "title": title, "excerpt": excerpt, "type": type}
return Document(page_content=page_content, metadata=metadata)
class QueryResult(BaseModel, extra=Extra.allow):
ResultItems: List[QueryResultItem]
def get_top_k_docs(self, top_n: int) -> List[Document]:
items_len = len(self.ResultItems)
count = items_len if items_len < top_n else top_n
docs = [self.ResultItems[i].to_doc() for i in range(0, count)]
return docs
class DocumentAttributeValue(BaseModel, extra=Extra.allow):
DateValue: Optional[str]
LongValue: Optional[int]
StringListValue: Optional[List[str]]
StringValue: Optional[str]
class DocumentAttribute(BaseModel, extra=Extra.allow):
Key: str
Value: DocumentAttributeValue
class RetrieveResultItem(BaseModel, extra=Extra.allow):
Content: Optional[str]
DocumentAttributes: Optional[List[DocumentAttribute]] = []
DocumentId: Optional[str]
DocumentTitle: Optional[str]
DocumentURI: Optional[str]
Id: Optional[str]
def get_excerpt(self) -> str:
if not self.Content:
return ""
return clean_excerpt(self.Content)
def to_doc(self) -> Document:
title = self.DocumentTitle if self.DocumentTitle else ""
source = self.DocumentURI
excerpt = self.get_excerpt()
page_content = combined_text(title, excerpt)
metadata = {"source": source, "title": title, "excerpt": excerpt}
return Document(page_content=page_content, metadata=metadata)
class RetrieveResult(BaseModel, extra=Extra.allow):
QueryId: str
ResultItems: List[RetrieveResultItem]
def get_top_k_docs(self, top_n: int) -> List[Document]:
items_len = len(self.ResultItems)
count = items_len if items_len < top_n else top_n
docs = [self.ResultItems[i].to_doc() for i in range(0, count)]
return docs
class AmazonKendraRetriever(BaseRetriever):
"""Retriever class to query documents from Amazon Kendra Index.
Args:
index_id: Kendra index id
region_name: The aws region e.g., `us-west-2`.
Fallsback to AWS_DEFAULT_REGION env variable
or region specified in ~/.aws/config.
credentials_profile_name: The name of the profile in the ~/.aws/credentials
or ~/.aws/config files, which has either access keys or role information
specified. If not specified, the default credential profile or, if on an
EC2 instance, credentials from IMDS will be used.
top_k: No of results to return
attribute_filter: Additional filtering of results based on metadata
See: https://docs.aws.amazon.com/kendra/latest/APIReference
client: boto3 client for Kendra
Example:
.. code-block:: python
retriever = AmazonKendraRetriever(
index_id="c0806df7-e76b-4bce-9b5c-d5582f6b1a03"
)
"""
def __init__(
self,
index_id: str,
region_name: Optional[str] = None,
credentials_profile_name: Optional[str] = None,
top_k: int = 3,
attribute_filter: Optional[Dict] = None,
client: Optional[Any] = None,
):
self.index_id = index_id
self.top_k = top_k
self.attribute_filter = attribute_filter
if client is not None:
self.client = client
return
try:
import boto3
if credentials_profile_name is not None:
session = boto3.Session(profile_name=credentials_profile_name)
else:
# use default credentials
session = boto3.Session()
client_params = {}
if region_name is not None:
client_params["region_name"] = region_name
self.client = session.client("kendra", **client_params)
except ImportError:
raise ModuleNotFoundError(
"Could not import boto3 python package. "
"Please install it with `pip install boto3`."
)
except Exception as e:
raise ValueError(
"Could not load credentials to authenticate with AWS client. "
"Please check that credentials in the specified "
"profile name are valid."
) from e
def _kendra_query(
self,
query: str,
top_k: int,
attribute_filter: Optional[Dict] = None,
) -> List[Document]:
if attribute_filter is not None:
response = self.client.retrieve(
IndexId=self.index_id,
QueryText=query.strip(),
PageSize=top_k,
AttributeFilter=attribute_filter,
)
else:
response = self.client.retrieve(
IndexId=self.index_id, QueryText=query.strip(), PageSize=top_k
)
r_result = RetrieveResult.parse_obj(response)
result_len = len(r_result.ResultItems)
if result_len == 0:
# retrieve API returned 0 results, call query API
if attribute_filter is not None:
response = self.client.query(
IndexId=self.index_id,
QueryText=query.strip(),
PageSize=top_k,
AttributeFilter=attribute_filter,
)
else:
response = self.client.query(
IndexId=self.index_id, QueryText=query.strip(), PageSize=top_k
)
q_result = QueryResult.parse_obj(response)
docs = q_result.get_top_k_docs(top_k)
else:
docs = r_result.get_top_k_docs(top_k)
return docs
def get_relevant_documents(self, query: str) -> List[Document]:
"""Run search on Kendra index and get top k documents
Example:
.. code-block:: python
docs = retriever.get_relevant_documents('This is my query')
"""
docs = self._kendra_query(query, self.top_k, self.attribute_filter)
return docs
async def aget_relevant_documents(self, query: str) -> List[Document]:
raise NotImplementedError("Async version is not implemented for Kendra yet.")