mirror of
https://github.com/k8sgpt-ai/k8sgpt.git
synced 2025-09-08 18:50:13 +00:00
feat: implement Top-K sampling for improved user control (#1110)
This commit adds Top-K sampling, a feature that allows users to control the randomness of the generated text by specifying the number of most probable next words considered by the model. This enhances user control and potentially improves the quality of the generated outputs. Fixes: https://github.com/k8sgpt-ai/k8sgpt/issues/1105 Signed-off-by: VaibhavMalik4187 <vaibhavmalik2018@gmail.com> Co-authored-by: Alex Jones <alexsimonjones@gmail.com>
This commit is contained in:
@@ -2,6 +2,7 @@ package ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/hupe1980/go-huggingface"
|
||||
"k8s.io/utils/ptr"
|
||||
)
|
||||
@@ -14,6 +15,7 @@ type HuggingfaceClient struct {
|
||||
client *huggingface.InferenceClient
|
||||
model string
|
||||
topP float32
|
||||
topK int32
|
||||
temperature float32
|
||||
maxTokens int
|
||||
}
|
||||
@@ -26,6 +28,7 @@ func (c *HuggingfaceClient) Configure(config IAIConfig) error {
|
||||
c.client = client
|
||||
c.model = config.GetModel()
|
||||
c.topP = config.GetTopP()
|
||||
c.topK = config.GetTopK()
|
||||
c.temperature = config.GetTemperature()
|
||||
if config.GetMaxTokens() > 500 {
|
||||
c.maxTokens = 500
|
||||
@@ -43,6 +46,7 @@ func (c *HuggingfaceClient) GetCompletion(ctx context.Context, prompt string) (s
|
||||
Model: c.model,
|
||||
Parameters: huggingface.ConversationalParameters{
|
||||
TopP: ptr.To[float64](float64(c.topP)),
|
||||
TopK: ptr.To[int](int(c.topK)),
|
||||
Temperature: ptr.To[float64](float64(c.temperature)),
|
||||
MaxLength: &c.maxTokens,
|
||||
},
|
||||
|
Reference in New Issue
Block a user