mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-02 03:15:11 +00:00
10634: Added the capability to inject boto3 client in SagemakerEndpointEmbeddings (#12146)
**Description: Allow to inject boto3 client for Cross account access type of scenarios in using SagemakerEndpointEmbeddings and also updated the documentation for same in the sample notebook** **Issue:SagemakerEndpointEmbeddings cross account capability #10634 #10184** Dependencies: None Tag maintainer: Twitter handle:lethargicoder Co-authored-by: Vikram(VS) <vssht@amazon.com>
This commit is contained in:
parent
ff79a99825
commit
0d44746430
@ -82,6 +82,15 @@
|
|||||||
"]"
|
"]"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Example to initialize with external boto3 session\n",
|
||||||
|
"\n",
|
||||||
|
"### for cross account scenarios"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
@ -92,7 +101,77 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"from typing import Dict\n",
|
"from typing import Dict\n",
|
||||||
"\n",
|
"\n",
|
||||||
"from langchain.prompts import PromptTemplate\nfrom langchain.llms import SagemakerEndpoint\n",
|
"from langchain.prompts import PromptTemplate\n",
|
||||||
|
"from langchain.llms import SagemakerEndpoint\n",
|
||||||
|
"from langchain.llms.sagemaker_endpoint import LLMContentHandler\n",
|
||||||
|
"from langchain.chains.question_answering import load_qa_chain\n",
|
||||||
|
"import json\n",
|
||||||
|
"import boto3\n",
|
||||||
|
"\n",
|
||||||
|
"query = \"\"\"How long was Elizabeth hospitalized?\n",
|
||||||
|
"\"\"\"\n",
|
||||||
|
"\n",
|
||||||
|
"prompt_template = \"\"\"Use the following pieces of context to answer the question at the end.\n",
|
||||||
|
"\n",
|
||||||
|
"{context}\n",
|
||||||
|
"\n",
|
||||||
|
"Question: {question}\n",
|
||||||
|
"Answer:\"\"\"\n",
|
||||||
|
"PROMPT = PromptTemplate(\n",
|
||||||
|
" template=prompt_template, input_variables=[\"context\", \"question\"]\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"roleARN = 'arn:aws:iam::123456789:role/cross-account-role'\n",
|
||||||
|
"sts_client = boto3.client('sts')\n",
|
||||||
|
"response = sts_client.assume_role(RoleArn=roleARN, \n",
|
||||||
|
" RoleSessionName='CrossAccountSession')\n",
|
||||||
|
"\n",
|
||||||
|
"client = boto3.client(\n",
|
||||||
|
" \"sagemaker-runtime\",\n",
|
||||||
|
" region_name=\"us-west-2\", \n",
|
||||||
|
" aws_access_key_id=response['Credentials']['AccessKeyId'],\n",
|
||||||
|
" aws_secret_access_key=response['Credentials']['SecretAccessKey'],\n",
|
||||||
|
" aws_session_token = response['Credentials']['SessionToken']\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"class ContentHandler(LLMContentHandler):\n",
|
||||||
|
" content_type = \"application/json\"\n",
|
||||||
|
" accepts = \"application/json\"\n",
|
||||||
|
"\n",
|
||||||
|
" def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:\n",
|
||||||
|
" input_str = json.dumps({prompt: prompt, **model_kwargs})\n",
|
||||||
|
" return input_str.encode(\"utf-8\")\n",
|
||||||
|
"\n",
|
||||||
|
" def transform_output(self, output: bytes) -> str:\n",
|
||||||
|
" response_json = json.loads(output.read().decode(\"utf-8\"))\n",
|
||||||
|
" return response_json[0][\"generated_text\"]\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"content_handler = ContentHandler()\n",
|
||||||
|
"\n",
|
||||||
|
"chain = load_qa_chain(\n",
|
||||||
|
" llm=SagemakerEndpoint(\n",
|
||||||
|
" endpoint_name=\"endpoint-name\",\n",
|
||||||
|
" client=client,\n",
|
||||||
|
" model_kwargs={\"temperature\": 1e-10},\n",
|
||||||
|
" content_handler=content_handler,\n",
|
||||||
|
" ),\n",
|
||||||
|
" prompt=PROMPT,\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"chain({\"input_documents\": docs, \"question\": query}, return_only_outputs=True)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from typing import Dict\n",
|
||||||
|
"\n",
|
||||||
|
"from langchain.prompts import PromptTemplate\n",
|
||||||
|
"from langchain.llms import SagemakerEndpoint\n",
|
||||||
"from langchain.llms.sagemaker_endpoint import LLMContentHandler\n",
|
"from langchain.llms.sagemaker_endpoint import LLMContentHandler\n",
|
||||||
"from langchain.chains.question_answering import load_qa_chain\n",
|
"from langchain.chains.question_answering import load_qa_chain\n",
|
||||||
"import json\n",
|
"import json\n",
|
||||||
|
@ -43,7 +43,7 @@
|
|||||||
"from langchain.embeddings import SagemakerEndpointEmbeddings\n",
|
"from langchain.embeddings import SagemakerEndpointEmbeddings\n",
|
||||||
"from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler\n",
|
"from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler\n",
|
||||||
"import json\n",
|
"import json\n",
|
||||||
"\n",
|
"import boto3\n",
|
||||||
"\n",
|
"\n",
|
||||||
"class ContentHandler(EmbeddingsContentHandler):\n",
|
"class ContentHandler(EmbeddingsContentHandler):\n",
|
||||||
" content_type = \"application/json\"\n",
|
" content_type = \"application/json\"\n",
|
||||||
@ -87,7 +87,18 @@
|
|||||||
" endpoint_name=\"huggingface-pytorch-inference-2023-03-21-16-14-03-834\",\n",
|
" endpoint_name=\"huggingface-pytorch-inference-2023-03-21-16-14-03-834\",\n",
|
||||||
" region_name=\"us-east-1\",\n",
|
" region_name=\"us-east-1\",\n",
|
||||||
" content_handler=content_handler,\n",
|
" content_handler=content_handler,\n",
|
||||||
")"
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"# client = boto3.client(\n",
|
||||||
|
"# \"sagemaker-runtime\",\n",
|
||||||
|
"# region_name=\"us-west-2\" \n",
|
||||||
|
"# )\n",
|
||||||
|
"# embeddings = SagemakerEndpointEmbeddings(\n",
|
||||||
|
"# endpoint_name=\"huggingface-pytorch-inference-2023-03-21-16-14-03-834\", \n",
|
||||||
|
"# client=client\n",
|
||||||
|
"# content_handler=content_handler,\n",
|
||||||
|
"# )"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -46,8 +46,18 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
|
|||||||
region_name=region_name,
|
region_name=region_name,
|
||||||
credentials_profile_name=credentials_profile_name
|
credentials_profile_name=credentials_profile_name
|
||||||
)
|
)
|
||||||
|
|
||||||
|
#Use with boto3 client
|
||||||
|
client = boto3.client(
|
||||||
|
"sagemaker-runtime",
|
||||||
|
region_name=region_name
|
||||||
|
)
|
||||||
|
se = SagemakerEndpointEmbeddings(
|
||||||
|
endpoint_name=endpoint_name,
|
||||||
|
client=client
|
||||||
|
)
|
||||||
"""
|
"""
|
||||||
client: Any #: :meta private:
|
client: Any = None
|
||||||
|
|
||||||
endpoint_name: str = ""
|
endpoint_name: str = ""
|
||||||
"""The name of the endpoint from the deployed Sagemaker model.
|
"""The name of the endpoint from the deployed Sagemaker model.
|
||||||
@ -106,6 +116,10 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
|
|||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
|
"""Dont do anything if client provided externally"""
|
||||||
|
if values.get("client") is not None:
|
||||||
|
return values
|
||||||
|
|
||||||
"""Validate that AWS credentials to and python package exists in environment."""
|
"""Validate that AWS credentials to and python package exists in environment."""
|
||||||
try:
|
try:
|
||||||
import boto3
|
import boto3
|
||||||
|
Loading…
Reference in New Issue
Block a user