Files
k8sgpt/pkg/ai/watsonxai.go
Yanwei Li 7e3d452cdb fix: add default maxToken value of watsonxai backend (#1209)
Signed-off-by: yanweili <yanweili@ibm.com>
Co-authored-by: yanweili <yanweili@ibm.com>
Signed-off-by: AlexsJones <alexsimonjones@gmail.com>
2024-10-24 07:27:47 +01:00

89 lines
1.9 KiB
Go

package ai
import (
"context"
"errors"
"fmt"
"os"
wx "github.com/IBM/watsonx-go/pkg/models"
)
const watsonxAIClientName = "watsonxai"
type WatsonxAIClient struct {
nopCloser
client *wx.Client
model string
temperature float32
topP float32
topK int32
maxNewTokens int
}
const (
modelMetallama = "ibm/granite-13b-chat-v2"
maxTokens = 2048
)
func (c *WatsonxAIClient) Configure(config IAIConfig) error {
if config.GetModel() == "" {
c.model = modelMetallama
} else {
c.model = config.GetModel()
}
if config.GetMaxTokens() == 0 {
c.maxNewTokens = maxTokens
} else {
c.maxNewTokens = config.GetMaxTokens()
}
c.temperature = config.GetTemperature()
c.topP = config.GetTopP()
c.topK = config.GetTopK()
// WatsonxAPIKeyEnvVarName = "WATSONX_API_KEY"
// WatsonxProjectIDEnvVarName = "WATSONX_PROJECT_ID"
apiKey, projectID := os.Getenv(wx.WatsonxAPIKeyEnvVarName), os.Getenv(wx.WatsonxProjectIDEnvVarName)
if apiKey == "" {
return errors.New("No watsonx API key provided")
}
if projectID == "" {
return errors.New("No watsonx project ID provided")
}
client, err := wx.NewClient(
wx.WithWatsonxAPIKey(apiKey),
wx.WithWatsonxProjectID(projectID),
)
if err != nil {
return fmt.Errorf("Failed to create client for testing. Error: %v", err)
}
c.client = client
return nil
}
func (c *WatsonxAIClient) GetCompletion(ctx context.Context, prompt string) (string, error) {
result, err := c.client.GenerateText(
c.model,
prompt,
wx.WithTemperature((float64)(c.temperature)),
wx.WithTopP((float64)(c.topP)),
wx.WithTopK((uint)(c.topK)),
wx.WithMaxNewTokens((uint)(c.maxNewTokens)),
)
if err != nil {
return "", fmt.Errorf("Expected no error, but got an error: %v", err)
}
if result.Text == "" {
return "", errors.New("Expected a result, but got an empty string")
}
return result.Text, nil
}
func (c *WatsonxAIClient) GetName() string {
return watsonxAIClientName
}