feat: add support for Amazon Bedrock Inference Profiles (#1492)

Signed-off-by: rkarthikr <38294804+rkarthikr@users.noreply.github.com>
Co-authored-by: Alex Jones <alexsimonjones@gmail.com>
This commit is contained in:
rkarthikr
2025-05-06 06:18:40 -04:00
committed by GitHub
parent d5341f3c00
commit 21bc76e5b7
7 changed files with 457 additions and 40 deletions

View File

@@ -466,6 +466,22 @@ k8sgpt auth default -p azureopenai
Default provider set to azureopenai
```
_Using Amazon Bedrock with inference profiles_
_System Inference Profile_
```
k8sgpt auth add --backend amazonbedrock --providerRegion us-east-1 --model arn:aws:bedrock:us-east-1:123456789012:inference-profile/my-inference-profile
```
_Application Inference Profile_
```
k8sgpt auth add --backend amazonbedrock --providerRegion us-east-1 --model arn:aws:bedrock:us-east-1:123456789012:application-inference-profile/2uzp4s0w39t6
```
## Key Features
<details>

20
go.mod
View File

@@ -36,7 +36,7 @@ require (
github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.5.0
github.com/IBM/watsonx-go v1.0.1
github.com/agiledragon/gomonkey/v2 v2.13.0
github.com/aws/aws-sdk-go v1.55.6
github.com/aws/aws-sdk-go v1.55.7
github.com/cohere-ai/cohere-go/v2 v2.12.2
github.com/go-logr/zapr v1.3.0
github.com/google/generative-ai-go v0.19.0
@@ -78,8 +78,22 @@ require (
github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.48.1 // indirect
github.com/Microsoft/hcsshim v0.12.4 // indirect
github.com/alecthomas/units v0.0.0-20240927000941-0f3dac36c52b // indirect
github.com/aws/aws-sdk-go-v2 v1.32.3 // indirect
github.com/aws/smithy-go v1.22.0 // indirect
github.com/aws/aws-sdk-go-v2 v1.36.3 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10 // indirect
github.com/aws/aws-sdk-go-v2/config v1.29.14 // indirect
github.com/aws/aws-sdk-go-v2/credentials v1.17.67 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34 // indirect
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 // indirect
github.com/aws/aws-sdk-go-v2/service/bedrock v1.33.0 // indirect
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.30.0 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15 // indirect
github.com/aws/aws-sdk-go-v2/service/sso v1.25.3 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 // indirect
github.com/aws/smithy-go v1.22.2 // indirect
github.com/bahlo/generic-list-go v0.2.0 // indirect
github.com/blang/semver/v4 v4.0.0 // indirect
github.com/buger/jsonparser v1.1.1 // indirect

36
go.sum
View File

@@ -735,12 +735,44 @@ github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkY
github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 h1:DklsrG3dyBCFEj5IhUbnKptjxatkF07cF2ak3yi77so=
github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw=
github.com/atomicgo/cursor v0.0.1/go.mod h1:cBON2QmmrysudxNBFthvMtN32r3jxVRIvzkUiF/RuIk=
github.com/aws/aws-sdk-go v1.55.6 h1:cSg4pvZ3m8dgYcgqB97MrcdjUmZ1BeMYKUxMMB89IPk=
github.com/aws/aws-sdk-go v1.55.6/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU=
github.com/aws/aws-sdk-go v1.55.7 h1:UJrkFq7es5CShfBwlWAC8DA077vp8PyVbQd3lqLiztE=
github.com/aws/aws-sdk-go v1.55.7/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU=
github.com/aws/aws-sdk-go-v2 v1.32.3 h1:T0dRlFBKcdaUPGNtkBSwHZxrtis8CQU17UpNBZYd0wk=
github.com/aws/aws-sdk-go-v2 v1.32.3/go.mod h1:2SK5n0a2karNTv5tbP1SjsX0uhttou00v/HpXKM1ZUo=
github.com/aws/aws-sdk-go-v2 v1.36.3 h1:mJoei2CxPutQVxaATCzDUjcZEjVRdpsiiXi2o38yqWM=
github.com/aws/aws-sdk-go-v2 v1.36.3/go.mod h1:LLXuLpgzEbD766Z5ECcRmi8AzSwfZItDtmABVkRLGzg=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10 h1:zAybnyUQXIZ5mok5Jqwlf58/TFE7uvd3IAsa1aF9cXs=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10/go.mod h1:qqvMj6gHLR/EXWZw4ZbqlPbQUyenf4h82UQUlKc+l14=
github.com/aws/aws-sdk-go-v2/config v1.29.14 h1:f+eEi/2cKCg9pqKBoAIwRGzVb70MRKqWX4dg1BDcSJM=
github.com/aws/aws-sdk-go-v2/config v1.29.14/go.mod h1:wVPHWcIFv3WO89w0rE10gzf17ZYy+UVS1Geq8Iei34g=
github.com/aws/aws-sdk-go-v2/credentials v1.17.67 h1:9KxtdcIA/5xPNQyZRgUSpYOE6j9Bc4+D7nZua0KGYOM=
github.com/aws/aws-sdk-go-v2/credentials v1.17.67/go.mod h1:p3C44m+cfnbv763s52gCqrjaqyPikj9Sg47kUVaNZQQ=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 h1:x793wxmUWVDhshP8WW2mlnXuFrO4cOd3HLBroh1paFw=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30/go.mod h1:Jpne2tDnYiFascUEs2AWHJL9Yp7A5ZVy3TNyxaAjD6M=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34 h1:ZK5jHhnrioRkUNOc+hOgQKlUL5JeC3S6JgLxtQ+Rm0Q=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34/go.mod h1:p4VfIceZokChbA9FzMbRGz5OV+lekcVtHlPKEO0gSZY=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34 h1:SZwFm17ZUNNg5Np0ioo/gq8Mn6u9w19Mri8DnJ15Jf0=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34/go.mod h1:dFZsC0BLo346mvKQLWmoJxT+Sjp+qcVR1tRVHQGOH9Q=
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 h1:bIqFDwgGXXN1Kpp99pDOdKMTTb5d2KyU5X/BZxjOkRo=
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3/go.mod h1:H5O/EsxDWyU+LP/V8i5sm8cxoZgc2fdNR9bxlOFrQTo=
github.com/aws/aws-sdk-go-v2/service/bedrock v1.33.0 h1:2P70khV5KDzoRs8UuplU3rAzzyLaj5kzND33Jutwpbg=
github.com/aws/aws-sdk-go-v2/service/bedrock v1.33.0/go.mod h1:rZOgAxQVRg9v5ZEQHrrKw0Gkb9DBAASeeRiwUmmXcG0=
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.30.0 h1:eMOwQ8ZZK+76+08RfxeaGUtRFN6wxmD1rvqovc2kq2w=
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.30.0/go.mod h1:0b5Rq7rUvSQFYHI1UO0zFTV/S6j6DUyuykXA80C+YOI=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3 h1:eAh2A4b5IzM/lum78bZ590jy36+d/aFLgKF/4Vd1xPE=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3/go.mod h1:0yKJC/kb8sAnmlYa6Zs3QVYqaC8ug2AbnNChv5Ox3uA=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15 h1:dM9/92u2F1JbDaGooxTq18wmmFzbJRfXfVfy96/1CXM=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15/go.mod h1:SwFBy2vjtA0vZbjjaFtfN045boopadnoVPhu4Fv66vY=
github.com/aws/aws-sdk-go-v2/service/sso v1.25.3 h1:1Gw+9ajCV1jogloEv1RRnvfRFia2cL6c9cuKV2Ps+G8=
github.com/aws/aws-sdk-go-v2/service/sso v1.25.3/go.mod h1:qs4a9T5EMLl/Cajiw2TcbNt2UNo/Hqlyp+GiuG4CFDI=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1 h1:hXmVKytPfTy5axZ+fYbR5d0cFmC3JvwLm5kM83luako=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1/go.mod h1:MlYRNmYu/fGPoxBQVvBYr9nyr948aY/WLUvwBMBJubs=
github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 h1:1XuUZ8mYJw9B6lzAkXhqHlJd/XvaX32evhproijJEZY=
github.com/aws/aws-sdk-go-v2/service/sts v1.33.19/go.mod h1:cQnB8CUnxbMU82JvlqjKR2HBOm3fe9pWorWBza6MBJ4=
github.com/aws/smithy-go v1.22.0 h1:uunKnWlcoL3zO7q+gG2Pk53joueEOsnNB28QdMsmiMM=
github.com/aws/smithy-go v1.22.0/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg=
github.com/aws/smithy-go v1.22.2 h1:6D9hW43xKFrRx/tXXfAlIZc4JI+yQe6snnWcQyxSyLQ=
github.com/aws/smithy-go v1.22.2/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg=
github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk=
github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg=
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=

View File

@@ -8,22 +8,22 @@ import (
"regexp"
"strings"
"github.com/aws/aws-sdk-go/service/bedrockruntime/bedrockruntimeiface"
"github.com/k8sgpt-ai/k8sgpt/pkg/ai/bedrock_support"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/bedrockruntime"
"github.com/aws/aws-sdk-go-v2/aws"
awsconfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/bedrock"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
)
const amazonbedrockAIClientName = "amazonbedrock"
// AmazonBedRockClient represents the client for interacting with the AmazonCompletion Bedrock service.
// AmazonBedRockClient represents the client for interacting with the Amazon Bedrock service.
type AmazonBedRockClient struct {
nopCloser
client bedrockruntimeiface.BedrockRuntimeAPI
client BedrockRuntimeAPI
mgmtClient BedrockManagementAPI
model *bedrock_support.BedrockModel
temperature float32
topP float32
@@ -59,9 +59,33 @@ var BEDROCKER_SUPPORTED_REGION = []string{
var defaultModels = []bedrock_support.BedrockModel{
{
Name: "anthropic.claude-3-5-sonnet-20240620-v1:0",
Name: "us.anthropic.claude-3-7-sonnet-20250219-v1:0",
Completion: &bedrock_support.CohereMessagesCompletion{},
Response: &bedrock_support.CohereMessagesResponse{},
Config: bedrock_support.BedrockModelConfig{
// sensible defaults
MaxTokens: 100,
Temperature: 0.5,
TopP: 0.9,
ModelName: "us.anthropic.claude-3-7-sonnet-20250219-v1:0",
},
},
{
Name: "eu.anthropic.claude-3-7-sonnet-20250219-v1:0",
Completion: &bedrock_support.CohereMessagesCompletion{},
Response: &bedrock_support.CohereMessagesResponse{},
Config: bedrock_support.BedrockModelConfig{
// sensible defaults
MaxTokens: 100,
Temperature: 0.5,
TopP: 0.9,
ModelName: "eu.anthropic.claude-3-7-sonnet-20250219-v1:0",
},
},
{
Name: "anthropic.claude-3-5-sonnet-20240620-v1:0",
Completion: &bedrock_support.CohereCompletion{},
Response: &bedrock_support.CohereResponse{},
Config: bedrock_support.BedrockModelConfig{
// sensible defaults
MaxTokens: 100,
@@ -254,7 +278,6 @@ func NewAmazonBedRockClient(models []bedrock_support.BedrockModel) *AmazonBedRoc
// GetModelOrDefault check config region
func GetRegionOrDefault(region string) string {
if os.Getenv("AWS_DEFAULT_REGION") != "" {
region = os.Getenv("AWS_DEFAULT_REGION")
}
@@ -269,6 +292,17 @@ func GetRegionOrDefault(region string) string {
return BEDROCK_DEFAULT_REGION
}
func validateModelArn(model string) bool {
var re = regexp.MustCompile(`(?m)^arn:(?P<Partition>[^:\n]*):bedrock:(?P<Region>[^:\n]*):(?P<AccountID>[^:\n]*):(?P<Ignore>(?P<ResourceType>[^:\/\n]*)[:\/])?(?P<Resource>.*)$`)
return re.MatchString(model)
}
func validateInferenceProfileArn(inferenceProfile string) bool {
// Support both inference-profile and application-inference-profile formats
var re = regexp.MustCompile(`(?m)^arn:(?P<Partition>[^:\n]*):bedrock:(?P<Region>[^:\n]*):(?P<AccountID>[^:\n]*):(?:inference-profile|application-inference-profile)\/(?P<ProfileName>.+)$`)
return re.MatchString(inferenceProfile)
}
// Get model from string
func (a *AmazonBedRockClient) getModelFromString(model string) (*bedrock_support.BedrockModel, error) {
if model == "" {
@@ -310,11 +344,6 @@ func (a *AmazonBedRockClient) getModelFromString(model string) (*bedrock_support
return nil, fmt.Errorf("model '%s' not found in supported models", model)
}
func validateModelArn(model string) bool {
var re = regexp.MustCompile(`(?m)^arn:(?P<Partition>[^:\n]*):bedrock:(?P<Region>[^:\n]*):(?P<AccountID>[^:\n]*):(?P<Ignore>(?P<ResourceType>[^:\/\n]*)[:\/])?(?P<Resource>.*)$`)
return re.MatchString(model)
}
// Configure configures the AmazonBedRockClient with the provided configuration.
func (a *AmazonBedRockClient) Configure(config IAIConfig) error {
// Initialize models if not already initialized
@@ -322,26 +351,77 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error {
a.models = defaultModels
}
// Create a new AWS session
providerRegion := GetRegionOrDefault(config.GetProviderRegion())
sess, err := session.NewSession(&aws.Config{
Region: aws.String(providerRegion),
})
if err != nil {
return err
// Get the model input
modelInput := config.GetModel()
// Determine the appropriate region to use
var region string
// Check if the model input is actually an inference profile ARN
if validateInferenceProfileArn(modelInput) {
// Extract the region from the inference profile ARN
arnParts := strings.Split(modelInput, ":")
if len(arnParts) >= 4 {
region = arnParts[3]
} else {
return fmt.Errorf("could not extract region from inference profile ARN: %s", modelInput)
}
} else {
// Use the provided region or default
region = GetRegionOrDefault(config.GetProviderRegion())
}
// Only create AWS clients if they haven't been injected (for testing)
if a.client == nil || a.mgmtClient == nil {
// Create a new AWS config with the determined region
cfg, err := awsconfig.LoadDefaultConfig(context.Background(),
awsconfig.WithRegion(region),
)
if err != nil {
return fmt.Errorf("failed to load AWS config for region %s: %w", region, err)
}
foundModel, err := a.getModelFromString(config.GetModel())
if err != nil {
return err
// Create clients with the config
a.client = bedrockruntime.NewFromConfig(cfg)
a.mgmtClient = bedrock.NewFromConfig(cfg)
}
// Create a new BedrockRuntime client
a.client = bedrockruntime.New(sess)
a.model = foundModel
a.model.Config.ModelName = foundModel.Config.ModelName
// Handle model selection based on input type
if validateInferenceProfileArn(modelInput) {
// Get the inference profile details
profile, err := a.getInferenceProfile(context.Background(), modelInput)
if err != nil {
// Instead of using a fallback model, throw an error
return fmt.Errorf("failed to get inference profile: %v", err)
} else {
// Extract the model ID from the inference profile
modelID, err := a.extractModelFromInferenceProfile(profile)
if err != nil {
return fmt.Errorf("failed to extract model ID from inference profile: %v", err)
}
// Find the model configuration for the extracted model ID
foundModel, err := a.getModelFromString(modelID)
if err != nil {
// Instead of using a fallback model, throw an error
return fmt.Errorf("failed to find model configuration for %s: %v", modelID, err)
}
a.model = foundModel
// Use the inference profile ARN as the model ID for API calls
a.model.Config.ModelName = modelInput
}
} else {
// Regular model ID provided
foundModel, err := a.getModelFromString(modelInput)
if err != nil {
return err
}
a.model = foundModel
a.model.Config.ModelName = foundModel.Config.ModelName
}
// Set common configuration parameters
a.temperature = config.GetTemperature()
a.topP = config.GetTopP()
a.maxTokens = config.GetMaxTokens()
@@ -349,9 +429,62 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error {
return nil
}
// getInferenceProfile retrieves the inference profile details from Amazon Bedrock
func (a *AmazonBedRockClient) getInferenceProfile(ctx context.Context, inferenceProfileARN string) (*bedrock.GetInferenceProfileOutput, error) {
// Extract the profile ID from the ARN
// ARN format: arn:aws:bedrock:region:account-id:inference-profile/profile-id
// or arn:aws:bedrock:region:account-id:application-inference-profile/profile-id
parts := strings.Split(inferenceProfileARN, "/")
if len(parts) != 2 {
return nil, fmt.Errorf("invalid inference profile ARN format: %s", inferenceProfileARN)
}
profileID := parts[1]
// Create the input for the GetInferenceProfile API call
input := &bedrock.GetInferenceProfileInput{
InferenceProfileIdentifier: aws.String(profileID),
}
// Call the GetInferenceProfile API
output, err := a.mgmtClient.GetInferenceProfile(ctx, input)
if err != nil {
return nil, fmt.Errorf("failed to get inference profile: %w", err)
}
return output, nil
}
// extractModelFromInferenceProfile extracts the model ID from the inference profile
func (a *AmazonBedRockClient) extractModelFromInferenceProfile(profile *bedrock.GetInferenceProfileOutput) (string, error) {
if profile == nil || len(profile.Models) == 0 {
return "", fmt.Errorf("inference profile does not contain any models")
}
// Check if the first model has a non-nil ModelArn
if profile.Models[0].ModelArn == nil {
return "", fmt.Errorf("model information is missing in inference profile")
}
// Get the first model ARN from the profile
modelARN := aws.ToString(profile.Models[0].ModelArn)
if modelARN == "" {
return "", fmt.Errorf("model ARN is empty in inference profile")
}
// Extract the model ID from the ARN
// ARN format: arn:aws:bedrock:region::foundation-model/model-id
parts := strings.Split(modelARN, "/")
if len(parts) != 2 {
return "", fmt.Errorf("invalid model ARN format: %s", modelARN)
}
modelID := parts[1]
return modelID, nil
}
// GetCompletion sends a request to the model for generating completion based on the provided prompt.
func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string) (string, error) {
// override config defaults
a.model.Config.MaxTokens = a.maxTokens
a.model.Config.Temperature = a.temperature
@@ -361,6 +494,7 @@ func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string)
if err != nil {
return "", err
}
// Build the parameters for the model invocation
params := &bedrockruntime.InvokeModelInput{
Body: body,
@@ -368,16 +502,15 @@ func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string)
ContentType: aws.String("application/json"),
Accept: aws.String("application/json"),
}
// Invoke the model
resp, err := a.client.InvokeModelWithContext(ctx, params)
resp, err := a.client.InvokeModel(ctx, params)
if err != nil {
return "", err
}
// Parse the response
return a.model.Response.ParseResponse(resp.Body)
}
// GetName returns the name of the AmazonBedRockClient.

View File

@@ -0,0 +1,103 @@
package ai
import (
"context"
"testing"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/bedrock"
"github.com/aws/aws-sdk-go-v2/service/bedrock/types"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/k8sgpt-ai/k8sgpt/pkg/ai/bedrock_support"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
// Mock for Bedrock Management Client
type MockBedrockClient struct {
mock.Mock
}
func (m *MockBedrockClient) GetInferenceProfile(ctx context.Context, params *bedrock.GetInferenceProfileInput, optFns ...func(*bedrock.Options)) (*bedrock.GetInferenceProfileOutput, error) {
args := m.Called(ctx, params)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*bedrock.GetInferenceProfileOutput), args.Error(1)
}
// Mock for Bedrock Runtime Client
type MockBedrockRuntimeClient struct {
mock.Mock
}
func (m *MockBedrockRuntimeClient) InvokeModel(ctx context.Context, params *bedrockruntime.InvokeModelInput, optFns ...func(*bedrockruntime.Options)) (*bedrockruntime.InvokeModelOutput, error) {
args := m.Called(ctx, params)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*bedrockruntime.InvokeModelOutput), args.Error(1)
}
// TestBedrockInferenceProfileARNWithMocks tests the inference profile ARN validation with mocks
func TestBedrockInferenceProfileARNWithMocks(t *testing.T) {
// Create test models
testModels := []bedrock_support.BedrockModel{
{
Name: "anthropic.claude-3-5-sonnet-20240620-v1:0",
Completion: &bedrock_support.CohereMessagesCompletion{},
Response: &bedrock_support.CohereMessagesResponse{},
Config: bedrock_support.BedrockModelConfig{
MaxTokens: 100,
Temperature: 0.5,
TopP: 0.9,
ModelName: "anthropic.claude-3-5-sonnet-20240620-v1:0",
},
},
}
// Create a client with test models
client := &AmazonBedRockClient{models: testModels}
// Create mock clients
mockMgmtClient := new(MockBedrockClient)
mockRuntimeClient := new(MockBedrockRuntimeClient)
// Inject mock clients into the AmazonBedRockClient
client.mgmtClient = mockMgmtClient
client.client = mockRuntimeClient
// Test with a valid inference profile ARN
inferenceProfileARN := "arn:aws:bedrock:us-east-1:123456789012:inference-profile/my-profile"
// Setup mock response for GetInferenceProfile
mockMgmtClient.On("GetInferenceProfile", mock.Anything, &bedrock.GetInferenceProfileInput{
InferenceProfileIdentifier: aws.String("my-profile"),
}).Return(&bedrock.GetInferenceProfileOutput{
Models: []types.InferenceProfileModel{
{
ModelArn: aws.String("arn:aws:bedrock:us-east-1::foundation-model/anthropic.claude-3-5-sonnet-20240620-v1:0"),
},
},
}, nil)
// Configure the client with the inference profile ARN
config := AIProvider{
Model: inferenceProfileARN,
ProviderRegion: "us-east-1",
}
// Test the Configure method with the inference profile ARN
err := client.Configure(&config)
// Verify that the configuration was successful
assert.NoError(t, err)
assert.Equal(t, inferenceProfileARN, client.model.Config.ModelName)
// Verify that the mock was called
mockMgmtClient.AssertExpectations(t)
}

View File

@@ -31,6 +31,17 @@ var testModels = []bedrock_support.BedrockModel{
ModelName: "anthropic.claude-3-5-sonnet-20241022-v2:0",
},
},
{
Name: "anthropic.claude-3-7-sonnet-20250219-v1:0",
Completion: &bedrock_support.CohereCompletion{},
Response: &bedrock_support.CohereResponse{},
Config: bedrock_support.BedrockModelConfig{
MaxTokens: 100,
Temperature: 0.5,
TopP: 0.9,
ModelName: "anthropic.claude-3-7-sonnet-20250219-v1:0",
},
},
}
func TestBedrockModelConfig(t *testing.T) {
@@ -52,6 +63,45 @@ func TestBedrockInvalidModel(t *testing.T) {
assert.Equal(t, foundModel.Config.MaxTokens, 100)
}
func TestBedrockInferenceProfileARN(t *testing.T) {
// Create a mock client with test models
client := &AmazonBedRockClient{models: testModels}
// Test with a valid inference profile ARN
inferenceProfileARN := "arn:aws:bedrock:us-east-1:123456789012:inference-profile/my-profile"
config := AIProvider{
Model: inferenceProfileARN,
ProviderRegion: "us-east-1",
}
// This will fail in a real environment without mocks, but we're just testing the validation logic
err := client.Configure(&config)
// We expect an error since we can't actually call AWS in tests
assert.NotNil(t, err, "Error should not be nil without AWS mocks")
// Test with a valid application inference profile ARN
appInferenceProfileARN := "arn:aws:bedrock:us-east-1:123456789012:application-inference-profile/my-profile"
config = AIProvider{
Model: appInferenceProfileARN,
ProviderRegion: "us-east-1",
}
// This will fail in a real environment without mocks, but we're just testing the validation logic
err = client.Configure(&config)
// We expect an error since we can't actually call AWS in tests
assert.NotNil(t, err, "Error should not be nil without AWS mocks")
// Test with an invalid inference profile ARN format
invalidARN := "arn:aws:bedrock:us-east-1:123456789012:invalid-resource/my-profile"
config = AIProvider{
Model: invalidARN,
ProviderRegion: "us-east-1",
}
err = client.Configure(&config)
assert.NotNil(t, err, "Error should not be nil for invalid inference profile ARN format")
}
func TestBedrockGetCompletionInferenceProfile(t *testing.T) {
modelName := "arn:aws:bedrock:us-east-1:*:inference-policy/anthropic.claude-3-5-sonnet-20240620-v1:0"
var inferenceModelModels = []bedrock_support.BedrockModel{
@@ -162,3 +212,54 @@ func TestDefaultModels(t *testing.T) {
assert.NoError(t, err, "Should find the model")
assert.Equal(t, "anthropic.claude-v2", model.Name, "Should find the correct model")
}
func TestValidateInferenceProfileArn(t *testing.T) {
tests := []struct {
name string
arn string
valid bool
}{
{
name: "valid inference profile ARN",
arn: "arn:aws:bedrock:us-east-1:123456789012:inference-profile/my-profile",
valid: true,
},
{
name: "valid application inference profile ARN",
arn: "arn:aws:bedrock:us-east-1:123456789012:application-inference-profile/my-profile",
valid: true,
},
{
name: "invalid service in ARN",
arn: "arn:aws:s3:us-east-1:123456789012:inference-profile/my-profile",
valid: false,
},
{
name: "invalid resource type in ARN",
arn: "arn:aws:bedrock:us-east-1:123456789012:model/my-profile",
valid: false,
},
{
name: "malformed ARN",
arn: "arn:aws:bedrock:us-east-1:inference-profile/my-profile",
valid: false,
},
{
name: "not an ARN",
arn: "not-an-arn",
valid: false,
},
{
name: "empty string",
arn: "",
valid: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := validateInferenceProfileArn(tt.arn)
assert.Equal(t, tt.valid, result, "validateInferenceProfileArn() result should match expected")
})
}
}

View File

@@ -0,0 +1,18 @@
package ai
import (
"context"
"github.com/aws/aws-sdk-go-v2/service/bedrock"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
)
// BedrockManagementAPI defines the interface for Bedrock management operations
type BedrockManagementAPI interface {
GetInferenceProfile(ctx context.Context, params *bedrock.GetInferenceProfileInput, optFns ...func(*bedrock.Options)) (*bedrock.GetInferenceProfileOutput, error)
}
// BedrockRuntimeAPI defines the interface for Bedrock runtime operations
type BedrockRuntimeAPI interface {
InvokeModel(ctx context.Context, params *bedrockruntime.InvokeModelInput, optFns ...func(*bedrockruntime.Options)) (*bedrockruntime.InvokeModelOutput, error)
}