feat: add proxysettings for azureopenai and openai (#987)

Signed-off-by: tanujd11 <dwiveditanuj41@gmail.com>
Co-authored-by: Aris Boutselis <arisboutselis08@gmail.com>
Co-authored-by: Alex Jones <alexsimonjones@gmail.com>
This commit is contained in:
Tanuj Dwivedi 2024-02-28 21:40:42 +05:30 committed by GitHub
parent aab8d77feb
commit 307710eddc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 43 additions and 0 deletions

View File

@ -73,6 +73,7 @@ var ServeCmd = &cobra.Command{
model := os.Getenv("K8SGPT_MODEL")
baseURL := os.Getenv("K8SGPT_BASEURL")
engine := os.Getenv("K8SGPT_ENGINE")
proxyEndpoint := os.Getenv("K8SGPT_PROXY_ENDPOINT")
// If the envs are set, allocate in place to the aiProvider
// else exit with error
envIsSet := backend != "" || password != "" || model != ""
@ -83,6 +84,7 @@ var ServeCmd = &cobra.Command{
Model: model,
BaseURL: baseURL,
Engine: engine,
ProxyEndpoint: proxyEndpoint,
Temperature: temperature(),
}

View File

@ -3,6 +3,8 @@ package ai
import (
"context"
"errors"
"net/http"
"net/url"
"github.com/sashabaranov/go-openai"
)
@ -21,6 +23,7 @@ func (c *AzureAIClient) Configure(config IAIConfig) error {
token := config.GetPassword()
baseURL := config.GetBaseURL()
engine := config.GetEngine()
proxyEndpoint := config.GetProxyEndpoint()
defaultConfig := openai.DefaultAzureConfig(token, baseURL)
defaultConfig.AzureModelMapperFunc = func(model string) string {
@ -31,6 +34,20 @@ func (c *AzureAIClient) Configure(config IAIConfig) error {
return azureModelMapping[model]
}
if proxyEndpoint != "" {
proxyUrl, err := url.Parse(proxyEndpoint)
if err != nil {
return err
}
transport := &http.Transport{
Proxy: http.ProxyURL(proxyUrl),
}
defaultConfig.HTTPClient = &http.Client{
Transport: transport,
}
}
client := openai.NewClientWithConfig(defaultConfig)
if client == nil {
return errors.New("error creating Azure OpenAI client")

View File

@ -64,6 +64,7 @@ type IAIConfig interface {
GetPassword() string
GetModel() string
GetBaseURL() string
GetProxyEndpoint() string
GetEndpointName() string
GetEngine() string
GetTemperature() float32
@ -92,6 +93,8 @@ type AIProvider struct {
Model string `mapstructure:"model"`
Password string `mapstructure:"password" yaml:"password,omitempty"`
BaseURL string `mapstructure:"baseurl" yaml:"baseurl,omitempty"`
ProxyEndpoint string `mapstructure:"proxyEndpoint" yaml:"proxyEndpoint,omitempty"`
ProxyPort string `mapstructure:"proxyPort" yaml:"proxyPort,omitempty"`
EndpointName string `mapstructure:"endpointname" yaml:"endpointname,omitempty"`
Engine string `mapstructure:"engine" yaml:"engine,omitempty"`
Temperature float32 `mapstructure:"temperature" yaml:"temperature,omitempty"`
@ -104,6 +107,10 @@ func (p *AIProvider) GetBaseURL() string {
return p.BaseURL
}
func (p *AIProvider) GetProxyEndpoint() string {
return p.ProxyEndpoint
}
func (p *AIProvider) GetEndpointName() string {
return p.EndpointName
}

View File

@ -16,6 +16,8 @@ package ai
import (
"context"
"errors"
"net/http"
"net/url"
"github.com/sashabaranov/go-openai"
)
@ -41,12 +43,27 @@ const (
func (c *OpenAIClient) Configure(config IAIConfig) error {
token := config.GetPassword()
defaultConfig := openai.DefaultConfig(token)
proxyEndpoint := config.GetProxyEndpoint()
baseURL := config.GetBaseURL()
if baseURL != "" {
defaultConfig.BaseURL = baseURL
}
if proxyEndpoint != "" {
proxyUrl, err := url.Parse(proxyEndpoint)
if err != nil {
return err
}
transport := &http.Transport{
Proxy: http.ProxyURL(proxyUrl),
}
defaultConfig.HTTPClient = &http.Client{
Transport: transport,
}
}
client := openai.NewClientWithConfig(defaultConfig)
if client == nil {
return errors.New("error creating OpenAI client")