diff --git a/README.md b/README.md index 70f9b903..de677b47 100644 --- a/README.md +++ b/README.md @@ -500,6 +500,21 @@ k8sgpt auth default -p azureopenai Default provider set to azureopenai ``` +_Using Amazon Bedrock Converse with inference profiles_ + +_System Inference Profile_ + +``` +k8sgpt auth add --backend amazonbedrockconverse --providerRegion us-east-1 --model arn:aws:bedrock:us-east-1:123456789012:inference-profile/my-inference-profile + +``` + +_Application Inference Profile_ + +``` +k8sgpt auth add --backend amazonbedrockconverse --providerRegion us-east-1 --model arn:aws:bedrock:us-east-1:123456789012:application-inference-profile/2uzp4s0w39t6 + +``` _Using Amazon Bedrock with inference profiles_ _System Inference Profile_ diff --git a/SUPPORTED_MODELS.md b/SUPPORTED_MODELS.md index aff895ab..19bd20c5 100644 --- a/SUPPORTED_MODELS.md +++ b/SUPPORTED_MODELS.md @@ -24,6 +24,9 @@ K8sGPT supports a variety of AI/LLM providers (backends). Some providers have a ### Cohere - **Model:** User-configurable (any model supported by Cohere) +### Amazon Bedrock Converse +- **Model:** User-configurable (any model supported by [Amazon Bedrock Converse](https://docs.aws.amazon.com/bedrock/latest/userguide/models-api-compatibility.html)) + ### Amazon Bedrock - **Supported Models:** - anthropic.claude-sonnet-4-20250514-v1:0 @@ -80,4 +83,4 @@ K8sGPT supports a variety of AI/LLM providers (backends). Some providers have a --- -For more details on configuring each provider and model, refer to the official K8sGPT documentation and the provider's own documentation. \ No newline at end of file +For more details on configuring each provider and model, refer to the official K8sGPT documentation and the provider's own documentation. diff --git a/cmd/auth/add.go b/cmd/auth/add.go index 415c4571..8564a924 100644 --- a/cmd/auth/add.go +++ b/cmd/auth/add.go @@ -48,6 +48,9 @@ var addCmd = &cobra.Command{ if strings.ToLower(backend) == "amazonbedrock" { _ = cmd.MarkFlagRequired("providerRegion") } + if strings.ToLower(backend) == "amazonbedrockconverse" { + _ = cmd.MarkFlagRequired("providerRegion") + } if strings.ToLower(backend) == "ibmwatsonxai" { _ = cmd.MarkFlagRequired("providerId") } @@ -140,6 +143,7 @@ var addCmd = &cobra.Command{ TopP: topP, TopK: topK, MaxTokens: maxTokens, + StopSequences: stopSequences, OrganizationId: organizationId, } @@ -173,12 +177,14 @@ func init() { addCmd.Flags().Int32VarP(&topK, "topk", "c", 50, "Sampling Cutoff: Set a threshold (1-100) to restrict the sampling process to the top K most probable words at each step. Higher values lead to greater variability, lower values increases predictability.") // max tokens addCmd.Flags().IntVarP(&maxTokens, "maxtokens", "l", 2048, "Specify a maximum output length. Adjust (1-...) to control text length. Higher values produce longer output, lower values limit length") + // stop sequences + addCmd.Flags().StringSliceVarP(&stopSequences, "stopsequences", "s", []string{}, "Stop Sequences: Define specific tokens or phrases that signal the model to stop generating text.") // add flag for temperature addCmd.Flags().Float32VarP(&temperature, "temperature", "t", 0.7, "The sampling temperature, value ranges between 0 ( output be more deterministic) and 1 (more random)") // add flag for azure open ai engine/deployment name addCmd.Flags().StringVarP(&engine, "engine", "e", "", "Azure AI deployment name (only for azureopenai backend)") //add flag for amazonbedrock region name - addCmd.Flags().StringVarP(&providerRegion, "providerRegion", "r", "", "Provider Region name (only for amazonbedrock, googlevertexai backend)") + addCmd.Flags().StringVarP(&providerRegion, "providerRegion", "r", "", "Provider Region name (only for amazonbedrock, amazonbedrockconverse, googlevertexai backend)") //add flag for vertexAI/WatsonxAI Project ID addCmd.Flags().StringVarP(&providerId, "providerId", "i", "", "Provider specific ID for e.g. project (only for googlevertexai/ibmwatsonxai backend)") //add flag for OCI Compartment ID diff --git a/cmd/auth/auth.go b/cmd/auth/auth.go index c8f4e209..2ea3c12e 100644 --- a/cmd/auth/auth.go +++ b/cmd/auth/auth.go @@ -32,6 +32,7 @@ var ( topP float32 topK int32 maxTokens int + stopSequences []string organizationId string ) diff --git a/pkg/ai/amazonbedrockconverse.go b/pkg/ai/amazonbedrockconverse.go new file mode 100644 index 00000000..651e0881 --- /dev/null +++ b/pkg/ai/amazonbedrockconverse.go @@ -0,0 +1,161 @@ +package ai + +import ( + "context" + "errors" + "fmt" + "github.com/aws/aws-sdk-go-v2/aws" + awsconfig "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types" + "os" + "strings" +) + +const amazonBedrockConverseClientName = "amazonbedrockconverse" + +type bedrockConverseAPI interface { + Converse(ctx context.Context, input *bedrockruntime.ConverseInput, optFns ...func(*bedrockruntime.Options)) (*bedrockruntime.ConverseOutput, error) +} + +type AmazonBedrockConverseClient struct { + nopCloser + + client bedrockConverseAPI + model string + temperature float32 + topP float32 + maxTokens int + stopSequences []string +} + +func getRegion(region string) string { + if os.Getenv("AWS_DEFAULT_REGION") != "" { + region = os.Getenv("AWS_DEFAULT_REGION") + } + // Return the supplied provider region if not overridden by environment variable + return region +} + +func (a *AmazonBedrockConverseClient) getModelFromString(model string) (string, error) { + if model == "" { + return "", errors.New("model name cannot be empty") + } + model = strings.TrimSpace(model) + + return model, nil +} + +func processError(err error, modelId string) error { + errMsg := err.Error() + if strings.Contains(errMsg, "no such host") { + return fmt.Errorf(`the bedrock service is not available in the selected region. + please double-check the service availability for your region at + https://aws.amazon.com/about-aws/global-infrastructure/regional-product-services/`) + } else if strings.Contains(errMsg, "Could not resolve the foundation model") { + return fmt.Errorf(`could not resolve the foundation model from model identifier: \"%s\". + please verify that the requested model exists and is accessible + within the specified region`, modelId) + } else { + return fmt.Errorf("could not invoke model: \"%s\". here is why: %s", modelId, err) + } +} + +func (a *AmazonBedrockConverseClient) Configure(config IAIConfig) error { + modelInput := config.GetModel() + + var region = getRegion(config.GetProviderRegion()) + + // Only create AWS clients if they haven't been injected (for testing) + if a.client == nil { + cfg, err := awsconfig.LoadDefaultConfig(context.Background(), + awsconfig.WithRegion(region), + ) + if err != nil { + if strings.Contains(err.Error(), "InvalidAccessKeyId") || strings.Contains(err.Error(), "SignatureDoesNotMatch") || strings.Contains(err.Error(), "NoCredentialProviders") { + return fmt.Errorf("aws credentials are invalid or missing. Please check your environment variables or aws config. details: %v", err) + } + return fmt.Errorf("failed to load aws config for region %s: %w", region, err) + } + + a.client = bedrockruntime.NewFromConfig(cfg) + } + + foundModel, err := a.getModelFromString(modelInput) + if err != nil { + return fmt.Errorf("failed to find model configuration for %s: %w", modelInput, err) + } + a.model = foundModel + + // Set common configuration parameters + a.temperature = config.GetTemperature() + a.topP = config.GetTopP() + a.maxTokens = config.GetMaxTokens() + a.stopSequences = config.GetStopSequences() + + return nil +} + +func extractTextFromConverseOutput(output types.ConverseOutput, modelId string) (string, error) { + if output == nil { + return "", fmt.Errorf("empty response from model: %s", modelId) + } + + msg, ok := output.(*types.ConverseOutputMemberMessage) + if !ok { + return "", fmt.Errorf("unexpected response type from model: %s", modelId) + } + + if len(msg.Value.Content) == 0 { + return "", fmt.Errorf("no content returned from model: %s", modelId) + } + + var builder strings.Builder + + for _, block := range msg.Value.Content { + if textBlock, ok := block.(*types.ContentBlockMemberText); ok && textBlock != nil { + builder.WriteString(textBlock.Value) + } + } + + if builder.Len() == 0 { + return "", fmt.Errorf("no text content returned from model: %s", modelId) + } + + return builder.String(), nil +} + +func (a *AmazonBedrockConverseClient) GetCompletion(ctx context.Context, prompt string) (string, error) { + var content = types.ContentBlockMemberText{ + Value: prompt, + } + var message = types.Message{ + Content: []types.ContentBlock{&content}, + Role: "user", + } + var converseInput = bedrockruntime.ConverseInput{ + ModelId: aws.String(a.model), + Messages: []types.Message{message}, + InferenceConfig: &types.InferenceConfiguration{ + Temperature: aws.Float32(a.temperature), + TopP: aws.Float32(a.topP), + MaxTokens: aws.Int32(int32(a.maxTokens)), + StopSequences: a.stopSequences, + }, + } + response, err := a.client.Converse(ctx, &converseInput) + if err != nil { + return "", processError(err, a.model) + } + + text, err := extractTextFromConverseOutput(response.Output, a.model) + if err != nil { + return "", err + } + + return text, nil +} + +func (a *AmazonBedrockConverseClient) GetName() string { + return amazonBedrockConverseClientName +} diff --git a/pkg/ai/amazonbedrockconverse_mock_test.go b/pkg/ai/amazonbedrockconverse_mock_test.go new file mode 100644 index 00000000..98164d02 --- /dev/null +++ b/pkg/ai/amazonbedrockconverse_mock_test.go @@ -0,0 +1,250 @@ +package ai + +import ( + "context" + "errors" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types" + "github.com/stretchr/testify/assert" + "testing" +) + +// ---- Mock Wrapper ---- +type mockConverseClient struct { + converseFunc func(ctx context.Context, input *bedrockruntime.ConverseInput) (*bedrockruntime.ConverseOutput, error) +} + +func (m *mockConverseClient) Converse(ctx context.Context, input *bedrockruntime.ConverseInput, _ ...func(*bedrockruntime.Options)) (*bedrockruntime.ConverseOutput, error) { + return m.converseFunc(ctx, input) +} + +// ---- Tests ---- +func TestGetCompletion_Success(t *testing.T) { + mock := &mockConverseClient{ + converseFunc: func(ctx context.Context, input *bedrockruntime.ConverseInput) (*bedrockruntime.ConverseOutput, error) { + return &bedrockruntime.ConverseOutput{ + Output: &types.ConverseOutputMemberMessage{ + Value: types.Message{ + Content: []types.ContentBlock{ + &types.ContentBlockMemberText{ + Value: "mock response", + }, + }, + }, + }, + }, nil + }, + } + + client := &AmazonBedrockConverseClient{ + client: mock, + model: "test-model", + } + + result, err := client.GetCompletion(context.Background(), "hello") + + assert.NoError(t, err) + assert.Equal(t, "mock response", result) +} + +func TestGetCompletion_Error(t *testing.T) { + mock := &mockConverseClient{ + converseFunc: func(ctx context.Context, input *bedrockruntime.ConverseInput) (*bedrockruntime.ConverseOutput, error) { + return nil, errors.New("some error") + }, + } + + client := &AmazonBedrockConverseClient{ + client: mock, + model: "test-model", + } + + _, err := client.GetCompletion(context.Background(), "hello") + + assert.Error(t, err) +} + +func TestConfigure_WithInjectedClient(t *testing.T) { + mock := &mockConverseClient{} + + cfg := &AIProvider{ + Model: "test-model", + ProviderRegion: "us-west-2", + Temperature: 0.5, + TopP: 0.9, + MaxTokens: 100, + StopSequences: []string{"stop"}, + } + + client := &AmazonBedrockConverseClient{ + client: mock, + } + + err := client.Configure(cfg) + + assert.NoError(t, err) + assert.Equal(t, "test-model", client.model) + assert.Equal(t, float32(0.5), client.temperature) + assert.Equal(t, float32(0.9), client.topP) + assert.Equal(t, 100, client.maxTokens) + assert.Equal(t, []string{"stop"}, client.stopSequences) +} + +func TestConfigure_InvalidModel(t *testing.T) { + mock := &mockConverseClient{} + + cfg := &AIProvider{ + Model: "", + } + + client := &AmazonBedrockConverseClient{ + client: mock, + } + + err := client.Configure(cfg) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "model name cannot be empty") +} + +func TestGetRegion(t *testing.T) { + t.Run("uses provided region when env not set", func(t *testing.T) { + t.Setenv("AWS_DEFAULT_REGION", "") + + result := getRegion("us-west-2") + assert.Equal(t, "us-west-2", result) + }) + + t.Run("env overrides provided region", func(t *testing.T) { + t.Setenv("AWS_DEFAULT_REGION", "us-east-1") + + result := getRegion("us-west-2") + assert.Equal(t, "us-east-1", result) + }) +} + +func TestProcessError(t *testing.T) { + tests := []struct { + name string + err error + modelId string + contains string + }{ + { + name: "no such host", + err: errors.New("dial tcp: no such host"), + modelId: "test-model", + contains: "bedrock service is not available", + }, + { + name: "model not found", + err: errors.New("Could not resolve the foundation model"), + modelId: "test-model", + contains: "could not resolve the foundation model", + }, + { + name: "generic error", + err: errors.New("something else"), + modelId: "test-model", + contains: "could not invoke model", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := processError(tt.err, tt.modelId) + assert.Contains(t, result.Error(), tt.contains) + }) + } +} + +func TestExtractTextFromConverseOutput(t *testing.T) { + tests := []struct { + name string + output types.ConverseOutput + expectError bool + expected string + }{ + { + name: "nil output", + output: nil, + expectError: true, + }, + { + name: "empty content", + output: &types.ConverseOutputMemberMessage{ + Value: types.Message{ + Content: []types.ContentBlock{}, + }, + }, + expectError: true, + }, + { + name: "single text block", + output: &types.ConverseOutputMemberMessage{ + Value: types.Message{ + Content: []types.ContentBlock{ + &types.ContentBlockMemberText{Value: "hello"}, + }, + }, + }, + expected: "hello", + }, + { + name: "multiple text blocks", + output: &types.ConverseOutputMemberMessage{ + Value: types.Message{ + Content: []types.ContentBlock{ + &types.ContentBlockMemberText{Value: "hello "}, + &types.ContentBlockMemberText{Value: "world"}, + }, + }, + }, + expected: "hello world", + }, + { + name: "mixed content blocks", + output: &types.ConverseOutputMemberMessage{ + Value: types.Message{ + Content: []types.ContentBlock{ + &types.ContentBlockMemberText{Value: "hello"}, + // simulate non-text block + &types.ContentBlockMemberImage{}, + &types.ContentBlockMemberText{Value: " world"}, + }, + }, + }, + expected: "hello world", + }, + { + name: "no text blocks", + output: &types.ConverseOutputMemberMessage{ + Value: types.Message{ + Content: []types.ContentBlock{ + &types.ContentBlockMemberImage{}, + }, + }, + }, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := extractTextFromConverseOutput(tt.output, "test-model") + + if tt.expectError { + assert.Error(t, err) + return + } + + assert.NoError(t, err) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestGetName(t *testing.T) { + client := &AmazonBedrockConverseClient{} + assert.Equal(t, "amazonbedrockconverse", client.GetName()) +} diff --git a/pkg/ai/iai.go b/pkg/ai/iai.go index 243fa546..e3fb6c73 100644 --- a/pkg/ai/iai.go +++ b/pkg/ai/iai.go @@ -27,6 +27,7 @@ var ( &NoOpAIClient{}, &CohereClient{}, &AmazonBedRockClient{}, + &AmazonBedrockConverseClient{}, &SageMakerAIClient{}, &GoogleGenAIClient{}, &HuggingfaceClient{}, @@ -43,6 +44,7 @@ var ( azureAIClientName, cohereAIClientName, amazonbedrockAIClientName, + amazonBedrockConverseClientName, amazonsagemakerAIClientName, googleAIClientName, noopAIClientName, @@ -85,6 +87,7 @@ type IAIConfig interface { GetTopP() float32 GetTopK() int32 GetMaxTokens() int + GetStopSequences() []string GetProviderId() string GetCompartmentId() string GetOrganizationId() string @@ -122,6 +125,7 @@ type AIProvider struct { TopP float32 `mapstructure:"topp" yaml:"topp,omitempty"` TopK int32 `mapstructure:"topk" yaml:"topk,omitempty"` MaxTokens int `mapstructure:"maxtokens" yaml:"maxtokens,omitempty"` + StopSequences []string `mapstructure:"stopsequences" yaml:"stopsequences,omitempty"` OrganizationId string `mapstructure:"organizationid" yaml:"organizationid,omitempty"` CustomHeaders []http.Header `mapstructure:"customHeaders"` } @@ -150,6 +154,10 @@ func (p *AIProvider) GetMaxTokens() int { return p.MaxTokens } +func (p *AIProvider) GetStopSequences() []string { + return p.StopSequences +} + func (p *AIProvider) GetPassword() string { return p.Password } @@ -185,7 +193,7 @@ func (p *AIProvider) GetCustomHeaders() []http.Header { return p.CustomHeaders } -var passwordlessProviders = []string{"localai", "ollama", "amazonsagemaker", "amazonbedrock", "googlevertexai", "oci", "customrest"} +var passwordlessProviders = []string{"localai", "ollama", "amazonsagemaker", "amazonbedrock", "amazonbedrockconverse", "googlevertexai", "oci", "customrest"} func NeedPassword(backend string) bool { for _, b := range passwordlessProviders { diff --git a/pkg/ai/openai_header_transport_test.go b/pkg/ai/openai_header_transport_test.go index 9d43f463..d64e77fd 100644 --- a/pkg/ai/openai_header_transport_test.go +++ b/pkg/ai/openai_header_transport_test.go @@ -61,6 +61,10 @@ func (m *mockConfig) GetMaxTokens() int { return 0 } +func (m *mockConfig) GetStopSequences() []string { + return []string{"", "", "", ""} +} + func (m *mockConfig) GetEndpointName() string { return "" }