k8sgpt/pkg/ai/amazonbedrock_test.go
ju187 f603948935
feat: using modelName will calling completion (#1469)
* using modelName will calling completion

Signed-off-by: Tony Chen <tony_chen@discovery.com>

* sign

Signed-off-by: Tony Chen <tony_chen@discovery.com>

---------

Signed-off-by: Tony Chen <tony_chen@discovery.com>
2025-04-24 09:15:17 +01:00

165 lines
5.0 KiB
Go

package ai
import (
"testing"
"github.com/k8sgpt-ai/k8sgpt/pkg/ai/bedrock_support"
"github.com/stretchr/testify/assert"
)
// Test models for unit testing
var testModels = []bedrock_support.BedrockModel{
{
Name: "anthropic.claude-3-5-sonnet-20240620-v1:0",
Completion: &bedrock_support.CohereMessagesCompletion{},
Response: &bedrock_support.CohereMessagesResponse{},
Config: bedrock_support.BedrockModelConfig{
MaxTokens: 100,
Temperature: 0.5,
TopP: 0.9,
ModelName: "anthropic.claude-3-5-sonnet-20240620-v1:0",
},
},
{
Name: "anthropic.claude-3-5-sonnet-20241022-v2:0",
Completion: &bedrock_support.CohereCompletion{},
Response: &bedrock_support.CohereResponse{},
Config: bedrock_support.BedrockModelConfig{
MaxTokens: 100,
Temperature: 0.5,
TopP: 0.9,
ModelName: "anthropic.claude-3-5-sonnet-20241022-v2:0",
},
},
}
func TestBedrockModelConfig(t *testing.T) {
client := &AmazonBedRockClient{models: testModels}
foundModel, err := client.getModelFromString("arn:aws:bedrock:us-east-1:*:inference-policy/anthropic.claude-3-5-sonnet-20240620-v1:0")
assert.Nil(t, err, "Error should be nil")
assert.Equal(t, foundModel.Config.MaxTokens, 100)
assert.Equal(t, foundModel.Config.Temperature, float32(0.5))
assert.Equal(t, foundModel.Config.TopP, float32(0.9))
assert.Equal(t, foundModel.Config.ModelName, "arn:aws:bedrock:us-east-1:*:inference-policy/anthropic.claude-3-5-sonnet-20240620-v1:0")
}
func TestBedrockInvalidModel(t *testing.T) {
client := &AmazonBedRockClient{models: testModels}
foundModel, err := client.getModelFromString("arn:aws:s3:us-east-1:*:inference-policy/anthropic.claude-3-5-sonnet-20240620-v1:0")
assert.Nil(t, err, "Error should be nil")
assert.Equal(t, foundModel.Config.MaxTokens, 100)
}
func TestBedrockGetCompletionInferenceProfile(t *testing.T) {
modelName := "arn:aws:bedrock:us-east-1:*:inference-policy/anthropic.claude-3-5-sonnet-20240620-v1:0"
var inferenceModelModels = []bedrock_support.BedrockModel{
{
Name: "anthropic.claude-3-5-sonnet-20240620-v1:0",
Completion: &bedrock_support.CohereMessagesCompletion{},
Response: &bedrock_support.CohereMessagesResponse{},
Config: bedrock_support.BedrockModelConfig{
MaxTokens: 100,
Temperature: 0.5,
TopP: 0.9,
ModelName: modelName,
},
},
}
client := &AmazonBedRockClient{models: inferenceModelModels}
config := AIProvider{
Model: modelName,
}
err := client.Configure(&config)
assert.Nil(t, err, "Error should be nil")
assert.Equal(t, modelName, client.model.Config.ModelName, "Model name should match")
}
func TestGetModelFromString(t *testing.T) {
client := &AmazonBedRockClient{models: testModels}
tests := []struct {
name string
model string
wantModel string
wantErr bool
}{
{
name: "exact model name match",
model: "anthropic.claude-3-5-sonnet-20240620-v1:0",
wantModel: "anthropic.claude-3-5-sonnet-20240620-v1:0",
wantErr: false,
},
{
name: "partial model name match",
model: "claude-3-5-sonnet",
wantModel: "anthropic.claude-3-5-sonnet-20240620-v1:0",
wantErr: false,
},
{
name: "model name with different version",
model: "anthropic.claude-3-5-sonnet-20241022-v2:0",
wantModel: "anthropic.claude-3-5-sonnet-20241022-v2:0",
wantErr: false,
},
{
name: "non-existent model",
model: "non-existent-model",
wantModel: "",
wantErr: true,
},
{
name: "empty model name",
model: "",
wantModel: "",
wantErr: true,
},
{
name: "model name with extra spaces",
model: " anthropic.claude-3-5-sonnet-20240620-v1:0 ",
wantModel: "anthropic.claude-3-5-sonnet-20240620-v1:0",
wantErr: false,
},
{
name: "case insensitive match",
model: "ANTHROPIC.CLAUDE-3-5-SONNET-20240620-V1:0",
wantModel: "anthropic.claude-3-5-sonnet-20240620-v1:0",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotModel, err := client.getModelFromString(tt.model)
if (err != nil) != tt.wantErr {
t.Errorf("getModelFromString() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr && gotModel.Name != tt.wantModel {
t.Errorf("getModelFromString() = %v, want %v", gotModel.Name, tt.wantModel)
}
})
}
}
// TestDefaultModels tests that the client works with default models
func TestDefaultModels(t *testing.T) {
client := &AmazonBedRockClient{}
// Configure should initialize default models
err := client.Configure(&AIProvider{
Model: "anthropic.claude-v2",
})
assert.NoError(t, err, "Configure should not return an error")
assert.NotNil(t, client.models, "Models should be initialized")
assert.NotEmpty(t, client.models, "Models should not be empty")
// Test finding a default model
model, err := client.getModelFromString("anthropic.claude-v2")
assert.NoError(t, err, "Should find the model")
assert.Equal(t, "anthropic.claude-v2", model.Name, "Should find the correct model")
}