mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-17 16:39:52 +00:00
cr
This commit is contained in:
parent
5cba2a1ecc
commit
f5879e73cb
@ -6,10 +6,10 @@ from langchain.llms.anthropic import Anthropic
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.llms.cohere import Cohere
|
||||
from langchain.llms.huggingface_hub import HuggingFaceHub
|
||||
from langchain.llms.sagemaker_endpoint import SagemakerEndpoint
|
||||
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
|
||||
from langchain.llms.nlpcloud import NLPCloud
|
||||
from langchain.llms.openai import AzureOpenAI, OpenAI
|
||||
from langchain.llms.sagemaker_endpoint import SagemakerEndpoint
|
||||
|
||||
__all__ = [
|
||||
"Anthropic",
|
||||
|
@ -1,13 +1,11 @@
|
||||
"""Wrapper around Sagemaker InvokeEndpoint API."""
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
|
||||
import boto3
|
||||
import json
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
from typing import Any, List, Mapping, Optional
|
||||
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
VALID_TASKS = ("text2text-generation", "text-generation")
|
||||
|
||||
@ -15,7 +13,8 @@ VALID_TASKS = ("text2text-generation", "text-generation")
|
||||
class SagemakerEndpoint(LLM, BaseModel):
|
||||
"""Wrapper around custom Sagemaker Inference Endpoints.
|
||||
|
||||
To use, you should pass the AWS IAM Role and Role Session Name as named parameters to the constructor.
|
||||
To use, you should pass the AWS IAM Role and Role Session Name as
|
||||
named parameters to the constructor.
|
||||
|
||||
Only supports `text-generation` and `text2text-generation` for now.
|
||||
"""
|
||||
@ -36,7 +35,7 @@ class SagemakerEndpoint(LLM, BaseModel):
|
||||
"""
|
||||
|
||||
endpoint_name: str = ""
|
||||
"""# The name of the endpoint. The name must be unique within an AWS Region in your AWS account."""
|
||||
"""# The name of the endpoint. Must be unique within an AWS Region."""
|
||||
task: Optional[str] = None
|
||||
"""Task to call the model with. Should be a task that returns `generated_text`."""
|
||||
model_kwargs: Optional[dict] = None
|
||||
@ -79,7 +78,8 @@ class SagemakerEndpoint(LLM, BaseModel):
|
||||
|
||||
response = se("Tell me a joke.")
|
||||
"""
|
||||
|
||||
import boto3
|
||||
|
||||
session = boto3.Session(profile_name="test-profile-name")
|
||||
sagemaker_runtime = session.client("sagemaker-runtime", region_name="us-west-2")
|
||||
|
||||
@ -100,31 +100,33 @@ class SagemakerEndpoint(LLM, BaseModel):
|
||||
# aws_session_token=response['Credentials']['SessionToken'])
|
||||
|
||||
# session = role_arn_to_session(RoleArn="$role-arn",
|
||||
# RoleSessionName="test-role-session-name")
|
||||
# sagemaker_runtime = session.client("sagemaker-runtime", region_name="us-west-2")
|
||||
# RoleSessionName="test-role-session-name")
|
||||
# sagemaker_runtime = session.client(
|
||||
# "sagemaker-runtime", region_name="us-west-2"
|
||||
# )
|
||||
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
|
||||
# payload samples
|
||||
parameter_payload = {"inputs": prompt, "parameters": _model_kwargs}
|
||||
|
||||
input_en = json.dumps(parameter_payload).encode('utf-8')
|
||||
input_en = json.dumps(parameter_payload).encode("utf-8")
|
||||
|
||||
# send request
|
||||
try:
|
||||
response = sagemaker_runtime.invoke_endpoint(
|
||||
EndpointName=self.endpoint_name,
|
||||
Body=input_en,
|
||||
ContentType='application/json'
|
||||
)
|
||||
except Exception as e:
|
||||
EndpointName=self.endpoint_name,
|
||||
Body=input_en,
|
||||
ContentType="application/json",
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error raised by inference endpoint: {e}")
|
||||
|
||||
|
||||
response_json = json.loads(response["Body"].read().decode("utf-8"))
|
||||
text = response_json[0]["generated_text"]
|
||||
if stop is not None:
|
||||
# This is a bit hacky, but I can't figure out a better way to enforce
|
||||
# stop tokens when making calls to huggingface_hub.
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
|
||||
response_json = json.loads(response['Body'].read().decode('utf-8'))
|
||||
|
||||
return response_json[0]["generated_text"]
|
||||
return text
|
||||
|
Loading…
Reference in New Issue
Block a user