From 0d44746430c8d61d6613879308db023418b71f66 Mon Sep 17 00:00:00 2001 From: Vikram Shitole Date: Tue, 24 Oct 2023 00:38:26 +0530 Subject: [PATCH] 10634: Added the capability to inject boto3 client in SagemakerEndpointEmbeddings (#12146) **Description: Allow to inject boto3 client for Cross account access type of scenarios in using SagemakerEndpointEmbeddings and also updated the documentation for same in the sample notebook** **Issue:SagemakerEndpointEmbeddings cross account capability #10634 #10184** Dependencies: None Tag maintainer: Twitter handle:lethargicoder Co-authored-by: Vikram(VS) --- docs/docs/integrations/llms/sagemaker.ipynb | 81 ++++++++++++++++++- .../text_embedding/sagemaker-endpoint.ipynb | 15 +++- .../embeddings/sagemaker_endpoint.py | 16 +++- 3 files changed, 108 insertions(+), 4 deletions(-) diff --git a/docs/docs/integrations/llms/sagemaker.ipynb b/docs/docs/integrations/llms/sagemaker.ipynb index 067aeaaa560..32659dd2aaa 100644 --- a/docs/docs/integrations/llms/sagemaker.ipynb +++ b/docs/docs/integrations/llms/sagemaker.ipynb @@ -82,6 +82,15 @@ "]" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example to initialize with external boto3 session\n", + "\n", + "### for cross account scenarios" + ] + }, { "cell_type": "code", "execution_count": null, @@ -92,7 +101,77 @@ "source": [ "from typing import Dict\n", "\n", - "from langchain.prompts import PromptTemplate\nfrom langchain.llms import SagemakerEndpoint\n", + "from langchain.prompts import PromptTemplate\n", + "from langchain.llms import SagemakerEndpoint\n", + "from langchain.llms.sagemaker_endpoint import LLMContentHandler\n", + "from langchain.chains.question_answering import load_qa_chain\n", + "import json\n", + "import boto3\n", + "\n", + "query = \"\"\"How long was Elizabeth hospitalized?\n", + "\"\"\"\n", + "\n", + "prompt_template = \"\"\"Use the following pieces of context to answer the question at the end.\n", + "\n", + "{context}\n", + "\n", + "Question: {question}\n", + "Answer:\"\"\"\n", + "PROMPT = PromptTemplate(\n", + " template=prompt_template, input_variables=[\"context\", \"question\"]\n", + ")\n", + "\n", + "roleARN = 'arn:aws:iam::123456789:role/cross-account-role'\n", + "sts_client = boto3.client('sts')\n", + "response = sts_client.assume_role(RoleArn=roleARN, \n", + " RoleSessionName='CrossAccountSession')\n", + "\n", + "client = boto3.client(\n", + " \"sagemaker-runtime\",\n", + " region_name=\"us-west-2\", \n", + " aws_access_key_id=response['Credentials']['AccessKeyId'],\n", + " aws_secret_access_key=response['Credentials']['SecretAccessKey'],\n", + " aws_session_token = response['Credentials']['SessionToken']\n", + ")\n", + "\n", + "class ContentHandler(LLMContentHandler):\n", + " content_type = \"application/json\"\n", + " accepts = \"application/json\"\n", + "\n", + " def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:\n", + " input_str = json.dumps({prompt: prompt, **model_kwargs})\n", + " return input_str.encode(\"utf-8\")\n", + "\n", + " def transform_output(self, output: bytes) -> str:\n", + " response_json = json.loads(output.read().decode(\"utf-8\"))\n", + " return response_json[0][\"generated_text\"]\n", + "\n", + "\n", + "content_handler = ContentHandler()\n", + "\n", + "chain = load_qa_chain(\n", + " llm=SagemakerEndpoint(\n", + " endpoint_name=\"endpoint-name\",\n", + " client=client,\n", + " model_kwargs={\"temperature\": 1e-10},\n", + " content_handler=content_handler,\n", + " ),\n", + " prompt=PROMPT,\n", + ")\n", + "\n", + "chain({\"input_documents\": docs, \"question\": query}, return_only_outputs=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Dict\n", + "\n", + "from langchain.prompts import PromptTemplate\n", + "from langchain.llms import SagemakerEndpoint\n", "from langchain.llms.sagemaker_endpoint import LLMContentHandler\n", "from langchain.chains.question_answering import load_qa_chain\n", "import json\n", diff --git a/docs/docs/integrations/text_embedding/sagemaker-endpoint.ipynb b/docs/docs/integrations/text_embedding/sagemaker-endpoint.ipynb index ec80112e101..98d423890db 100644 --- a/docs/docs/integrations/text_embedding/sagemaker-endpoint.ipynb +++ b/docs/docs/integrations/text_embedding/sagemaker-endpoint.ipynb @@ -43,7 +43,7 @@ "from langchain.embeddings import SagemakerEndpointEmbeddings\n", "from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler\n", "import json\n", - "\n", + "import boto3\n", "\n", "class ContentHandler(EmbeddingsContentHandler):\n", " content_type = \"application/json\"\n", @@ -87,7 +87,18 @@ " endpoint_name=\"huggingface-pytorch-inference-2023-03-21-16-14-03-834\",\n", " region_name=\"us-east-1\",\n", " content_handler=content_handler,\n", - ")" + ")\n", + "\n", + "\n", + "# client = boto3.client(\n", + "# \"sagemaker-runtime\",\n", + "# region_name=\"us-west-2\" \n", + "# )\n", + "# embeddings = SagemakerEndpointEmbeddings(\n", + "# endpoint_name=\"huggingface-pytorch-inference-2023-03-21-16-14-03-834\", \n", + "# client=client\n", + "# content_handler=content_handler,\n", + "# )" ] }, { diff --git a/libs/langchain/langchain/embeddings/sagemaker_endpoint.py b/libs/langchain/langchain/embeddings/sagemaker_endpoint.py index 6bfd29f2c22..0e724624ae8 100644 --- a/libs/langchain/langchain/embeddings/sagemaker_endpoint.py +++ b/libs/langchain/langchain/embeddings/sagemaker_endpoint.py @@ -46,8 +46,18 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings): region_name=region_name, credentials_profile_name=credentials_profile_name ) + + #Use with boto3 client + client = boto3.client( + "sagemaker-runtime", + region_name=region_name + ) + se = SagemakerEndpointEmbeddings( + endpoint_name=endpoint_name, + client=client + ) """ - client: Any #: :meta private: + client: Any = None endpoint_name: str = "" """The name of the endpoint from the deployed Sagemaker model. @@ -106,6 +116,10 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings): @root_validator() def validate_environment(cls, values: Dict) -> Dict: + """Dont do anything if client provided externally""" + if values.get("client") is not None: + return values + """Validate that AWS credentials to and python package exists in environment.""" try: import boto3