From 1732103e19660af7c3b5972d4d2d8f169a1453ed Mon Sep 17 00:00:00 2001 From: Nimisha Mehta <116048415+nimimeht@users.noreply.github.com> Date: Tue, 21 Feb 2023 17:02:04 -0800 Subject: [PATCH] Updates to Sagemaker Endpoint (#1217) --- .../document_loaders/examples/sagemaker.ipynb | 183 ++++++++++++++++++ langchain/llms/sagemaker_endpoint.py | 153 ++++++++++----- 2 files changed, 283 insertions(+), 53 deletions(-) create mode 100644 docs/modules/document_loaders/examples/sagemaker.ipynb diff --git a/docs/modules/document_loaders/examples/sagemaker.ipynb b/docs/modules/document_loaders/examples/sagemaker.ipynb new file mode 100644 index 00000000000..b75cb1f3bb1 --- /dev/null +++ b/docs/modules/document_loaders/examples/sagemaker.ipynb @@ -0,0 +1,183 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Defaulting to user installation because normal site-packages is not writeable\n", + "Collecting langchain\n", + " Downloading langchain-0.0.80-py3-none-any.whl (222 kB)\n", + "\u001b[K |████████████████████████████████| 222 kB 2.1 MB/s eta 0:00:01\n", + "\u001b[?25hRequirement already satisfied: numpy<2,>=1 in /Users/nmehta/Library/Python/3.9/lib/python/site-packages (from langchain) (1.24.1)\n", + "Requirement already satisfied: aiohttp<4.0.0,>=3.8.3 in /Users/nmehta/Library/Python/3.9/lib/python/site-packages (from langchain) (3.8.3)\n", + "Collecting pydantic<2,>=1\n", + " Downloading pydantic-1.10.4-cp39-cp39-macosx_11_0_arm64.whl (2.6 MB)\n", + "\u001b[K |████████████████████████████████| 2.6 MB 3.3 MB/s eta 0:00:01\n", + "\u001b[?25hCollecting SQLAlchemy<2,>=1\n", + " Downloading SQLAlchemy-1.4.46.tar.gz (8.5 MB)\n", + "\u001b[K |████████████████████████████████| 8.5 MB 23.4 MB/s eta 0:00:01\n", + "\u001b[?25hCollecting tenacity<9.0.0,>=8.1.0\n", + " Downloading tenacity-8.2.0-py3-none-any.whl (24 kB)\n", + "Requirement already satisfied: requests<3,>=2 in /Users/nmehta/Library/Python/3.9/lib/python/site-packages (from langchain) (2.28.2)\n", + "Requirement already satisfied: PyYAML<7,>=6 in /Users/nmehta/Library/Python/3.9/lib/python/site-packages (from langchain) (6.0)\n", + "Collecting dataclasses-json<0.6.0,>=0.5.7\n", + " Downloading dataclasses_json-0.5.7-py3-none-any.whl (25 kB)\n", + "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /Users/nmehta/Library/Python/3.9/lib/python/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (4.0.2)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /Users/nmehta/Library/Python/3.9/lib/python/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (6.0.4)\n", + "Requirement already satisfied: attrs>=17.3.0 in /Users/nmehta/Library/Python/3.9/lib/python/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (22.2.0)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /Users/nmehta/Library/Python/3.9/lib/python/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (1.3.3)\n", + "Requirement already satisfied: yarl<2.0,>=1.0 in /Users/nmehta/Library/Python/3.9/lib/python/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (1.8.2)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /Users/nmehta/Library/Python/3.9/lib/python/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (1.3.1)\n", + "Requirement already satisfied: charset-normalizer<3.0,>=2.0 in /Users/nmehta/Library/Python/3.9/lib/python/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (2.1.1)\n", + "Collecting marshmallow<4.0.0,>=3.3.0\n", + " Downloading marshmallow-3.19.0-py3-none-any.whl (49 kB)\n", + "\u001b[K |████████████████████████████████| 49 kB 26.9 MB/s eta 0:00:01\n", + "\u001b[?25hCollecting marshmallow-enum<2.0.0,>=1.5.1\n", + " Downloading marshmallow_enum-1.5.1-py2.py3-none-any.whl (4.2 kB)\n", + "Collecting typing-inspect>=0.4.0\n", + " Downloading typing_inspect-0.8.0-py3-none-any.whl (8.7 kB)\n", + "Requirement already satisfied: packaging>=17.0 in /Users/nmehta/Library/Python/3.9/lib/python/site-packages (from marshmallow<4.0.0,>=3.3.0->dataclasses-json<0.6.0,>=0.5.7->langchain) (23.0)\n", + "Requirement already satisfied: typing-extensions>=4.2.0 in /Users/nmehta/Library/Python/3.9/lib/python/site-packages (from pydantic<2,>=1->langchain) (4.4.0)\n", + "Requirement already satisfied: idna<4,>=2.5 in /Users/nmehta/Library/Python/3.9/lib/python/site-packages (from requests<3,>=2->langchain) (3.4)\n", + "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /Users/nmehta/Library/Python/3.9/lib/python/site-packages (from requests<3,>=2->langchain) (1.26.14)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /Users/nmehta/Library/Python/3.9/lib/python/site-packages (from requests<3,>=2->langchain) (2022.12.7)\n", + "Collecting mypy-extensions>=0.3.0\n", + " Downloading mypy_extensions-1.0.0-py3-none-any.whl (4.7 kB)\n", + "Building wheels for collected packages: SQLAlchemy\n", + " Building wheel for SQLAlchemy (setup.py) ... \u001b[?25ldone\n", + "\u001b[?25h Created wheel for SQLAlchemy: filename=SQLAlchemy-1.4.46-cp39-cp39-macosx_10_9_universal2.whl size=1578667 sha256=9991d70fde083b993d7fe1fd61fca33a279e921f1b8296b02037e24b8cac1097\n", + " Stored in directory: /Users/nmehta/Library/Caches/pip/wheels/3c/99/65/57cf5a0ec6e7f3b803a68d31694501e168997e03e80adc903d\n", + "Successfully built SQLAlchemy\n", + "Installing collected packages: mypy-extensions, marshmallow, typing-inspect, marshmallow-enum, tenacity, SQLAlchemy, pydantic, dataclasses-json, langchain\n", + "\u001b[33m WARNING: The script langchain-server is installed in '/Users/nmehta/Library/Python/3.9/bin' which is not on PATH.\n", + " Consider adding this directory to PATH or, if you prefer to suppress this warning, use --no-warn-script-location.\u001b[0m\n", + "Successfully installed SQLAlchemy-1.4.46 dataclasses-json-0.5.7 langchain-0.0.80 marshmallow-3.19.0 marshmallow-enum-1.5.1 mypy-extensions-1.0.0 pydantic-1.10.4 tenacity-8.2.0 typing-inspect-0.8.0\n", + "\u001b[33mWARNING: You are using pip version 21.2.4; however, version 23.0 is available.\n", + "You should consider upgrading via the '/Library/Developer/CommandLineTools/usr/bin/python3 -m pip install --upgrade pip' command.\u001b[0m\n", + "Defaulting to user installation because normal site-packages is not writeable\n", + "Collecting html2text\n", + " Downloading html2text-2020.1.16-py3-none-any.whl (32 kB)\n", + "Installing collected packages: html2text\n", + "\u001b[33m WARNING: The script html2text is installed in '/Users/nmehta/Library/Python/3.9/bin' which is not on PATH.\n", + " Consider adding this directory to PATH or, if you prefer to suppress this warning, use --no-warn-script-location.\u001b[0m\n", + "Successfully installed html2text-2020.1.16\n", + "\u001b[33mWARNING: You are using pip version 21.2.4; however, version 23.0 is available.\n", + "You should consider upgrading via the '/Library/Developer/CommandLineTools/usr/bin/python3 -m pip install --upgrade pip' command.\u001b[0m\n" + ] + } + ], + "source": [ + "!pip3 install langchain" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.docstore.document import Document" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "example_doc_1 = \"\"\"\n", + "Peter and Elizabeth took a taxi to attend the night party in the city. While in the party, Elizabeth collapsed and was rushed to the hospital.\n", + "Since she was diagnosed with a brain injury, the doctor told Peter to stay besides her until she gets well.\n", + "Therefore, Peter stayed with her at the hospital for 3 days without leaving.\n", + "\"\"\"\n", + "\n", + "docs = [\n", + " Document(\n", + " page_content=example_doc_1,\n", + " )\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'output_text': '3 days'}" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from langchain import PromptTemplate, HuggingFaceHub, LLMChain, SagemakerEndpoint\n", + "from langchain.chains.question_answering import load_qa_chain\n", + "import json\n", + "\n", + "query = \"\"\"How long was Elizabeth hospitalized?\n", + "\"\"\"\n", + "\n", + "prompt_template = \"\"\"Use the following pieces of context to answer the question at the end.\n", + "\n", + "{context}\n", + "\n", + "Question: {question}\n", + "Answer:\"\"\"\n", + "PROMPT = PromptTemplate(\n", + " template=prompt_template, input_variables=[\"context\", \"question\"]\n", + ")\n", + "\n", + "def model_input_transform_fn(prompt, model_kwargs):\n", + " parameter_payload = {\"inputs\": prompt, \"parameters\": model_kwargs}\n", + " return json.dumps(parameter_payload).encode(\"utf-8\") \n", + "\n", + "chain = load_qa_chain(llm=SagemakerEndpoint(\n", + " endpoint_name=\"my-sagemaker-model-endpoint\", \n", + " credentials_profile_name=\"credentials-profile-name\", \n", + " region_name=\"us-west-2\", \n", + " model_kwargs={\"temperature\":1e-10},\n", + " content_type=\"application/json\", \n", + " model_input_transform_fn=model_input_transform_fn), \n", + " prompt=PROMPT) \n", + "\n", + "chain({\"input_documents\": docs, \"question\": query}, return_only_outputs=True)\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.6" + }, + "vscode": { + "interpreter": { + "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/langchain/llms/sagemaker_endpoint.py b/langchain/llms/sagemaker_endpoint.py index 739bf923f6d..344f46f4dec 100644 --- a/langchain/llms/sagemaker_endpoint.py +++ b/langchain/llms/sagemaker_endpoint.py @@ -1,22 +1,29 @@ """Wrapper around Sagemaker InvokeEndpoint API.""" import json -from typing import Any, List, Mapping, Optional +from typing import Any, Callable, Dict, List, Mapping, Optional -from pydantic import BaseModel, Extra +from pydantic import BaseModel, Extra, root_validator from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens -VALID_TASKS = ("text2text-generation", "text-generation") - class SagemakerEndpoint(LLM, BaseModel): """Wrapper around custom Sagemaker Inference Endpoints. - To use, you should pass the AWS IAM Role and Role Session Name as - named parameters to the constructor. + To use, you must supply the endpoint name from your deployed + Sagemaker model & the region where it is deployed. - Only supports `text-generation` and `text2text-generation` for now. + To authenticate, the AWS client uses the following methods to + automatically load credentials: + https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html + + If a specific credential profile should be used, you must pass + the name of the profile from the ~/.aws/credentials file that is to be used. + + Make sure the credentials / roles used have the required policies to + access the Sagemaker endpoint. + See: https://docs.aws.amazon.com/IAM/latest/UserGuide/access_policies.html """ """ @@ -27,34 +34,104 @@ class SagemakerEndpoint(LLM, BaseModel): endpoint_name = ( "https://runtime.sagemaker.us-west-2.amazonaws.com/endpoints/abcdefghijklmnop/invocations" ) + region_name = ( + "us-west-2" + ) + credentials_profile_name = ( + "default" + ) se = SagemakerEndpoint( endpoint_name=endpoint_name, - role_arn="role_arn", - role_session_name="role_session_name" + region_name=region_name, + credentials_profile_name=credentials_profile_name ) """ + client: Any #: :meta private: endpoint_name: str = "" - """# The name of the endpoint. Must be unique within an AWS Region.""" - task: Optional[str] = None - """Task to call the model with. Should be a task that returns `generated_text`.""" - model_kwargs: Optional[dict] = None - """Key word arguments to pass to the model.""" + """The name of the endpoint from the deployed Sagemaker model. + Must be unique within an AWS Region.""" - role_arn: Optional[str] = None - role_session_name: Optional[str] = None + region_name: str = "" + """The aws region where the Sagemaker model is deployed, eg. `us-west-2`.""" + + credentials_profile_name: Optional[str] = None + """The name of the profile in the ~/.aws/credentials or ~/.aws/config files, which + has either access keys or role information specified. + If not specified, the default credential profile or, if on an EC2 instance, + credentials from IMDS will be used. + See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html + """ + + content_type: Optional[str] = "application/json" + """The MIME type of the input data in the request body to be used in the header + for the request to the Sagemaker invoke_endpoint API. + Defaults to "application/json".""" + + model_input_transform_fn: Callable[[str, Dict], bytes] + """ + Function which takes the prompt (str) and model_kwargs (dict) and transforms + the input to the format which the model can accept as the request Body. + Should return bytes or seekable file-like object in the format specified in the + content_type request header. + """ + + """ + Example: + .. code-block:: python + + def model_input_transform_fn(prompt, model_kwargs): + parameter_payload = {"inputs": prompt, "parameters": model_kwargs} + return json.dumps(parameter_payload).encode("utf-8") + """ + + model_kwargs: Optional[Dict] = None + """Key word arguments to pass to the model.""" class Config: """Configuration for this pydantic object.""" extra = Extra.forbid + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that AWS credentials to and python package exists in environment.""" + try: + import boto3 + + try: + if values["credentials_profile_name"] is not None: + session = boto3.Session( + profile_name=values["credentials_profile_name"] + ) + else: + # use default credentials + session = boto3.Session() + + values["client"] = session.client( + "sagemaker-runtime", region_name=values["region_name"] + ) + + except Exception as e: + raise ValueError( + "Could not load credentials to authenticate with AWS client. " + "Please check that credentials in the specified " + "profile name are valid." + ) from e + + except ImportError: + raise ValueError( + "Could not import boto3 python package. " + "Please it install it with `pip install boto3`." + ) + return values + @property def _identifying_params(self) -> Mapping[str, Any]: """Get the identifying parameters.""" _model_kwargs = self.model_kwargs or {} return { - **{"endpoint_name": self.endpoint_name, "task": self.task}, + **{"endpoint_name": self.endpoint_name}, **{"model_kwargs": _model_kwargs}, } @@ -78,46 +155,16 @@ class SagemakerEndpoint(LLM, BaseModel): response = se("Tell me a joke.") """ - import boto3 - - session = boto3.Session(profile_name="test-profile-name") - sagemaker_runtime = session.client("sagemaker-runtime", region_name="us-west-2") - - # TODO: use AWS IAM assumed roles to authenticate from the EC2 instance - # def role_arn_to_session(**args): - # """ - # Usage : - # session = role_arn_to_session( - # RoleArn='arn:aws:iam::012345678901:role/example-role', - # RoleSessionName='ExampleSessionName') - # client = session.client('sqs') - # """ - # client = boto3.client('sts') - # response = client.assume_role(**args) - # return boto3.Session( - # aws_access_key_id=response['Credentials']['AccessKeyId'], - # aws_secret_access_key=response['Credentials']['SecretAccessKey'], - # aws_session_token=response['Credentials']['SessionToken']) - - # session = role_arn_to_session(RoleArn="$role-arn", - # RoleSessionName="test-role-session-name") - # sagemaker_runtime = session.client( - # "sagemaker-runtime", region_name="us-west-2" - # ) - _model_kwargs = self.model_kwargs or {} - - # payload samples - parameter_payload = {"inputs": prompt, "parameters": _model_kwargs} - - input_en = json.dumps(parameter_payload).encode("utf-8") + if self.model_input_transform_fn is None: + raise NotImplementedError("model_input_transform_fn not implemented") # send request try: - response = sagemaker_runtime.invoke_endpoint( + response = self.client.invoke_endpoint( EndpointName=self.endpoint_name, - Body=input_en, - ContentType="application/json", + Body=self.model_input_transform_fn(prompt, _model_kwargs), + ContentType=self.content_type, ) except Exception as e: raise ValueError(f"Error raised by inference endpoint: {e}") @@ -126,7 +173,7 @@ class SagemakerEndpoint(LLM, BaseModel): text = response_json[0]["generated_text"] if stop is not None: # This is a bit hacky, but I can't figure out a better way to enforce - # stop tokens when making calls to huggingface_hub. + # stop tokens when making calls to the sagemaker endpoint. text = enforce_stop_tokens(text, stop) return text