mirror of
https://github.com/k8sgpt-ai/k8sgpt.git
synced 2025-09-17 15:52:50 +00:00
feat: openAI explicit value for maxToken and temperature (#659)
* feat: openAI explicit value for maxToken and temp Because when k8sgpt talks with vLLM, the default MaxToken is 16, which is so small. Given the most model supports 2048 token(like Llama1 ..etc), so put here for a safe value. Signed-off-by: Peter Pan <Peter.Pan@daocloud.io> * feat: make temperature a flag Signed-off-by: Peter Pan <Peter.Pan@daocloud.io> --------- Signed-off-by: Peter Pan <Peter.Pan@daocloud.io>
This commit is contained in:
@@ -302,7 +302,7 @@ To start the API server, follow the instruction in [LocalAI](https://github.com/
|
|||||||
To run k8sgpt, run `k8sgpt auth add` with the `localai` backend:
|
To run k8sgpt, run `k8sgpt auth add` with the `localai` backend:
|
||||||
|
|
||||||
```
|
```
|
||||||
k8sgpt auth add --backend localai --model <model_name> --baseurl http://localhost:8080/v1
|
k8sgpt auth add --backend localai --model <model_name> --baseurl http://localhost:8080/v1 --temperature 0.7
|
||||||
```
|
```
|
||||||
|
|
||||||
Now you can analyze with the `localai` backend:
|
Now you can analyze with the `localai` backend:
|
||||||
|
@@ -75,6 +75,10 @@ var addCmd = &cobra.Command{
|
|||||||
color.Red("Error: Model cannot be empty.")
|
color.Red("Error: Model cannot be empty.")
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
if temperature > 1.0 || temperature < 0.0 {
|
||||||
|
color.Red("Error: temperature ranges from 0 to 1.")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
if ai.NeedPassword(backend) && password == "" {
|
if ai.NeedPassword(backend) && password == "" {
|
||||||
fmt.Printf("Enter %s Key: ", backend)
|
fmt.Printf("Enter %s Key: ", backend)
|
||||||
@@ -89,11 +93,12 @@ var addCmd = &cobra.Command{
|
|||||||
|
|
||||||
// create new provider object
|
// create new provider object
|
||||||
newProvider := ai.AIProvider{
|
newProvider := ai.AIProvider{
|
||||||
Name: backend,
|
Name: backend,
|
||||||
Model: model,
|
Model: model,
|
||||||
Password: password,
|
Password: password,
|
||||||
BaseURL: baseURL,
|
BaseURL: baseURL,
|
||||||
Engine: engine,
|
Engine: engine,
|
||||||
|
Temperature: temperature,
|
||||||
}
|
}
|
||||||
|
|
||||||
if providerIndex == -1 {
|
if providerIndex == -1 {
|
||||||
@@ -121,6 +126,8 @@ func init() {
|
|||||||
addCmd.Flags().StringVarP(&password, "password", "p", "", "Backend AI password")
|
addCmd.Flags().StringVarP(&password, "password", "p", "", "Backend AI password")
|
||||||
// add flag for url
|
// add flag for url
|
||||||
addCmd.Flags().StringVarP(&baseURL, "baseurl", "u", "", "URL AI provider, (e.g `http://localhost:8080/v1`)")
|
addCmd.Flags().StringVarP(&baseURL, "baseurl", "u", "", "URL AI provider, (e.g `http://localhost:8080/v1`)")
|
||||||
|
// add flag for temperature
|
||||||
|
addCmd.Flags().Float32VarP(&temperature, "temperature", "t", 0.7, "The sampling temperature, value ranges between 0 ( output be more deterministic) and 1 (more random)")
|
||||||
// add flag for azure open ai engine/deployment name
|
// add flag for azure open ai engine/deployment name
|
||||||
addCmd.Flags().StringVarP(&engine, "engine", "e", "", "Azure AI deployment name")
|
addCmd.Flags().StringVarP(&engine, "engine", "e", "", "Azure AI deployment name")
|
||||||
}
|
}
|
||||||
|
@@ -19,11 +19,12 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
backend string
|
backend string
|
||||||
password string
|
password string
|
||||||
baseURL string
|
baseURL string
|
||||||
model string
|
model string
|
||||||
engine string
|
engine string
|
||||||
|
temperature float32
|
||||||
)
|
)
|
||||||
|
|
||||||
var configAI ai.AIConfiguration
|
var configAI ai.AIConfiguration
|
||||||
|
@@ -49,6 +49,10 @@ var updateCmd = &cobra.Command{
|
|||||||
color.Red("Error: backend must be set.")
|
color.Red("Error: backend must be set.")
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
if temperature > 1.0 || temperature < 0.0 {
|
||||||
|
color.Red("Error: temperature ranges from 0 to 1.")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
for _, b := range inputBackends {
|
for _, b := range inputBackends {
|
||||||
foundBackend := false
|
foundBackend := false
|
||||||
@@ -74,6 +78,7 @@ var updateCmd = &cobra.Command{
|
|||||||
if engine != "" {
|
if engine != "" {
|
||||||
configAI.Providers[i].Engine = engine
|
configAI.Providers[i].Engine = engine
|
||||||
}
|
}
|
||||||
|
configAI.Providers[i].Temperature = temperature
|
||||||
color.Green("%s updated in the AI backend provider list", b)
|
color.Green("%s updated in the AI backend provider list", b)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -101,6 +106,8 @@ func init() {
|
|||||||
updateCmd.Flags().StringVarP(&password, "password", "p", "", "Update backend AI password")
|
updateCmd.Flags().StringVarP(&password, "password", "p", "", "Update backend AI password")
|
||||||
// update flag for url
|
// update flag for url
|
||||||
updateCmd.Flags().StringVarP(&baseURL, "baseurl", "u", "", "Update URL AI provider, (e.g `http://localhost:8080/v1`)")
|
updateCmd.Flags().StringVarP(&baseURL, "baseurl", "u", "", "Update URL AI provider, (e.g `http://localhost:8080/v1`)")
|
||||||
|
// add flag for temperature
|
||||||
|
updateCmd.Flags().Float32VarP(&temperature, "temperature", "t", 0.7, "The sampling temperature, value ranges between 0 ( output be more deterministic) and 1 (more random)")
|
||||||
// update flag for azure open ai engine/deployment name
|
// update flag for azure open ai engine/deployment name
|
||||||
updateCmd.Flags().StringVarP(&engine, "engine", "e", "", "Update Azure AI deployment name")
|
updateCmd.Flags().StringVarP(&engine, "engine", "e", "", "Update Azure AI deployment name")
|
||||||
}
|
}
|
||||||
|
@@ -16,9 +16,10 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type AzureAIClient struct {
|
type AzureAIClient struct {
|
||||||
client *openai.Client
|
client *openai.Client
|
||||||
language string
|
language string
|
||||||
model string
|
model string
|
||||||
|
temperature float32
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *AzureAIClient) Configure(config IAIConfig, lang string) error {
|
func (c *AzureAIClient) Configure(config IAIConfig, lang string) error {
|
||||||
@@ -42,6 +43,7 @@ func (c *AzureAIClient) Configure(config IAIConfig, lang string) error {
|
|||||||
c.language = lang
|
c.language = lang
|
||||||
c.client = client
|
c.client = client
|
||||||
c.model = config.GetModel()
|
c.model = config.GetModel()
|
||||||
|
c.temperature = config.GetTemperature()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -55,6 +57,7 @@ func (c *AzureAIClient) GetCompletion(ctx context.Context, prompt string, prompt
|
|||||||
Content: fmt.Sprintf(default_prompt, c.language, prompt),
|
Content: fmt.Sprintf(default_prompt, c.language, prompt),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
Temperature: c.temperature,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
|
@@ -28,9 +28,10 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type CohereClient struct {
|
type CohereClient struct {
|
||||||
client *cohere.Client
|
client *cohere.Client
|
||||||
language string
|
language string
|
||||||
model string
|
model string
|
||||||
|
temperature float32
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *CohereClient) Configure(config IAIConfig, language string) error {
|
func (c *CohereClient) Configure(config IAIConfig, language string) error {
|
||||||
@@ -52,6 +53,7 @@ func (c *CohereClient) Configure(config IAIConfig, language string) error {
|
|||||||
c.language = language
|
c.language = language
|
||||||
c.client = client
|
c.client = client
|
||||||
c.model = config.GetModel()
|
c.model = config.GetModel()
|
||||||
|
c.temperature = config.GetTemperature()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -64,7 +66,7 @@ func (c *CohereClient) GetCompletion(ctx context.Context, prompt, promptTmpl str
|
|||||||
Model: c.model,
|
Model: c.model,
|
||||||
Prompt: fmt.Sprintf(strings.TrimSpace(promptTmpl), c.language, prompt),
|
Prompt: fmt.Sprintf(strings.TrimSpace(promptTmpl), c.language, prompt),
|
||||||
MaxTokens: cohere.Uint(2048),
|
MaxTokens: cohere.Uint(2048),
|
||||||
Temperature: cohere.Float64(0.75),
|
Temperature: cohere.Float64(float64(c.temperature)),
|
||||||
K: cohere.Int(0),
|
K: cohere.Int(0),
|
||||||
StopSequences: []string{},
|
StopSequences: []string{},
|
||||||
ReturnLikelihoods: "NONE",
|
ReturnLikelihoods: "NONE",
|
||||||
|
@@ -48,6 +48,7 @@ type IAIConfig interface {
|
|||||||
GetModel() string
|
GetModel() string
|
||||||
GetBaseURL() string
|
GetBaseURL() string
|
||||||
GetEngine() string
|
GetEngine() string
|
||||||
|
GetTemperature() float32
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewClient(provider string) IAI {
|
func NewClient(provider string) IAI {
|
||||||
@@ -66,11 +67,12 @@ type AIConfiguration struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type AIProvider struct {
|
type AIProvider struct {
|
||||||
Name string `mapstructure:"name"`
|
Name string `mapstructure:"name"`
|
||||||
Model string `mapstructure:"model"`
|
Model string `mapstructure:"model"`
|
||||||
Password string `mapstructure:"password" yaml:"password,omitempty"`
|
Password string `mapstructure:"password" yaml:"password,omitempty"`
|
||||||
BaseURL string `mapstructure:"baseurl" yaml:"baseurl,omitempty"`
|
BaseURL string `mapstructure:"baseurl" yaml:"baseurl,omitempty"`
|
||||||
Engine string `mapstructure:"engine" yaml:"engine,omitempty"`
|
Engine string `mapstructure:"engine" yaml:"engine,omitempty"`
|
||||||
|
Temperature float32 `mapstructure:"temperature" yaml:"temperature,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *AIProvider) GetBaseURL() string {
|
func (p *AIProvider) GetBaseURL() string {
|
||||||
@@ -88,6 +90,9 @@ func (p *AIProvider) GetModel() string {
|
|||||||
func (p *AIProvider) GetEngine() string {
|
func (p *AIProvider) GetEngine() string {
|
||||||
return p.Engine
|
return p.Engine
|
||||||
}
|
}
|
||||||
|
func (p *AIProvider) GetTemperature() float32 {
|
||||||
|
return p.Temperature
|
||||||
|
}
|
||||||
|
|
||||||
func NeedPassword(backend string) bool {
|
func NeedPassword(backend string) bool {
|
||||||
return backend != "localai"
|
return backend != "localai"
|
||||||
|
@@ -29,11 +29,20 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type OpenAIClient struct {
|
type OpenAIClient struct {
|
||||||
client *openai.Client
|
client *openai.Client
|
||||||
language string
|
language string
|
||||||
model string
|
model string
|
||||||
|
temperature float32
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
// OpenAI completion parameters
|
||||||
|
maxToken = 2048
|
||||||
|
presencePenalty = 0.0
|
||||||
|
frequencyPenalty = 0.0
|
||||||
|
topP = 1.0
|
||||||
|
)
|
||||||
|
|
||||||
func (c *OpenAIClient) Configure(config IAIConfig, language string) error {
|
func (c *OpenAIClient) Configure(config IAIConfig, language string) error {
|
||||||
token := config.GetPassword()
|
token := config.GetPassword()
|
||||||
defaultConfig := openai.DefaultConfig(token)
|
defaultConfig := openai.DefaultConfig(token)
|
||||||
@@ -50,6 +59,7 @@ func (c *OpenAIClient) Configure(config IAIConfig, language string) error {
|
|||||||
c.language = language
|
c.language = language
|
||||||
c.client = client
|
c.client = client
|
||||||
c.model = config.GetModel()
|
c.model = config.GetModel()
|
||||||
|
c.temperature = config.GetTemperature()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -66,6 +76,11 @@ func (c *OpenAIClient) GetCompletion(ctx context.Context, prompt string, promptT
|
|||||||
Content: fmt.Sprintf(promptTmpl, c.language, prompt),
|
Content: fmt.Sprintf(promptTmpl, c.language, prompt),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
Temperature: c.temperature,
|
||||||
|
MaxTokens: maxToken,
|
||||||
|
PresencePenalty: presencePenalty,
|
||||||
|
FrequencyPenalty: frequencyPenalty,
|
||||||
|
TopP: topP,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
|
Reference in New Issue
Block a user