mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-20 01:49:51 +00:00
cr
This commit is contained in:
parent
5cba2a1ecc
commit
f5879e73cb
@ -6,10 +6,10 @@ from langchain.llms.anthropic import Anthropic
|
|||||||
from langchain.llms.base import BaseLLM
|
from langchain.llms.base import BaseLLM
|
||||||
from langchain.llms.cohere import Cohere
|
from langchain.llms.cohere import Cohere
|
||||||
from langchain.llms.huggingface_hub import HuggingFaceHub
|
from langchain.llms.huggingface_hub import HuggingFaceHub
|
||||||
from langchain.llms.sagemaker_endpoint import SagemakerEndpoint
|
|
||||||
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
|
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
|
||||||
from langchain.llms.nlpcloud import NLPCloud
|
from langchain.llms.nlpcloud import NLPCloud
|
||||||
from langchain.llms.openai import AzureOpenAI, OpenAI
|
from langchain.llms.openai import AzureOpenAI, OpenAI
|
||||||
|
from langchain.llms.sagemaker_endpoint import SagemakerEndpoint
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Anthropic",
|
"Anthropic",
|
||||||
|
@ -1,13 +1,11 @@
|
|||||||
"""Wrapper around Sagemaker InvokeEndpoint API."""
|
"""Wrapper around Sagemaker InvokeEndpoint API."""
|
||||||
from typing import Any, Dict, List, Mapping, Optional
|
|
||||||
|
|
||||||
import boto3
|
|
||||||
import json
|
import json
|
||||||
from pydantic import BaseModel, Extra, root_validator
|
from typing import Any, List, Mapping, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Extra
|
||||||
|
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import LLM
|
||||||
from langchain.llms.utils import enforce_stop_tokens
|
from langchain.llms.utils import enforce_stop_tokens
|
||||||
from langchain.utils import get_from_dict_or_env
|
|
||||||
|
|
||||||
VALID_TASKS = ("text2text-generation", "text-generation")
|
VALID_TASKS = ("text2text-generation", "text-generation")
|
||||||
|
|
||||||
@ -15,7 +13,8 @@ VALID_TASKS = ("text2text-generation", "text-generation")
|
|||||||
class SagemakerEndpoint(LLM, BaseModel):
|
class SagemakerEndpoint(LLM, BaseModel):
|
||||||
"""Wrapper around custom Sagemaker Inference Endpoints.
|
"""Wrapper around custom Sagemaker Inference Endpoints.
|
||||||
|
|
||||||
To use, you should pass the AWS IAM Role and Role Session Name as named parameters to the constructor.
|
To use, you should pass the AWS IAM Role and Role Session Name as
|
||||||
|
named parameters to the constructor.
|
||||||
|
|
||||||
Only supports `text-generation` and `text2text-generation` for now.
|
Only supports `text-generation` and `text2text-generation` for now.
|
||||||
"""
|
"""
|
||||||
@ -36,7 +35,7 @@ class SagemakerEndpoint(LLM, BaseModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
endpoint_name: str = ""
|
endpoint_name: str = ""
|
||||||
"""# The name of the endpoint. The name must be unique within an AWS Region in your AWS account."""
|
"""# The name of the endpoint. Must be unique within an AWS Region."""
|
||||||
task: Optional[str] = None
|
task: Optional[str] = None
|
||||||
"""Task to call the model with. Should be a task that returns `generated_text`."""
|
"""Task to call the model with. Should be a task that returns `generated_text`."""
|
||||||
model_kwargs: Optional[dict] = None
|
model_kwargs: Optional[dict] = None
|
||||||
@ -79,6 +78,7 @@ class SagemakerEndpoint(LLM, BaseModel):
|
|||||||
|
|
||||||
response = se("Tell me a joke.")
|
response = se("Tell me a joke.")
|
||||||
"""
|
"""
|
||||||
|
import boto3
|
||||||
|
|
||||||
session = boto3.Session(profile_name="test-profile-name")
|
session = boto3.Session(profile_name="test-profile-name")
|
||||||
sagemaker_runtime = session.client("sagemaker-runtime", region_name="us-west-2")
|
sagemaker_runtime = session.client("sagemaker-runtime", region_name="us-west-2")
|
||||||
@ -101,30 +101,32 @@ class SagemakerEndpoint(LLM, BaseModel):
|
|||||||
|
|
||||||
# session = role_arn_to_session(RoleArn="$role-arn",
|
# session = role_arn_to_session(RoleArn="$role-arn",
|
||||||
# RoleSessionName="test-role-session-name")
|
# RoleSessionName="test-role-session-name")
|
||||||
# sagemaker_runtime = session.client("sagemaker-runtime", region_name="us-west-2")
|
# sagemaker_runtime = session.client(
|
||||||
|
# "sagemaker-runtime", region_name="us-west-2"
|
||||||
|
# )
|
||||||
|
|
||||||
_model_kwargs = self.model_kwargs or {}
|
_model_kwargs = self.model_kwargs or {}
|
||||||
|
|
||||||
# payload samples
|
# payload samples
|
||||||
parameter_payload = {"inputs": prompt, "parameters": _model_kwargs}
|
parameter_payload = {"inputs": prompt, "parameters": _model_kwargs}
|
||||||
|
|
||||||
input_en = json.dumps(parameter_payload).encode('utf-8')
|
input_en = json.dumps(parameter_payload).encode("utf-8")
|
||||||
|
|
||||||
# send request
|
# send request
|
||||||
try:
|
try:
|
||||||
response = sagemaker_runtime.invoke_endpoint(
|
response = sagemaker_runtime.invoke_endpoint(
|
||||||
EndpointName=self.endpoint_name,
|
EndpointName=self.endpoint_name,
|
||||||
Body=input_en,
|
Body=input_en,
|
||||||
ContentType='application/json'
|
ContentType="application/json",
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Error raised by inference endpoint: {e}")
|
raise ValueError(f"Error raised by inference endpoint: {e}")
|
||||||
|
|
||||||
|
response_json = json.loads(response["Body"].read().decode("utf-8"))
|
||||||
|
text = response_json[0]["generated_text"]
|
||||||
if stop is not None:
|
if stop is not None:
|
||||||
# This is a bit hacky, but I can't figure out a better way to enforce
|
# This is a bit hacky, but I can't figure out a better way to enforce
|
||||||
# stop tokens when making calls to huggingface_hub.
|
# stop tokens when making calls to huggingface_hub.
|
||||||
text = enforce_stop_tokens(text, stop)
|
text = enforce_stop_tokens(text, stop)
|
||||||
|
|
||||||
response_json = json.loads(response['Body'].read().decode('utf-8'))
|
return text
|
||||||
|
|
||||||
return response_json[0]["generated_text"]
|
|
||||||
|
Loading…
Reference in New Issue
Block a user