diff --git a/pkg/ai/watsonxai.go b/pkg/ai/watsonxai.go index f6ce81c..655b8b8 100644 --- a/pkg/ai/watsonxai.go +++ b/pkg/ai/watsonxai.go @@ -1,10 +1,10 @@ package ai import ( - "os" - "fmt" "context" "errors" + "fmt" + "os" wx "github.com/IBM/watsonx-go/pkg/models" ) @@ -14,28 +14,33 @@ const watsonxAIClientName = "watsonxai" type WatsonxAIClient struct { nopCloser - client *wx.Client - model string - temperature float32 - topP float32 - topK int32 - maxNewTokens int + client *wx.Client + model string + temperature float32 + topP float32 + topK int32 + maxNewTokens int } const ( modelMetallama = "ibm/granite-13b-chat-v2" + maxTokens = 2048 ) func (c *WatsonxAIClient) Configure(config IAIConfig) error { - if(config.GetModel() == "") { - c.model = config.GetModel() - } else { + if config.GetModel() == "" { c.model = modelMetallama + } else { + c.model = config.GetModel() + } + if config.GetMaxTokens() == 0 { + c.maxNewTokens = maxTokens + } else { + c.maxNewTokens = config.GetMaxTokens() } c.temperature = config.GetTemperature() c.topP = config.GetTopP() c.topK = config.GetTopK() - c.maxNewTokens = config.GetMaxTokens() // WatsonxAPIKeyEnvVarName = "WATSONX_API_KEY" // WatsonxProjectIDEnvVarName = "WATSONX_PROJECT_ID" @@ -75,7 +80,6 @@ func (c *WatsonxAIClient) GetCompletion(ctx context.Context, prompt string) (str if result.Text == "" { return "", errors.New("Expected a result, but got an empty string") } - return result.Text, nil }