fix: add default maxToken value of watsonxai backend (#1209)

Signed-off-by: yanweili <yanweili@ibm.com>
Co-authored-by: yanweili <yanweili@ibm.com>
This commit is contained in:
Yanwei Li
2024-08-02 02:15:28 -07:00
committed by GitHub
parent a068310731
commit d43fd878ba

View File

@@ -1,10 +1,10 @@
package ai package ai
import ( import (
"os"
"fmt"
"context" "context"
"errors" "errors"
"fmt"
"os"
wx "github.com/IBM/watsonx-go/pkg/models" wx "github.com/IBM/watsonx-go/pkg/models"
) )
@@ -24,18 +24,23 @@ type WatsonxAIClient struct {
const ( const (
modelMetallama = "ibm/granite-13b-chat-v2" modelMetallama = "ibm/granite-13b-chat-v2"
maxTokens = 2048
) )
func (c *WatsonxAIClient) Configure(config IAIConfig) error { func (c *WatsonxAIClient) Configure(config IAIConfig) error {
if(config.GetModel() == "") { if config.GetModel() == "" {
c.model = config.GetModel()
} else {
c.model = modelMetallama c.model = modelMetallama
} else {
c.model = config.GetModel()
}
if config.GetMaxTokens() == 0 {
c.maxNewTokens = maxTokens
} else {
c.maxNewTokens = config.GetMaxTokens()
} }
c.temperature = config.GetTemperature() c.temperature = config.GetTemperature()
c.topP = config.GetTopP() c.topP = config.GetTopP()
c.topK = config.GetTopK() c.topK = config.GetTopK()
c.maxNewTokens = config.GetMaxTokens()
// WatsonxAPIKeyEnvVarName = "WATSONX_API_KEY" // WatsonxAPIKeyEnvVarName = "WATSONX_API_KEY"
// WatsonxProjectIDEnvVarName = "WATSONX_PROJECT_ID" // WatsonxProjectIDEnvVarName = "WATSONX_PROJECT_ID"
@@ -75,7 +80,6 @@ func (c *WatsonxAIClient) GetCompletion(ctx context.Context, prompt string) (str
if result.Text == "" { if result.Text == "" {
return "", errors.New("Expected a result, but got an empty string") return "", errors.New("Expected a result, but got an empty string")
} }
return result.Text, nil return result.Text, nil
} }