diff --git a/pkg/ai/amazonbedrock.go b/pkg/ai/amazonbedrock.go index ad048b5..30f3b11 100644 --- a/pkg/ai/amazonbedrock.go +++ b/pkg/ai/amazonbedrock.go @@ -2,8 +2,8 @@ package ai import ( "context" - "encoding/json" - "fmt" + "errors" + "github.com/k8sgpt-ai/k8sgpt/pkg/ai/bedrock_support" "os" "github.com/aws/aws-sdk-go/aws" @@ -13,18 +13,18 @@ import ( const amazonbedrockAIClientName = "amazonbedrock" -// AmazonBedRockClient represents the client for interacting with the Amazon Bedrock service. +// AmazonBedRockClient represents the client for interacting with the AmazonCompletion Bedrock service. type AmazonBedRockClient struct { nopCloser client *bedrockruntime.BedrockRuntime - model string + model *bedrock_support.BedrockModel temperature float32 topP float32 maxTokens int } -// Amazon BedRock support region list US East (N. Virginia),US West (Oregon),Asia Pacific (Singapore),Asia Pacific (Tokyo),Europe (Frankfurt) +// AmazonCompletion BedRock support region list US East (N. Virginia),US West (Oregon),Asia Pacific (Singapore),Asia Pacific (Tokyo),Europe (Frankfurt) // https://docs.aws.amazon.com/bedrock/latest/userguide/what-is-bedrock.html#bedrock-regions const BEDROCK_DEFAULT_REGION = "us-east-1" // default use us-east-1 region @@ -44,41 +44,109 @@ var BEDROCKER_SUPPORTED_REGION = []string{ EU_Central_1, } -const ( - ModelAnthropicClaudeSonnetV3_5 = "anthropic.claude-3-5-sonnet-20240620-v1:0" - ModelAnthropicClaudeSonnetV3_5_V2 = "us.anthropic.claude-3-5-sonnet-20241022-v2:0" - ModelAnthropicClaudeV2 = "anthropic.claude-v2" - ModelAnthropicClaudeV1 = "anthropic.claude-v1" - ModelAnthropicClaudeInstantV1 = "anthropic.claude-instant-v1" - ModelA21J2UltraV1 = "ai21.j2-ultra-v1" - ModelA21J2JumboInstruct = "ai21.j2-jumbo-instruct" - ModelAmazonTitanExpressV1 = "amazon.titan-text-express-v1" -) - -var BEDROCK_MODELS = []string{ - ModelAnthropicClaudeV2, - ModelAnthropicClaudeV1, - ModelAnthropicClaudeInstantV1, - ModelA21J2UltraV1, - ModelA21J2JumboInstruct, - ModelAmazonTitanExpressV1, -} - -//const TOPP = 0.9 moved to config - -// GetModelOrDefault check config model -func GetModelOrDefault(model string) string { - - // Check if the provided model is in the list - for _, m := range BEDROCK_MODELS { - if m == model { - return model // Return the provided model - } +var ( + models = []bedrock_support.BedrockModel{ + { + Name: "anthropic.claude-3-5-sonnet-20240620-v1:0", + Completion: &bedrock_support.CohereCompletion{}, + Response: &bedrock_support.CohereResponse{}, + Config: bedrock_support.BedrockModelConfig{ + // sensible defaults + MaxTokens: 100, + Temperature: 0.5, + TopP: 0.9, + }, + }, + { + Name: "us.anthropic.claude-3-5-sonnet-20241022-v2:0", + Completion: &bedrock_support.CohereCompletion{}, + Response: &bedrock_support.CohereResponse{}, + Config: bedrock_support.BedrockModelConfig{ + // sensible defaults + MaxTokens: 100, + Temperature: 0.5, + TopP: 0.9, + }, + }, + { + Name: "us.anthropic.claude-3-5-sonnet-20241022-v2:0", + Completion: &bedrock_support.CohereCompletion{}, + Response: &bedrock_support.CohereResponse{}, + Config: bedrock_support.BedrockModelConfig{ + // sensible defaults + MaxTokens: 100, + Temperature: 0.5, + TopP: 0.9, + }, + }, + { + Name: "anthropic.claude-v2", + Completion: &bedrock_support.CohereCompletion{}, + Response: &bedrock_support.CohereResponse{}, + Config: bedrock_support.BedrockModelConfig{ + // sensible defaults + MaxTokens: 100, + Temperature: 0.5, + TopP: 0.9, + }, + }, + { + Name: "anthropic.claude-v1", + Completion: &bedrock_support.CohereCompletion{}, + Response: &bedrock_support.CohereResponse{}, + Config: bedrock_support.BedrockModelConfig{ + // sensible defaults + MaxTokens: 100, + Temperature: 0.5, + TopP: 0.9, + }, + }, + { + Name: "anthropic.claude-instant-v1", + Completion: &bedrock_support.CohereCompletion{}, + Response: &bedrock_support.CohereResponse{}, + Config: bedrock_support.BedrockModelConfig{ + // sensible defaults + MaxTokens: 100, + Temperature: 0.5, + TopP: 0.9, + }, + }, + { + Name: "ai21.j2-ultra-v1", + Completion: &bedrock_support.AI21{}, + Response: &bedrock_support.AI21Response{}, + Config: bedrock_support.BedrockModelConfig{ + // sensible defaults + MaxTokens: 100, + Temperature: 0.5, + TopP: 0.9, + }, + }, + { + Name: "ai21.j2-jumbo-instruct", + Completion: &bedrock_support.AI21{}, + Response: &bedrock_support.AI21Response{}, + Config: bedrock_support.BedrockModelConfig{ + // sensible defaults + MaxTokens: 100, + Temperature: 0.5, + TopP: 0.9, + }, + }, + { + Name: "amazon.titan-text-express-v1", + Completion: &bedrock_support.AmazonCompletion{}, + Response: &bedrock_support.AmazonResponse{}, + Config: bedrock_support.BedrockModelConfig{ + // sensible defaults + MaxTokens: 100, + Temperature: 0.5, + TopP: 0.9, + }, + }, } - - // Return the default model if the provided model is not in the list - return BEDROCK_MODELS[0] -} +) // GetModelOrDefault check config region func GetRegionOrDefault(region string) string { @@ -97,6 +165,16 @@ func GetRegionOrDefault(region string) string { return BEDROCK_DEFAULT_REGION } +// Get model from string +func (a *AmazonBedRockClient) getModelFromString(model string) (*bedrock_support.BedrockModel, error) { + for _, m := range models { + if model == m.Name { + return &m, nil + } + } + return nil, errors.New("model not found") +} + // Configure configures the AmazonBedRockClient with the provided configuration. func (a *AmazonBedRockClient) Configure(config IAIConfig) error { @@ -111,9 +189,15 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error { return err } + foundModel, err := a.getModelFromString(config.GetModel()) + if err != nil { + return err + } + // TODO: Override the completion config somehow + // Create a new BedrockRuntime client a.client = bedrockruntime.New(sess) - a.model = GetModelOrDefault(config.GetModel()) + a.model = foundModel a.temperature = config.GetTemperature() a.topP = config.GetTopP() a.maxTokens = config.GetMaxTokens() @@ -124,45 +208,19 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error { // GetCompletion sends a request to the model for generating completion based on the provided prompt. func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string) (string, error) { - // Prepare the input data for the model invocation based on the model & the Response Body per model as well. - var request map[string]interface{} - switch a.model { - case ModelAnthropicClaudeSonnetV3_5, ModelAnthropicClaudeSonnetV3_5_V2, ModelAnthropicClaudeV2, ModelAnthropicClaudeV1, ModelAnthropicClaudeInstantV1: - request = map[string]interface{}{ - "prompt": fmt.Sprintf("\n\nHuman: %s \n\nAssistant:", prompt), - "max_tokens_to_sample": a.maxTokens, - "temperature": a.temperature, - "top_p": a.topP, - } - case ModelA21J2UltraV1, ModelA21J2JumboInstruct: - request = map[string]interface{}{ - "prompt": prompt, - "maxTokens": a.maxTokens, - "temperature": a.temperature, - "topP": a.topP, - } - case ModelAmazonTitanExpressV1: - request = map[string]interface{}{ - "inputText": fmt.Sprintf("\n\nUser: %s", prompt), - "textGenerationConfig": map[string]interface{}{ - "maxTokenCount": a.maxTokens, - "temperature": a.temperature, - "topP": a.topP, - }, - } - default: - return "", fmt.Errorf("model %s not supported", a.model) - } + // override config defaults + a.model.Config.MaxTokens = a.maxTokens + a.model.Config.Temperature = a.temperature + a.model.Config.TopP = a.topP - body, err := json.Marshal(request) + body, err := a.model.Completion.GetCompletion(ctx, prompt, a.model.Config) if err != nil { return "", err } - // Build the parameters for the model invocation params := &bedrockruntime.InvokeModelInput{ Body: body, - ModelId: aws.String(a.model), + ModelId: aws.String(a.model.Name), ContentType: aws.String("application/json"), Accept: aws.String("application/json"), } @@ -173,54 +231,9 @@ func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string) return "", err } - // Response type changes as per model - switch a.model { - case ModelAnthropicClaudeSonnetV3_5, ModelAnthropicClaudeSonnetV3_5_V2, ModelAnthropicClaudeV2, ModelAnthropicClaudeV1, ModelAnthropicClaudeInstantV1: - type InvokeModelResponseBody struct { - Completion string `json:"completion"` - Stop_reason string `json:"stop_reason"` - } - output := &InvokeModelResponseBody{} - err = json.Unmarshal(resp.Body, output) - if err != nil { - return "", err - } - return output.Completion, nil - case ModelA21J2UltraV1, ModelA21J2JumboInstruct: - type Data struct { - Text string `json:"text"` - } - type Completion struct { - Data Data `json:"data"` - } - type InvokeModelResponseBody struct { - Completions []Completion `json:"completions"` - } - output := &InvokeModelResponseBody{} - err = json.Unmarshal(resp.Body, output) - if err != nil { - return "", err - } - return output.Completions[0].Data.Text, nil - case ModelAmazonTitanExpressV1: - type Result struct { - TokenCount int `json:"tokenCount"` - OutputText string `json:"outputText"` - CompletionReason string `json:"completionReason"` - } - type InvokeModelResponseBody struct { - InputTextTokenCount int `json:"inputTextTokenCount"` - Results []Result `json:"results"` - } - output := &InvokeModelResponseBody{} - err = json.Unmarshal(resp.Body, output) - if err != nil { - return "", err - } - return output.Results[0].OutputText, nil - default: - return "", fmt.Errorf("model %s not supported", a.model) - } + // Parse the response + return a.model.Response.ParseResponse(resp.Body) + } // GetName returns the name of the AmazonBedRockClient. diff --git a/pkg/ai/bedrock_support/completions.go b/pkg/ai/bedrock_support/completions.go new file mode 100644 index 0000000..400658d --- /dev/null +++ b/pkg/ai/bedrock_support/completions.go @@ -0,0 +1,67 @@ +package bedrock_support + +import ( + "context" + "encoding/json" + "fmt" +) + +type ICompletion interface { + GetCompletion(ctx context.Context, prompt string, modelConfig BedrockModelConfig) ([]byte, error) +} + +type CohereCompletion struct { + completion ICompletion +} + +func (a *CohereCompletion) GetCompletion(ctx context.Context, prompt string, modelConfig BedrockModelConfig) ([]byte, error) { + request := map[string]interface{}{ + "prompt": fmt.Sprintf("\n\nHuman: %s \n\nAssistant:", prompt), + "max_tokens_to_sample": modelConfig.MaxTokens, + "temperature": modelConfig.Temperature, + "top_p": modelConfig.TopP, + } + body, err := json.Marshal(request) + if err != nil { + return []byte{}, err + } + return body, nil +} + +type AI21 struct { + completion ICompletion +} + +func (a *AI21) GetCompletion(ctx context.Context, prompt string, modelConfig BedrockModelConfig) ([]byte, error) { + request := map[string]interface{}{ + "prompt": prompt, + "maxTokens": modelConfig.MaxTokens, + "temperature": modelConfig.Temperature, + "topP": modelConfig.TopP, + } + body, err := json.Marshal(request) + if err != nil { + return []byte{}, err + } + return body, nil +} + +type AmazonCompletion struct { + completion ICompletion +} + +func (a *AmazonCompletion) GetCompletion(ctx context.Context, prompt string, modelConfig BedrockModelConfig) ([]byte, error) { + request := map[string]interface{}{ + "inputText": fmt.Sprintf("\n\nUser: %s", prompt), + "textGenerationConfig": map[string]interface{}{ + "maxTokenCount": modelConfig.MaxTokens, + "temperature": modelConfig.Temperature, + "topP": modelConfig.TopP, + }, + } + body, err := json.Marshal(request) + if err != nil { + return []byte{}, err + } + return body, nil +} diff --git a/pkg/ai/bedrock_support/model.go b/pkg/ai/bedrock_support/model.go new file mode 100644 index 0000000..dacfae7 --- /dev/null +++ b/pkg/ai/bedrock_support/model.go @@ -0,0 +1,13 @@ +package bedrock_support + +type BedrockModelConfig struct { + MaxTokens int + Temperature float32 + TopP float32 +} +type BedrockModel struct { + Name string + Completion ICompletion + Response IResponse + Config BedrockModelConfig +} diff --git a/pkg/ai/bedrock_support/responses.go b/pkg/ai/bedrock_support/responses.go new file mode 100644 index 0000000..3300e3e --- /dev/null +++ b/pkg/ai/bedrock_support/responses.go @@ -0,0 +1,68 @@ +package bedrock_support + +import "encoding/json" + +type IResponse interface { + ParseResponse(rawResponse []byte) (string, error) +} + +type CohereResponse struct { + response IResponse +} + +func (a *CohereResponse) ParseResponse(rawResponse []byte) (string, error) { + type InvokeModelResponseBody struct { + Completion string `json:"completion"` + Stop_reason string `json:"stop_reason"` + } + output := &InvokeModelResponseBody{} + err := json.Unmarshal(rawResponse, output) + if err != nil { + return "", err + } + return output.Completion, nil +} + +type AI21Response struct { + response IResponse +} + +func (a *AI21Response) ParseResponse(rawResponse []byte) (string, error) { + type Data struct { + Text string `json:"text"` + } + type Completion struct { + Data Data `json:"data"` + } + type InvokeModelResponseBody struct { + Completions []Completion `json:"completions"` + } + output := &InvokeModelResponseBody{} + err := json.Unmarshal(rawResponse, output) + if err != nil { + return "", err + } + return output.Completions[0].Data.Text, nil +} + +type AmazonResponse struct { + response IResponse +} + +func (a *AmazonResponse) ParseResponse(rawResponse []byte) (string, error) { + type Result struct { + TokenCount int `json:"tokenCount"` + OutputText string `json:"outputText"` + CompletionReason string `json:"completionReason"` + } + type InvokeModelResponseBody struct { + InputTextTokenCount int `json:"inputTextTokenCount"` + Results []Result `json:"results"` + } + output := &InvokeModelResponseBody{} + err := json.Unmarshal(rawResponse, output) + if err != nil { + return "", err + } + return output.Results[0].OutputText, nil +}