mirror of
https://github.com/k8sgpt-ai/k8sgpt.git
synced 2025-08-17 15:26:37 +00:00
Signed-off-by: rkarthikr <38294804+rkarthikr@users.noreply.github.com> Co-authored-by: Alex Jones <alexsimonjones@gmail.com>
104 lines
3.2 KiB
Go
104 lines
3.2 KiB
Go
package ai
|
|
|
|
import (
|
|
"context"
|
|
"testing"
|
|
|
|
"github.com/aws/aws-sdk-go-v2/aws"
|
|
"github.com/aws/aws-sdk-go-v2/service/bedrock"
|
|
"github.com/aws/aws-sdk-go-v2/service/bedrock/types"
|
|
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
|
|
"github.com/k8sgpt-ai/k8sgpt/pkg/ai/bedrock_support"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/mock"
|
|
)
|
|
|
|
// Mock for Bedrock Management Client
|
|
type MockBedrockClient struct {
|
|
mock.Mock
|
|
}
|
|
|
|
func (m *MockBedrockClient) GetInferenceProfile(ctx context.Context, params *bedrock.GetInferenceProfileInput, optFns ...func(*bedrock.Options)) (*bedrock.GetInferenceProfileOutput, error) {
|
|
args := m.Called(ctx, params)
|
|
|
|
if args.Get(0) == nil {
|
|
return nil, args.Error(1)
|
|
}
|
|
|
|
return args.Get(0).(*bedrock.GetInferenceProfileOutput), args.Error(1)
|
|
}
|
|
|
|
// Mock for Bedrock Runtime Client
|
|
type MockBedrockRuntimeClient struct {
|
|
mock.Mock
|
|
}
|
|
|
|
func (m *MockBedrockRuntimeClient) InvokeModel(ctx context.Context, params *bedrockruntime.InvokeModelInput, optFns ...func(*bedrockruntime.Options)) (*bedrockruntime.InvokeModelOutput, error) {
|
|
args := m.Called(ctx, params)
|
|
|
|
if args.Get(0) == nil {
|
|
return nil, args.Error(1)
|
|
}
|
|
|
|
return args.Get(0).(*bedrockruntime.InvokeModelOutput), args.Error(1)
|
|
}
|
|
|
|
// TestBedrockInferenceProfileARNWithMocks tests the inference profile ARN validation with mocks
|
|
func TestBedrockInferenceProfileARNWithMocks(t *testing.T) {
|
|
// Create test models
|
|
testModels := []bedrock_support.BedrockModel{
|
|
{
|
|
Name: "anthropic.claude-3-5-sonnet-20240620-v1:0",
|
|
Completion: &bedrock_support.CohereMessagesCompletion{},
|
|
Response: &bedrock_support.CohereMessagesResponse{},
|
|
Config: bedrock_support.BedrockModelConfig{
|
|
MaxTokens: 100,
|
|
Temperature: 0.5,
|
|
TopP: 0.9,
|
|
ModelName: "anthropic.claude-3-5-sonnet-20240620-v1:0",
|
|
},
|
|
},
|
|
}
|
|
|
|
// Create a client with test models
|
|
client := &AmazonBedRockClient{models: testModels}
|
|
|
|
// Create mock clients
|
|
mockMgmtClient := new(MockBedrockClient)
|
|
mockRuntimeClient := new(MockBedrockRuntimeClient)
|
|
|
|
// Inject mock clients into the AmazonBedRockClient
|
|
client.mgmtClient = mockMgmtClient
|
|
client.client = mockRuntimeClient
|
|
|
|
// Test with a valid inference profile ARN
|
|
inferenceProfileARN := "arn:aws:bedrock:us-east-1:123456789012:inference-profile/my-profile"
|
|
|
|
// Setup mock response for GetInferenceProfile
|
|
mockMgmtClient.On("GetInferenceProfile", mock.Anything, &bedrock.GetInferenceProfileInput{
|
|
InferenceProfileIdentifier: aws.String("my-profile"),
|
|
}).Return(&bedrock.GetInferenceProfileOutput{
|
|
Models: []types.InferenceProfileModel{
|
|
{
|
|
ModelArn: aws.String("arn:aws:bedrock:us-east-1::foundation-model/anthropic.claude-3-5-sonnet-20240620-v1:0"),
|
|
},
|
|
},
|
|
}, nil)
|
|
|
|
// Configure the client with the inference profile ARN
|
|
config := AIProvider{
|
|
Model: inferenceProfileARN,
|
|
ProviderRegion: "us-east-1",
|
|
}
|
|
|
|
// Test the Configure method with the inference profile ARN
|
|
err := client.Configure(&config)
|
|
|
|
// Verify that the configuration was successful
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, inferenceProfileARN, client.model.Config.ModelName)
|
|
|
|
// Verify that the mock was called
|
|
mockMgmtClient.AssertExpectations(t)
|
|
}
|