mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-14 05:56:40 +00:00
10634: Added the capability to inject boto3 client in SagemakerEndpointEmbeddings (#12146)
**Description: Allow to inject boto3 client for Cross account access type of scenarios in using SagemakerEndpointEmbeddings and also updated the documentation for same in the sample notebook** **Issue:SagemakerEndpointEmbeddings cross account capability #10634 #10184** Dependencies: None Tag maintainer: Twitter handle:lethargicoder Co-authored-by: Vikram(VS) <vssht@amazon.com>
This commit is contained in:
@@ -46,8 +46,18 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
|
||||
region_name=region_name,
|
||||
credentials_profile_name=credentials_profile_name
|
||||
)
|
||||
|
||||
#Use with boto3 client
|
||||
client = boto3.client(
|
||||
"sagemaker-runtime",
|
||||
region_name=region_name
|
||||
)
|
||||
se = SagemakerEndpointEmbeddings(
|
||||
endpoint_name=endpoint_name,
|
||||
client=client
|
||||
)
|
||||
"""
|
||||
client: Any #: :meta private:
|
||||
client: Any = None
|
||||
|
||||
endpoint_name: str = ""
|
||||
"""The name of the endpoint from the deployed Sagemaker model.
|
||||
@@ -106,6 +116,10 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Dont do anything if client provided externally"""
|
||||
if values.get("client") is not None:
|
||||
return values
|
||||
|
||||
"""Validate that AWS credentials to and python package exists in environment."""
|
||||
try:
|
||||
import boto3
|
||||
|
Reference in New Issue
Block a user