mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-23 11:30:37 +00:00
langchain[minor]: Add retriever for Knowledge Bases for Amazon Bedrock (#13980)
- **Description:** Adds a retriever implementation for [Knowledge Bases
for Amazon Bedrock](https://aws.amazon.com/bedrock/knowledge-bases/), a
new service announced at AWS re:Invent, shortly before this PR was
opened. This depends on the `bedrock-agent-runtime` service, which will
be included in a future version of `boto3` and of `botocore`. We will
open a follow-up PR documenting the minimum required versions of `boto3`
and `botocore` after that information is available.
- **Issue:** N/A
- **Dependencies:** `boto3>=1.33.2, botocore>=1.33.2`
- **Tag maintainer:** @baskaryan
- **Twitter handles:** `@pjain7` `@dead_letter_q`
This PR includes a documentation notebook under
`docs/docs/integrations/retrievers`, which I (@dlqqq) have verified
independently.
EDIT: `bedrock-agent-runtime` service is now included in
`boto3>=1.33.2`:
5cf793f493
---------
Co-authored-by: Piyush Jain <piyushjain@duck.com>
Co-authored-by: Erick Friis <erick@langchain.dev>
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
@@ -21,6 +21,7 @@ the backbone of a retriever, but there are other types of retrievers as well.
|
||||
from langchain.retrievers.arcee import ArceeRetriever
|
||||
from langchain.retrievers.arxiv import ArxivRetriever
|
||||
from langchain.retrievers.azure_cognitive_search import AzureCognitiveSearchRetriever
|
||||
from langchain.retrievers.bedrock import AmazonKnowledgeBasesRetriever
|
||||
from langchain.retrievers.bm25 import BM25Retriever
|
||||
from langchain.retrievers.chaindesk import ChaindeskRetriever
|
||||
from langchain.retrievers.chatgpt_plugin_retriever import ChatGPTPluginRetriever
|
||||
@@ -72,6 +73,7 @@ from langchain.retrievers.zilliz import ZillizRetriever
|
||||
|
||||
__all__ = [
|
||||
"AmazonKendraRetriever",
|
||||
"AmazonKnowledgeBasesRetriever",
|
||||
"ArceeRetriever",
|
||||
"ArxivRetriever",
|
||||
"AzureCognitiveSearchRetriever",
|
||||
|
124
libs/langchain/langchain/retrievers/bedrock.py
Normal file
124
libs/langchain/langchain/retrievers/bedrock.py
Normal file
@@ -0,0 +1,124 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.pydantic_v1 import BaseModel, root_validator
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
|
||||
|
||||
class VectorSearchConfig(BaseModel, extra="allow"): # type: ignore[call-arg]
|
||||
numberOfResults: int = 4
|
||||
|
||||
|
||||
class RetrievalConfig(BaseModel, extra="allow"): # type: ignore[call-arg]
|
||||
vectorSearchConfiguration: VectorSearchConfig
|
||||
|
||||
|
||||
class AmazonKnowledgeBasesRetriever(BaseRetriever):
|
||||
"""A retriever class for `Amazon Bedrock Knowledge Bases`.
|
||||
|
||||
See https://aws.amazon.com/bedrock/knowledge-bases for more info.
|
||||
|
||||
Args:
|
||||
knowledge_base_id: Knowledge Base ID.
|
||||
region_name: The aws region e.g., `us-west-2`.
|
||||
Fallback to AWS_DEFAULT_REGION env variable or region specified in
|
||||
~/.aws/config.
|
||||
credentials_profile_name: The name of the profile in the ~/.aws/credentials
|
||||
or ~/.aws/config files, which has either access keys or role information
|
||||
specified. If not specified, the default credential profile or, if on an
|
||||
EC2 instance, credentials from IMDS will be used.
|
||||
client: boto3 client for bedrock agent runtime.
|
||||
retrieval_config: Configuration for retrieval.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.retrievers import AmazonKnowledgeBasesRetriever
|
||||
|
||||
retriever = AmazonKnowledgeBasesRetriever(
|
||||
knowledge_base_id="<knowledge-base-id>",
|
||||
retrieval_config={
|
||||
"vectorSearchConfiguration": {
|
||||
"numberOfResults": 4
|
||||
}
|
||||
},
|
||||
)
|
||||
"""
|
||||
|
||||
knowledge_base_id: str
|
||||
region_name: Optional[str] = None
|
||||
credentials_profile_name: Optional[str] = None
|
||||
endpoint_url: Optional[str] = None
|
||||
client: Any
|
||||
retrieval_config: RetrievalConfig
|
||||
|
||||
@root_validator(pre=True)
|
||||
def create_client(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if values.get("client") is not None:
|
||||
return values
|
||||
|
||||
try:
|
||||
import boto3
|
||||
from botocore.client import Config
|
||||
from botocore.exceptions import UnknownServiceError
|
||||
|
||||
if values.get("credentials_profile_name"):
|
||||
session = boto3.Session(profile_name=values["credentials_profile_name"])
|
||||
else:
|
||||
# use default credentials
|
||||
session = boto3.Session()
|
||||
|
||||
client_params = {
|
||||
"config": Config(
|
||||
connect_timeout=120, read_timeout=120, retries={"max_attempts": 0}
|
||||
)
|
||||
}
|
||||
if values.get("region_name"):
|
||||
client_params["region_name"] = values["region_name"]
|
||||
|
||||
if values.get("endpoint_url"):
|
||||
client_params["endpoint_url"] = values["endpoint_url"]
|
||||
|
||||
values["client"] = session.client("bedrock-agent-runtime", **client_params)
|
||||
|
||||
return values
|
||||
except ImportError:
|
||||
raise ModuleNotFoundError(
|
||||
"Could not import boto3 python package. "
|
||||
"Please install it with `pip install boto3`."
|
||||
)
|
||||
except UnknownServiceError as e:
|
||||
raise ModuleNotFoundError(
|
||||
"Ensure that you have installed the latest boto3 package "
|
||||
"that contains the API for `bedrock-runtime-agent`."
|
||||
) from e
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
"Could not load credentials to authenticate with AWS client. "
|
||||
"Please check that credentials in the specified "
|
||||
"profile name are valid."
|
||||
) from e
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
response = self.client.retrieve(
|
||||
retrievalQuery={"text": query.strip()},
|
||||
knowledgeBaseId=self.knowledge_base_id,
|
||||
retrievalConfiguration=self.retrieval_config.dict(),
|
||||
)
|
||||
results = response["retrievalResults"]
|
||||
documents = []
|
||||
for result in results:
|
||||
documents.append(
|
||||
Document(
|
||||
page_content=result["content"]["text"],
|
||||
metadata={
|
||||
"location": result["location"],
|
||||
"score": result["score"] if "score" in result else 0,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return documents
|
@@ -2,6 +2,7 @@ from langchain.retrievers import __all__
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"AmazonKendraRetriever",
|
||||
"AmazonKnowledgeBasesRetriever",
|
||||
"ArceeRetriever",
|
||||
"ArxivRetriever",
|
||||
"AzureCognitiveSearchRetriever",
|
||||
|
Reference in New Issue
Block a user