feat: rework to how bedrock data models are structured and accessed (#1369)

* feat: rework to how bedrock data models are structured and accessed

Signed-off-by: AlexsJones <alexsimonjones@gmail.com>

* feat: rework to how bedrock data models are structured and accessed

Signed-off-by: AlexsJones <alexsimonjones@gmail.com>

---------

Signed-off-by: AlexsJones <alexsimonjones@gmail.com>
This commit is contained in:
Alex Jones
2025-02-24 03:03:19 -08:00
committed by GitHub
parent 3b85f09348
commit 7dadea2570
4 changed files with 281 additions and 120 deletions

View File

@@ -2,8 +2,8 @@ package ai
import ( import (
"context" "context"
"encoding/json" "errors"
"fmt" "github.com/k8sgpt-ai/k8sgpt/pkg/ai/bedrock_support"
"os" "os"
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
@@ -13,18 +13,18 @@ import (
const amazonbedrockAIClientName = "amazonbedrock" const amazonbedrockAIClientName = "amazonbedrock"
// AmazonBedRockClient represents the client for interacting with the Amazon Bedrock service. // AmazonBedRockClient represents the client for interacting with the AmazonCompletion Bedrock service.
type AmazonBedRockClient struct { type AmazonBedRockClient struct {
nopCloser nopCloser
client *bedrockruntime.BedrockRuntime client *bedrockruntime.BedrockRuntime
model string model *bedrock_support.BedrockModel
temperature float32 temperature float32
topP float32 topP float32
maxTokens int maxTokens int
} }
// Amazon BedRock support region list US East (N. Virginia),US West (Oregon),Asia Pacific (Singapore),Asia Pacific (Tokyo),Europe (Frankfurt) // AmazonCompletion BedRock support region list US East (N. Virginia),US West (Oregon),Asia Pacific (Singapore),Asia Pacific (Tokyo),Europe (Frankfurt)
// https://docs.aws.amazon.com/bedrock/latest/userguide/what-is-bedrock.html#bedrock-regions // https://docs.aws.amazon.com/bedrock/latest/userguide/what-is-bedrock.html#bedrock-regions
const BEDROCK_DEFAULT_REGION = "us-east-1" // default use us-east-1 region const BEDROCK_DEFAULT_REGION = "us-east-1" // default use us-east-1 region
@@ -44,41 +44,109 @@ var BEDROCKER_SUPPORTED_REGION = []string{
EU_Central_1, EU_Central_1,
} }
const ( var (
ModelAnthropicClaudeSonnetV3_5 = "anthropic.claude-3-5-sonnet-20240620-v1:0" models = []bedrock_support.BedrockModel{
ModelAnthropicClaudeSonnetV3_5_V2 = "us.anthropic.claude-3-5-sonnet-20241022-v2:0" {
ModelAnthropicClaudeV2 = "anthropic.claude-v2" Name: "anthropic.claude-3-5-sonnet-20240620-v1:0",
ModelAnthropicClaudeV1 = "anthropic.claude-v1" Completion: &bedrock_support.CohereCompletion{},
ModelAnthropicClaudeInstantV1 = "anthropic.claude-instant-v1" Response: &bedrock_support.CohereResponse{},
ModelA21J2UltraV1 = "ai21.j2-ultra-v1" Config: bedrock_support.BedrockModelConfig{
ModelA21J2JumboInstruct = "ai21.j2-jumbo-instruct" // sensible defaults
ModelAmazonTitanExpressV1 = "amazon.titan-text-express-v1" MaxTokens: 100,
) Temperature: 0.5,
TopP: 0.9,
var BEDROCK_MODELS = []string{ },
ModelAnthropicClaudeV2, },
ModelAnthropicClaudeV1, {
ModelAnthropicClaudeInstantV1, Name: "us.anthropic.claude-3-5-sonnet-20241022-v2:0",
ModelA21J2UltraV1, Completion: &bedrock_support.CohereCompletion{},
ModelA21J2JumboInstruct, Response: &bedrock_support.CohereResponse{},
ModelAmazonTitanExpressV1, Config: bedrock_support.BedrockModelConfig{
} // sensible defaults
MaxTokens: 100,
//const TOPP = 0.9 moved to config Temperature: 0.5,
TopP: 0.9,
// GetModelOrDefault check config model },
func GetModelOrDefault(model string) string { },
{
// Check if the provided model is in the list Name: "us.anthropic.claude-3-5-sonnet-20241022-v2:0",
for _, m := range BEDROCK_MODELS { Completion: &bedrock_support.CohereCompletion{},
if m == model { Response: &bedrock_support.CohereResponse{},
return model // Return the provided model Config: bedrock_support.BedrockModelConfig{
} // sensible defaults
MaxTokens: 100,
Temperature: 0.5,
TopP: 0.9,
},
},
{
Name: "anthropic.claude-v2",
Completion: &bedrock_support.CohereCompletion{},
Response: &bedrock_support.CohereResponse{},
Config: bedrock_support.BedrockModelConfig{
// sensible defaults
MaxTokens: 100,
Temperature: 0.5,
TopP: 0.9,
},
},
{
Name: "anthropic.claude-v1",
Completion: &bedrock_support.CohereCompletion{},
Response: &bedrock_support.CohereResponse{},
Config: bedrock_support.BedrockModelConfig{
// sensible defaults
MaxTokens: 100,
Temperature: 0.5,
TopP: 0.9,
},
},
{
Name: "anthropic.claude-instant-v1",
Completion: &bedrock_support.CohereCompletion{},
Response: &bedrock_support.CohereResponse{},
Config: bedrock_support.BedrockModelConfig{
// sensible defaults
MaxTokens: 100,
Temperature: 0.5,
TopP: 0.9,
},
},
{
Name: "ai21.j2-ultra-v1",
Completion: &bedrock_support.AI21{},
Response: &bedrock_support.AI21Response{},
Config: bedrock_support.BedrockModelConfig{
// sensible defaults
MaxTokens: 100,
Temperature: 0.5,
TopP: 0.9,
},
},
{
Name: "ai21.j2-jumbo-instruct",
Completion: &bedrock_support.AI21{},
Response: &bedrock_support.AI21Response{},
Config: bedrock_support.BedrockModelConfig{
// sensible defaults
MaxTokens: 100,
Temperature: 0.5,
TopP: 0.9,
},
},
{
Name: "amazon.titan-text-express-v1",
Completion: &bedrock_support.AmazonCompletion{},
Response: &bedrock_support.AmazonResponse{},
Config: bedrock_support.BedrockModelConfig{
// sensible defaults
MaxTokens: 100,
Temperature: 0.5,
TopP: 0.9,
},
},
} }
)
// Return the default model if the provided model is not in the list
return BEDROCK_MODELS[0]
}
// GetModelOrDefault check config region // GetModelOrDefault check config region
func GetRegionOrDefault(region string) string { func GetRegionOrDefault(region string) string {
@@ -97,6 +165,16 @@ func GetRegionOrDefault(region string) string {
return BEDROCK_DEFAULT_REGION return BEDROCK_DEFAULT_REGION
} }
// Get model from string
func (a *AmazonBedRockClient) getModelFromString(model string) (*bedrock_support.BedrockModel, error) {
for _, m := range models {
if model == m.Name {
return &m, nil
}
}
return nil, errors.New("model not found")
}
// Configure configures the AmazonBedRockClient with the provided configuration. // Configure configures the AmazonBedRockClient with the provided configuration.
func (a *AmazonBedRockClient) Configure(config IAIConfig) error { func (a *AmazonBedRockClient) Configure(config IAIConfig) error {
@@ -111,9 +189,15 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error {
return err return err
} }
foundModel, err := a.getModelFromString(config.GetModel())
if err != nil {
return err
}
// TODO: Override the completion config somehow
// Create a new BedrockRuntime client // Create a new BedrockRuntime client
a.client = bedrockruntime.New(sess) a.client = bedrockruntime.New(sess)
a.model = GetModelOrDefault(config.GetModel()) a.model = foundModel
a.temperature = config.GetTemperature() a.temperature = config.GetTemperature()
a.topP = config.GetTopP() a.topP = config.GetTopP()
a.maxTokens = config.GetMaxTokens() a.maxTokens = config.GetMaxTokens()
@@ -124,45 +208,19 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error {
// GetCompletion sends a request to the model for generating completion based on the provided prompt. // GetCompletion sends a request to the model for generating completion based on the provided prompt.
func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string) (string, error) { func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string) (string, error) {
// Prepare the input data for the model invocation based on the model & the Response Body per model as well. // override config defaults
var request map[string]interface{} a.model.Config.MaxTokens = a.maxTokens
switch a.model { a.model.Config.Temperature = a.temperature
case ModelAnthropicClaudeSonnetV3_5, ModelAnthropicClaudeSonnetV3_5_V2, ModelAnthropicClaudeV2, ModelAnthropicClaudeV1, ModelAnthropicClaudeInstantV1: a.model.Config.TopP = a.topP
request = map[string]interface{}{
"prompt": fmt.Sprintf("\n\nHuman: %s \n\nAssistant:", prompt),
"max_tokens_to_sample": a.maxTokens,
"temperature": a.temperature,
"top_p": a.topP,
}
case ModelA21J2UltraV1, ModelA21J2JumboInstruct:
request = map[string]interface{}{
"prompt": prompt,
"maxTokens": a.maxTokens,
"temperature": a.temperature,
"topP": a.topP,
}
case ModelAmazonTitanExpressV1:
request = map[string]interface{}{
"inputText": fmt.Sprintf("\n\nUser: %s", prompt),
"textGenerationConfig": map[string]interface{}{
"maxTokenCount": a.maxTokens,
"temperature": a.temperature,
"topP": a.topP,
},
}
default:
return "", fmt.Errorf("model %s not supported", a.model)
}
body, err := json.Marshal(request) body, err := a.model.Completion.GetCompletion(ctx, prompt, a.model.Config)
if err != nil { if err != nil {
return "", err return "", err
} }
// Build the parameters for the model invocation // Build the parameters for the model invocation
params := &bedrockruntime.InvokeModelInput{ params := &bedrockruntime.InvokeModelInput{
Body: body, Body: body,
ModelId: aws.String(a.model), ModelId: aws.String(a.model.Name),
ContentType: aws.String("application/json"), ContentType: aws.String("application/json"),
Accept: aws.String("application/json"), Accept: aws.String("application/json"),
} }
@@ -173,54 +231,9 @@ func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string)
return "", err return "", err
} }
// Response type changes as per model // Parse the response
switch a.model { return a.model.Response.ParseResponse(resp.Body)
case ModelAnthropicClaudeSonnetV3_5, ModelAnthropicClaudeSonnetV3_5_V2, ModelAnthropicClaudeV2, ModelAnthropicClaudeV1, ModelAnthropicClaudeInstantV1:
type InvokeModelResponseBody struct {
Completion string `json:"completion"`
Stop_reason string `json:"stop_reason"`
}
output := &InvokeModelResponseBody{}
err = json.Unmarshal(resp.Body, output)
if err != nil {
return "", err
}
return output.Completion, nil
case ModelA21J2UltraV1, ModelA21J2JumboInstruct:
type Data struct {
Text string `json:"text"`
}
type Completion struct {
Data Data `json:"data"`
}
type InvokeModelResponseBody struct {
Completions []Completion `json:"completions"`
}
output := &InvokeModelResponseBody{}
err = json.Unmarshal(resp.Body, output)
if err != nil {
return "", err
}
return output.Completions[0].Data.Text, nil
case ModelAmazonTitanExpressV1:
type Result struct {
TokenCount int `json:"tokenCount"`
OutputText string `json:"outputText"`
CompletionReason string `json:"completionReason"`
}
type InvokeModelResponseBody struct {
InputTextTokenCount int `json:"inputTextTokenCount"`
Results []Result `json:"results"`
}
output := &InvokeModelResponseBody{}
err = json.Unmarshal(resp.Body, output)
if err != nil {
return "", err
}
return output.Results[0].OutputText, nil
default:
return "", fmt.Errorf("model %s not supported", a.model)
}
} }
// GetName returns the name of the AmazonBedRockClient. // GetName returns the name of the AmazonBedRockClient.

View File

@@ -0,0 +1,67 @@
package bedrock_support
import (
"context"
"encoding/json"
"fmt"
)
type ICompletion interface {
GetCompletion(ctx context.Context, prompt string, modelConfig BedrockModelConfig) ([]byte, error)
}
type CohereCompletion struct {
completion ICompletion
}
func (a *CohereCompletion) GetCompletion(ctx context.Context, prompt string, modelConfig BedrockModelConfig) ([]byte, error) {
request := map[string]interface{}{
"prompt": fmt.Sprintf("\n\nHuman: %s \n\nAssistant:", prompt),
"max_tokens_to_sample": modelConfig.MaxTokens,
"temperature": modelConfig.Temperature,
"top_p": modelConfig.TopP,
}
body, err := json.Marshal(request)
if err != nil {
return []byte{}, err
}
return body, nil
}
type AI21 struct {
completion ICompletion
}
func (a *AI21) GetCompletion(ctx context.Context, prompt string, modelConfig BedrockModelConfig) ([]byte, error) {
request := map[string]interface{}{
"prompt": prompt,
"maxTokens": modelConfig.MaxTokens,
"temperature": modelConfig.Temperature,
"topP": modelConfig.TopP,
}
body, err := json.Marshal(request)
if err != nil {
return []byte{}, err
}
return body, nil
}
type AmazonCompletion struct {
completion ICompletion
}
func (a *AmazonCompletion) GetCompletion(ctx context.Context, prompt string, modelConfig BedrockModelConfig) ([]byte, error) {
request := map[string]interface{}{
"inputText": fmt.Sprintf("\n\nUser: %s", prompt),
"textGenerationConfig": map[string]interface{}{
"maxTokenCount": modelConfig.MaxTokens,
"temperature": modelConfig.Temperature,
"topP": modelConfig.TopP,
},
}
body, err := json.Marshal(request)
if err != nil {
return []byte{}, err
}
return body, nil
}

View File

@@ -0,0 +1,13 @@
package bedrock_support
type BedrockModelConfig struct {
MaxTokens int
Temperature float32
TopP float32
}
type BedrockModel struct {
Name string
Completion ICompletion
Response IResponse
Config BedrockModelConfig
}

View File

@@ -0,0 +1,68 @@
package bedrock_support
import "encoding/json"
type IResponse interface {
ParseResponse(rawResponse []byte) (string, error)
}
type CohereResponse struct {
response IResponse
}
func (a *CohereResponse) ParseResponse(rawResponse []byte) (string, error) {
type InvokeModelResponseBody struct {
Completion string `json:"completion"`
Stop_reason string `json:"stop_reason"`
}
output := &InvokeModelResponseBody{}
err := json.Unmarshal(rawResponse, output)
if err != nil {
return "", err
}
return output.Completion, nil
}
type AI21Response struct {
response IResponse
}
func (a *AI21Response) ParseResponse(rawResponse []byte) (string, error) {
type Data struct {
Text string `json:"text"`
}
type Completion struct {
Data Data `json:"data"`
}
type InvokeModelResponseBody struct {
Completions []Completion `json:"completions"`
}
output := &InvokeModelResponseBody{}
err := json.Unmarshal(rawResponse, output)
if err != nil {
return "", err
}
return output.Completions[0].Data.Text, nil
}
type AmazonResponse struct {
response IResponse
}
func (a *AmazonResponse) ParseResponse(rawResponse []byte) (string, error) {
type Result struct {
TokenCount int `json:"tokenCount"`
OutputText string `json:"outputText"`
CompletionReason string `json:"completionReason"`
}
type InvokeModelResponseBody struct {
InputTextTokenCount int `json:"inputTextTokenCount"`
Results []Result `json:"results"`
}
output := &InvokeModelResponseBody{}
err := json.Unmarshal(rawResponse, output)
if err != nil {
return "", err
}
return output.Results[0].OutputText, nil
}