mirror of
https://github.com/k8sgpt-ai/k8sgpt.git
synced 2025-08-02 08:06:18 +00:00
fix: config ai provider in query (#1457)
Signed-off-by: Guoxun Wei <guwe@microsoft.com>
This commit is contained in:
parent
80904e3063
commit
df17e3e728
1
go.mod
1
go.mod
@ -122,6 +122,7 @@ require (
|
|||||||
github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 // indirect
|
github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 // indirect
|
||||||
github.com/sony/gobreaker v0.5.0 // indirect
|
github.com/sony/gobreaker v0.5.0 // indirect
|
||||||
github.com/sourcegraph/conc v0.3.0 // indirect
|
github.com/sourcegraph/conc v0.3.0 // indirect
|
||||||
|
github.com/stretchr/objx v0.5.2 // indirect
|
||||||
github.com/x448/float16 v0.8.4 // indirect
|
github.com/x448/float16 v0.8.4 // indirect
|
||||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||||
go.opencensus.io v0.24.0 // indirect
|
go.opencensus.io v0.24.0 // indirect
|
||||||
|
87
pkg/ai/factory.go
Normal file
87
pkg/ai/factory.go
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
/*
|
||||||
|
Copyright 2023 The K8sGPT Authors.
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/spf13/viper"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AIClientFactory is an interface for creating AI clients
|
||||||
|
type AIClientFactory interface {
|
||||||
|
NewClient(provider string) IAI
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultAIClientFactory is the default implementation of AIClientFactory
|
||||||
|
type DefaultAIClientFactory struct{}
|
||||||
|
|
||||||
|
// NewClient creates a new AI client using the default implementation
|
||||||
|
func (f *DefaultAIClientFactory) NewClient(provider string) IAI {
|
||||||
|
return NewClient(provider)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConfigProvider is an interface for accessing configuration
|
||||||
|
type ConfigProvider interface {
|
||||||
|
UnmarshalKey(key string, rawVal interface{}) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// ViperConfigProvider is the default implementation of ConfigProvider using Viper
|
||||||
|
type ViperConfigProvider struct{}
|
||||||
|
|
||||||
|
// UnmarshalKey unmarshals a key from the configuration using Viper
|
||||||
|
func (p *ViperConfigProvider) UnmarshalKey(key string, rawVal interface{}) error {
|
||||||
|
return viper.UnmarshalKey(key, rawVal)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default instances to be used
|
||||||
|
var (
|
||||||
|
DefaultClientFactory = &DefaultAIClientFactory{}
|
||||||
|
DefaultConfigProvider = &ViperConfigProvider{}
|
||||||
|
)
|
||||||
|
|
||||||
|
// For testing - these variables can be overridden in tests
|
||||||
|
var (
|
||||||
|
testAIClientFactory AIClientFactory = nil
|
||||||
|
testConfigProvider ConfigProvider = nil
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetAIClientFactory returns the test factory if set, otherwise the default
|
||||||
|
func GetAIClientFactory() AIClientFactory {
|
||||||
|
if testAIClientFactory != nil {
|
||||||
|
return testAIClientFactory
|
||||||
|
}
|
||||||
|
return DefaultClientFactory
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConfigProvider returns the test provider if set, otherwise the default
|
||||||
|
func GetConfigProvider() ConfigProvider {
|
||||||
|
if testConfigProvider != nil {
|
||||||
|
return testConfigProvider
|
||||||
|
}
|
||||||
|
return DefaultConfigProvider
|
||||||
|
}
|
||||||
|
|
||||||
|
// For testing - set the test implementations
|
||||||
|
func SetTestAIClientFactory(factory AIClientFactory) {
|
||||||
|
testAIClientFactory = factory
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetTestConfigProvider(provider ConfigProvider) {
|
||||||
|
testConfigProvider = provider
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset test implementations
|
||||||
|
func ResetTestImplementations() {
|
||||||
|
testAIClientFactory = nil
|
||||||
|
testConfigProvider = nil
|
||||||
|
}
|
@ -1,8 +1,10 @@
|
|||||||
package query
|
package query
|
||||||
|
|
||||||
import (
|
import (
|
||||||
schemav1 "buf.build/gen/go/k8sgpt-ai/k8sgpt/protocolbuffers/go/schema/v1"
|
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
schemav1 "buf.build/gen/go/k8sgpt-ai/k8sgpt/protocolbuffers/go/schema/v1"
|
||||||
"github.com/k8sgpt-ai/k8sgpt/pkg/ai"
|
"github.com/k8sgpt-ai/k8sgpt/pkg/ai"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -10,9 +12,50 @@ func (h *Handler) Query(ctx context.Context, i *schemav1.QueryRequest) (
|
|||||||
*schemav1.QueryResponse,
|
*schemav1.QueryResponse,
|
||||||
error,
|
error,
|
||||||
) {
|
) {
|
||||||
aiClient := ai.NewClient(i.Backend)
|
// Create client factory and config provider
|
||||||
|
factory := ai.GetAIClientFactory()
|
||||||
|
configProvider := ai.GetConfigProvider()
|
||||||
|
|
||||||
|
// Use the factory to create the client
|
||||||
|
aiClient := factory.NewClient(i.Backend)
|
||||||
defer aiClient.Close()
|
defer aiClient.Close()
|
||||||
|
|
||||||
|
var configAI ai.AIConfiguration
|
||||||
|
if err := configProvider.UnmarshalKey("ai", &configAI); err != nil {
|
||||||
|
return &schemav1.QueryResponse{
|
||||||
|
Response: "",
|
||||||
|
Error: &schemav1.QueryError{
|
||||||
|
Message: fmt.Sprintf("Failed to unmarshal AI configuration: %v", err),
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var aiProvider ai.AIProvider
|
||||||
|
for _, provider := range configAI.Providers {
|
||||||
|
if i.Backend == provider.Name {
|
||||||
|
aiProvider = provider
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if aiProvider.Name == "" {
|
||||||
|
return &schemav1.QueryResponse{
|
||||||
|
Response: "",
|
||||||
|
Error: &schemav1.QueryError{
|
||||||
|
Message: fmt.Sprintf("AI provider %s not found in configuration", i.Backend),
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure the AI client
|
||||||
|
if err := aiClient.Configure(&aiProvider); err != nil {
|
||||||
|
return &schemav1.QueryResponse{
|
||||||
|
Response: "",
|
||||||
|
Error: &schemav1.QueryError{
|
||||||
|
Message: fmt.Sprintf("Failed to configure AI client: %v", err),
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
resp, err := aiClient.GetCompletion(ctx, i.Query)
|
resp, err := aiClient.GetCompletion(ctx, i.Query)
|
||||||
var errMessage string = ""
|
var errMessage string = ""
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
310
pkg/server/query/query_test.go
Normal file
310
pkg/server/query/query_test.go
Normal file
@ -0,0 +1,310 @@
|
|||||||
|
package query
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
schemav1 "buf.build/gen/go/k8sgpt-ai/k8sgpt/protocolbuffers/go/schema/v1"
|
||||||
|
"github.com/k8sgpt-ai/k8sgpt/pkg/ai"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockAI is a mock implementation of the ai.IAI interface for testing
|
||||||
|
type MockAI struct {
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockAI) Configure(config ai.IAIConfig) error {
|
||||||
|
args := m.Called(config)
|
||||||
|
return args.Error(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockAI) GetCompletion(ctx context.Context, prompt string) (string, error) {
|
||||||
|
args := m.Called(ctx, prompt)
|
||||||
|
return args.String(0), args.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockAI) GetName() string {
|
||||||
|
args := m.Called()
|
||||||
|
return args.String(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockAI) Close() {
|
||||||
|
m.Called()
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockAIClientFactory is a mock implementation of AIClientFactory
|
||||||
|
type MockAIClientFactory struct {
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockAIClientFactory) NewClient(provider string) ai.IAI {
|
||||||
|
args := m.Called(provider)
|
||||||
|
return args.Get(0).(ai.IAI)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockConfigProvider is a mock implementation of ConfigProvider
|
||||||
|
type MockConfigProvider struct {
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockConfigProvider) UnmarshalKey(key string, rawVal interface{}) error {
|
||||||
|
args := m.Called(key, rawVal)
|
||||||
|
|
||||||
|
// If we want to set the rawVal (which is a pointer)
|
||||||
|
if fn, ok := args.Get(0).(func(interface{})); ok && fn != nil {
|
||||||
|
fn(rawVal)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the error as the first return value
|
||||||
|
return args.Error(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQuery_Success(t *testing.T) {
|
||||||
|
// Setup mocks
|
||||||
|
mockAI := new(MockAI)
|
||||||
|
mockFactory := new(MockAIClientFactory)
|
||||||
|
mockConfig := new(MockConfigProvider)
|
||||||
|
|
||||||
|
// Set test implementations
|
||||||
|
ai.SetTestAIClientFactory(mockFactory)
|
||||||
|
ai.SetTestConfigProvider(mockConfig)
|
||||||
|
defer ai.ResetTestImplementations()
|
||||||
|
|
||||||
|
// Define test data
|
||||||
|
testBackend := "test-backend"
|
||||||
|
testQuery := "test query"
|
||||||
|
testResponse := "test response"
|
||||||
|
|
||||||
|
// Setup expectations
|
||||||
|
mockFactory.On("NewClient", testBackend).Return(mockAI)
|
||||||
|
mockAI.On("Close").Return()
|
||||||
|
|
||||||
|
// Set up configuration with a valid provider
|
||||||
|
mockConfig.On("UnmarshalKey", "ai", mock.Anything).Run(func(args mock.Arguments) {
|
||||||
|
config := args.Get(1).(*ai.AIConfiguration)
|
||||||
|
*config = ai.AIConfiguration{
|
||||||
|
Providers: []ai.AIProvider{
|
||||||
|
{
|
||||||
|
Name: testBackend,
|
||||||
|
Password: "test-password",
|
||||||
|
Model: "test-model",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}).Return(nil)
|
||||||
|
|
||||||
|
mockAI.On("Configure", mock.AnythingOfType("*ai.AIProvider")).Return(nil)
|
||||||
|
mockAI.On("GetCompletion", mock.Anything, testQuery).Return(testResponse, nil)
|
||||||
|
|
||||||
|
// Create handler and call Query
|
||||||
|
handler := &Handler{}
|
||||||
|
response, err := handler.Query(context.Background(), &schemav1.QueryRequest{
|
||||||
|
Backend: testBackend,
|
||||||
|
Query: testQuery,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Assertions
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, response)
|
||||||
|
assert.Equal(t, testResponse, response.Response)
|
||||||
|
assert.Equal(t, "", response.Error.Message)
|
||||||
|
|
||||||
|
// Verify mocks
|
||||||
|
mockAI.AssertExpectations(t)
|
||||||
|
mockFactory.AssertExpectations(t)
|
||||||
|
mockConfig.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQuery_UnmarshalError(t *testing.T) {
|
||||||
|
// Setup mocks
|
||||||
|
mockAI := new(MockAI)
|
||||||
|
mockFactory := new(MockAIClientFactory)
|
||||||
|
mockConfig := new(MockConfigProvider)
|
||||||
|
|
||||||
|
// Set test implementations
|
||||||
|
ai.SetTestAIClientFactory(mockFactory)
|
||||||
|
ai.SetTestConfigProvider(mockConfig)
|
||||||
|
defer ai.ResetTestImplementations()
|
||||||
|
|
||||||
|
// Setup expectations
|
||||||
|
mockFactory.On("NewClient", "test-backend").Return(mockAI)
|
||||||
|
mockAI.On("Close").Return()
|
||||||
|
|
||||||
|
// Mock unmarshal error
|
||||||
|
mockConfig.On("UnmarshalKey", "ai", mock.Anything).Return(errors.New("unmarshal error"))
|
||||||
|
|
||||||
|
// Create handler and call Query
|
||||||
|
handler := &Handler{}
|
||||||
|
response, err := handler.Query(context.Background(), &schemav1.QueryRequest{
|
||||||
|
Backend: "test-backend",
|
||||||
|
Query: "test query",
|
||||||
|
})
|
||||||
|
|
||||||
|
// Assertions
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, response)
|
||||||
|
assert.Equal(t, "", response.Response)
|
||||||
|
assert.Contains(t, response.Error.Message, "Failed to unmarshal AI configuration")
|
||||||
|
|
||||||
|
// Verify mocks
|
||||||
|
mockAI.AssertExpectations(t)
|
||||||
|
mockFactory.AssertExpectations(t)
|
||||||
|
mockConfig.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQuery_ProviderNotFound(t *testing.T) {
|
||||||
|
// Setup mocks
|
||||||
|
mockAI := new(MockAI)
|
||||||
|
mockFactory := new(MockAIClientFactory)
|
||||||
|
mockConfig := new(MockConfigProvider)
|
||||||
|
|
||||||
|
// Set test implementations
|
||||||
|
ai.SetTestAIClientFactory(mockFactory)
|
||||||
|
ai.SetTestConfigProvider(mockConfig)
|
||||||
|
defer ai.ResetTestImplementations()
|
||||||
|
|
||||||
|
// Define test data
|
||||||
|
testBackend := "test-backend"
|
||||||
|
|
||||||
|
// Setup expectations
|
||||||
|
mockFactory.On("NewClient", testBackend).Return(mockAI)
|
||||||
|
mockAI.On("Close").Return()
|
||||||
|
|
||||||
|
// Set up configuration with no matching provider
|
||||||
|
mockConfig.On("UnmarshalKey", "ai", mock.Anything).Run(func(args mock.Arguments) {
|
||||||
|
config := args.Get(1).(*ai.AIConfiguration)
|
||||||
|
*config = ai.AIConfiguration{
|
||||||
|
Providers: []ai.AIProvider{
|
||||||
|
{
|
||||||
|
Name: "other-backend",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}).Return(nil)
|
||||||
|
|
||||||
|
// Create handler and call Query
|
||||||
|
handler := &Handler{}
|
||||||
|
response, err := handler.Query(context.Background(), &schemav1.QueryRequest{
|
||||||
|
Backend: testBackend,
|
||||||
|
Query: "test query",
|
||||||
|
})
|
||||||
|
|
||||||
|
// Assertions
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, response)
|
||||||
|
assert.Equal(t, "", response.Response)
|
||||||
|
assert.Contains(t, response.Error.Message, "AI provider test-backend not found in configuration")
|
||||||
|
|
||||||
|
// Verify mocks
|
||||||
|
mockAI.AssertExpectations(t)
|
||||||
|
mockFactory.AssertExpectations(t)
|
||||||
|
mockConfig.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQuery_ConfigureError(t *testing.T) {
|
||||||
|
// Setup mocks
|
||||||
|
mockAI := new(MockAI)
|
||||||
|
mockFactory := new(MockAIClientFactory)
|
||||||
|
mockConfig := new(MockConfigProvider)
|
||||||
|
|
||||||
|
// Set test implementations
|
||||||
|
ai.SetTestAIClientFactory(mockFactory)
|
||||||
|
ai.SetTestConfigProvider(mockConfig)
|
||||||
|
defer ai.ResetTestImplementations()
|
||||||
|
|
||||||
|
// Define test data
|
||||||
|
testBackend := "test-backend"
|
||||||
|
|
||||||
|
// Setup expectations
|
||||||
|
mockFactory.On("NewClient", testBackend).Return(mockAI)
|
||||||
|
mockAI.On("Close").Return()
|
||||||
|
|
||||||
|
// Set up configuration with a valid provider
|
||||||
|
mockConfig.On("UnmarshalKey", "ai", mock.Anything).Run(func(args mock.Arguments) {
|
||||||
|
config := args.Get(1).(*ai.AIConfiguration)
|
||||||
|
*config = ai.AIConfiguration{
|
||||||
|
Providers: []ai.AIProvider{
|
||||||
|
{
|
||||||
|
Name: testBackend,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}).Return(nil)
|
||||||
|
|
||||||
|
// Mock configure error
|
||||||
|
mockAI.On("Configure", mock.AnythingOfType("*ai.AIProvider")).Return(errors.New("configure error"))
|
||||||
|
|
||||||
|
// Create handler and call Query
|
||||||
|
handler := &Handler{}
|
||||||
|
response, err := handler.Query(context.Background(), &schemav1.QueryRequest{
|
||||||
|
Backend: testBackend,
|
||||||
|
Query: "test query",
|
||||||
|
})
|
||||||
|
|
||||||
|
// Assertions
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, response)
|
||||||
|
assert.Equal(t, "", response.Response)
|
||||||
|
assert.Contains(t, response.Error.Message, "Failed to configure AI client")
|
||||||
|
|
||||||
|
// Verify mocks
|
||||||
|
mockAI.AssertExpectations(t)
|
||||||
|
mockFactory.AssertExpectations(t)
|
||||||
|
mockConfig.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQuery_GetCompletionError(t *testing.T) {
|
||||||
|
// Setup mocks
|
||||||
|
mockAI := new(MockAI)
|
||||||
|
mockFactory := new(MockAIClientFactory)
|
||||||
|
mockConfig := new(MockConfigProvider)
|
||||||
|
|
||||||
|
// Set test implementations
|
||||||
|
ai.SetTestAIClientFactory(mockFactory)
|
||||||
|
ai.SetTestConfigProvider(mockConfig)
|
||||||
|
defer ai.ResetTestImplementations()
|
||||||
|
|
||||||
|
// Define test data
|
||||||
|
testBackend := "test-backend"
|
||||||
|
testQuery := "test query"
|
||||||
|
|
||||||
|
// Setup expectations
|
||||||
|
mockFactory.On("NewClient", testBackend).Return(mockAI)
|
||||||
|
mockAI.On("Close").Return()
|
||||||
|
|
||||||
|
// Set up configuration with a valid provider
|
||||||
|
mockConfig.On("UnmarshalKey", "ai", mock.Anything).Run(func(args mock.Arguments) {
|
||||||
|
config := args.Get(1).(*ai.AIConfiguration)
|
||||||
|
*config = ai.AIConfiguration{
|
||||||
|
Providers: []ai.AIProvider{
|
||||||
|
{
|
||||||
|
Name: testBackend,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}).Return(nil)
|
||||||
|
|
||||||
|
mockAI.On("Configure", mock.AnythingOfType("*ai.AIProvider")).Return(nil)
|
||||||
|
mockAI.On("GetCompletion", mock.Anything, testQuery).Return("", errors.New("completion error"))
|
||||||
|
|
||||||
|
// Create handler and call Query
|
||||||
|
handler := &Handler{}
|
||||||
|
response, err := handler.Query(context.Background(), &schemav1.QueryRequest{
|
||||||
|
Backend: testBackend,
|
||||||
|
Query: testQuery,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Assertions
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, response)
|
||||||
|
assert.Equal(t, "", response.Response)
|
||||||
|
assert.Equal(t, "completion error", response.Error.Message)
|
||||||
|
|
||||||
|
// Verify mocks
|
||||||
|
mockAI.AssertExpectations(t)
|
||||||
|
mockFactory.AssertExpectations(t)
|
||||||
|
mockConfig.AssertExpectations(t)
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user