From 9fb6805be43f8c1a7c9a6e037db3b161c1075c5c Mon Sep 17 00:00:00 2001 From: david qiu Date: Tue, 28 Nov 2023 14:10:23 -0800 Subject: [PATCH] langchain[minor]: Add retriever for Knowledge Bases for Amazon Bedrock (#13980) - **Description:** Adds a retriever implementation for [Knowledge Bases for Amazon Bedrock](https://aws.amazon.com/bedrock/knowledge-bases/), a new service announced at AWS re:Invent, shortly before this PR was opened. This depends on the `bedrock-agent-runtime` service, which will be included in a future version of `boto3` and of `botocore`. We will open a follow-up PR documenting the minimum required versions of `boto3` and `botocore` after that information is available. - **Issue:** N/A - **Dependencies:** `boto3>=1.33.2, botocore>=1.33.2` - **Tag maintainer:** @baskaryan - **Twitter handles:** `@pjain7` `@dead_letter_q` This PR includes a documentation notebook under `docs/docs/integrations/retrievers`, which I (@dlqqq) have verified independently. EDIT: `bedrock-agent-runtime` service is now included in `boto3>=1.33.2`: https://github.com/boto/boto3/commit/5cf793f49369607f2557fd11741c357df866e4a5 --------- Co-authored-by: Piyush Jain Co-authored-by: Erick Friis Co-authored-by: Bagatur --- .../integrations/retrievers/bedrock.ipynb | 117 +++++++++++++++++ .../langchain/retrievers/__init__.py | 2 + .../langchain/langchain/retrievers/bedrock.py | 124 ++++++++++++++++++ .../unit_tests/retrievers/test_imports.py | 1 + 4 files changed, 244 insertions(+) create mode 100644 docs/docs/integrations/retrievers/bedrock.ipynb create mode 100644 libs/langchain/langchain/retrievers/bedrock.py diff --git a/docs/docs/integrations/retrievers/bedrock.ipynb b/docs/docs/integrations/retrievers/bedrock.ipynb new file mode 100644 index 00000000000..2f51537ad33 --- /dev/null +++ b/docs/docs/integrations/retrievers/bedrock.ipynb @@ -0,0 +1,117 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "b6636c27-35da-4ba7-8313-eca21660cab3", + "metadata": {}, + "source": [ + "# Amazon Bedrock (Knowledge Bases)\n", + "\n", + "> [Knowledge bases for Amazon Bedrock](https://aws.amazon.com/bedrock/knowledge-bases/) is an Amazon Web Services (AWS) offering which lets you quickly build RAG applications by using your private data to customize FM response.\n", + "\n", + "> Implementing RAG requires organizations to perform several cumbersome steps to convert data into embeddings (vectors), store the embeddings in a specialized vector database, and build custom integrations into the database to search and retrieve text relevant to the user’s query. This can be time-consuming and inefficient.\n", + "\n", + "> With Knowledge Bases for Amazon Bedrock, simply point to the location of your data in Amazon S3, and Knowledge Bases for Amazon Bedrock takes care of the entire ingestion workflow into your vector database. If you do not have an existing vector database, Amazon Bedrock creates an Amazon OpenSearch Serverless vector store for you. For retrievals, use the Langchain - Amazon Bedrock integration via the Retrieve API to retrieve relevant results for a user query from knowledge bases.\n", + "\n", + "> Knowledge base can be configured through [AWS Console](https://aws.amazon.com/console/) or by using [AWS SDKs](https://aws.amazon.com/developer/tools/)." + ] + }, + { + "cell_type": "markdown", + "id": "b34c8cbe-c6e5-4398-adf1-4925204bcaed", + "metadata": {}, + "source": [ + "## Using the Knowledge Bases Retriever" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "26c97d36-911c-4fe0-a478-546192728f30", + "metadata": {}, + "outputs": [], + "source": [ + "%pip install boto3" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "30337664-8844-4dfe-97db-077abb51af68", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.retrievers import AmazonKnowledgeBasesRetriever\n", + "\n", + "retriever = AmazonKnowledgeBasesRetriever(\n", + " knowledge_base_id=\"PUIJP4EQUA\",\n", + " retrieval_config={\"vectorSearchConfiguration\": {\"numberOfResults\": 4}},\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f9fefa50-f0fb-40e3-b4e4-67c5b232a090", + "metadata": {}, + "outputs": [], + "source": [ + "query = \"What did the president say about Ketanji Brown?\"\n", + "\n", + "retriever.get_relevant_documents(query=query)" + ] + }, + { + "cell_type": "markdown", + "id": "7de9b61b-597b-4aba-95fb-49d11e84510e", + "metadata": {}, + "source": [ + "### Using in a QA Chain" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0fd71709-aaed-42b5-a990-e3067bfa7143", + "metadata": {}, + "outputs": [], + "source": [ + "from botocore.client import Config\n", + "\n", + "from langchain.chains import RetrievalQA\n", + "from langchain.llms import Bedrock\n", + "\n", + "model_kwargs_claude = {\"temperature\": 0, \"top_k\": 10, \"max_tokens_to_sample\": 3000}\n", + "\n", + "llm = Bedrock(model_id=\"anthropic.claude-v2\", model_kwargs=model_kwargs_claude)\n", + "\n", + "qa = RetrievalQA.from_chain_type(\n", + " llm=llm, retriever=retriever, return_source_documents=True\n", + ")\n", + "\n", + "qa(query)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/libs/langchain/langchain/retrievers/__init__.py b/libs/langchain/langchain/retrievers/__init__.py index 0652252651a..0a695730997 100644 --- a/libs/langchain/langchain/retrievers/__init__.py +++ b/libs/langchain/langchain/retrievers/__init__.py @@ -21,6 +21,7 @@ the backbone of a retriever, but there are other types of retrievers as well. from langchain.retrievers.arcee import ArceeRetriever from langchain.retrievers.arxiv import ArxivRetriever from langchain.retrievers.azure_cognitive_search import AzureCognitiveSearchRetriever +from langchain.retrievers.bedrock import AmazonKnowledgeBasesRetriever from langchain.retrievers.bm25 import BM25Retriever from langchain.retrievers.chaindesk import ChaindeskRetriever from langchain.retrievers.chatgpt_plugin_retriever import ChatGPTPluginRetriever @@ -72,6 +73,7 @@ from langchain.retrievers.zilliz import ZillizRetriever __all__ = [ "AmazonKendraRetriever", + "AmazonKnowledgeBasesRetriever", "ArceeRetriever", "ArxivRetriever", "AzureCognitiveSearchRetriever", diff --git a/libs/langchain/langchain/retrievers/bedrock.py b/libs/langchain/langchain/retrievers/bedrock.py new file mode 100644 index 00000000000..06d7b8cdaa7 --- /dev/null +++ b/libs/langchain/langchain/retrievers/bedrock.py @@ -0,0 +1,124 @@ +from typing import Any, Dict, List, Optional + +from langchain_core.callbacks import CallbackManagerForRetrieverRun +from langchain_core.documents import Document +from langchain_core.pydantic_v1 import BaseModel, root_validator +from langchain_core.retrievers import BaseRetriever + + +class VectorSearchConfig(BaseModel, extra="allow"): # type: ignore[call-arg] + numberOfResults: int = 4 + + +class RetrievalConfig(BaseModel, extra="allow"): # type: ignore[call-arg] + vectorSearchConfiguration: VectorSearchConfig + + +class AmazonKnowledgeBasesRetriever(BaseRetriever): + """A retriever class for `Amazon Bedrock Knowledge Bases`. + + See https://aws.amazon.com/bedrock/knowledge-bases for more info. + + Args: + knowledge_base_id: Knowledge Base ID. + region_name: The aws region e.g., `us-west-2`. + Fallback to AWS_DEFAULT_REGION env variable or region specified in + ~/.aws/config. + credentials_profile_name: The name of the profile in the ~/.aws/credentials + or ~/.aws/config files, which has either access keys or role information + specified. If not specified, the default credential profile or, if on an + EC2 instance, credentials from IMDS will be used. + client: boto3 client for bedrock agent runtime. + retrieval_config: Configuration for retrieval. + + Example: + .. code-block:: python + + from langchain.retrievers import AmazonKnowledgeBasesRetriever + + retriever = AmazonKnowledgeBasesRetriever( + knowledge_base_id="", + retrieval_config={ + "vectorSearchConfiguration": { + "numberOfResults": 4 + } + }, + ) + """ + + knowledge_base_id: str + region_name: Optional[str] = None + credentials_profile_name: Optional[str] = None + endpoint_url: Optional[str] = None + client: Any + retrieval_config: RetrievalConfig + + @root_validator(pre=True) + def create_client(cls, values: Dict[str, Any]) -> Dict[str, Any]: + if values.get("client") is not None: + return values + + try: + import boto3 + from botocore.client import Config + from botocore.exceptions import UnknownServiceError + + if values.get("credentials_profile_name"): + session = boto3.Session(profile_name=values["credentials_profile_name"]) + else: + # use default credentials + session = boto3.Session() + + client_params = { + "config": Config( + connect_timeout=120, read_timeout=120, retries={"max_attempts": 0} + ) + } + if values.get("region_name"): + client_params["region_name"] = values["region_name"] + + if values.get("endpoint_url"): + client_params["endpoint_url"] = values["endpoint_url"] + + values["client"] = session.client("bedrock-agent-runtime", **client_params) + + return values + except ImportError: + raise ModuleNotFoundError( + "Could not import boto3 python package. " + "Please install it with `pip install boto3`." + ) + except UnknownServiceError as e: + raise ModuleNotFoundError( + "Ensure that you have installed the latest boto3 package " + "that contains the API for `bedrock-runtime-agent`." + ) from e + except Exception as e: + raise ValueError( + "Could not load credentials to authenticate with AWS client. " + "Please check that credentials in the specified " + "profile name are valid." + ) from e + + def _get_relevant_documents( + self, query: str, *, run_manager: CallbackManagerForRetrieverRun + ) -> List[Document]: + response = self.client.retrieve( + retrievalQuery={"text": query.strip()}, + knowledgeBaseId=self.knowledge_base_id, + retrievalConfiguration=self.retrieval_config.dict(), + ) + results = response["retrievalResults"] + documents = [] + for result in results: + documents.append( + Document( + page_content=result["content"]["text"], + metadata={ + "location": result["location"], + "score": result["score"] if "score" in result else 0, + }, + ) + ) + + return documents diff --git a/libs/langchain/tests/unit_tests/retrievers/test_imports.py b/libs/langchain/tests/unit_tests/retrievers/test_imports.py index f0120a8d643..a26d7d48918 100644 --- a/libs/langchain/tests/unit_tests/retrievers/test_imports.py +++ b/libs/langchain/tests/unit_tests/retrievers/test_imports.py @@ -2,6 +2,7 @@ from langchain.retrievers import __all__ EXPECTED_ALL = [ "AmazonKendraRetriever", + "AmazonKnowledgeBasesRetriever", "ArceeRetriever", "ArxivRetriever", "AzureCognitiveSearchRetriever",