fix: set topP from config (#1053)

* fix: set topP from config

Signed-off-by: “Guido <muscionig@gmail.com>

* style: correct format of openai ai provider

Signed-off-by: “Guido <muscionig@gmail.com>

* feat: set topP from the environment

Signed-off-by: “Guido <muscionig@gmail.com>

---------

Signed-off-by: “Guido <muscionig@gmail.com>
This commit is contained in:
Guido Muscioni
2024-04-19 10:38:52 -05:00
committed by GitHub
parent 1ae4e75196
commit c162cc22ee
2 changed files with 21 additions and 2 deletions

View File

@@ -27,6 +27,7 @@ import (
const ( const (
defaultTemperature float32 = 0.7 defaultTemperature float32 = 0.7
defaultTopP float32 = 1.0
) )
var ( var (
@@ -67,6 +68,22 @@ var ServeCmd = &cobra.Command{
} }
return float32(temperature) return float32(temperature)
} }
topP := func() float32 {
env := os.Getenv("K8SGPT_TOP_P")
if env == "" {
return defaultTopP
}
topP, err := strconv.ParseFloat(env, 32)
if err != nil {
color.Red("Unable to convert topP value: %v", err)
os.Exit(1)
}
if topP > 1.0 || topP < 0.0 {
color.Red("Error: topP ranges from 0 to 1.")
os.Exit(1)
}
return float32(topP)
}
// Check for env injection // Check for env injection
backend = os.Getenv("K8SGPT_BACKEND") backend = os.Getenv("K8SGPT_BACKEND")
password := os.Getenv("K8SGPT_PASSWORD") password := os.Getenv("K8SGPT_PASSWORD")
@@ -86,6 +103,7 @@ var ServeCmd = &cobra.Command{
Engine: engine, Engine: engine,
ProxyEndpoint: proxyEndpoint, ProxyEndpoint: proxyEndpoint,
Temperature: temperature(), Temperature: temperature(),
TopP: topP(),
} }
configAI.Providers = append(configAI.Providers, *aiProvider) configAI.Providers = append(configAI.Providers, *aiProvider)

View File

@@ -30,6 +30,7 @@ type OpenAIClient struct {
client *openai.Client client *openai.Client
model string model string
temperature float32 temperature float32
topP float32
} }
const ( const (
@@ -37,7 +38,6 @@ const (
maxToken = 2048 maxToken = 2048
presencePenalty = 0.0 presencePenalty = 0.0
frequencyPenalty = 0.0 frequencyPenalty = 0.0
topP = 1.0
) )
func (c *OpenAIClient) Configure(config IAIConfig) error { func (c *OpenAIClient) Configure(config IAIConfig) error {
@@ -71,6 +71,7 @@ func (c *OpenAIClient) Configure(config IAIConfig) error {
c.client = client c.client = client
c.model = config.GetModel() c.model = config.GetModel()
c.temperature = config.GetTemperature() c.temperature = config.GetTemperature()
c.topP = config.GetTopP()
return nil return nil
} }
@@ -88,7 +89,7 @@ func (c *OpenAIClient) GetCompletion(ctx context.Context, prompt string) (string
MaxTokens: maxToken, MaxTokens: maxToken,
PresencePenalty: presencePenalty, PresencePenalty: presencePenalty,
FrequencyPenalty: frequencyPenalty, FrequencyPenalty: frequencyPenalty,
TopP: topP, TopP: c.topP,
}) })
if err != nil { if err != nil {
return "", err return "", err