From f5879e73cbe16a7b075310c3e7183f1eabcd9a07 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Thu, 9 Feb 2023 23:28:46 -0800 Subject: [PATCH] cr --- langchain/llms/__init__.py | 2 +- langchain/llms/sagemaker_endpoint.py | 42 +++++++++++++++------------- 2 files changed, 23 insertions(+), 21 deletions(-) diff --git a/langchain/llms/__init__.py b/langchain/llms/__init__.py index 87ff1821623..825e8460535 100644 --- a/langchain/llms/__init__.py +++ b/langchain/llms/__init__.py @@ -6,10 +6,10 @@ from langchain.llms.anthropic import Anthropic from langchain.llms.base import BaseLLM from langchain.llms.cohere import Cohere from langchain.llms.huggingface_hub import HuggingFaceHub -from langchain.llms.sagemaker_endpoint import SagemakerEndpoint from langchain.llms.huggingface_pipeline import HuggingFacePipeline from langchain.llms.nlpcloud import NLPCloud from langchain.llms.openai import AzureOpenAI, OpenAI +from langchain.llms.sagemaker_endpoint import SagemakerEndpoint __all__ = [ "Anthropic", diff --git a/langchain/llms/sagemaker_endpoint.py b/langchain/llms/sagemaker_endpoint.py index ec4319d5cfb..739bf923f6d 100644 --- a/langchain/llms/sagemaker_endpoint.py +++ b/langchain/llms/sagemaker_endpoint.py @@ -1,13 +1,11 @@ """Wrapper around Sagemaker InvokeEndpoint API.""" -from typing import Any, Dict, List, Mapping, Optional - -import boto3 import json -from pydantic import BaseModel, Extra, root_validator +from typing import Any, List, Mapping, Optional + +from pydantic import BaseModel, Extra from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens -from langchain.utils import get_from_dict_or_env VALID_TASKS = ("text2text-generation", "text-generation") @@ -15,7 +13,8 @@ VALID_TASKS = ("text2text-generation", "text-generation") class SagemakerEndpoint(LLM, BaseModel): """Wrapper around custom Sagemaker Inference Endpoints. - To use, you should pass the AWS IAM Role and Role Session Name as named parameters to the constructor. + To use, you should pass the AWS IAM Role and Role Session Name as + named parameters to the constructor. Only supports `text-generation` and `text2text-generation` for now. """ @@ -36,7 +35,7 @@ class SagemakerEndpoint(LLM, BaseModel): """ endpoint_name: str = "" - """# The name of the endpoint. The name must be unique within an AWS Region in your AWS account.""" + """# The name of the endpoint. Must be unique within an AWS Region.""" task: Optional[str] = None """Task to call the model with. Should be a task that returns `generated_text`.""" model_kwargs: Optional[dict] = None @@ -79,7 +78,8 @@ class SagemakerEndpoint(LLM, BaseModel): response = se("Tell me a joke.") """ - + import boto3 + session = boto3.Session(profile_name="test-profile-name") sagemaker_runtime = session.client("sagemaker-runtime", region_name="us-west-2") @@ -100,31 +100,33 @@ class SagemakerEndpoint(LLM, BaseModel): # aws_session_token=response['Credentials']['SessionToken']) # session = role_arn_to_session(RoleArn="$role-arn", - # RoleSessionName="test-role-session-name") - # sagemaker_runtime = session.client("sagemaker-runtime", region_name="us-west-2") + # RoleSessionName="test-role-session-name") + # sagemaker_runtime = session.client( + # "sagemaker-runtime", region_name="us-west-2" + # ) _model_kwargs = self.model_kwargs or {} # payload samples parameter_payload = {"inputs": prompt, "parameters": _model_kwargs} - input_en = json.dumps(parameter_payload).encode('utf-8') + input_en = json.dumps(parameter_payload).encode("utf-8") # send request try: response = sagemaker_runtime.invoke_endpoint( - EndpointName=self.endpoint_name, - Body=input_en, - ContentType='application/json' - ) - except Exception as e: + EndpointName=self.endpoint_name, + Body=input_en, + ContentType="application/json", + ) + except Exception as e: raise ValueError(f"Error raised by inference endpoint: {e}") - + + response_json = json.loads(response["Body"].read().decode("utf-8")) + text = response_json[0]["generated_text"] if stop is not None: # This is a bit hacky, but I can't figure out a better way to enforce # stop tokens when making calls to huggingface_hub. text = enforce_stop_tokens(text, stop) - response_json = json.loads(response['Body'].read().decode('utf-8')) - - return response_json[0]["generated_text"] + return text