mirror of
https://github.com/k8sgpt-ai/k8sgpt.git
synced 2025-05-30 10:37:00 +00:00
This commit adds Top-K sampling, a feature that allows users to control the randomness of the generated text by specifying the number of most probable next words considered by the model. This enhances user control and potentially improves the quality of the generated outputs. Fixes: https://github.com/k8sgpt-ai/k8sgpt/issues/1105 Signed-off-by: VaibhavMalik4187 <vaibhavmalik2018@gmail.com> Co-authored-by: Alex Jones <alexsimonjones@gmail.com>
64 lines
1.4 KiB
Go
64 lines
1.4 KiB
Go
package ai
|
|
|
|
import (
|
|
"context"
|
|
|
|
"github.com/hupe1980/go-huggingface"
|
|
"k8s.io/utils/ptr"
|
|
)
|
|
|
|
const huggingfaceAIClientName = "huggingface"
|
|
|
|
type HuggingfaceClient struct {
|
|
nopCloser
|
|
|
|
client *huggingface.InferenceClient
|
|
model string
|
|
topP float32
|
|
topK int32
|
|
temperature float32
|
|
maxTokens int
|
|
}
|
|
|
|
func (c *HuggingfaceClient) Configure(config IAIConfig) error {
|
|
token := config.GetPassword()
|
|
|
|
client := huggingface.NewInferenceClient(token)
|
|
|
|
c.client = client
|
|
c.model = config.GetModel()
|
|
c.topP = config.GetTopP()
|
|
c.topK = config.GetTopK()
|
|
c.temperature = config.GetTemperature()
|
|
if config.GetMaxTokens() > 500 {
|
|
c.maxTokens = 500
|
|
} else {
|
|
c.maxTokens = config.GetMaxTokens()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *HuggingfaceClient) GetCompletion(ctx context.Context, prompt string) (string, error) {
|
|
resp, err := c.client.Conversational(ctx, &huggingface.ConversationalRequest{
|
|
Inputs: huggingface.ConverstationalInputs{
|
|
Text: prompt,
|
|
},
|
|
Model: c.model,
|
|
Parameters: huggingface.ConversationalParameters{
|
|
TopP: ptr.To[float64](float64(c.topP)),
|
|
TopK: ptr.To[int](int(c.topK)),
|
|
Temperature: ptr.To[float64](float64(c.temperature)),
|
|
MaxLength: &c.maxTokens,
|
|
},
|
|
Options: huggingface.Options{
|
|
WaitForModel: ptr.To[bool](true),
|
|
},
|
|
})
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return resp.GeneratedText, nil
|
|
}
|
|
|
|
func (c *HuggingfaceClient) GetName() string { return huggingfaceAIClientName }
|