mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-17 16:39:52 +00:00
Updates to Sagemaker Endpoint (#1217)
This commit is contained in:
parent
f5879e73cb
commit
1732103e19
183
docs/modules/document_loaders/examples/sagemaker.ipynb
Normal file
183
docs/modules/document_loaders/examples/sagemaker.ipynb
Normal file
@ -0,0 +1,183 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Defaulting to user installation because normal site-packages is not writeable\n",
|
||||||
|
"Collecting langchain\n",
|
||||||
|
" Downloading langchain-0.0.80-py3-none-any.whl (222 kB)\n",
|
||||||
|
"\u001b[K |████████████████████████████████| 222 kB 2.1 MB/s eta 0:00:01\n",
|
||||||
|
"\u001b[?25hRequirement already satisfied: numpy<2,>=1 in /Users/nmehta/Library/Python/3.9/lib/python/site-packages (from langchain) (1.24.1)\n",
|
||||||
|
"Requirement already satisfied: aiohttp<4.0.0,>=3.8.3 in /Users/nmehta/Library/Python/3.9/lib/python/site-packages (from langchain) (3.8.3)\n",
|
||||||
|
"Collecting pydantic<2,>=1\n",
|
||||||
|
" Downloading pydantic-1.10.4-cp39-cp39-macosx_11_0_arm64.whl (2.6 MB)\n",
|
||||||
|
"\u001b[K |████████████████████████████████| 2.6 MB 3.3 MB/s eta 0:00:01\n",
|
||||||
|
"\u001b[?25hCollecting SQLAlchemy<2,>=1\n",
|
||||||
|
" Downloading SQLAlchemy-1.4.46.tar.gz (8.5 MB)\n",
|
||||||
|
"\u001b[K |████████████████████████████████| 8.5 MB 23.4 MB/s eta 0:00:01\n",
|
||||||
|
"\u001b[?25hCollecting tenacity<9.0.0,>=8.1.0\n",
|
||||||
|
" Downloading tenacity-8.2.0-py3-none-any.whl (24 kB)\n",
|
||||||
|
"Requirement already satisfied: requests<3,>=2 in /Users/nmehta/Library/Python/3.9/lib/python/site-packages (from langchain) (2.28.2)\n",
|
||||||
|
"Requirement already satisfied: PyYAML<7,>=6 in /Users/nmehta/Library/Python/3.9/lib/python/site-packages (from langchain) (6.0)\n",
|
||||||
|
"Collecting dataclasses-json<0.6.0,>=0.5.7\n",
|
||||||
|
" Downloading dataclasses_json-0.5.7-py3-none-any.whl (25 kB)\n",
|
||||||
|
"Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /Users/nmehta/Library/Python/3.9/lib/python/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (4.0.2)\n",
|
||||||
|
"Requirement already satisfied: multidict<7.0,>=4.5 in /Users/nmehta/Library/Python/3.9/lib/python/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (6.0.4)\n",
|
||||||
|
"Requirement already satisfied: attrs>=17.3.0 in /Users/nmehta/Library/Python/3.9/lib/python/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (22.2.0)\n",
|
||||||
|
"Requirement already satisfied: frozenlist>=1.1.1 in /Users/nmehta/Library/Python/3.9/lib/python/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (1.3.3)\n",
|
||||||
|
"Requirement already satisfied: yarl<2.0,>=1.0 in /Users/nmehta/Library/Python/3.9/lib/python/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (1.8.2)\n",
|
||||||
|
"Requirement already satisfied: aiosignal>=1.1.2 in /Users/nmehta/Library/Python/3.9/lib/python/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (1.3.1)\n",
|
||||||
|
"Requirement already satisfied: charset-normalizer<3.0,>=2.0 in /Users/nmehta/Library/Python/3.9/lib/python/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (2.1.1)\n",
|
||||||
|
"Collecting marshmallow<4.0.0,>=3.3.0\n",
|
||||||
|
" Downloading marshmallow-3.19.0-py3-none-any.whl (49 kB)\n",
|
||||||
|
"\u001b[K |████████████████████████████████| 49 kB 26.9 MB/s eta 0:00:01\n",
|
||||||
|
"\u001b[?25hCollecting marshmallow-enum<2.0.0,>=1.5.1\n",
|
||||||
|
" Downloading marshmallow_enum-1.5.1-py2.py3-none-any.whl (4.2 kB)\n",
|
||||||
|
"Collecting typing-inspect>=0.4.0\n",
|
||||||
|
" Downloading typing_inspect-0.8.0-py3-none-any.whl (8.7 kB)\n",
|
||||||
|
"Requirement already satisfied: packaging>=17.0 in /Users/nmehta/Library/Python/3.9/lib/python/site-packages (from marshmallow<4.0.0,>=3.3.0->dataclasses-json<0.6.0,>=0.5.7->langchain) (23.0)\n",
|
||||||
|
"Requirement already satisfied: typing-extensions>=4.2.0 in /Users/nmehta/Library/Python/3.9/lib/python/site-packages (from pydantic<2,>=1->langchain) (4.4.0)\n",
|
||||||
|
"Requirement already satisfied: idna<4,>=2.5 in /Users/nmehta/Library/Python/3.9/lib/python/site-packages (from requests<3,>=2->langchain) (3.4)\n",
|
||||||
|
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /Users/nmehta/Library/Python/3.9/lib/python/site-packages (from requests<3,>=2->langchain) (1.26.14)\n",
|
||||||
|
"Requirement already satisfied: certifi>=2017.4.17 in /Users/nmehta/Library/Python/3.9/lib/python/site-packages (from requests<3,>=2->langchain) (2022.12.7)\n",
|
||||||
|
"Collecting mypy-extensions>=0.3.0\n",
|
||||||
|
" Downloading mypy_extensions-1.0.0-py3-none-any.whl (4.7 kB)\n",
|
||||||
|
"Building wheels for collected packages: SQLAlchemy\n",
|
||||||
|
" Building wheel for SQLAlchemy (setup.py) ... \u001b[?25ldone\n",
|
||||||
|
"\u001b[?25h Created wheel for SQLAlchemy: filename=SQLAlchemy-1.4.46-cp39-cp39-macosx_10_9_universal2.whl size=1578667 sha256=9991d70fde083b993d7fe1fd61fca33a279e921f1b8296b02037e24b8cac1097\n",
|
||||||
|
" Stored in directory: /Users/nmehta/Library/Caches/pip/wheels/3c/99/65/57cf5a0ec6e7f3b803a68d31694501e168997e03e80adc903d\n",
|
||||||
|
"Successfully built SQLAlchemy\n",
|
||||||
|
"Installing collected packages: mypy-extensions, marshmallow, typing-inspect, marshmallow-enum, tenacity, SQLAlchemy, pydantic, dataclasses-json, langchain\n",
|
||||||
|
"\u001b[33m WARNING: The script langchain-server is installed in '/Users/nmehta/Library/Python/3.9/bin' which is not on PATH.\n",
|
||||||
|
" Consider adding this directory to PATH or, if you prefer to suppress this warning, use --no-warn-script-location.\u001b[0m\n",
|
||||||
|
"Successfully installed SQLAlchemy-1.4.46 dataclasses-json-0.5.7 langchain-0.0.80 marshmallow-3.19.0 marshmallow-enum-1.5.1 mypy-extensions-1.0.0 pydantic-1.10.4 tenacity-8.2.0 typing-inspect-0.8.0\n",
|
||||||
|
"\u001b[33mWARNING: You are using pip version 21.2.4; however, version 23.0 is available.\n",
|
||||||
|
"You should consider upgrading via the '/Library/Developer/CommandLineTools/usr/bin/python3 -m pip install --upgrade pip' command.\u001b[0m\n",
|
||||||
|
"Defaulting to user installation because normal site-packages is not writeable\n",
|
||||||
|
"Collecting html2text\n",
|
||||||
|
" Downloading html2text-2020.1.16-py3-none-any.whl (32 kB)\n",
|
||||||
|
"Installing collected packages: html2text\n",
|
||||||
|
"\u001b[33m WARNING: The script html2text is installed in '/Users/nmehta/Library/Python/3.9/bin' which is not on PATH.\n",
|
||||||
|
" Consider adding this directory to PATH or, if you prefer to suppress this warning, use --no-warn-script-location.\u001b[0m\n",
|
||||||
|
"Successfully installed html2text-2020.1.16\n",
|
||||||
|
"\u001b[33mWARNING: You are using pip version 21.2.4; however, version 23.0 is available.\n",
|
||||||
|
"You should consider upgrading via the '/Library/Developer/CommandLineTools/usr/bin/python3 -m pip install --upgrade pip' command.\u001b[0m\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"!pip3 install langchain"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.docstore.document import Document"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"example_doc_1 = \"\"\"\n",
|
||||||
|
"Peter and Elizabeth took a taxi to attend the night party in the city. While in the party, Elizabeth collapsed and was rushed to the hospital.\n",
|
||||||
|
"Since she was diagnosed with a brain injury, the doctor told Peter to stay besides her until she gets well.\n",
|
||||||
|
"Therefore, Peter stayed with her at the hospital for 3 days without leaving.\n",
|
||||||
|
"\"\"\"\n",
|
||||||
|
"\n",
|
||||||
|
"docs = [\n",
|
||||||
|
" Document(\n",
|
||||||
|
" page_content=example_doc_1,\n",
|
||||||
|
" )\n",
|
||||||
|
"]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"{'output_text': '3 days'}"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from langchain import PromptTemplate, HuggingFaceHub, LLMChain, SagemakerEndpoint\n",
|
||||||
|
"from langchain.chains.question_answering import load_qa_chain\n",
|
||||||
|
"import json\n",
|
||||||
|
"\n",
|
||||||
|
"query = \"\"\"How long was Elizabeth hospitalized?\n",
|
||||||
|
"\"\"\"\n",
|
||||||
|
"\n",
|
||||||
|
"prompt_template = \"\"\"Use the following pieces of context to answer the question at the end.\n",
|
||||||
|
"\n",
|
||||||
|
"{context}\n",
|
||||||
|
"\n",
|
||||||
|
"Question: {question}\n",
|
||||||
|
"Answer:\"\"\"\n",
|
||||||
|
"PROMPT = PromptTemplate(\n",
|
||||||
|
" template=prompt_template, input_variables=[\"context\", \"question\"]\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"def model_input_transform_fn(prompt, model_kwargs):\n",
|
||||||
|
" parameter_payload = {\"inputs\": prompt, \"parameters\": model_kwargs}\n",
|
||||||
|
" return json.dumps(parameter_payload).encode(\"utf-8\") \n",
|
||||||
|
"\n",
|
||||||
|
"chain = load_qa_chain(llm=SagemakerEndpoint(\n",
|
||||||
|
" endpoint_name=\"my-sagemaker-model-endpoint\", \n",
|
||||||
|
" credentials_profile_name=\"credentials-profile-name\", \n",
|
||||||
|
" region_name=\"us-west-2\", \n",
|
||||||
|
" model_kwargs={\"temperature\":1e-10},\n",
|
||||||
|
" content_type=\"application/json\", \n",
|
||||||
|
" model_input_transform_fn=model_input_transform_fn), \n",
|
||||||
|
" prompt=PROMPT) \n",
|
||||||
|
"\n",
|
||||||
|
"chain({\"input_documents\": docs, \"question\": query}, return_only_outputs=True)\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.9.6"
|
||||||
|
},
|
||||||
|
"vscode": {
|
||||||
|
"interpreter": {
|
||||||
|
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
@ -1,22 +1,29 @@
|
|||||||
"""Wrapper around Sagemaker InvokeEndpoint API."""
|
"""Wrapper around Sagemaker InvokeEndpoint API."""
|
||||||
import json
|
import json
|
||||||
from typing import Any, List, Mapping, Optional
|
from typing import Any, Callable, Dict, List, Mapping, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Extra
|
from pydantic import BaseModel, Extra, root_validator
|
||||||
|
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import LLM
|
||||||
from langchain.llms.utils import enforce_stop_tokens
|
from langchain.llms.utils import enforce_stop_tokens
|
||||||
|
|
||||||
VALID_TASKS = ("text2text-generation", "text-generation")
|
|
||||||
|
|
||||||
|
|
||||||
class SagemakerEndpoint(LLM, BaseModel):
|
class SagemakerEndpoint(LLM, BaseModel):
|
||||||
"""Wrapper around custom Sagemaker Inference Endpoints.
|
"""Wrapper around custom Sagemaker Inference Endpoints.
|
||||||
|
|
||||||
To use, you should pass the AWS IAM Role and Role Session Name as
|
To use, you must supply the endpoint name from your deployed
|
||||||
named parameters to the constructor.
|
Sagemaker model & the region where it is deployed.
|
||||||
|
|
||||||
Only supports `text-generation` and `text2text-generation` for now.
|
To authenticate, the AWS client uses the following methods to
|
||||||
|
automatically load credentials:
|
||||||
|
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
|
||||||
|
|
||||||
|
If a specific credential profile should be used, you must pass
|
||||||
|
the name of the profile from the ~/.aws/credentials file that is to be used.
|
||||||
|
|
||||||
|
Make sure the credentials / roles used have the required policies to
|
||||||
|
access the Sagemaker endpoint.
|
||||||
|
See: https://docs.aws.amazon.com/IAM/latest/UserGuide/access_policies.html
|
||||||
"""
|
"""
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -27,34 +34,104 @@ class SagemakerEndpoint(LLM, BaseModel):
|
|||||||
endpoint_name = (
|
endpoint_name = (
|
||||||
"https://runtime.sagemaker.us-west-2.amazonaws.com/endpoints/abcdefghijklmnop/invocations"
|
"https://runtime.sagemaker.us-west-2.amazonaws.com/endpoints/abcdefghijklmnop/invocations"
|
||||||
)
|
)
|
||||||
|
region_name = (
|
||||||
|
"us-west-2"
|
||||||
|
)
|
||||||
|
credentials_profile_name = (
|
||||||
|
"default"
|
||||||
|
)
|
||||||
se = SagemakerEndpoint(
|
se = SagemakerEndpoint(
|
||||||
endpoint_name=endpoint_name,
|
endpoint_name=endpoint_name,
|
||||||
role_arn="role_arn",
|
region_name=region_name,
|
||||||
role_session_name="role_session_name"
|
credentials_profile_name=credentials_profile_name
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
|
client: Any #: :meta private:
|
||||||
|
|
||||||
endpoint_name: str = ""
|
endpoint_name: str = ""
|
||||||
"""# The name of the endpoint. Must be unique within an AWS Region."""
|
"""The name of the endpoint from the deployed Sagemaker model.
|
||||||
task: Optional[str] = None
|
Must be unique within an AWS Region."""
|
||||||
"""Task to call the model with. Should be a task that returns `generated_text`."""
|
|
||||||
model_kwargs: Optional[dict] = None
|
|
||||||
"""Key word arguments to pass to the model."""
|
|
||||||
|
|
||||||
role_arn: Optional[str] = None
|
region_name: str = ""
|
||||||
role_session_name: Optional[str] = None
|
"""The aws region where the Sagemaker model is deployed, eg. `us-west-2`."""
|
||||||
|
|
||||||
|
credentials_profile_name: Optional[str] = None
|
||||||
|
"""The name of the profile in the ~/.aws/credentials or ~/.aws/config files, which
|
||||||
|
has either access keys or role information specified.
|
||||||
|
If not specified, the default credential profile or, if on an EC2 instance,
|
||||||
|
credentials from IMDS will be used.
|
||||||
|
See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
|
||||||
|
"""
|
||||||
|
|
||||||
|
content_type: Optional[str] = "application/json"
|
||||||
|
"""The MIME type of the input data in the request body to be used in the header
|
||||||
|
for the request to the Sagemaker invoke_endpoint API.
|
||||||
|
Defaults to "application/json"."""
|
||||||
|
|
||||||
|
model_input_transform_fn: Callable[[str, Dict], bytes]
|
||||||
|
"""
|
||||||
|
Function which takes the prompt (str) and model_kwargs (dict) and transforms
|
||||||
|
the input to the format which the model can accept as the request Body.
|
||||||
|
Should return bytes or seekable file-like object in the format specified in the
|
||||||
|
content_type request header.
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
def model_input_transform_fn(prompt, model_kwargs):
|
||||||
|
parameter_payload = {"inputs": prompt, "parameters": model_kwargs}
|
||||||
|
return json.dumps(parameter_payload).encode("utf-8")
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_kwargs: Optional[Dict] = None
|
||||||
|
"""Key word arguments to pass to the model."""
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
extra = Extra.forbid
|
extra = Extra.forbid
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
|
"""Validate that AWS credentials to and python package exists in environment."""
|
||||||
|
try:
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
try:
|
||||||
|
if values["credentials_profile_name"] is not None:
|
||||||
|
session = boto3.Session(
|
||||||
|
profile_name=values["credentials_profile_name"]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# use default credentials
|
||||||
|
session = boto3.Session()
|
||||||
|
|
||||||
|
values["client"] = session.client(
|
||||||
|
"sagemaker-runtime", region_name=values["region_name"]
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not load credentials to authenticate with AWS client. "
|
||||||
|
"Please check that credentials in the specified "
|
||||||
|
"profile name are valid."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import boto3 python package. "
|
||||||
|
"Please it install it with `pip install boto3`."
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _identifying_params(self) -> Mapping[str, Any]:
|
def _identifying_params(self) -> Mapping[str, Any]:
|
||||||
"""Get the identifying parameters."""
|
"""Get the identifying parameters."""
|
||||||
_model_kwargs = self.model_kwargs or {}
|
_model_kwargs = self.model_kwargs or {}
|
||||||
return {
|
return {
|
||||||
**{"endpoint_name": self.endpoint_name, "task": self.task},
|
**{"endpoint_name": self.endpoint_name},
|
||||||
**{"model_kwargs": _model_kwargs},
|
**{"model_kwargs": _model_kwargs},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -78,46 +155,16 @@ class SagemakerEndpoint(LLM, BaseModel):
|
|||||||
|
|
||||||
response = se("Tell me a joke.")
|
response = se("Tell me a joke.")
|
||||||
"""
|
"""
|
||||||
import boto3
|
|
||||||
|
|
||||||
session = boto3.Session(profile_name="test-profile-name")
|
|
||||||
sagemaker_runtime = session.client("sagemaker-runtime", region_name="us-west-2")
|
|
||||||
|
|
||||||
# TODO: use AWS IAM assumed roles to authenticate from the EC2 instance
|
|
||||||
# def role_arn_to_session(**args):
|
|
||||||
# """
|
|
||||||
# Usage :
|
|
||||||
# session = role_arn_to_session(
|
|
||||||
# RoleArn='arn:aws:iam::012345678901:role/example-role',
|
|
||||||
# RoleSessionName='ExampleSessionName')
|
|
||||||
# client = session.client('sqs')
|
|
||||||
# """
|
|
||||||
# client = boto3.client('sts')
|
|
||||||
# response = client.assume_role(**args)
|
|
||||||
# return boto3.Session(
|
|
||||||
# aws_access_key_id=response['Credentials']['AccessKeyId'],
|
|
||||||
# aws_secret_access_key=response['Credentials']['SecretAccessKey'],
|
|
||||||
# aws_session_token=response['Credentials']['SessionToken'])
|
|
||||||
|
|
||||||
# session = role_arn_to_session(RoleArn="$role-arn",
|
|
||||||
# RoleSessionName="test-role-session-name")
|
|
||||||
# sagemaker_runtime = session.client(
|
|
||||||
# "sagemaker-runtime", region_name="us-west-2"
|
|
||||||
# )
|
|
||||||
|
|
||||||
_model_kwargs = self.model_kwargs or {}
|
_model_kwargs = self.model_kwargs or {}
|
||||||
|
if self.model_input_transform_fn is None:
|
||||||
# payload samples
|
raise NotImplementedError("model_input_transform_fn not implemented")
|
||||||
parameter_payload = {"inputs": prompt, "parameters": _model_kwargs}
|
|
||||||
|
|
||||||
input_en = json.dumps(parameter_payload).encode("utf-8")
|
|
||||||
|
|
||||||
# send request
|
# send request
|
||||||
try:
|
try:
|
||||||
response = sagemaker_runtime.invoke_endpoint(
|
response = self.client.invoke_endpoint(
|
||||||
EndpointName=self.endpoint_name,
|
EndpointName=self.endpoint_name,
|
||||||
Body=input_en,
|
Body=self.model_input_transform_fn(prompt, _model_kwargs),
|
||||||
ContentType="application/json",
|
ContentType=self.content_type,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Error raised by inference endpoint: {e}")
|
raise ValueError(f"Error raised by inference endpoint: {e}")
|
||||||
@ -126,7 +173,7 @@ class SagemakerEndpoint(LLM, BaseModel):
|
|||||||
text = response_json[0]["generated_text"]
|
text = response_json[0]["generated_text"]
|
||||||
if stop is not None:
|
if stop is not None:
|
||||||
# This is a bit hacky, but I can't figure out a better way to enforce
|
# This is a bit hacky, but I can't figure out a better way to enforce
|
||||||
# stop tokens when making calls to huggingface_hub.
|
# stop tokens when making calls to the sagemaker endpoint.
|
||||||
text = enforce_stop_tokens(text, stop)
|
text = enforce_stop_tokens(text, stop)
|
||||||
|
|
||||||
return text
|
return text
|
||||||
|
Loading…
Reference in New Issue
Block a user