mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-28 01:19:31 +00:00
Kendra retriever api (#6616)
## Description Replaces [Kendra Retriever](https://github.com/hwchase17/langchain/blob/master/langchain/retrievers/aws_kendra_index_retriever.py) with an updated version that uses the new [retriever API](https://docs.aws.amazon.com/kendra/latest/dg/searching-retrieve.html) which is better suited for retrieval augmented generation (RAG) systems. **Note**: This change requires the latest version (1.26.159) of boto3 to work. `pip install -U boto3` to upgrade the boto3 version. cc @hupe1980 cc @dev2049
This commit is contained in:
parent
4e5d78579b
commit
b1de927f1b
@ -26,7 +26,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#!pip install boto3"
|
||||
"%pip install boto3"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -36,7 +36,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import boto3\n",
|
||||
"from langchain.retrievers import AwsKendraIndexRetriever"
|
||||
"from langchain.retrievers import AmazonKendraRetriever"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -53,11 +53,8 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"kclient = boto3.client(\"kendra\", region_name=\"us-east-1\")\n",
|
||||
"\n",
|
||||
"retriever = AwsKendraIndexRetriever(\n",
|
||||
" kclient=kclient,\n",
|
||||
" kendraindex=\"kendraindex\",\n",
|
||||
"retriever = AmazonKendraRetriever(\n",
|
||||
" index_id=\"c0806df7-e76b-4bce-9b5c-d5582f6b1a03\"\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
@ -66,7 +63,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now you can use retrieved documents from AWS Kendra Index"
|
||||
"Now you can use retrieved documents from Kendra index"
|
||||
]
|
||||
},
|
||||
{
|
@ -1,11 +1,11 @@
|
||||
from langchain.retrievers.arxiv import ArxivRetriever
|
||||
from langchain.retrievers.aws_kendra_index_retriever import AwsKendraIndexRetriever
|
||||
from langchain.retrievers.azure_cognitive_search import AzureCognitiveSearchRetriever
|
||||
from langchain.retrievers.chatgpt_plugin_retriever import ChatGPTPluginRetriever
|
||||
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
|
||||
from langchain.retrievers.databerry import DataberryRetriever
|
||||
from langchain.retrievers.docarray import DocArrayRetriever
|
||||
from langchain.retrievers.elastic_search_bm25 import ElasticSearchBM25Retriever
|
||||
from langchain.retrievers.kendra import AmazonKendraRetriever
|
||||
from langchain.retrievers.knn import KNNRetriever
|
||||
from langchain.retrievers.llama_index import (
|
||||
LlamaIndexGraphRetriever,
|
||||
@ -30,8 +30,8 @@ from langchain.retrievers.zep import ZepRetriever
|
||||
from langchain.retrievers.zilliz import ZillizRetriever
|
||||
|
||||
__all__ = [
|
||||
"AmazonKendraRetriever",
|
||||
"ArxivRetriever",
|
||||
"AwsKendraIndexRetriever",
|
||||
"AzureCognitiveSearchRetriever",
|
||||
"ChatGPTPluginRetriever",
|
||||
"ContextualCompressionRetriever",
|
||||
|
@ -1,95 +0,0 @@
|
||||
"""Retriever wrapper for AWS Kendra."""
|
||||
import re
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
|
||||
|
||||
class AwsKendraIndexRetriever(BaseRetriever):
|
||||
"""Wrapper around AWS Kendra."""
|
||||
|
||||
kendraindex: str
|
||||
"""Kendra index id"""
|
||||
k: int
|
||||
"""Number of documents to query for."""
|
||||
languagecode: str
|
||||
"""Languagecode used for querying."""
|
||||
kclient: Any
|
||||
""" boto3 client for Kendra. """
|
||||
|
||||
def __init__(
|
||||
self, kclient: Any, kendraindex: str, k: int = 3, languagecode: str = "en"
|
||||
):
|
||||
self.kendraindex = kendraindex
|
||||
self.k = k
|
||||
self.languagecode = languagecode
|
||||
self.kclient = kclient
|
||||
|
||||
def _clean_result(self, res_text: str) -> str:
|
||||
return re.sub("\s+", " ", res_text).replace("...", "")
|
||||
|
||||
def _get_top_n_results(self, resp: Dict, count: int) -> Document:
|
||||
r = resp["ResultItems"][count]
|
||||
doc_title = r["DocumentTitle"]["Text"]
|
||||
doc_uri = r["DocumentURI"]
|
||||
r_type = r["Type"]
|
||||
|
||||
if (
|
||||
r["AdditionalAttributes"]
|
||||
and r["AdditionalAttributes"][0]["Key"] == "AnswerText"
|
||||
):
|
||||
res_text = r["AdditionalAttributes"][0]["Value"]["TextWithHighlightsValue"][
|
||||
"Text"
|
||||
]
|
||||
else:
|
||||
res_text = r["DocumentExcerpt"]["Text"]
|
||||
|
||||
doc_excerpt = self._clean_result(res_text)
|
||||
combined_text = f"""Document Title: {doc_title}
|
||||
Document Excerpt: {doc_excerpt}
|
||||
"""
|
||||
|
||||
return Document(
|
||||
page_content=combined_text,
|
||||
metadata={
|
||||
"source": doc_uri,
|
||||
"title": doc_title,
|
||||
"excerpt": doc_excerpt,
|
||||
"type": r_type,
|
||||
},
|
||||
)
|
||||
|
||||
def _kendra_query(self, kquery: str) -> List[Document]:
|
||||
response = self.kclient.query(
|
||||
IndexId=self.kendraindex,
|
||||
QueryText=kquery.strip(),
|
||||
AttributeFilter={
|
||||
"AndAllFilters": [
|
||||
{
|
||||
"EqualsTo": {
|
||||
"Key": "_language_code",
|
||||
"Value": {
|
||||
"StringValue": self.languagecode,
|
||||
},
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
if len(response["ResultItems"]) > self.k:
|
||||
r_count = self.k
|
||||
else:
|
||||
r_count = len(response["ResultItems"])
|
||||
|
||||
return [self._get_top_n_results(response, i) for i in range(0, r_count)]
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
"""Run search on Kendra index and get top k documents
|
||||
|
||||
docs = get_relevant_documents('This is my query')
|
||||
"""
|
||||
return self._kendra_query(query)
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
raise NotImplementedError("AwsKendraIndexRetriever does not support async")
|
272
langchain/retrievers/kendra.py
Normal file
272
langchain/retrievers/kendra.py
Normal file
@ -0,0 +1,272 @@
|
||||
import re
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.schema import BaseRetriever
|
||||
|
||||
|
||||
def clean_excerpt(excerpt: str) -> str:
|
||||
if not excerpt:
|
||||
return excerpt
|
||||
res = re.sub("\s+", " ", excerpt).replace("...", "")
|
||||
return res
|
||||
|
||||
|
||||
def combined_text(title: str, excerpt: str) -> str:
|
||||
if not title or not excerpt:
|
||||
return ""
|
||||
return f"Document Title: {title} \nDocument Excerpt: \n{excerpt}\n"
|
||||
|
||||
|
||||
class Highlight(BaseModel, extra=Extra.allow):
|
||||
BeginOffset: int
|
||||
EndOffset: int
|
||||
TopAnswer: Optional[bool]
|
||||
Type: Optional[str]
|
||||
|
||||
|
||||
class TextWithHighLights(BaseModel, extra=Extra.allow):
|
||||
Text: str
|
||||
Highlights: Optional[Any]
|
||||
|
||||
|
||||
class AdditionalResultAttribute(BaseModel, extra=Extra.allow):
|
||||
Key: str
|
||||
ValueType: Literal["TEXT_WITH_HIGHLIGHTS_VALUE"]
|
||||
Value: Optional[TextWithHighLights]
|
||||
|
||||
def get_value_text(self) -> str:
|
||||
if not self.Value:
|
||||
return ""
|
||||
else:
|
||||
return self.Value.Text
|
||||
|
||||
|
||||
class QueryResultItem(BaseModel, extra=Extra.allow):
|
||||
DocumentId: str
|
||||
DocumentTitle: TextWithHighLights
|
||||
DocumentURI: Optional[str]
|
||||
FeedbackToken: Optional[str]
|
||||
Format: Optional[str]
|
||||
Id: Optional[str]
|
||||
Type: Optional[str]
|
||||
AdditionalAttributes: Optional[List[AdditionalResultAttribute]] = []
|
||||
DocumentExcerpt: Optional[TextWithHighLights]
|
||||
|
||||
def get_attribute_value(self) -> str:
|
||||
if not self.AdditionalAttributes:
|
||||
return ""
|
||||
if not self.AdditionalAttributes[0]:
|
||||
return ""
|
||||
else:
|
||||
return self.AdditionalAttributes[0].get_value_text()
|
||||
|
||||
def get_excerpt(self) -> str:
|
||||
if (
|
||||
self.AdditionalAttributes
|
||||
and self.AdditionalAttributes[0].Key == "AnswerText"
|
||||
):
|
||||
excerpt = self.get_attribute_value()
|
||||
elif self.DocumentExcerpt:
|
||||
excerpt = self.DocumentExcerpt.Text
|
||||
else:
|
||||
excerpt = ""
|
||||
|
||||
return clean_excerpt(excerpt)
|
||||
|
||||
def to_doc(self) -> Document:
|
||||
title = self.DocumentTitle.Text
|
||||
source = self.DocumentURI
|
||||
excerpt = self.get_excerpt()
|
||||
type = self.Type
|
||||
page_content = combined_text(title, excerpt)
|
||||
metadata = {"source": source, "title": title, "excerpt": excerpt, "type": type}
|
||||
return Document(page_content=page_content, metadata=metadata)
|
||||
|
||||
|
||||
class QueryResult(BaseModel, extra=Extra.allow):
|
||||
ResultItems: List[QueryResultItem]
|
||||
|
||||
def get_top_k_docs(self, top_n: int) -> List[Document]:
|
||||
items_len = len(self.ResultItems)
|
||||
count = items_len if items_len < top_n else top_n
|
||||
docs = [self.ResultItems[i].to_doc() for i in range(0, count)]
|
||||
|
||||
return docs
|
||||
|
||||
|
||||
class DocumentAttributeValue(BaseModel, extra=Extra.allow):
|
||||
DateValue: Optional[str]
|
||||
LongValue: Optional[int]
|
||||
StringListValue: Optional[List[str]]
|
||||
StringValue: Optional[str]
|
||||
|
||||
|
||||
class DocumentAttribute(BaseModel, extra=Extra.allow):
|
||||
Key: str
|
||||
Value: DocumentAttributeValue
|
||||
|
||||
|
||||
class RetrieveResultItem(BaseModel, extra=Extra.allow):
|
||||
Content: Optional[str]
|
||||
DocumentAttributes: Optional[List[DocumentAttribute]] = []
|
||||
DocumentId: Optional[str]
|
||||
DocumentTitle: Optional[str]
|
||||
DocumentURI: Optional[str]
|
||||
Id: Optional[str]
|
||||
|
||||
def get_excerpt(self) -> str:
|
||||
if not self.Content:
|
||||
return ""
|
||||
return clean_excerpt(self.Content)
|
||||
|
||||
def to_doc(self) -> Document:
|
||||
title = self.DocumentTitle if self.DocumentTitle else ""
|
||||
source = self.DocumentURI
|
||||
excerpt = self.get_excerpt()
|
||||
page_content = combined_text(title, excerpt)
|
||||
metadata = {"source": source, "title": title, "excerpt": excerpt}
|
||||
return Document(page_content=page_content, metadata=metadata)
|
||||
|
||||
|
||||
class RetrieveResult(BaseModel, extra=Extra.allow):
|
||||
QueryId: str
|
||||
ResultItems: List[RetrieveResultItem]
|
||||
|
||||
def get_top_k_docs(self, top_n: int) -> List[Document]:
|
||||
items_len = len(self.ResultItems)
|
||||
count = items_len if items_len < top_n else top_n
|
||||
docs = [self.ResultItems[i].to_doc() for i in range(0, count)]
|
||||
|
||||
return docs
|
||||
|
||||
|
||||
class AmazonKendraRetriever(BaseRetriever):
|
||||
"""Retriever class to query documents from Amazon Kendra Index.
|
||||
|
||||
Args:
|
||||
index_id: Kendra index id
|
||||
|
||||
region_name: The aws region e.g., `us-west-2`.
|
||||
Fallsback to AWS_DEFAULT_REGION env variable
|
||||
or region specified in ~/.aws/config.
|
||||
|
||||
credentials_profile_name: The name of the profile in the ~/.aws/credentials
|
||||
or ~/.aws/config files, which has either access keys or role information
|
||||
specified. If not specified, the default credential profile or, if on an
|
||||
EC2 instance, credentials from IMDS will be used.
|
||||
|
||||
top_k: No of results to return
|
||||
|
||||
attribute_filter: Additional filtering of results based on metadata
|
||||
See: https://docs.aws.amazon.com/kendra/latest/APIReference
|
||||
|
||||
client: boto3 client for Kendra
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
retriever = AmazonKendraRetriever(
|
||||
index_id="c0806df7-e76b-4bce-9b5c-d5582f6b1a03"
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index_id: str,
|
||||
region_name: Optional[str] = None,
|
||||
credentials_profile_name: Optional[str] = None,
|
||||
top_k: int = 3,
|
||||
attribute_filter: Optional[Dict] = None,
|
||||
client: Optional[Any] = None,
|
||||
):
|
||||
self.index_id = index_id
|
||||
self.top_k = top_k
|
||||
self.attribute_filter = attribute_filter
|
||||
|
||||
if client is not None:
|
||||
self.client = client
|
||||
return
|
||||
|
||||
try:
|
||||
import boto3
|
||||
|
||||
if credentials_profile_name is not None:
|
||||
session = boto3.Session(profile_name=credentials_profile_name)
|
||||
else:
|
||||
# use default credentials
|
||||
session = boto3.Session()
|
||||
|
||||
client_params = {}
|
||||
if region_name is not None:
|
||||
client_params["region_name"] = region_name
|
||||
|
||||
self.client = session.client("kendra", **client_params)
|
||||
except ImportError:
|
||||
raise ModuleNotFoundError(
|
||||
"Could not import boto3 python package. "
|
||||
"Please install it with `pip install boto3`."
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
"Could not load credentials to authenticate with AWS client. "
|
||||
"Please check that credentials in the specified "
|
||||
"profile name are valid."
|
||||
) from e
|
||||
|
||||
def _kendra_query(
|
||||
self,
|
||||
query: str,
|
||||
top_k: int,
|
||||
attribute_filter: Optional[Dict] = None,
|
||||
) -> List[Document]:
|
||||
if attribute_filter is not None:
|
||||
response = self.client.retrieve(
|
||||
IndexId=self.index_id,
|
||||
QueryText=query.strip(),
|
||||
PageSize=top_k,
|
||||
AttributeFilter=attribute_filter,
|
||||
)
|
||||
else:
|
||||
response = self.client.retrieve(
|
||||
IndexId=self.index_id, QueryText=query.strip(), PageSize=top_k
|
||||
)
|
||||
r_result = RetrieveResult.parse_obj(response)
|
||||
result_len = len(r_result.ResultItems)
|
||||
|
||||
if result_len == 0:
|
||||
# retrieve API returned 0 results, call query API
|
||||
if attribute_filter is not None:
|
||||
response = self.client.query(
|
||||
IndexId=self.index_id,
|
||||
QueryText=query.strip(),
|
||||
PageSize=top_k,
|
||||
AttributeFilter=attribute_filter,
|
||||
)
|
||||
else:
|
||||
response = self.client.query(
|
||||
IndexId=self.index_id, QueryText=query.strip(), PageSize=top_k
|
||||
)
|
||||
q_result = QueryResult.parse_obj(response)
|
||||
docs = q_result.get_top_k_docs(top_k)
|
||||
else:
|
||||
docs = r_result.get_top_k_docs(top_k)
|
||||
return docs
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
"""Run search on Kendra index and get top k documents
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
docs = retriever.get_relevant_documents('This is my query')
|
||||
|
||||
"""
|
||||
docs = self._kendra_query(query, self.top_k, self.attribute_filter)
|
||||
return docs
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
raise NotImplementedError("Async version is not implemented for Kendra yet.")
|
Loading…
Reference in New Issue
Block a user