This commit is contained in:
Harrison Chase 2023-02-09 23:28:46 -08:00
parent 5cba2a1ecc
commit f5879e73cb
2 changed files with 23 additions and 21 deletions

View File

@ -6,10 +6,10 @@ from langchain.llms.anthropic import Anthropic
from langchain.llms.base import BaseLLM
from langchain.llms.cohere import Cohere
from langchain.llms.huggingface_hub import HuggingFaceHub
from langchain.llms.sagemaker_endpoint import SagemakerEndpoint
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
from langchain.llms.nlpcloud import NLPCloud
from langchain.llms.openai import AzureOpenAI, OpenAI
from langchain.llms.sagemaker_endpoint import SagemakerEndpoint
__all__ = [
"Anthropic",

View File

@ -1,13 +1,11 @@
"""Wrapper around Sagemaker InvokeEndpoint API."""
from typing import Any, Dict, List, Mapping, Optional
import boto3
import json
from pydantic import BaseModel, Extra, root_validator
from typing import Any, List, Mapping, Optional
from pydantic import BaseModel, Extra
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env
VALID_TASKS = ("text2text-generation", "text-generation")
@ -15,7 +13,8 @@ VALID_TASKS = ("text2text-generation", "text-generation")
class SagemakerEndpoint(LLM, BaseModel):
"""Wrapper around custom Sagemaker Inference Endpoints.
To use, you should pass the AWS IAM Role and Role Session Name as named parameters to the constructor.
To use, you should pass the AWS IAM Role and Role Session Name as
named parameters to the constructor.
Only supports `text-generation` and `text2text-generation` for now.
"""
@ -36,7 +35,7 @@ class SagemakerEndpoint(LLM, BaseModel):
"""
endpoint_name: str = ""
"""# The name of the endpoint. The name must be unique within an AWS Region in your AWS account."""
"""# The name of the endpoint. Must be unique within an AWS Region."""
task: Optional[str] = None
"""Task to call the model with. Should be a task that returns `generated_text`."""
model_kwargs: Optional[dict] = None
@ -79,7 +78,8 @@ class SagemakerEndpoint(LLM, BaseModel):
response = se("Tell me a joke.")
"""
import boto3
session = boto3.Session(profile_name="test-profile-name")
sagemaker_runtime = session.client("sagemaker-runtime", region_name="us-west-2")
@ -100,31 +100,33 @@ class SagemakerEndpoint(LLM, BaseModel):
# aws_session_token=response['Credentials']['SessionToken'])
# session = role_arn_to_session(RoleArn="$role-arn",
# RoleSessionName="test-role-session-name")
# sagemaker_runtime = session.client("sagemaker-runtime", region_name="us-west-2")
# RoleSessionName="test-role-session-name")
# sagemaker_runtime = session.client(
# "sagemaker-runtime", region_name="us-west-2"
# )
_model_kwargs = self.model_kwargs or {}
# payload samples
parameter_payload = {"inputs": prompt, "parameters": _model_kwargs}
input_en = json.dumps(parameter_payload).encode('utf-8')
input_en = json.dumps(parameter_payload).encode("utf-8")
# send request
try:
response = sagemaker_runtime.invoke_endpoint(
EndpointName=self.endpoint_name,
Body=input_en,
ContentType='application/json'
)
except Exception as e:
EndpointName=self.endpoint_name,
Body=input_en,
ContentType="application/json",
)
except Exception as e:
raise ValueError(f"Error raised by inference endpoint: {e}")
response_json = json.loads(response["Body"].read().decode("utf-8"))
text = response_json[0]["generated_text"]
if stop is not None:
# This is a bit hacky, but I can't figure out a better way to enforce
# stop tokens when making calls to huggingface_hub.
text = enforce_stop_tokens(text, stop)
response_json = json.loads(response['Body'].read().decode('utf-8'))
return response_json[0]["generated_text"]
return text