mirror of
https://github.com/k8sgpt-ai/k8sgpt.git
synced 2025-08-09 11:47:23 +00:00
feat: add custom http headers to openai related api backends (#1174)
* feat: add custom http headers to openai related api backends Signed-off-by: Aris Boutselis <arisboutselis08@gmail.com> * ci: add custom headers test Signed-off-by: Aris Boutselis <arisboutselis08@gmail.com> * add error handling Signed-off-by: Aris Boutselis <arisboutselis08@gmail.com> * chore(deps): update docker/setup-buildx-action digest to 4fd8129 (#1173) Signed-off-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com> Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com> Signed-off-by: Aris Boutselis <arisboutselis08@gmail.com> * fix(deps): update module buf.build/gen/go/k8sgpt-ai/k8sgpt/grpc-ecosystem/gateway/v2 to v2.20.0-20240406062209-1cc152efbf5c.1 (#1147) Signed-off-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com> Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com> Signed-off-by: Aris Boutselis <arisboutselis08@gmail.com> * chore(deps): update anchore/sbom-action action to v0.16.0 (#1146) Signed-off-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com> Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com> Co-authored-by: Alex Jones <alexsimonjones@gmail.com> Signed-off-by: Aris Boutselis <arisboutselis08@gmail.com> * Update README.md Signed-off-by: Aris Boutselis <arisboutselis08@gmail.com> --------- Signed-off-by: Aris Boutselis <arisboutselis08@gmail.com> Signed-off-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com> Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com> Co-authored-by: Alex Jones <alexsimonjones@gmail.com>
This commit is contained in:
parent
fef853966f
commit
02e754ed59
@ -299,6 +299,12 @@ _Analysis with serve mode_
|
|||||||
```
|
```
|
||||||
grpcurl -plaintext -d '{"namespace": "k8sgpt", "explain": false}' localhost:8080 schema.v1.ServerService/Analyze
|
grpcurl -plaintext -d '{"namespace": "k8sgpt", "explain": false}' localhost:8080 schema.v1.ServerService/Analyze
|
||||||
```
|
```
|
||||||
|
|
||||||
|
_Analysis with custom headers_
|
||||||
|
|
||||||
|
```
|
||||||
|
k8sgpt analyze --explain --custom-headers CustomHeaderKey:CustomHeaderValue
|
||||||
|
```
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
## LLM AI Backends
|
## LLM AI Backends
|
||||||
|
@ -38,6 +38,7 @@ var (
|
|||||||
withDoc bool
|
withDoc bool
|
||||||
interactiveMode bool
|
interactiveMode bool
|
||||||
customAnalysis bool
|
customAnalysis bool
|
||||||
|
customHeaders []string
|
||||||
)
|
)
|
||||||
|
|
||||||
// AnalyzeCmd represents the problems command
|
// AnalyzeCmd represents the problems command
|
||||||
@ -59,6 +60,7 @@ var AnalyzeCmd = &cobra.Command{
|
|||||||
maxConcurrency,
|
maxConcurrency,
|
||||||
withDoc,
|
withDoc,
|
||||||
interactiveMode,
|
interactiveMode,
|
||||||
|
customHeaders,
|
||||||
)
|
)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -138,5 +140,6 @@ func init() {
|
|||||||
AnalyzeCmd.Flags().BoolVarP(&interactiveMode, "interactive", "i", false, "Enable interactive mode that allows further conversation with LLM about the problem. Works only with --explain flag")
|
AnalyzeCmd.Flags().BoolVarP(&interactiveMode, "interactive", "i", false, "Enable interactive mode that allows further conversation with LLM about the problem. Works only with --explain flag")
|
||||||
// custom analysis flag
|
// custom analysis flag
|
||||||
AnalyzeCmd.Flags().BoolVarP(&customAnalysis, "custom-analysis", "z", false, "Enable custom analyzers")
|
AnalyzeCmd.Flags().BoolVarP(&customAnalysis, "custom-analysis", "z", false, "Enable custom analyzers")
|
||||||
|
// add custom headers flag
|
||||||
|
AnalyzeCmd.Flags().StringSliceVarP(&customHeaders, "custom-headers", "r", []string{}, "Custom Headers, <key>:<value> (e.g CustomHeaderKey:CustomHeaderValue AnotherHeader:AnotherValue)")
|
||||||
}
|
}
|
||||||
|
@ -15,6 +15,7 @@ package ai
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -83,6 +84,7 @@ type IAIConfig interface {
|
|||||||
GetProviderId() string
|
GetProviderId() string
|
||||||
GetCompartmentId() string
|
GetCompartmentId() string
|
||||||
GetOrganizationId() string
|
GetOrganizationId() string
|
||||||
|
GetCustomHeaders() []http.Header
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewClient(provider string) IAI {
|
func NewClient(provider string) IAI {
|
||||||
@ -117,6 +119,7 @@ type AIProvider struct {
|
|||||||
TopK int32 `mapstructure:"topk" yaml:"topk,omitempty"`
|
TopK int32 `mapstructure:"topk" yaml:"topk,omitempty"`
|
||||||
MaxTokens int `mapstructure:"maxtokens" yaml:"maxtokens,omitempty"`
|
MaxTokens int `mapstructure:"maxtokens" yaml:"maxtokens,omitempty"`
|
||||||
OrganizationId string `mapstructure:"organizationid" yaml:"organizationid,omitempty"`
|
OrganizationId string `mapstructure:"organizationid" yaml:"organizationid,omitempty"`
|
||||||
|
CustomHeaders []http.Header `mapstructure:"customHeaders"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *AIProvider) GetBaseURL() string {
|
func (p *AIProvider) GetBaseURL() string {
|
||||||
@ -174,6 +177,10 @@ func (p *AIProvider) GetOrganizationId() string {
|
|||||||
return p.OrganizationId
|
return p.OrganizationId
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *AIProvider) GetCustomHeaders() []http.Header {
|
||||||
|
return p.CustomHeaders
|
||||||
|
}
|
||||||
|
|
||||||
var passwordlessProviders = []string{"localai", "ollama", "amazonsagemaker", "amazonbedrock", "googlevertexai", "oci", "watsonxai"}
|
var passwordlessProviders = []string{"localai", "ollama", "amazonsagemaker", "amazonbedrock", "googlevertexai", "oci", "watsonxai"}
|
||||||
|
|
||||||
func NeedPassword(backend string) bool {
|
func NeedPassword(backend string) bool {
|
||||||
|
@ -52,24 +52,27 @@ func (c *OpenAIClient) Configure(config IAIConfig) error {
|
|||||||
defaultConfig.BaseURL = baseURL
|
defaultConfig.BaseURL = baseURL
|
||||||
}
|
}
|
||||||
|
|
||||||
|
transport := &http.Transport{}
|
||||||
if proxyEndpoint != "" {
|
if proxyEndpoint != "" {
|
||||||
proxyUrl, err := url.Parse(proxyEndpoint)
|
proxyUrl, err := url.Parse(proxyEndpoint)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
transport := &http.Transport{
|
transport.Proxy = http.ProxyURL(proxyUrl)
|
||||||
Proxy: http.ProxyURL(proxyUrl),
|
|
||||||
}
|
|
||||||
|
|
||||||
defaultConfig.HTTPClient = &http.Client{
|
|
||||||
Transport: transport,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if orgId != "" {
|
if orgId != "" {
|
||||||
defaultConfig.OrgID = orgId
|
defaultConfig.OrgID = orgId
|
||||||
}
|
}
|
||||||
|
|
||||||
|
customHeaders := config.GetCustomHeaders()
|
||||||
|
defaultConfig.HTTPClient = &http.Client{
|
||||||
|
Transport: &OpenAIHeaderTransport{
|
||||||
|
Origin: transport,
|
||||||
|
Headers: customHeaders,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
client := openai.NewClientWithConfig(defaultConfig)
|
client := openai.NewClientWithConfig(defaultConfig)
|
||||||
if client == nil {
|
if client == nil {
|
||||||
return errors.New("error creating OpenAI client")
|
return errors.New("error creating OpenAI client")
|
||||||
@ -106,3 +109,25 @@ func (c *OpenAIClient) GetCompletion(ctx context.Context, prompt string) (string
|
|||||||
func (c *OpenAIClient) GetName() string {
|
func (c *OpenAIClient) GetName() string {
|
||||||
return openAIClientName
|
return openAIClientName
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OpenAIHeaderTransport is an http.RoundTripper that adds the given headers to each request.
|
||||||
|
type OpenAIHeaderTransport struct {
|
||||||
|
Origin http.RoundTripper
|
||||||
|
Headers []http.Header
|
||||||
|
}
|
||||||
|
|
||||||
|
// RoundTrip implements the http.RoundTripper interface.
|
||||||
|
func (t *OpenAIHeaderTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
// Clone the request to avoid modifying the original request
|
||||||
|
clonedReq := req.Clone(req.Context())
|
||||||
|
for _, header := range t.Headers {
|
||||||
|
for key, values := range header {
|
||||||
|
// Possible values per header: RFC 2616
|
||||||
|
for _, value := range values {
|
||||||
|
clonedReq.Header.Add(key, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return t.Origin.RoundTrip(clonedReq)
|
||||||
|
}
|
||||||
|
106
pkg/ai/openai_header_transport_test.go
Normal file
106
pkg/ai/openai_header_transport_test.go
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
package ai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Mock configuration
|
||||||
|
type mockConfig struct {
|
||||||
|
baseURL string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockConfig) GetPassword() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockConfig) GetOrganizationId() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockConfig) GetProxyEndpoint() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockConfig) GetBaseURL() string {
|
||||||
|
return m.baseURL
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockConfig) GetCustomHeaders() []http.Header {
|
||||||
|
return []http.Header{
|
||||||
|
{"X-Custom-Header-1": []string{"Value1"}},
|
||||||
|
{"X-Custom-Header-2": []string{"Value2"}},
|
||||||
|
{"X-Custom-Header-2": []string{"Value3"}}, // Testing multiple values for the same header
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockConfig) GetModel() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockConfig) GetTemperature() float32 {
|
||||||
|
return 0.0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockConfig) GetTopP() float32 {
|
||||||
|
return 0.0
|
||||||
|
}
|
||||||
|
func (m *mockConfig) GetCompartmentId() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockConfig) GetTopK() int32 {
|
||||||
|
return 0.0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockConfig) GetMaxTokens() int {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockConfig) GetEndpointName() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
func (m *mockConfig) GetEngine() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockConfig) GetProviderId() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockConfig) GetProviderRegion() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIClient_CustomHeaders(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
assert.Equal(t, "Value1", r.Header.Get("X-Custom-Header-1"))
|
||||||
|
assert.ElementsMatch(t, []string{"Value2", "Value3"}, r.Header["X-Custom-Header-2"])
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
// Mock response for openai completion
|
||||||
|
mockResponse := `{"choices": [{"message": {"content": "test"}}]}`
|
||||||
|
n, err := w.Write([]byte(mockResponse))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error writing response: %v", err)
|
||||||
|
}
|
||||||
|
if n != len(mockResponse) {
|
||||||
|
t.Fatalf("expected to write %d bytes but wrote %d bytes", len(mockResponse), n)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
config := &mockConfig{baseURL: server.URL}
|
||||||
|
|
||||||
|
client := &OpenAIClient{}
|
||||||
|
err := client.Configure(config)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Make a completion request to trigger the headers
|
||||||
|
ctx := context.Background()
|
||||||
|
_, err = client.GetCompletion(ctx, "foo prompt")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
@ -79,6 +79,7 @@ func NewAnalysis(
|
|||||||
maxConcurrency int,
|
maxConcurrency int,
|
||||||
withDoc bool,
|
withDoc bool,
|
||||||
interactiveMode bool,
|
interactiveMode bool,
|
||||||
|
httpHeaders []string,
|
||||||
) (*Analysis, error) {
|
) (*Analysis, error) {
|
||||||
// Get kubernetes client from viper.
|
// Get kubernetes client from viper.
|
||||||
kubecontext := viper.GetString("kubecontext")
|
kubecontext := viper.GetString("kubecontext")
|
||||||
@ -146,6 +147,8 @@ func NewAnalysis(
|
|||||||
}
|
}
|
||||||
|
|
||||||
aiClient := ai.NewClient(aiProvider.Name)
|
aiClient := ai.NewClient(aiProvider.Name)
|
||||||
|
customHeaders := util.NewHeaders(httpHeaders)
|
||||||
|
aiProvider.CustomHeaders = customHeaders
|
||||||
if err := aiClient.Configure(&aiProvider); err != nil {
|
if err := aiClient.Configure(&aiProvider); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -30,6 +30,7 @@ func (h *handler) Analyze(ctx context.Context, i *schemav1.AnalyzeRequest) (
|
|||||||
int(i.MaxConcurrency),
|
int(i.MaxConcurrency),
|
||||||
false, // Kubernetes Doc disabled in server mode
|
false, // Kubernetes Doc disabled in server mode
|
||||||
false, // Interactive mode disabled in server mode
|
false, // Interactive mode disabled in server mode
|
||||||
|
[]string{}, //TODO: add custom http headers in server mode
|
||||||
)
|
)
|
||||||
config.Context = ctx // Replace context for correct timeouts.
|
config.Context = ctx // Replace context for correct timeouts.
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -21,6 +21,7 @@ import (
|
|||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
@ -261,3 +262,36 @@ func FetchLatestEvent(ctx context.Context, kubernetesClient *kubernetes.Client,
|
|||||||
}
|
}
|
||||||
return latestEvent, nil
|
return latestEvent, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewHeaders parses a slice of strings in the format "key:value" into []http.Header
|
||||||
|
// It handles headers with the same key by appending values
|
||||||
|
func NewHeaders(customHeaders []string) []http.Header {
|
||||||
|
headers := make(map[string][]string)
|
||||||
|
|
||||||
|
for _, header := range customHeaders {
|
||||||
|
vals := strings.SplitN(header, ":", 2)
|
||||||
|
if len(vals) != 2 {
|
||||||
|
//TODO: Handle error instead of ignoring it
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
key := strings.TrimSpace(vals[0])
|
||||||
|
value := strings.TrimSpace(vals[1])
|
||||||
|
|
||||||
|
if _, ok := headers[key]; !ok {
|
||||||
|
headers[key] = []string{}
|
||||||
|
}
|
||||||
|
headers[key] = append(headers[key], value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert map to []http.Header format
|
||||||
|
var result []http.Header
|
||||||
|
for key, values := range headers {
|
||||||
|
header := make(http.Header)
|
||||||
|
for _, value := range values {
|
||||||
|
header.Add(key, value)
|
||||||
|
}
|
||||||
|
result = append(result, header)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user