diff --git a/pkg/ai/amzonbedrock.go b/pkg/ai/amzonbedrock.go new file mode 100644 index 0000000..473efa4 --- /dev/null +++ b/pkg/ai/amzonbedrock.go @@ -0,0 +1,151 @@ +package ai + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "strings" + + "github.com/fatih/color" + + "github.com/k8sgpt-ai/k8sgpt/pkg/cache" + "github.com/k8sgpt-ai/k8sgpt/pkg/util" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/bedrockruntime" +) + +// AmazonBedRockClient represents the client for interacting with the Amazon Bedrock service. +type AmazonBedRockClient struct { + client *bedrockruntime.BedrockRuntime + language string + model string + temperature float32 +} + +// InvokeModelResponseBody represents the response body structure from the model invocation. +type InvokeModelResponseBody struct { + Completion string `json:"completion"` + Stop_reason string `json:"stop_reason"` +} + +const BEDROCK_REGION = "us-east-1" // default use us-east-1 region + +// GetModelOrDefault check config model +func GetModelOrDefault(model string) string { + modelList := []string{"anthropic.claude-v2", "anthropic.claude-v1", "anthropic.claude-instant-v1"} + + // Check if the provided model is in the list + for _, m := range modelList { + if m == model { + return model // Return the provided model + } + } + + // Return the default model if the provided model is not in the list + return modelList[0] +} + +// Configure configures the AmazonBedRockClient with the provided configuration and language. +func (a *AmazonBedRockClient) Configure(config IAIConfig, language string) error { + + // Create a new AWS session + sess, err := session.NewSession(&aws.Config{ + Region: aws.String(BEDROCK_REGION), + }) + + if err != nil { + return err + } + + // Create a new BedrockRuntime client + a.client = bedrockruntime.New(sess) + a.language = language + a.model = GetModelOrDefault(config.GetModel()) + a.temperature = config.GetTemperature() + + return nil +} + +// GetCompletion sends a request to the model for generating completion based on the provided prompt. +func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string, promptTmpl string) (string, error) { + + // Prepare the input data for the model invocation + request := map[string]interface{}{ + "prompt": fmt.Sprintf("\n\nHuman: %s \n\nAssistant:", prompt), + "max_tokens_to_sample": 1024, + "temperature": a.temperature, + "top_p": 0.9, + } + + body, err := json.Marshal(request) + if err != nil { + return "", err + } + + // Build the parameters for the model invocation + params := &bedrockruntime.InvokeModelInput{ + Body: body, + ModelId: aws.String(a.model), + ContentType: aws.String("application/json"), + Accept: aws.String("application/json"), + } + // Invoke the model + resp, err := a.client.InvokeModelWithContext(ctx, params) + + if err != nil { + return "", err + } + // Parse the response body + output := &InvokeModelResponseBody{} + err = json.Unmarshal(resp.Body, output) + if err != nil { + return "", err + } + return output.Completion, nil +} + +// Parse generates a completion for the provided prompt using the Amazon Bedrock model. +func (a *AmazonBedRockClient) Parse(ctx context.Context, prompt []string, cache cache.ICache, promptTmpl string) (string, error) { + inputKey := strings.Join(prompt, " ") + // Check for cached data + cacheKey := util.GetCacheKey(a.GetName(), a.language, inputKey) + + if !cache.IsCacheDisabled() && cache.Exists(cacheKey) { + response, err := cache.Load(cacheKey) + if err != nil { + return "", err + } + + if response != "" { + output, err := base64.StdEncoding.DecodeString(response) + if err != nil { + color.Red("error decoding cached data: %v", err) + return "", nil + } + return string(output), nil + } + } + + response, err := a.GetCompletion(ctx, inputKey, promptTmpl) + + if err != nil { + return "", err + } + + err = cache.Store(cacheKey, base64.StdEncoding.EncodeToString([]byte(response))) + + if err != nil { + color.Red("error storing value to cache: %v", err) + return "", nil + } + + return response, nil +} + +// GetName returns the name of the AmazonBedRockClient. +func (a *AmazonBedRockClient) GetName() string { + return "amazonbedrock" +} diff --git a/pkg/ai/iai.go b/pkg/ai/iai.go index b8172d1..3c54511 100644 --- a/pkg/ai/iai.go +++ b/pkg/ai/iai.go @@ -26,6 +26,7 @@ var ( &LocalAIClient{}, &NoOpAIClient{}, &CohereClient{}, + &AmazonBedRockClient{}, } Backends = []string{ "openai", @@ -33,6 +34,7 @@ var ( "azureopenai", "noopai", "cohere", + "amazonbedrock", } )