feat: openAI explicit value for maxToken and temperature (#659)

* feat: openAI explicit value for maxToken and temp

Because when k8sgpt talks with vLLM, the default MaxToken is 16,
which is so small.
Given the most model supports 2048 token(like Llama1 ..etc), so
put here for a safe value.

Signed-off-by: Peter Pan <Peter.Pan@daocloud.io>

* feat: make temperature a flag

Signed-off-by: Peter Pan <Peter.Pan@daocloud.io>

---------

Signed-off-by: Peter Pan <Peter.Pan@daocloud.io>
This commit is contained in:
Peter Pan
2023-09-18 20:14:43 +08:00
committed by GitHub
parent 54caff837d
commit f55946d60e
8 changed files with 66 additions and 26 deletions

View File

@@ -302,7 +302,7 @@ To start the API server, follow the instruction in [LocalAI](https://github.com/
To run k8sgpt, run `k8sgpt auth add` with the `localai` backend:
```
k8sgpt auth add --backend localai --model <model_name> --baseurl http://localhost:8080/v1
k8sgpt auth add --backend localai --model <model_name> --baseurl http://localhost:8080/v1 --temperature 0.7
```
Now you can analyze with the `localai` backend:

View File

@@ -75,6 +75,10 @@ var addCmd = &cobra.Command{
color.Red("Error: Model cannot be empty.")
os.Exit(1)
}
if temperature > 1.0 || temperature < 0.0 {
color.Red("Error: temperature ranges from 0 to 1.")
os.Exit(1)
}
if ai.NeedPassword(backend) && password == "" {
fmt.Printf("Enter %s Key: ", backend)
@@ -94,6 +98,7 @@ var addCmd = &cobra.Command{
Password: password,
BaseURL: baseURL,
Engine: engine,
Temperature: temperature,
}
if providerIndex == -1 {
@@ -121,6 +126,8 @@ func init() {
addCmd.Flags().StringVarP(&password, "password", "p", "", "Backend AI password")
// add flag for url
addCmd.Flags().StringVarP(&baseURL, "baseurl", "u", "", "URL AI provider, (e.g `http://localhost:8080/v1`)")
// add flag for temperature
addCmd.Flags().Float32VarP(&temperature, "temperature", "t", 0.7, "The sampling temperature, value ranges between 0 ( output be more deterministic) and 1 (more random)")
// add flag for azure open ai engine/deployment name
addCmd.Flags().StringVarP(&engine, "engine", "e", "", "Azure AI deployment name")
}

View File

@@ -24,6 +24,7 @@ var (
baseURL string
model string
engine string
temperature float32
)
var configAI ai.AIConfiguration

View File

@@ -49,6 +49,10 @@ var updateCmd = &cobra.Command{
color.Red("Error: backend must be set.")
os.Exit(1)
}
if temperature > 1.0 || temperature < 0.0 {
color.Red("Error: temperature ranges from 0 to 1.")
os.Exit(1)
}
for _, b := range inputBackends {
foundBackend := false
@@ -74,6 +78,7 @@ var updateCmd = &cobra.Command{
if engine != "" {
configAI.Providers[i].Engine = engine
}
configAI.Providers[i].Temperature = temperature
color.Green("%s updated in the AI backend provider list", b)
}
}
@@ -101,6 +106,8 @@ func init() {
updateCmd.Flags().StringVarP(&password, "password", "p", "", "Update backend AI password")
// update flag for url
updateCmd.Flags().StringVarP(&baseURL, "baseurl", "u", "", "Update URL AI provider, (e.g `http://localhost:8080/v1`)")
// add flag for temperature
updateCmd.Flags().Float32VarP(&temperature, "temperature", "t", 0.7, "The sampling temperature, value ranges between 0 ( output be more deterministic) and 1 (more random)")
// update flag for azure open ai engine/deployment name
updateCmd.Flags().StringVarP(&engine, "engine", "e", "", "Update Azure AI deployment name")
}

View File

@@ -19,6 +19,7 @@ type AzureAIClient struct {
client *openai.Client
language string
model string
temperature float32
}
func (c *AzureAIClient) Configure(config IAIConfig, lang string) error {
@@ -42,6 +43,7 @@ func (c *AzureAIClient) Configure(config IAIConfig, lang string) error {
c.language = lang
c.client = client
c.model = config.GetModel()
c.temperature = config.GetTemperature()
return nil
}
@@ -55,6 +57,7 @@ func (c *AzureAIClient) GetCompletion(ctx context.Context, prompt string, prompt
Content: fmt.Sprintf(default_prompt, c.language, prompt),
},
},
Temperature: c.temperature,
})
if err != nil {
return "", err

View File

@@ -31,6 +31,7 @@ type CohereClient struct {
client *cohere.Client
language string
model string
temperature float32
}
func (c *CohereClient) Configure(config IAIConfig, language string) error {
@@ -52,6 +53,7 @@ func (c *CohereClient) Configure(config IAIConfig, language string) error {
c.language = language
c.client = client
c.model = config.GetModel()
c.temperature = config.GetTemperature()
return nil
}
@@ -64,7 +66,7 @@ func (c *CohereClient) GetCompletion(ctx context.Context, prompt, promptTmpl str
Model: c.model,
Prompt: fmt.Sprintf(strings.TrimSpace(promptTmpl), c.language, prompt),
MaxTokens: cohere.Uint(2048),
Temperature: cohere.Float64(0.75),
Temperature: cohere.Float64(float64(c.temperature)),
K: cohere.Int(0),
StopSequences: []string{},
ReturnLikelihoods: "NONE",

View File

@@ -48,6 +48,7 @@ type IAIConfig interface {
GetModel() string
GetBaseURL() string
GetEngine() string
GetTemperature() float32
}
func NewClient(provider string) IAI {
@@ -71,6 +72,7 @@ type AIProvider struct {
Password string `mapstructure:"password" yaml:"password,omitempty"`
BaseURL string `mapstructure:"baseurl" yaml:"baseurl,omitempty"`
Engine string `mapstructure:"engine" yaml:"engine,omitempty"`
Temperature float32 `mapstructure:"temperature" yaml:"temperature,omitempty"`
}
func (p *AIProvider) GetBaseURL() string {
@@ -88,6 +90,9 @@ func (p *AIProvider) GetModel() string {
func (p *AIProvider) GetEngine() string {
return p.Engine
}
func (p *AIProvider) GetTemperature() float32 {
return p.Temperature
}
func NeedPassword(backend string) bool {
return backend != "localai"

View File

@@ -32,8 +32,17 @@ type OpenAIClient struct {
client *openai.Client
language string
model string
temperature float32
}
const (
// OpenAI completion parameters
maxToken = 2048
presencePenalty = 0.0
frequencyPenalty = 0.0
topP = 1.0
)
func (c *OpenAIClient) Configure(config IAIConfig, language string) error {
token := config.GetPassword()
defaultConfig := openai.DefaultConfig(token)
@@ -50,6 +59,7 @@ func (c *OpenAIClient) Configure(config IAIConfig, language string) error {
c.language = language
c.client = client
c.model = config.GetModel()
c.temperature = config.GetTemperature()
return nil
}
@@ -66,6 +76,11 @@ func (c *OpenAIClient) GetCompletion(ctx context.Context, prompt string, promptT
Content: fmt.Sprintf(promptTmpl, c.language, prompt),
},
},
Temperature: c.temperature,
MaxTokens: maxToken,
PresencePenalty: presencePenalty,
FrequencyPenalty: frequencyPenalty,
TopP: topP,
})
if err != nil {
return "", err