feat: implement Top-K sampling for improved user control (#1110)

This commit adds Top-K sampling, a feature that allows users to control
the randomness of the generated text by specifying the number of most
probable next words considered by the model. This enhances user control
and potentially improves the quality of the generated outputs.

Fixes: https://github.com/k8sgpt-ai/k8sgpt/issues/1105

Signed-off-by: VaibhavMalik4187 <vaibhavmalik2018@gmail.com>
Co-authored-by: Alex Jones <alexsimonjones@gmail.com>
This commit is contained in:
Vaibhav Malik
2024-05-16 20:11:07 +05:30
committed by GitHub
parent 882c6f5225
commit eda52312ae
8 changed files with 46 additions and 0 deletions

View File

@@ -2,6 +2,7 @@ package ai
import (
"context"
"github.com/hupe1980/go-huggingface"
"k8s.io/utils/ptr"
)
@@ -14,6 +15,7 @@ type HuggingfaceClient struct {
client *huggingface.InferenceClient
model string
topP float32
topK int32
temperature float32
maxTokens int
}
@@ -26,6 +28,7 @@ func (c *HuggingfaceClient) Configure(config IAIConfig) error {
c.client = client
c.model = config.GetModel()
c.topP = config.GetTopP()
c.topK = config.GetTopK()
c.temperature = config.GetTemperature()
if config.GetMaxTokens() > 500 {
c.maxTokens = 500
@@ -43,6 +46,7 @@ func (c *HuggingfaceClient) GetCompletion(ctx context.Context, prompt string) (s
Model: c.model,
Parameters: huggingface.ConversationalParameters{
TopP: ptr.To[float64](float64(c.topP)),
TopK: ptr.To[int](int(c.topK)),
Temperature: ptr.To[float64](float64(c.temperature)),
MaxLength: &c.maxTokens,
},