From 5cba2a1ecce19e1d53b7a608eb88a32f2efc9618 Mon Sep 17 00:00:00 2001 From: Nimisha Mehta <116048415+nimimeht@users.noreply.github.com> Date: Thu, 9 Feb 2023 23:22:01 -0800 Subject: [PATCH] Adding a SagemakerEndpoint class (#953) --- langchain/__init__.py | 3 +- langchain/llms/__init__.py | 3 + langchain/llms/sagemaker_endpoint.py | 130 +++++++++++++++++++++++++++ 3 files changed, 135 insertions(+), 1 deletion(-) create mode 100644 langchain/llms/sagemaker_endpoint.py diff --git a/langchain/__init__.py b/langchain/__init__.py index 3096f77a474..a0bd2ae30f2 100644 --- a/langchain/__init__.py +++ b/langchain/__init__.py @@ -22,7 +22,7 @@ from langchain.chains import ( VectorDBQAWithSourcesChain, ) from langchain.docstore import InMemoryDocstore, Wikipedia -from langchain.llms import Anthropic, Cohere, HuggingFaceHub, OpenAI +from langchain.llms import Anthropic, Cohere, HuggingFaceHub, OpenAI, SagemakerEndpoint from langchain.llms.huggingface_pipeline import HuggingFacePipeline from langchain.prompts import ( BasePromptTemplate, @@ -60,6 +60,7 @@ __all__ = [ "ReActChain", "Wikipedia", "HuggingFaceHub", + "SagemakerEndpoint", "HuggingFacePipeline", "SQLDatabase", "SQLDatabaseChain", diff --git a/langchain/llms/__init__.py b/langchain/llms/__init__.py index dac7fb67bc5..87ff1821623 100644 --- a/langchain/llms/__init__.py +++ b/langchain/llms/__init__.py @@ -6,6 +6,7 @@ from langchain.llms.anthropic import Anthropic from langchain.llms.base import BaseLLM from langchain.llms.cohere import Cohere from langchain.llms.huggingface_hub import HuggingFaceHub +from langchain.llms.sagemaker_endpoint import SagemakerEndpoint from langchain.llms.huggingface_pipeline import HuggingFacePipeline from langchain.llms.nlpcloud import NLPCloud from langchain.llms.openai import AzureOpenAI, OpenAI @@ -16,6 +17,7 @@ __all__ = [ "NLPCloud", "OpenAI", "HuggingFaceHub", + "SagemakerEndpoint", "HuggingFacePipeline", "AI21", "AzureOpenAI", @@ -26,6 +28,7 @@ type_to_cls_dict: Dict[str, Type[BaseLLM]] = { "anthropic": Anthropic, "cohere": Cohere, "huggingface_hub": HuggingFaceHub, + "sagemaker_endpoint": SagemakerEndpoint, "nlpcloud": NLPCloud, "openai": OpenAI, "huggingface_pipeline": HuggingFacePipeline, diff --git a/langchain/llms/sagemaker_endpoint.py b/langchain/llms/sagemaker_endpoint.py new file mode 100644 index 00000000000..ec4319d5cfb --- /dev/null +++ b/langchain/llms/sagemaker_endpoint.py @@ -0,0 +1,130 @@ +"""Wrapper around Sagemaker InvokeEndpoint API.""" +from typing import Any, Dict, List, Mapping, Optional + +import boto3 +import json +from pydantic import BaseModel, Extra, root_validator + +from langchain.llms.base import LLM +from langchain.llms.utils import enforce_stop_tokens +from langchain.utils import get_from_dict_or_env + +VALID_TASKS = ("text2text-generation", "text-generation") + + +class SagemakerEndpoint(LLM, BaseModel): + """Wrapper around custom Sagemaker Inference Endpoints. + + To use, you should pass the AWS IAM Role and Role Session Name as named parameters to the constructor. + + Only supports `text-generation` and `text2text-generation` for now. + """ + + """ + Example: + .. code-block:: python + + from langchain import SagemakerEndpoint + endpoint_name = ( + "https://runtime.sagemaker.us-west-2.amazonaws.com/endpoints/abcdefghijklmnop/invocations" + ) + se = SagemakerEndpoint( + endpoint_name=endpoint_name, + role_arn="role_arn", + role_session_name="role_session_name" + ) + """ + + endpoint_name: str = "" + """# The name of the endpoint. The name must be unique within an AWS Region in your AWS account.""" + task: Optional[str] = None + """Task to call the model with. Should be a task that returns `generated_text`.""" + model_kwargs: Optional[dict] = None + """Key word arguments to pass to the model.""" + + role_arn: Optional[str] = None + role_session_name: Optional[str] = None + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + @property + def _identifying_params(self) -> Mapping[str, Any]: + """Get the identifying parameters.""" + _model_kwargs = self.model_kwargs or {} + return { + **{"endpoint_name": self.endpoint_name, "task": self.task}, + **{"model_kwargs": _model_kwargs}, + } + + @property + def _llm_type(self) -> str: + """Return type of llm.""" + return "sagemaker_endpoint" + + def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + """Call out to Sagemaker inference endpoint. + + Args: + prompt: The prompt to pass into the model. + stop: Optional list of stop words to use when generating. + + Returns: + The string generated by the model. + + Example: + .. code-block:: python + + response = se("Tell me a joke.") + """ + + session = boto3.Session(profile_name="test-profile-name") + sagemaker_runtime = session.client("sagemaker-runtime", region_name="us-west-2") + + # TODO: use AWS IAM assumed roles to authenticate from the EC2 instance + # def role_arn_to_session(**args): + # """ + # Usage : + # session = role_arn_to_session( + # RoleArn='arn:aws:iam::012345678901:role/example-role', + # RoleSessionName='ExampleSessionName') + # client = session.client('sqs') + # """ + # client = boto3.client('sts') + # response = client.assume_role(**args) + # return boto3.Session( + # aws_access_key_id=response['Credentials']['AccessKeyId'], + # aws_secret_access_key=response['Credentials']['SecretAccessKey'], + # aws_session_token=response['Credentials']['SessionToken']) + + # session = role_arn_to_session(RoleArn="$role-arn", + # RoleSessionName="test-role-session-name") + # sagemaker_runtime = session.client("sagemaker-runtime", region_name="us-west-2") + + _model_kwargs = self.model_kwargs or {} + + # payload samples + parameter_payload = {"inputs": prompt, "parameters": _model_kwargs} + + input_en = json.dumps(parameter_payload).encode('utf-8') + + # send request + try: + response = sagemaker_runtime.invoke_endpoint( + EndpointName=self.endpoint_name, + Body=input_en, + ContentType='application/json' + ) + except Exception as e: + raise ValueError(f"Error raised by inference endpoint: {e}") + + if stop is not None: + # This is a bit hacky, but I can't figure out a better way to enforce + # stop tokens when making calls to huggingface_hub. + text = enforce_stop_tokens(text, stop) + + response_json = json.loads(response['Body'].read().decode('utf-8')) + + return response_json[0]["generated_text"]