fix: config ai provider in query (#1457)

Signed-off-by: Guoxun Wei <guwe@microsoft.com>
This commit is contained in:
gossion 2025-04-15 18:11:40 +08:00 committed by GitHub
parent 80904e3063
commit df17e3e728
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 443 additions and 2 deletions

1
go.mod
View File

@ -122,6 +122,7 @@ require (
github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 // indirect
github.com/sony/gobreaker v0.5.0 // indirect
github.com/sourcegraph/conc v0.3.0 // indirect
github.com/stretchr/objx v0.5.2 // indirect
github.com/x448/float16 v0.8.4 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
go.opencensus.io v0.24.0 // indirect

87
pkg/ai/factory.go Normal file
View File

@ -0,0 +1,87 @@
/*
Copyright 2023 The K8sGPT Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ai
import (
"github.com/spf13/viper"
)
// AIClientFactory is an interface for creating AI clients
type AIClientFactory interface {
NewClient(provider string) IAI
}
// DefaultAIClientFactory is the default implementation of AIClientFactory
type DefaultAIClientFactory struct{}
// NewClient creates a new AI client using the default implementation
func (f *DefaultAIClientFactory) NewClient(provider string) IAI {
return NewClient(provider)
}
// ConfigProvider is an interface for accessing configuration
type ConfigProvider interface {
UnmarshalKey(key string, rawVal interface{}) error
}
// ViperConfigProvider is the default implementation of ConfigProvider using Viper
type ViperConfigProvider struct{}
// UnmarshalKey unmarshals a key from the configuration using Viper
func (p *ViperConfigProvider) UnmarshalKey(key string, rawVal interface{}) error {
return viper.UnmarshalKey(key, rawVal)
}
// Default instances to be used
var (
DefaultClientFactory = &DefaultAIClientFactory{}
DefaultConfigProvider = &ViperConfigProvider{}
)
// For testing - these variables can be overridden in tests
var (
testAIClientFactory AIClientFactory = nil
testConfigProvider ConfigProvider = nil
)
// GetAIClientFactory returns the test factory if set, otherwise the default
func GetAIClientFactory() AIClientFactory {
if testAIClientFactory != nil {
return testAIClientFactory
}
return DefaultClientFactory
}
// GetConfigProvider returns the test provider if set, otherwise the default
func GetConfigProvider() ConfigProvider {
if testConfigProvider != nil {
return testConfigProvider
}
return DefaultConfigProvider
}
// For testing - set the test implementations
func SetTestAIClientFactory(factory AIClientFactory) {
testAIClientFactory = factory
}
func SetTestConfigProvider(provider ConfigProvider) {
testConfigProvider = provider
}
// Reset test implementations
func ResetTestImplementations() {
testAIClientFactory = nil
testConfigProvider = nil
}

View File

@ -1,8 +1,10 @@
package query
import (
schemav1 "buf.build/gen/go/k8sgpt-ai/k8sgpt/protocolbuffers/go/schema/v1"
"context"
"fmt"
schemav1 "buf.build/gen/go/k8sgpt-ai/k8sgpt/protocolbuffers/go/schema/v1"
"github.com/k8sgpt-ai/k8sgpt/pkg/ai"
)
@ -10,9 +12,50 @@ func (h *Handler) Query(ctx context.Context, i *schemav1.QueryRequest) (
*schemav1.QueryResponse,
error,
) {
aiClient := ai.NewClient(i.Backend)
// Create client factory and config provider
factory := ai.GetAIClientFactory()
configProvider := ai.GetConfigProvider()
// Use the factory to create the client
aiClient := factory.NewClient(i.Backend)
defer aiClient.Close()
var configAI ai.AIConfiguration
if err := configProvider.UnmarshalKey("ai", &configAI); err != nil {
return &schemav1.QueryResponse{
Response: "",
Error: &schemav1.QueryError{
Message: fmt.Sprintf("Failed to unmarshal AI configuration: %v", err),
},
}, nil
}
var aiProvider ai.AIProvider
for _, provider := range configAI.Providers {
if i.Backend == provider.Name {
aiProvider = provider
break
}
}
if aiProvider.Name == "" {
return &schemav1.QueryResponse{
Response: "",
Error: &schemav1.QueryError{
Message: fmt.Sprintf("AI provider %s not found in configuration", i.Backend),
},
}, nil
}
// Configure the AI client
if err := aiClient.Configure(&aiProvider); err != nil {
return &schemav1.QueryResponse{
Response: "",
Error: &schemav1.QueryError{
Message: fmt.Sprintf("Failed to configure AI client: %v", err),
},
}, nil
}
resp, err := aiClient.GetCompletion(ctx, i.Query)
var errMessage string = ""
if err != nil {

View File

@ -0,0 +1,310 @@
package query
import (
"context"
"errors"
"testing"
schemav1 "buf.build/gen/go/k8sgpt-ai/k8sgpt/protocolbuffers/go/schema/v1"
"github.com/k8sgpt-ai/k8sgpt/pkg/ai"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
// MockAI is a mock implementation of the ai.IAI interface for testing
type MockAI struct {
mock.Mock
}
func (m *MockAI) Configure(config ai.IAIConfig) error {
args := m.Called(config)
return args.Error(0)
}
func (m *MockAI) GetCompletion(ctx context.Context, prompt string) (string, error) {
args := m.Called(ctx, prompt)
return args.String(0), args.Error(1)
}
func (m *MockAI) GetName() string {
args := m.Called()
return args.String(0)
}
func (m *MockAI) Close() {
m.Called()
}
// MockAIClientFactory is a mock implementation of AIClientFactory
type MockAIClientFactory struct {
mock.Mock
}
func (m *MockAIClientFactory) NewClient(provider string) ai.IAI {
args := m.Called(provider)
return args.Get(0).(ai.IAI)
}
// MockConfigProvider is a mock implementation of ConfigProvider
type MockConfigProvider struct {
mock.Mock
}
func (m *MockConfigProvider) UnmarshalKey(key string, rawVal interface{}) error {
args := m.Called(key, rawVal)
// If we want to set the rawVal (which is a pointer)
if fn, ok := args.Get(0).(func(interface{})); ok && fn != nil {
fn(rawVal)
}
// Return the error as the first return value
return args.Error(0)
}
func TestQuery_Success(t *testing.T) {
// Setup mocks
mockAI := new(MockAI)
mockFactory := new(MockAIClientFactory)
mockConfig := new(MockConfigProvider)
// Set test implementations
ai.SetTestAIClientFactory(mockFactory)
ai.SetTestConfigProvider(mockConfig)
defer ai.ResetTestImplementations()
// Define test data
testBackend := "test-backend"
testQuery := "test query"
testResponse := "test response"
// Setup expectations
mockFactory.On("NewClient", testBackend).Return(mockAI)
mockAI.On("Close").Return()
// Set up configuration with a valid provider
mockConfig.On("UnmarshalKey", "ai", mock.Anything).Run(func(args mock.Arguments) {
config := args.Get(1).(*ai.AIConfiguration)
*config = ai.AIConfiguration{
Providers: []ai.AIProvider{
{
Name: testBackend,
Password: "test-password",
Model: "test-model",
},
},
}
}).Return(nil)
mockAI.On("Configure", mock.AnythingOfType("*ai.AIProvider")).Return(nil)
mockAI.On("GetCompletion", mock.Anything, testQuery).Return(testResponse, nil)
// Create handler and call Query
handler := &Handler{}
response, err := handler.Query(context.Background(), &schemav1.QueryRequest{
Backend: testBackend,
Query: testQuery,
})
// Assertions
assert.NoError(t, err)
assert.NotNil(t, response)
assert.Equal(t, testResponse, response.Response)
assert.Equal(t, "", response.Error.Message)
// Verify mocks
mockAI.AssertExpectations(t)
mockFactory.AssertExpectations(t)
mockConfig.AssertExpectations(t)
}
func TestQuery_UnmarshalError(t *testing.T) {
// Setup mocks
mockAI := new(MockAI)
mockFactory := new(MockAIClientFactory)
mockConfig := new(MockConfigProvider)
// Set test implementations
ai.SetTestAIClientFactory(mockFactory)
ai.SetTestConfigProvider(mockConfig)
defer ai.ResetTestImplementations()
// Setup expectations
mockFactory.On("NewClient", "test-backend").Return(mockAI)
mockAI.On("Close").Return()
// Mock unmarshal error
mockConfig.On("UnmarshalKey", "ai", mock.Anything).Return(errors.New("unmarshal error"))
// Create handler and call Query
handler := &Handler{}
response, err := handler.Query(context.Background(), &schemav1.QueryRequest{
Backend: "test-backend",
Query: "test query",
})
// Assertions
assert.NoError(t, err)
assert.NotNil(t, response)
assert.Equal(t, "", response.Response)
assert.Contains(t, response.Error.Message, "Failed to unmarshal AI configuration")
// Verify mocks
mockAI.AssertExpectations(t)
mockFactory.AssertExpectations(t)
mockConfig.AssertExpectations(t)
}
func TestQuery_ProviderNotFound(t *testing.T) {
// Setup mocks
mockAI := new(MockAI)
mockFactory := new(MockAIClientFactory)
mockConfig := new(MockConfigProvider)
// Set test implementations
ai.SetTestAIClientFactory(mockFactory)
ai.SetTestConfigProvider(mockConfig)
defer ai.ResetTestImplementations()
// Define test data
testBackend := "test-backend"
// Setup expectations
mockFactory.On("NewClient", testBackend).Return(mockAI)
mockAI.On("Close").Return()
// Set up configuration with no matching provider
mockConfig.On("UnmarshalKey", "ai", mock.Anything).Run(func(args mock.Arguments) {
config := args.Get(1).(*ai.AIConfiguration)
*config = ai.AIConfiguration{
Providers: []ai.AIProvider{
{
Name: "other-backend",
},
},
}
}).Return(nil)
// Create handler and call Query
handler := &Handler{}
response, err := handler.Query(context.Background(), &schemav1.QueryRequest{
Backend: testBackend,
Query: "test query",
})
// Assertions
assert.NoError(t, err)
assert.NotNil(t, response)
assert.Equal(t, "", response.Response)
assert.Contains(t, response.Error.Message, "AI provider test-backend not found in configuration")
// Verify mocks
mockAI.AssertExpectations(t)
mockFactory.AssertExpectations(t)
mockConfig.AssertExpectations(t)
}
func TestQuery_ConfigureError(t *testing.T) {
// Setup mocks
mockAI := new(MockAI)
mockFactory := new(MockAIClientFactory)
mockConfig := new(MockConfigProvider)
// Set test implementations
ai.SetTestAIClientFactory(mockFactory)
ai.SetTestConfigProvider(mockConfig)
defer ai.ResetTestImplementations()
// Define test data
testBackend := "test-backend"
// Setup expectations
mockFactory.On("NewClient", testBackend).Return(mockAI)
mockAI.On("Close").Return()
// Set up configuration with a valid provider
mockConfig.On("UnmarshalKey", "ai", mock.Anything).Run(func(args mock.Arguments) {
config := args.Get(1).(*ai.AIConfiguration)
*config = ai.AIConfiguration{
Providers: []ai.AIProvider{
{
Name: testBackend,
},
},
}
}).Return(nil)
// Mock configure error
mockAI.On("Configure", mock.AnythingOfType("*ai.AIProvider")).Return(errors.New("configure error"))
// Create handler and call Query
handler := &Handler{}
response, err := handler.Query(context.Background(), &schemav1.QueryRequest{
Backend: testBackend,
Query: "test query",
})
// Assertions
assert.NoError(t, err)
assert.NotNil(t, response)
assert.Equal(t, "", response.Response)
assert.Contains(t, response.Error.Message, "Failed to configure AI client")
// Verify mocks
mockAI.AssertExpectations(t)
mockFactory.AssertExpectations(t)
mockConfig.AssertExpectations(t)
}
func TestQuery_GetCompletionError(t *testing.T) {
// Setup mocks
mockAI := new(MockAI)
mockFactory := new(MockAIClientFactory)
mockConfig := new(MockConfigProvider)
// Set test implementations
ai.SetTestAIClientFactory(mockFactory)
ai.SetTestConfigProvider(mockConfig)
defer ai.ResetTestImplementations()
// Define test data
testBackend := "test-backend"
testQuery := "test query"
// Setup expectations
mockFactory.On("NewClient", testBackend).Return(mockAI)
mockAI.On("Close").Return()
// Set up configuration with a valid provider
mockConfig.On("UnmarshalKey", "ai", mock.Anything).Run(func(args mock.Arguments) {
config := args.Get(1).(*ai.AIConfiguration)
*config = ai.AIConfiguration{
Providers: []ai.AIProvider{
{
Name: testBackend,
},
},
}
}).Return(nil)
mockAI.On("Configure", mock.AnythingOfType("*ai.AIProvider")).Return(nil)
mockAI.On("GetCompletion", mock.Anything, testQuery).Return("", errors.New("completion error"))
// Create handler and call Query
handler := &Handler{}
response, err := handler.Query(context.Background(), &schemav1.QueryRequest{
Backend: testBackend,
Query: testQuery,
})
// Assertions
assert.NoError(t, err)
assert.NotNil(t, response)
assert.Equal(t, "", response.Response)
assert.Equal(t, "completion error", response.Error.Message)
// Verify mocks
mockAI.AssertExpectations(t)
mockFactory.AssertExpectations(t)
mockConfig.AssertExpectations(t)
}