diff --git a/go.mod b/go.mod index 1f2353a..2534d3c 100644 --- a/go.mod +++ b/go.mod @@ -122,6 +122,7 @@ require ( github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 // indirect github.com/sony/gobreaker v0.5.0 // indirect github.com/sourcegraph/conc v0.3.0 // indirect + github.com/stretchr/objx v0.5.2 // indirect github.com/x448/float16 v0.8.4 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect go.opencensus.io v0.24.0 // indirect diff --git a/pkg/ai/factory.go b/pkg/ai/factory.go new file mode 100644 index 0000000..f1af83a --- /dev/null +++ b/pkg/ai/factory.go @@ -0,0 +1,87 @@ +/* +Copyright 2023 The K8sGPT Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package ai + +import ( + "github.com/spf13/viper" +) + +// AIClientFactory is an interface for creating AI clients +type AIClientFactory interface { + NewClient(provider string) IAI +} + +// DefaultAIClientFactory is the default implementation of AIClientFactory +type DefaultAIClientFactory struct{} + +// NewClient creates a new AI client using the default implementation +func (f *DefaultAIClientFactory) NewClient(provider string) IAI { + return NewClient(provider) +} + +// ConfigProvider is an interface for accessing configuration +type ConfigProvider interface { + UnmarshalKey(key string, rawVal interface{}) error +} + +// ViperConfigProvider is the default implementation of ConfigProvider using Viper +type ViperConfigProvider struct{} + +// UnmarshalKey unmarshals a key from the configuration using Viper +func (p *ViperConfigProvider) UnmarshalKey(key string, rawVal interface{}) error { + return viper.UnmarshalKey(key, rawVal) +} + +// Default instances to be used +var ( + DefaultClientFactory = &DefaultAIClientFactory{} + DefaultConfigProvider = &ViperConfigProvider{} +) + +// For testing - these variables can be overridden in tests +var ( + testAIClientFactory AIClientFactory = nil + testConfigProvider ConfigProvider = nil +) + +// GetAIClientFactory returns the test factory if set, otherwise the default +func GetAIClientFactory() AIClientFactory { + if testAIClientFactory != nil { + return testAIClientFactory + } + return DefaultClientFactory +} + +// GetConfigProvider returns the test provider if set, otherwise the default +func GetConfigProvider() ConfigProvider { + if testConfigProvider != nil { + return testConfigProvider + } + return DefaultConfigProvider +} + +// For testing - set the test implementations +func SetTestAIClientFactory(factory AIClientFactory) { + testAIClientFactory = factory +} + +func SetTestConfigProvider(provider ConfigProvider) { + testConfigProvider = provider +} + +// Reset test implementations +func ResetTestImplementations() { + testAIClientFactory = nil + testConfigProvider = nil +} \ No newline at end of file diff --git a/pkg/server/query/query.go b/pkg/server/query/query.go index 640ab5e..9eabb7a 100644 --- a/pkg/server/query/query.go +++ b/pkg/server/query/query.go @@ -1,8 +1,10 @@ package query import ( - schemav1 "buf.build/gen/go/k8sgpt-ai/k8sgpt/protocolbuffers/go/schema/v1" "context" + "fmt" + + schemav1 "buf.build/gen/go/k8sgpt-ai/k8sgpt/protocolbuffers/go/schema/v1" "github.com/k8sgpt-ai/k8sgpt/pkg/ai" ) @@ -10,9 +12,50 @@ func (h *Handler) Query(ctx context.Context, i *schemav1.QueryRequest) ( *schemav1.QueryResponse, error, ) { - aiClient := ai.NewClient(i.Backend) + // Create client factory and config provider + factory := ai.GetAIClientFactory() + configProvider := ai.GetConfigProvider() + + // Use the factory to create the client + aiClient := factory.NewClient(i.Backend) defer aiClient.Close() + var configAI ai.AIConfiguration + if err := configProvider.UnmarshalKey("ai", &configAI); err != nil { + return &schemav1.QueryResponse{ + Response: "", + Error: &schemav1.QueryError{ + Message: fmt.Sprintf("Failed to unmarshal AI configuration: %v", err), + }, + }, nil + } + + var aiProvider ai.AIProvider + for _, provider := range configAI.Providers { + if i.Backend == provider.Name { + aiProvider = provider + break + } + } + if aiProvider.Name == "" { + return &schemav1.QueryResponse{ + Response: "", + Error: &schemav1.QueryError{ + Message: fmt.Sprintf("AI provider %s not found in configuration", i.Backend), + }, + }, nil + } + + // Configure the AI client + if err := aiClient.Configure(&aiProvider); err != nil { + return &schemav1.QueryResponse{ + Response: "", + Error: &schemav1.QueryError{ + Message: fmt.Sprintf("Failed to configure AI client: %v", err), + }, + }, nil + } + resp, err := aiClient.GetCompletion(ctx, i.Query) var errMessage string = "" if err != nil { diff --git a/pkg/server/query/query_test.go b/pkg/server/query/query_test.go new file mode 100644 index 0000000..2c29c1a --- /dev/null +++ b/pkg/server/query/query_test.go @@ -0,0 +1,310 @@ +package query + +import ( + "context" + "errors" + "testing" + + schemav1 "buf.build/gen/go/k8sgpt-ai/k8sgpt/protocolbuffers/go/schema/v1" + "github.com/k8sgpt-ai/k8sgpt/pkg/ai" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +// MockAI is a mock implementation of the ai.IAI interface for testing +type MockAI struct { + mock.Mock +} + +func (m *MockAI) Configure(config ai.IAIConfig) error { + args := m.Called(config) + return args.Error(0) +} + +func (m *MockAI) GetCompletion(ctx context.Context, prompt string) (string, error) { + args := m.Called(ctx, prompt) + return args.String(0), args.Error(1) +} + +func (m *MockAI) GetName() string { + args := m.Called() + return args.String(0) +} + +func (m *MockAI) Close() { + m.Called() +} + +// MockAIClientFactory is a mock implementation of AIClientFactory +type MockAIClientFactory struct { + mock.Mock +} + +func (m *MockAIClientFactory) NewClient(provider string) ai.IAI { + args := m.Called(provider) + return args.Get(0).(ai.IAI) +} + +// MockConfigProvider is a mock implementation of ConfigProvider +type MockConfigProvider struct { + mock.Mock +} + +func (m *MockConfigProvider) UnmarshalKey(key string, rawVal interface{}) error { + args := m.Called(key, rawVal) + + // If we want to set the rawVal (which is a pointer) + if fn, ok := args.Get(0).(func(interface{})); ok && fn != nil { + fn(rawVal) + } + + // Return the error as the first return value + return args.Error(0) +} + +func TestQuery_Success(t *testing.T) { + // Setup mocks + mockAI := new(MockAI) + mockFactory := new(MockAIClientFactory) + mockConfig := new(MockConfigProvider) + + // Set test implementations + ai.SetTestAIClientFactory(mockFactory) + ai.SetTestConfigProvider(mockConfig) + defer ai.ResetTestImplementations() + + // Define test data + testBackend := "test-backend" + testQuery := "test query" + testResponse := "test response" + + // Setup expectations + mockFactory.On("NewClient", testBackend).Return(mockAI) + mockAI.On("Close").Return() + + // Set up configuration with a valid provider + mockConfig.On("UnmarshalKey", "ai", mock.Anything).Run(func(args mock.Arguments) { + config := args.Get(1).(*ai.AIConfiguration) + *config = ai.AIConfiguration{ + Providers: []ai.AIProvider{ + { + Name: testBackend, + Password: "test-password", + Model: "test-model", + }, + }, + } + }).Return(nil) + + mockAI.On("Configure", mock.AnythingOfType("*ai.AIProvider")).Return(nil) + mockAI.On("GetCompletion", mock.Anything, testQuery).Return(testResponse, nil) + + // Create handler and call Query + handler := &Handler{} + response, err := handler.Query(context.Background(), &schemav1.QueryRequest{ + Backend: testBackend, + Query: testQuery, + }) + + // Assertions + assert.NoError(t, err) + assert.NotNil(t, response) + assert.Equal(t, testResponse, response.Response) + assert.Equal(t, "", response.Error.Message) + + // Verify mocks + mockAI.AssertExpectations(t) + mockFactory.AssertExpectations(t) + mockConfig.AssertExpectations(t) +} + +func TestQuery_UnmarshalError(t *testing.T) { + // Setup mocks + mockAI := new(MockAI) + mockFactory := new(MockAIClientFactory) + mockConfig := new(MockConfigProvider) + + // Set test implementations + ai.SetTestAIClientFactory(mockFactory) + ai.SetTestConfigProvider(mockConfig) + defer ai.ResetTestImplementations() + + // Setup expectations + mockFactory.On("NewClient", "test-backend").Return(mockAI) + mockAI.On("Close").Return() + + // Mock unmarshal error + mockConfig.On("UnmarshalKey", "ai", mock.Anything).Return(errors.New("unmarshal error")) + + // Create handler and call Query + handler := &Handler{} + response, err := handler.Query(context.Background(), &schemav1.QueryRequest{ + Backend: "test-backend", + Query: "test query", + }) + + // Assertions + assert.NoError(t, err) + assert.NotNil(t, response) + assert.Equal(t, "", response.Response) + assert.Contains(t, response.Error.Message, "Failed to unmarshal AI configuration") + + // Verify mocks + mockAI.AssertExpectations(t) + mockFactory.AssertExpectations(t) + mockConfig.AssertExpectations(t) +} + +func TestQuery_ProviderNotFound(t *testing.T) { + // Setup mocks + mockAI := new(MockAI) + mockFactory := new(MockAIClientFactory) + mockConfig := new(MockConfigProvider) + + // Set test implementations + ai.SetTestAIClientFactory(mockFactory) + ai.SetTestConfigProvider(mockConfig) + defer ai.ResetTestImplementations() + + // Define test data + testBackend := "test-backend" + + // Setup expectations + mockFactory.On("NewClient", testBackend).Return(mockAI) + mockAI.On("Close").Return() + + // Set up configuration with no matching provider + mockConfig.On("UnmarshalKey", "ai", mock.Anything).Run(func(args mock.Arguments) { + config := args.Get(1).(*ai.AIConfiguration) + *config = ai.AIConfiguration{ + Providers: []ai.AIProvider{ + { + Name: "other-backend", + }, + }, + } + }).Return(nil) + + // Create handler and call Query + handler := &Handler{} + response, err := handler.Query(context.Background(), &schemav1.QueryRequest{ + Backend: testBackend, + Query: "test query", + }) + + // Assertions + assert.NoError(t, err) + assert.NotNil(t, response) + assert.Equal(t, "", response.Response) + assert.Contains(t, response.Error.Message, "AI provider test-backend not found in configuration") + + // Verify mocks + mockAI.AssertExpectations(t) + mockFactory.AssertExpectations(t) + mockConfig.AssertExpectations(t) +} + +func TestQuery_ConfigureError(t *testing.T) { + // Setup mocks + mockAI := new(MockAI) + mockFactory := new(MockAIClientFactory) + mockConfig := new(MockConfigProvider) + + // Set test implementations + ai.SetTestAIClientFactory(mockFactory) + ai.SetTestConfigProvider(mockConfig) + defer ai.ResetTestImplementations() + + // Define test data + testBackend := "test-backend" + + // Setup expectations + mockFactory.On("NewClient", testBackend).Return(mockAI) + mockAI.On("Close").Return() + + // Set up configuration with a valid provider + mockConfig.On("UnmarshalKey", "ai", mock.Anything).Run(func(args mock.Arguments) { + config := args.Get(1).(*ai.AIConfiguration) + *config = ai.AIConfiguration{ + Providers: []ai.AIProvider{ + { + Name: testBackend, + }, + }, + } + }).Return(nil) + + // Mock configure error + mockAI.On("Configure", mock.AnythingOfType("*ai.AIProvider")).Return(errors.New("configure error")) + + // Create handler and call Query + handler := &Handler{} + response, err := handler.Query(context.Background(), &schemav1.QueryRequest{ + Backend: testBackend, + Query: "test query", + }) + + // Assertions + assert.NoError(t, err) + assert.NotNil(t, response) + assert.Equal(t, "", response.Response) + assert.Contains(t, response.Error.Message, "Failed to configure AI client") + + // Verify mocks + mockAI.AssertExpectations(t) + mockFactory.AssertExpectations(t) + mockConfig.AssertExpectations(t) +} + +func TestQuery_GetCompletionError(t *testing.T) { + // Setup mocks + mockAI := new(MockAI) + mockFactory := new(MockAIClientFactory) + mockConfig := new(MockConfigProvider) + + // Set test implementations + ai.SetTestAIClientFactory(mockFactory) + ai.SetTestConfigProvider(mockConfig) + defer ai.ResetTestImplementations() + + // Define test data + testBackend := "test-backend" + testQuery := "test query" + + // Setup expectations + mockFactory.On("NewClient", testBackend).Return(mockAI) + mockAI.On("Close").Return() + + // Set up configuration with a valid provider + mockConfig.On("UnmarshalKey", "ai", mock.Anything).Run(func(args mock.Arguments) { + config := args.Get(1).(*ai.AIConfiguration) + *config = ai.AIConfiguration{ + Providers: []ai.AIProvider{ + { + Name: testBackend, + }, + }, + } + }).Return(nil) + + mockAI.On("Configure", mock.AnythingOfType("*ai.AIProvider")).Return(nil) + mockAI.On("GetCompletion", mock.Anything, testQuery).Return("", errors.New("completion error")) + + // Create handler and call Query + handler := &Handler{} + response, err := handler.Query(context.Background(), &schemav1.QueryRequest{ + Backend: testBackend, + Query: testQuery, + }) + + // Assertions + assert.NoError(t, err) + assert.NotNil(t, response) + assert.Equal(t, "", response.Response) + assert.Equal(t, "completion error", response.Error.Message) + + // Verify mocks + mockAI.AssertExpectations(t) + mockFactory.AssertExpectations(t) + mockConfig.AssertExpectations(t) +} \ No newline at end of file