k8sgpt/pkg/ai/huggingface.go
Vaibhav Malik eda52312ae
feat: implement Top-K sampling for improved user control (#1110)
This commit adds Top-K sampling, a feature that allows users to control
the randomness of the generated text by specifying the number of most
probable next words considered by the model. This enhances user control
and potentially improves the quality of the generated outputs.

Fixes: https://github.com/k8sgpt-ai/k8sgpt/issues/1105

Signed-off-by: VaibhavMalik4187 <vaibhavmalik2018@gmail.com>
Co-authored-by: Alex Jones <alexsimonjones@gmail.com>
2024-05-16 15:41:07 +01:00

64 lines
1.4 KiB
Go

package ai
import (
"context"
"github.com/hupe1980/go-huggingface"
"k8s.io/utils/ptr"
)
const huggingfaceAIClientName = "huggingface"
type HuggingfaceClient struct {
nopCloser
client *huggingface.InferenceClient
model string
topP float32
topK int32
temperature float32
maxTokens int
}
func (c *HuggingfaceClient) Configure(config IAIConfig) error {
token := config.GetPassword()
client := huggingface.NewInferenceClient(token)
c.client = client
c.model = config.GetModel()
c.topP = config.GetTopP()
c.topK = config.GetTopK()
c.temperature = config.GetTemperature()
if config.GetMaxTokens() > 500 {
c.maxTokens = 500
} else {
c.maxTokens = config.GetMaxTokens()
}
return nil
}
func (c *HuggingfaceClient) GetCompletion(ctx context.Context, prompt string) (string, error) {
resp, err := c.client.Conversational(ctx, &huggingface.ConversationalRequest{
Inputs: huggingface.ConverstationalInputs{
Text: prompt,
},
Model: c.model,
Parameters: huggingface.ConversationalParameters{
TopP: ptr.To[float64](float64(c.topP)),
TopK: ptr.To[int](int(c.topK)),
Temperature: ptr.To[float64](float64(c.temperature)),
MaxLength: &c.maxTokens,
},
Options: huggingface.Options{
WaitForModel: ptr.To[bool](true),
},
})
if err != nil {
return "", err
}
return resp.GeneratedText, nil
}
func (c *HuggingfaceClient) GetName() string { return huggingfaceAIClientName }