From 78ffa5904addf71caf04554966437b14351f21e5 Mon Sep 17 00:00:00 2001 From: ju187 <tony_chen@discovery.com> Date: Thu, 10 Apr 2025 01:47:58 -0700 Subject: [PATCH] feat: add a naive support of bedrock inference profile (#1446) * feat: add a naive support of bedrock inference profile Signed-off-by: Tony Chen <tony_chen@discovery.com> * feat: improving the tests Signed-off-by: Alex Jones <alexsimonjones@gmail.com> --------- Signed-off-by: Tony Chen <tony_chen@discovery.com> Signed-off-by: Alex Jones <alexsimonjones@gmail.com> Co-authored-by: Alex Jones <alexsimonjones@gmail.com> --- pkg/ai/amazonbedrock.go | 393 ++++++++++++++------------ pkg/ai/amazonbedrock_test.go | 131 +++++++++ pkg/ai/bedrock_support/completions.go | 7 +- 3 files changed, 354 insertions(+), 177 deletions(-) create mode 100644 pkg/ai/amazonbedrock_test.go diff --git a/pkg/ai/amazonbedrock.go b/pkg/ai/amazonbedrock.go index 5a6f203..c086a78 100644 --- a/pkg/ai/amazonbedrock.go +++ b/pkg/ai/amazonbedrock.go @@ -3,8 +3,11 @@ package ai import ( "context" "errors" - "github.com/aws/aws-sdk-go/service/bedrockruntime/bedrockruntimeiface" + "fmt" "os" + "strings" + + "github.com/aws/aws-sdk-go/service/bedrockruntime/bedrockruntimeiface" "github.com/k8sgpt-ai/k8sgpt/pkg/ai/bedrock_support" @@ -24,6 +27,7 @@ type AmazonBedRockClient struct { temperature float32 topP float32 maxTokens int + models []bedrock_support.BedrockModel } // AmazonCompletion BedRock support region list US East (N. Virginia),US West (Oregon),Asia Pacific (Singapore),Asia Pacific (Tokyo),Europe (Frankfurt) @@ -48,192 +52,200 @@ var BEDROCKER_SUPPORTED_REGION = []string{ AP_South_1, } -var ( - models = []bedrock_support.BedrockModel{ - { - Name: "anthropic.claude-3-5-sonnet-20240620-v1:0", - Completion: &bedrock_support.CohereMessagesCompletion{}, - Response: &bedrock_support.CohereMessagesResponse{}, - Config: bedrock_support.BedrockModelConfig{ - // sensible defaults - MaxTokens: 100, - Temperature: 0.5, - TopP: 0.9, - ModelName: "anthropic.claude-3-5-sonnet-20240620-v1:0", - }, +var defaultModels = []bedrock_support.BedrockModel{ + { + Name: "anthropic.claude-3-5-sonnet-20240620-v1:0", + Completion: &bedrock_support.CohereMessagesCompletion{}, + Response: &bedrock_support.CohereMessagesResponse{}, + Config: bedrock_support.BedrockModelConfig{ + // sensible defaults + MaxTokens: 100, + Temperature: 0.5, + TopP: 0.9, + ModelName: "anthropic.claude-3-5-sonnet-20240620-v1:0", }, - { - Name: "us.anthropic.claude-3-5-sonnet-20241022-v2:0", - Completion: &bedrock_support.CohereCompletion{}, - Response: &bedrock_support.CohereResponse{}, - Config: bedrock_support.BedrockModelConfig{ - // sensible defaults - MaxTokens: 100, - Temperature: 0.5, - TopP: 0.9, - ModelName: "us.anthropic.claude-3-5-sonnet-20241022-v2:0", - }, + }, + { + Name: "us.anthropic.claude-3-5-sonnet-20241022-v2:0", + Completion: &bedrock_support.CohereCompletion{}, + Response: &bedrock_support.CohereResponse{}, + Config: bedrock_support.BedrockModelConfig{ + // sensible defaults + MaxTokens: 100, + Temperature: 0.5, + TopP: 0.9, + ModelName: "us.anthropic.claude-3-5-sonnet-20241022-v2:0", }, - { - Name: "anthropic.claude-v2", - Completion: &bedrock_support.CohereCompletion{}, - Response: &bedrock_support.CohereResponse{}, - Config: bedrock_support.BedrockModelConfig{ - // sensible defaults - MaxTokens: 100, - Temperature: 0.5, - TopP: 0.9, - ModelName: "anthropic.claude-v2", - }, + }, + { + Name: "anthropic.claude-v2", + Completion: &bedrock_support.CohereCompletion{}, + Response: &bedrock_support.CohereResponse{}, + Config: bedrock_support.BedrockModelConfig{ + // sensible defaults + MaxTokens: 100, + Temperature: 0.5, + TopP: 0.9, + ModelName: "anthropic.claude-v2", }, - { - Name: "anthropic.claude-v1", - Completion: &bedrock_support.CohereCompletion{}, - Response: &bedrock_support.CohereResponse{}, - Config: bedrock_support.BedrockModelConfig{ - // sensible defaults - MaxTokens: 100, - Temperature: 0.5, - TopP: 0.9, - ModelName: "anthropic.claude-v1", - }, + }, + { + Name: "anthropic.claude-v1", + Completion: &bedrock_support.CohereCompletion{}, + Response: &bedrock_support.CohereResponse{}, + Config: bedrock_support.BedrockModelConfig{ + // sensible defaults + MaxTokens: 100, + Temperature: 0.5, + TopP: 0.9, + ModelName: "anthropic.claude-v1", }, - { - Name: "anthropic.claude-instant-v1", - Completion: &bedrock_support.CohereCompletion{}, - Response: &bedrock_support.CohereResponse{}, - Config: bedrock_support.BedrockModelConfig{ - // sensible defaults - MaxTokens: 100, - Temperature: 0.5, - TopP: 0.9, - ModelName: "anthropic.claude-instant-v1", - }, + }, + { + Name: "anthropic.claude-instant-v1", + Completion: &bedrock_support.CohereCompletion{}, + Response: &bedrock_support.CohereResponse{}, + Config: bedrock_support.BedrockModelConfig{ + // sensible defaults + MaxTokens: 100, + Temperature: 0.5, + TopP: 0.9, + ModelName: "anthropic.claude-instant-v1", }, - { - Name: "ai21.j2-ultra-v1", - Completion: &bedrock_support.AI21{}, - Response: &bedrock_support.AI21Response{}, - Config: bedrock_support.BedrockModelConfig{ - // sensible defaults - MaxTokens: 100, - Temperature: 0.5, - TopP: 0.9, - ModelName: "ai21.j2-ultra-v1", - }, + }, + { + Name: "ai21.j2-ultra-v1", + Completion: &bedrock_support.AI21{}, + Response: &bedrock_support.AI21Response{}, + Config: bedrock_support.BedrockModelConfig{ + // sensible defaults + MaxTokens: 100, + Temperature: 0.5, + TopP: 0.9, + ModelName: "ai21.j2-ultra-v1", }, - { - Name: "ai21.j2-jumbo-instruct", - Completion: &bedrock_support.AI21{}, - Response: &bedrock_support.AI21Response{}, - Config: bedrock_support.BedrockModelConfig{ - // sensible defaults - MaxTokens: 100, - Temperature: 0.5, - TopP: 0.9, - ModelName: "ai21.j2-jumbo-instruct", - }, + }, + { + Name: "ai21.j2-jumbo-instruct", + Completion: &bedrock_support.AI21{}, + Response: &bedrock_support.AI21Response{}, + Config: bedrock_support.BedrockModelConfig{ + // sensible defaults + MaxTokens: 100, + Temperature: 0.5, + TopP: 0.9, + ModelName: "ai21.j2-jumbo-instruct", }, - { - Name: "amazon.titan-text-express-v1", - Completion: &bedrock_support.AmazonCompletion{}, - Response: &bedrock_support.AmazonResponse{}, - Config: bedrock_support.BedrockModelConfig{ - // sensible defaults - MaxTokens: 100, - Temperature: 0.5, - TopP: 0.9, - ModelName: "amazon.titan-text-express-v1", - }, + }, + { + Name: "amazon.titan-text-express-v1", + Completion: &bedrock_support.AmazonCompletion{}, + Response: &bedrock_support.AmazonResponse{}, + Config: bedrock_support.BedrockModelConfig{ + // sensible defaults + MaxTokens: 100, + Temperature: 0.5, + TopP: 0.9, + ModelName: "amazon.titan-text-express-v1", }, - { - Name: "amazon.nova-pro-v1:0", - Completion: &bedrock_support.AmazonCompletion{}, - Response: &bedrock_support.NovaResponse{}, - Config: bedrock_support.BedrockModelConfig{ - // sensible defaults - // https://docs.aws.amazon.com/nova/latest/userguide/getting-started-api.html - MaxTokens: 100, // max of 300k tokens - Temperature: 0.5, - TopP: 0.9, - ModelName: "amazon.nova-pro-v1:0", - }, + }, + { + Name: "amazon.nova-pro-v1:0", + Completion: &bedrock_support.AmazonCompletion{}, + Response: &bedrock_support.NovaResponse{}, + Config: bedrock_support.BedrockModelConfig{ + // sensible defaults + // https://docs.aws.amazon.com/nova/latest/userguide/getting-started-api.html + MaxTokens: 100, // max of 300k tokens + Temperature: 0.5, + TopP: 0.9, + ModelName: "amazon.nova-pro-v1:0", }, - { - Name: "eu.amazon.nova-pro-v1:0", - Completion: &bedrock_support.AmazonCompletion{}, - Response: &bedrock_support.NovaResponse{}, - Config: bedrock_support.BedrockModelConfig{ - // sensible defaults - // https://docs.aws.amazon.com/nova/latest/userguide/getting-started-api.html - MaxTokens: 100, // max of 300k tokens - Temperature: 0.5, - TopP: 0.9, - ModelName: "eu.wamazon.nova-pro-v1:0", - }, + }, + { + Name: "eu.amazon.nova-pro-v1:0", + Completion: &bedrock_support.AmazonCompletion{}, + Response: &bedrock_support.NovaResponse{}, + Config: bedrock_support.BedrockModelConfig{ + // sensible defaults + // https://docs.aws.amazon.com/nova/latest/userguide/getting-started-api.html + MaxTokens: 100, // max of 300k tokens + Temperature: 0.5, + TopP: 0.9, + ModelName: "eu.amazon.nova-pro-v1:0", }, - { - Name: "us.amazon.nova-pro-v1:0", - Completion: &bedrock_support.AmazonCompletion{}, - Response: &bedrock_support.NovaResponse{}, - Config: bedrock_support.BedrockModelConfig{ - // sensible defaults - // https://docs.aws.amazon.com/nova/latest/userguide/getting-started-api.html - MaxTokens: 100, // max of 300k tokens - Temperature: 0.5, - TopP: 0.9, - ModelName: "us.amazon.nova-pro-v1:0", - }, + }, + { + Name: "us.amazon.nova-pro-v1:0", + Completion: &bedrock_support.AmazonCompletion{}, + Response: &bedrock_support.NovaResponse{}, + Config: bedrock_support.BedrockModelConfig{ + // sensible defaults + // https://docs.aws.amazon.com/nova/latest/userguide/getting-started-api.html + MaxTokens: 100, // max of 300k tokens + Temperature: 0.5, + TopP: 0.9, + ModelName: "us.amazon.nova-pro-v1:0", }, - { - Name: "amazon.nova-lite-v1:0", - Completion: &bedrock_support.AmazonCompletion{}, - Response: &bedrock_support.NovaResponse{}, - Config: bedrock_support.BedrockModelConfig{ - // sensible defaults - MaxTokens: 100, // max of 300k tokens - Temperature: 0.5, - TopP: 0.9, - ModelName: "amazon.nova-lite-v1:0", - }, + }, + { + Name: "amazon.nova-lite-v1:0", + Completion: &bedrock_support.AmazonCompletion{}, + Response: &bedrock_support.NovaResponse{}, + Config: bedrock_support.BedrockModelConfig{ + // sensible defaults + MaxTokens: 100, // max of 300k tokens + Temperature: 0.5, + TopP: 0.9, + ModelName: "amazon.nova-lite-v1:0", }, - { - Name: "eu.amazon.nova-lite-v1:0", - Completion: &bedrock_support.AmazonCompletion{}, - Response: &bedrock_support.NovaResponse{}, - Config: bedrock_support.BedrockModelConfig{ - // sensible defaults - MaxTokens: 100, // max of 300k tokens - Temperature: 0.5, - TopP: 0.9, - ModelName: "eu.amazon.nova-lite-v1:0", - }, + }, + { + Name: "eu.amazon.nova-lite-v1:0", + Completion: &bedrock_support.AmazonCompletion{}, + Response: &bedrock_support.NovaResponse{}, + Config: bedrock_support.BedrockModelConfig{ + // sensible defaults + MaxTokens: 100, // max of 300k tokens + Temperature: 0.5, + TopP: 0.9, + ModelName: "eu.amazon.nova-lite-v1:0", }, - { - Name: "us.amazon.nova-lite-v1:0", - Completion: &bedrock_support.AmazonCompletion{}, - Response: &bedrock_support.NovaResponse{}, - Config: bedrock_support.BedrockModelConfig{ - // sensible defaults - MaxTokens: 100, // max of 300k tokens - Temperature: 0.5, - TopP: 0.9, - ModelName: "us.amazon.nova-lite-v1:0", - }, + }, + { + Name: "us.amazon.nova-lite-v1:0", + Completion: &bedrock_support.AmazonCompletion{}, + Response: &bedrock_support.NovaResponse{}, + Config: bedrock_support.BedrockModelConfig{ + // sensible defaults + MaxTokens: 100, // max of 300k tokens + Temperature: 0.5, + TopP: 0.9, + ModelName: "us.amazon.nova-lite-v1:0", }, - { - Name: "anthropic.claude-3-haiku-20240307-v1:0", - Completion: &bedrock_support.CohereCompletion{}, - Response: &bedrock_support.CohereResponse{}, - Config: bedrock_support.BedrockModelConfig{ - // sensible defaults - MaxTokens: 100, - Temperature: 0.5, - TopP: 0.9, - }, + }, + { + Name: "anthropic.claude-3-haiku-20240307-v1:0", + Completion: &bedrock_support.CohereCompletion{}, + Response: &bedrock_support.CohereResponse{}, + Config: bedrock_support.BedrockModelConfig{ + // sensible defaults + MaxTokens: 100, + Temperature: 0.5, + TopP: 0.9, }, + }, +} + +// NewAmazonBedRockClient creates a new AmazonBedRockClient with the given models +func NewAmazonBedRockClient(models []bedrock_support.BedrockModel) *AmazonBedRockClient { + if models == nil { + models = defaultModels // Use default models if none provided } -) + return &AmazonBedRockClient{ + models: models, + } +} // GetModelOrDefault check config region func GetRegionOrDefault(region string) string { @@ -254,16 +266,46 @@ func GetRegionOrDefault(region string) string { // Get model from string func (a *AmazonBedRockClient) getModelFromString(model string) (*bedrock_support.BedrockModel, error) { - for _, m := range models { - if model == m.Name { - return &m, nil + if model == "" { + return nil, errors.New("model name cannot be empty") + } + + // Trim spaces from the model name + model = strings.TrimSpace(model) + modelLower := strings.ToLower(model) + + // Try to find an exact match first + for i := range a.models { + if strings.EqualFold(model, a.models[i].Name) || strings.EqualFold(model, a.models[i].Config.ModelName) { + // Create a copy to avoid returning a pointer to a loop variable + modelCopy := a.models[i] + return &modelCopy, nil } } - return nil, errors.New("model not found") + + // If no exact match, try partial match + for i := range a.models { + modelNameLower := strings.ToLower(a.models[i].Name) + modelConfigNameLower := strings.ToLower(a.models[i].Config.ModelName) + + // Check if the input string contains the model name or vice versa + if strings.Contains(modelNameLower, modelLower) || strings.Contains(modelLower, modelNameLower) || + strings.Contains(modelConfigNameLower, modelLower) || strings.Contains(modelLower, modelConfigNameLower) { + // Create a copy to avoid returning a pointer to a loop variable + modelCopy := a.models[i] + return &modelCopy, nil + } + } + + return nil, fmt.Errorf("model '%s' not found in supported models", model) } // Configure configures the AmazonBedRockClient with the provided configuration. func (a *AmazonBedRockClient) Configure(config IAIConfig) error { + // Initialize models if not already initialized + if a.models == nil { + a.models = defaultModels + } // Create a new AWS session providerRegion := GetRegionOrDefault(config.GetProviderRegion()) @@ -280,7 +322,6 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error { if err != nil { return err } - // TODO: Override the completion config somehow // Create a new BedrockRuntime client a.client = bedrockruntime.New(sess) diff --git a/pkg/ai/amazonbedrock_test.go b/pkg/ai/amazonbedrock_test.go new file mode 100644 index 0000000..cb39373 --- /dev/null +++ b/pkg/ai/amazonbedrock_test.go @@ -0,0 +1,131 @@ +package ai + +import ( + "testing" + + "github.com/k8sgpt-ai/k8sgpt/pkg/ai/bedrock_support" + "github.com/stretchr/testify/assert" +) + +// Test models for unit testing +var testModels = []bedrock_support.BedrockModel{ + { + Name: "anthropic.claude-3-5-sonnet-20240620-v1:0", + Completion: &bedrock_support.CohereMessagesCompletion{}, + Response: &bedrock_support.CohereMessagesResponse{}, + Config: bedrock_support.BedrockModelConfig{ + MaxTokens: 100, + Temperature: 0.5, + TopP: 0.9, + ModelName: "anthropic.claude-3-5-sonnet-20240620-v1:0", + }, + }, + { + Name: "anthropic.claude-3-5-sonnet-20241022-v2:0", + Completion: &bedrock_support.CohereCompletion{}, + Response: &bedrock_support.CohereResponse{}, + Config: bedrock_support.BedrockModelConfig{ + MaxTokens: 100, + Temperature: 0.5, + TopP: 0.9, + ModelName: "anthropic.claude-3-5-sonnet-20241022-v2:0", + }, + }, +} + +func TestBedrockModelConfig(t *testing.T) { + client := &AmazonBedRockClient{models: testModels} + + foundModel, err := client.getModelFromString("arn:aws:bedrock:us-east-1:*:inference-policy/anthropic.claude-3-5-sonnet-20240620-v1:0") + assert.Nil(t, err, "Error should be nil") + assert.Equal(t, foundModel.Config.MaxTokens, 100) + assert.Equal(t, foundModel.Config.Temperature, float32(0.5)) + assert.Equal(t, foundModel.Config.TopP, float32(0.9)) + assert.Equal(t, foundModel.Config.ModelName, "anthropic.claude-3-5-sonnet-20240620-v1:0") +} + +func TestGetModelFromString(t *testing.T) { + client := &AmazonBedRockClient{models: testModels} + + tests := []struct { + name string + model string + wantModel string + wantErr bool + }{ + { + name: "exact model name match", + model: "anthropic.claude-3-5-sonnet-20240620-v1:0", + wantModel: "anthropic.claude-3-5-sonnet-20240620-v1:0", + wantErr: false, + }, + { + name: "partial model name match", + model: "claude-3-5-sonnet", + wantModel: "anthropic.claude-3-5-sonnet-20240620-v1:0", + wantErr: false, + }, + { + name: "model name with different version", + model: "anthropic.claude-3-5-sonnet-20241022-v2:0", + wantModel: "anthropic.claude-3-5-sonnet-20241022-v2:0", + wantErr: false, + }, + { + name: "non-existent model", + model: "non-existent-model", + wantModel: "", + wantErr: true, + }, + { + name: "empty model name", + model: "", + wantModel: "", + wantErr: true, + }, + { + name: "model name with extra spaces", + model: " anthropic.claude-3-5-sonnet-20240620-v1:0 ", + wantModel: "anthropic.claude-3-5-sonnet-20240620-v1:0", + wantErr: false, + }, + { + name: "case insensitive match", + model: "ANTHROPIC.CLAUDE-3-5-SONNET-20240620-V1:0", + wantModel: "anthropic.claude-3-5-sonnet-20240620-v1:0", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotModel, err := client.getModelFromString(tt.model) + if (err != nil) != tt.wantErr { + t.Errorf("getModelFromString() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && gotModel.Name != tt.wantModel { + t.Errorf("getModelFromString() = %v, want %v", gotModel.Name, tt.wantModel) + } + }) + } +} + +// TestDefaultModels tests that the client works with default models +func TestDefaultModels(t *testing.T) { + client := &AmazonBedRockClient{} + + // Configure should initialize default models + err := client.Configure(&AIProvider{ + Model: "anthropic.claude-v2", + }) + + assert.NoError(t, err, "Configure should not return an error") + assert.NotNil(t, client.models, "Models should be initialized") + assert.NotEmpty(t, client.models, "Models should not be empty") + + // Test finding a default model + model, err := client.getModelFromString("anthropic.claude-v2") + assert.NoError(t, err, "Should find the model") + assert.Equal(t, "anthropic.claude-v2", model.Name, "Should find the correct model") +} diff --git a/pkg/ai/bedrock_support/completions.go b/pkg/ai/bedrock_support/completions.go index ae1ca29..1484304 100644 --- a/pkg/ai/bedrock_support/completions.go +++ b/pkg/ai/bedrock_support/completions.go @@ -17,7 +17,12 @@ var SUPPPORTED_BEDROCK_MODELS = []string{ "ai21.j2-jumbo-instruct", "amazon.titan-text-express-v1", "amazon.nova-pro-v1:0", + "eu.amazon.nova-pro-v1:0", + "us.amazon.nova-pro-v1:0", + "amazon.nova-lite-v1:0", "eu.amazon.nova-lite-v1:0", + "us.amazon.nova-lite-v1:0", + "anthropic.claude-3-haiku-20240307-v1:0", } type ICompletion interface { @@ -91,7 +96,7 @@ type AmazonCompletion struct { func isModelSupported(modelName string) bool { for _, supportedModel := range SUPPPORTED_BEDROCK_MODELS { - if modelName == supportedModel { + if strings.Contains(modelName, supportedModel) { return true } }