mirror of
https://github.com/k8sgpt-ai/k8sgpt.git
synced 2025-08-29 20:52:55 +00:00
feat: add a naive support of bedrock inference profile (#1446)
* feat: add a naive support of bedrock inference profile Signed-off-by: Tony Chen <tony_chen@discovery.com> * feat: improving the tests Signed-off-by: Alex Jones <alexsimonjones@gmail.com> --------- Signed-off-by: Tony Chen <tony_chen@discovery.com> Signed-off-by: Alex Jones <alexsimonjones@gmail.com> Co-authored-by: Alex Jones <alexsimonjones@gmail.com>
This commit is contained in:
parent
dceda9a6a1
commit
78ffa5904a
@ -3,8 +3,11 @@ package ai
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"github.com/aws/aws-sdk-go/service/bedrockruntime/bedrockruntimeiface"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/aws/aws-sdk-go/service/bedrockruntime/bedrockruntimeiface"
|
||||||
|
|
||||||
"github.com/k8sgpt-ai/k8sgpt/pkg/ai/bedrock_support"
|
"github.com/k8sgpt-ai/k8sgpt/pkg/ai/bedrock_support"
|
||||||
|
|
||||||
@ -24,6 +27,7 @@ type AmazonBedRockClient struct {
|
|||||||
temperature float32
|
temperature float32
|
||||||
topP float32
|
topP float32
|
||||||
maxTokens int
|
maxTokens int
|
||||||
|
models []bedrock_support.BedrockModel
|
||||||
}
|
}
|
||||||
|
|
||||||
// AmazonCompletion BedRock support region list US East (N. Virginia),US West (Oregon),Asia Pacific (Singapore),Asia Pacific (Tokyo),Europe (Frankfurt)
|
// AmazonCompletion BedRock support region list US East (N. Virginia),US West (Oregon),Asia Pacific (Singapore),Asia Pacific (Tokyo),Europe (Frankfurt)
|
||||||
@ -48,8 +52,7 @@ var BEDROCKER_SUPPORTED_REGION = []string{
|
|||||||
AP_South_1,
|
AP_South_1,
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var defaultModels = []bedrock_support.BedrockModel{
|
||||||
models = []bedrock_support.BedrockModel{
|
|
||||||
{
|
{
|
||||||
Name: "anthropic.claude-3-5-sonnet-20240620-v1:0",
|
Name: "anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||||
Completion: &bedrock_support.CohereMessagesCompletion{},
|
Completion: &bedrock_support.CohereMessagesCompletion{},
|
||||||
@ -169,7 +172,7 @@ var (
|
|||||||
MaxTokens: 100, // max of 300k tokens
|
MaxTokens: 100, // max of 300k tokens
|
||||||
Temperature: 0.5,
|
Temperature: 0.5,
|
||||||
TopP: 0.9,
|
TopP: 0.9,
|
||||||
ModelName: "eu.wamazon.nova-pro-v1:0",
|
ModelName: "eu.amazon.nova-pro-v1:0",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -232,8 +235,17 @@ var (
|
|||||||
TopP: 0.9,
|
TopP: 0.9,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAmazonBedRockClient creates a new AmazonBedRockClient with the given models
|
||||||
|
func NewAmazonBedRockClient(models []bedrock_support.BedrockModel) *AmazonBedRockClient {
|
||||||
|
if models == nil {
|
||||||
|
models = defaultModels // Use default models if none provided
|
||||||
}
|
}
|
||||||
)
|
return &AmazonBedRockClient{
|
||||||
|
models: models,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// GetModelOrDefault check config region
|
// GetModelOrDefault check config region
|
||||||
func GetRegionOrDefault(region string) string {
|
func GetRegionOrDefault(region string) string {
|
||||||
@ -254,16 +266,46 @@ func GetRegionOrDefault(region string) string {
|
|||||||
|
|
||||||
// Get model from string
|
// Get model from string
|
||||||
func (a *AmazonBedRockClient) getModelFromString(model string) (*bedrock_support.BedrockModel, error) {
|
func (a *AmazonBedRockClient) getModelFromString(model string) (*bedrock_support.BedrockModel, error) {
|
||||||
for _, m := range models {
|
if model == "" {
|
||||||
if model == m.Name {
|
return nil, errors.New("model name cannot be empty")
|
||||||
return &m, nil
|
}
|
||||||
|
|
||||||
|
// Trim spaces from the model name
|
||||||
|
model = strings.TrimSpace(model)
|
||||||
|
modelLower := strings.ToLower(model)
|
||||||
|
|
||||||
|
// Try to find an exact match first
|
||||||
|
for i := range a.models {
|
||||||
|
if strings.EqualFold(model, a.models[i].Name) || strings.EqualFold(model, a.models[i].Config.ModelName) {
|
||||||
|
// Create a copy to avoid returning a pointer to a loop variable
|
||||||
|
modelCopy := a.models[i]
|
||||||
|
return &modelCopy, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil, errors.New("model not found")
|
|
||||||
|
// If no exact match, try partial match
|
||||||
|
for i := range a.models {
|
||||||
|
modelNameLower := strings.ToLower(a.models[i].Name)
|
||||||
|
modelConfigNameLower := strings.ToLower(a.models[i].Config.ModelName)
|
||||||
|
|
||||||
|
// Check if the input string contains the model name or vice versa
|
||||||
|
if strings.Contains(modelNameLower, modelLower) || strings.Contains(modelLower, modelNameLower) ||
|
||||||
|
strings.Contains(modelConfigNameLower, modelLower) || strings.Contains(modelLower, modelConfigNameLower) {
|
||||||
|
// Create a copy to avoid returning a pointer to a loop variable
|
||||||
|
modelCopy := a.models[i]
|
||||||
|
return &modelCopy, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("model '%s' not found in supported models", model)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Configure configures the AmazonBedRockClient with the provided configuration.
|
// Configure configures the AmazonBedRockClient with the provided configuration.
|
||||||
func (a *AmazonBedRockClient) Configure(config IAIConfig) error {
|
func (a *AmazonBedRockClient) Configure(config IAIConfig) error {
|
||||||
|
// Initialize models if not already initialized
|
||||||
|
if a.models == nil {
|
||||||
|
a.models = defaultModels
|
||||||
|
}
|
||||||
|
|
||||||
// Create a new AWS session
|
// Create a new AWS session
|
||||||
providerRegion := GetRegionOrDefault(config.GetProviderRegion())
|
providerRegion := GetRegionOrDefault(config.GetProviderRegion())
|
||||||
@ -280,7 +322,6 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// TODO: Override the completion config somehow
|
|
||||||
|
|
||||||
// Create a new BedrockRuntime client
|
// Create a new BedrockRuntime client
|
||||||
a.client = bedrockruntime.New(sess)
|
a.client = bedrockruntime.New(sess)
|
||||||
|
131
pkg/ai/amazonbedrock_test.go
Normal file
131
pkg/ai/amazonbedrock_test.go
Normal file
@ -0,0 +1,131 @@
|
|||||||
|
package ai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/k8sgpt-ai/k8sgpt/pkg/ai/bedrock_support"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Test models for unit testing
|
||||||
|
var testModels = []bedrock_support.BedrockModel{
|
||||||
|
{
|
||||||
|
Name: "anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||||
|
Completion: &bedrock_support.CohereMessagesCompletion{},
|
||||||
|
Response: &bedrock_support.CohereMessagesResponse{},
|
||||||
|
Config: bedrock_support.BedrockModelConfig{
|
||||||
|
MaxTokens: 100,
|
||||||
|
Temperature: 0.5,
|
||||||
|
TopP: 0.9,
|
||||||
|
ModelName: "anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||||
|
Completion: &bedrock_support.CohereCompletion{},
|
||||||
|
Response: &bedrock_support.CohereResponse{},
|
||||||
|
Config: bedrock_support.BedrockModelConfig{
|
||||||
|
MaxTokens: 100,
|
||||||
|
Temperature: 0.5,
|
||||||
|
TopP: 0.9,
|
||||||
|
ModelName: "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBedrockModelConfig(t *testing.T) {
|
||||||
|
client := &AmazonBedRockClient{models: testModels}
|
||||||
|
|
||||||
|
foundModel, err := client.getModelFromString("arn:aws:bedrock:us-east-1:*:inference-policy/anthropic.claude-3-5-sonnet-20240620-v1:0")
|
||||||
|
assert.Nil(t, err, "Error should be nil")
|
||||||
|
assert.Equal(t, foundModel.Config.MaxTokens, 100)
|
||||||
|
assert.Equal(t, foundModel.Config.Temperature, float32(0.5))
|
||||||
|
assert.Equal(t, foundModel.Config.TopP, float32(0.9))
|
||||||
|
assert.Equal(t, foundModel.Config.ModelName, "anthropic.claude-3-5-sonnet-20240620-v1:0")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetModelFromString(t *testing.T) {
|
||||||
|
client := &AmazonBedRockClient{models: testModels}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
model string
|
||||||
|
wantModel string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "exact model name match",
|
||||||
|
model: "anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||||
|
wantModel: "anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "partial model name match",
|
||||||
|
model: "claude-3-5-sonnet",
|
||||||
|
wantModel: "anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model name with different version",
|
||||||
|
model: "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||||
|
wantModel: "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-existent model",
|
||||||
|
model: "non-existent-model",
|
||||||
|
wantModel: "",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty model name",
|
||||||
|
model: "",
|
||||||
|
wantModel: "",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model name with extra spaces",
|
||||||
|
model: " anthropic.claude-3-5-sonnet-20240620-v1:0 ",
|
||||||
|
wantModel: "anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "case insensitive match",
|
||||||
|
model: "ANTHROPIC.CLAUDE-3-5-SONNET-20240620-V1:0",
|
||||||
|
wantModel: "anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
gotModel, err := client.getModelFromString(tt.model)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("getModelFromString() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !tt.wantErr && gotModel.Name != tt.wantModel {
|
||||||
|
t.Errorf("getModelFromString() = %v, want %v", gotModel.Name, tt.wantModel)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDefaultModels tests that the client works with default models
|
||||||
|
func TestDefaultModels(t *testing.T) {
|
||||||
|
client := &AmazonBedRockClient{}
|
||||||
|
|
||||||
|
// Configure should initialize default models
|
||||||
|
err := client.Configure(&AIProvider{
|
||||||
|
Model: "anthropic.claude-v2",
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.NoError(t, err, "Configure should not return an error")
|
||||||
|
assert.NotNil(t, client.models, "Models should be initialized")
|
||||||
|
assert.NotEmpty(t, client.models, "Models should not be empty")
|
||||||
|
|
||||||
|
// Test finding a default model
|
||||||
|
model, err := client.getModelFromString("anthropic.claude-v2")
|
||||||
|
assert.NoError(t, err, "Should find the model")
|
||||||
|
assert.Equal(t, "anthropic.claude-v2", model.Name, "Should find the correct model")
|
||||||
|
}
|
@ -17,7 +17,12 @@ var SUPPPORTED_BEDROCK_MODELS = []string{
|
|||||||
"ai21.j2-jumbo-instruct",
|
"ai21.j2-jumbo-instruct",
|
||||||
"amazon.titan-text-express-v1",
|
"amazon.titan-text-express-v1",
|
||||||
"amazon.nova-pro-v1:0",
|
"amazon.nova-pro-v1:0",
|
||||||
|
"eu.amazon.nova-pro-v1:0",
|
||||||
|
"us.amazon.nova-pro-v1:0",
|
||||||
|
"amazon.nova-lite-v1:0",
|
||||||
"eu.amazon.nova-lite-v1:0",
|
"eu.amazon.nova-lite-v1:0",
|
||||||
|
"us.amazon.nova-lite-v1:0",
|
||||||
|
"anthropic.claude-3-haiku-20240307-v1:0",
|
||||||
}
|
}
|
||||||
|
|
||||||
type ICompletion interface {
|
type ICompletion interface {
|
||||||
@ -91,7 +96,7 @@ type AmazonCompletion struct {
|
|||||||
|
|
||||||
func isModelSupported(modelName string) bool {
|
func isModelSupported(modelName string) bool {
|
||||||
for _, supportedModel := range SUPPPORTED_BEDROCK_MODELS {
|
for _, supportedModel := range SUPPPORTED_BEDROCK_MODELS {
|
||||||
if modelName == supportedModel {
|
if strings.Contains(modelName, supportedModel) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user