This commit is contained in:
Harrison Chase 2023-02-09 23:28:46 -08:00
parent 5cba2a1ecc
commit f5879e73cb
2 changed files with 23 additions and 21 deletions

View File

@ -6,10 +6,10 @@ from langchain.llms.anthropic import Anthropic
from langchain.llms.base import BaseLLM from langchain.llms.base import BaseLLM
from langchain.llms.cohere import Cohere from langchain.llms.cohere import Cohere
from langchain.llms.huggingface_hub import HuggingFaceHub from langchain.llms.huggingface_hub import HuggingFaceHub
from langchain.llms.sagemaker_endpoint import SagemakerEndpoint
from langchain.llms.huggingface_pipeline import HuggingFacePipeline from langchain.llms.huggingface_pipeline import HuggingFacePipeline
from langchain.llms.nlpcloud import NLPCloud from langchain.llms.nlpcloud import NLPCloud
from langchain.llms.openai import AzureOpenAI, OpenAI from langchain.llms.openai import AzureOpenAI, OpenAI
from langchain.llms.sagemaker_endpoint import SagemakerEndpoint
__all__ = [ __all__ = [
"Anthropic", "Anthropic",

View File

@ -1,13 +1,11 @@
"""Wrapper around Sagemaker InvokeEndpoint API.""" """Wrapper around Sagemaker InvokeEndpoint API."""
from typing import Any, Dict, List, Mapping, Optional
import boto3
import json import json
from pydantic import BaseModel, Extra, root_validator from typing import Any, List, Mapping, Optional
from pydantic import BaseModel, Extra
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env
VALID_TASKS = ("text2text-generation", "text-generation") VALID_TASKS = ("text2text-generation", "text-generation")
@ -15,7 +13,8 @@ VALID_TASKS = ("text2text-generation", "text-generation")
class SagemakerEndpoint(LLM, BaseModel): class SagemakerEndpoint(LLM, BaseModel):
"""Wrapper around custom Sagemaker Inference Endpoints. """Wrapper around custom Sagemaker Inference Endpoints.
To use, you should pass the AWS IAM Role and Role Session Name as named parameters to the constructor. To use, you should pass the AWS IAM Role and Role Session Name as
named parameters to the constructor.
Only supports `text-generation` and `text2text-generation` for now. Only supports `text-generation` and `text2text-generation` for now.
""" """
@ -36,7 +35,7 @@ class SagemakerEndpoint(LLM, BaseModel):
""" """
endpoint_name: str = "" endpoint_name: str = ""
"""# The name of the endpoint. The name must be unique within an AWS Region in your AWS account.""" """# The name of the endpoint. Must be unique within an AWS Region."""
task: Optional[str] = None task: Optional[str] = None
"""Task to call the model with. Should be a task that returns `generated_text`.""" """Task to call the model with. Should be a task that returns `generated_text`."""
model_kwargs: Optional[dict] = None model_kwargs: Optional[dict] = None
@ -79,6 +78,7 @@ class SagemakerEndpoint(LLM, BaseModel):
response = se("Tell me a joke.") response = se("Tell me a joke.")
""" """
import boto3
session = boto3.Session(profile_name="test-profile-name") session = boto3.Session(profile_name="test-profile-name")
sagemaker_runtime = session.client("sagemaker-runtime", region_name="us-west-2") sagemaker_runtime = session.client("sagemaker-runtime", region_name="us-west-2")
@ -101,30 +101,32 @@ class SagemakerEndpoint(LLM, BaseModel):
# session = role_arn_to_session(RoleArn="$role-arn", # session = role_arn_to_session(RoleArn="$role-arn",
# RoleSessionName="test-role-session-name") # RoleSessionName="test-role-session-name")
# sagemaker_runtime = session.client("sagemaker-runtime", region_name="us-west-2") # sagemaker_runtime = session.client(
# "sagemaker-runtime", region_name="us-west-2"
# )
_model_kwargs = self.model_kwargs or {} _model_kwargs = self.model_kwargs or {}
# payload samples # payload samples
parameter_payload = {"inputs": prompt, "parameters": _model_kwargs} parameter_payload = {"inputs": prompt, "parameters": _model_kwargs}
input_en = json.dumps(parameter_payload).encode('utf-8') input_en = json.dumps(parameter_payload).encode("utf-8")
# send request # send request
try: try:
response = sagemaker_runtime.invoke_endpoint( response = sagemaker_runtime.invoke_endpoint(
EndpointName=self.endpoint_name, EndpointName=self.endpoint_name,
Body=input_en, Body=input_en,
ContentType='application/json' ContentType="application/json",
) )
except Exception as e: except Exception as e:
raise ValueError(f"Error raised by inference endpoint: {e}") raise ValueError(f"Error raised by inference endpoint: {e}")
response_json = json.loads(response["Body"].read().decode("utf-8"))
text = response_json[0]["generated_text"]
if stop is not None: if stop is not None:
# This is a bit hacky, but I can't figure out a better way to enforce # This is a bit hacky, but I can't figure out a better way to enforce
# stop tokens when making calls to huggingface_hub. # stop tokens when making calls to huggingface_hub.
text = enforce_stop_tokens(text, stop) text = enforce_stop_tokens(text, stop)
response_json = json.loads(response['Body'].read().decode('utf-8')) return text
return response_json[0]["generated_text"]