diff --git a/go.mod b/go.mod index 5697ead..1dba30e 100644 --- a/go.mod +++ b/go.mod @@ -32,7 +32,7 @@ require ( github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.5.1 github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.3.1 github.com/aws/aws-sdk-go v1.51.21 - github.com/cohere-ai/cohere-go v0.2.0 + github.com/cohere-ai/cohere-go/v2 v2.7.1 github.com/google/generative-ai-go v0.10.0 github.com/grpc-ecosystem/grpc-gateway/v2 v2.19.1 github.com/hupe1980/go-huggingface v0.0.15 @@ -62,11 +62,9 @@ require ( github.com/Microsoft/hcsshim v0.11.4 // indirect github.com/alecthomas/units v0.0.0-20231202071711-9a357b53e9c9 // indirect github.com/anchore/go-struct-converter v0.0.0-20230627203149-c72ef8859ca9 // indirect - github.com/cohere-ai/tokenizer v1.1.1 // indirect github.com/containerd/console v1.0.3 // indirect github.com/containerd/log v0.1.0 // indirect github.com/distribution/reference v0.5.0 // indirect - github.com/dlclark/regexp2 v1.10.0 // indirect github.com/evanphx/json-patch/v5 v5.7.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/go-kit/log v0.2.1 // indirect diff --git a/go.sum b/go.sum index d59e1f2..a284a39 100644 --- a/go.sum +++ b/go.sum @@ -1379,10 +1379,8 @@ github.com/cncf/xds/go v0.0.0-20230607035331-e9ce68804cb4/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20231109132714-523115ebc101/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20231128003011-0fa0005c9caa h1:jQCWAUqqlij9Pgj2i/PB79y4KOPYVyFYdROxgaCwdTQ= github.com/cncf/xds/go v0.0.0-20231128003011-0fa0005c9caa/go.mod h1:x/1Gn8zydmfq8dk6e9PdstVsDgu9RuyIIJqAaF//0IM= -github.com/cohere-ai/cohere-go v0.2.0 h1:Gljkn8LTtsAPy79ks1AVmZH9Av4kuQuXEgzEJ/1Ea34= -github.com/cohere-ai/cohere-go v0.2.0/go.mod h1:DFcCu5rwro4wAlluIXY9l17NLGiVBGb2bRio46RXBm8= -github.com/cohere-ai/tokenizer v1.1.1 h1:wCtmCj07O82TMrIiA/CORhIlEYsvMMM8ey+sUdEapHc= -github.com/cohere-ai/tokenizer v1.1.1/go.mod h1:9MNFPd9j1fuiEK3ua2HSCUxxcrfGMlSqpa93livg/C0= +github.com/cohere-ai/cohere-go/v2 v2.7.1 h1:w2osOaSXZLGmmLuIAIR3hepaJENZSLrpU9VSh8rPJ2s= +github.com/cohere-ai/cohere-go/v2 v2.7.1/go.mod h1:dlDCT66i8BqZDuuskFvYzsrc+O0M4l5J9Ibckoflvt4= github.com/containerd/cgroups v1.1.0 h1:v8rEWFl6EoqHB+swVNjVoCJE8o3jX7e8nqBGPLaDFBM= github.com/containerd/cgroups v1.1.0/go.mod h1:6ppBcbh/NOOUU+dMKrykgaBnK9lCIBxHqJDGwsa1mIw= github.com/containerd/console v1.0.3 h1:lIr7SlA5PxZyMV30bDW0MGbiOPXwc63yRuCP0ARubLw= diff --git a/pkg/ai/cohere.go b/pkg/ai/cohere.go index 3ea394a..2c7c959 100644 --- a/pkg/ai/cohere.go +++ b/pkg/ai/cohere.go @@ -17,7 +17,9 @@ import ( "context" "errors" - "github.com/cohere-ai/cohere-go" + api "github.com/cohere-ai/cohere-go/v2" + cohere "github.com/cohere-ai/cohere-go/v2/client" + "github.com/cohere-ai/cohere-go/v2/option" ) const cohereAIClientName = "cohere" @@ -28,45 +30,49 @@ type CohereClient struct { client *cohere.Client model string temperature float32 + maxTokens int } func (c *CohereClient) Configure(config IAIConfig) error { token := config.GetPassword() - client, err := cohere.CreateClient(token) - if err != nil { - return err + opts := []option.RequestOption{ + cohere.WithToken(token), } baseURL := config.GetBaseURL() if baseURL != "" { - client.BaseURL = baseURL + opts = append(opts, cohere.WithBaseURL(baseURL)) } + client := cohere.NewClient(opts...) if client == nil { return errors.New("error creating Cohere client") } + c.client = client c.model = config.GetModel() c.temperature = config.GetTemperature() + c.maxTokens = config.GetMaxTokens() + return nil } -func (c *CohereClient) GetCompletion(_ context.Context, prompt string) (string, error) { +func (c *CohereClient) GetCompletion(ctx context.Context, prompt string) (string, error) { // Create a completion request - resp, err := c.client.Generate(cohere.GenerateOptions{ - Model: c.model, - Prompt: prompt, - MaxTokens: cohere.Uint(2048), - Temperature: cohere.Float64(float64(c.temperature)), - K: cohere.Int(0), - StopSequences: []string{}, - ReturnLikelihoods: "NONE", + response, err := c.client.Chat(ctx, &api.ChatRequest{ + Message: prompt, + Model: &c.model, + K: api.Int(0), + Preamble: api.String(""), + Temperature: api.Float64(float64(c.temperature)), + RawPrompting: api.Bool(false), + MaxTokens: api.Int(c.maxTokens), }) if err != nil { return "", err } - return resp.Generations[0].Text, nil + return response.Text, nil } func (c *CohereClient) GetName() string { diff --git a/renovate.json b/renovate.json index 5acd998..99c0506 100644 --- a/renovate.json +++ b/renovate.json @@ -37,11 +37,6 @@ "enabled": true, "groupName": "golang-group" }, - { - "description": "Exclude retracted cohere-go versions: https://github.com/renovatebot/renovate/issues/13012", - "matchPackageNames": ["github.com/cohere-ai/cohere-go"], - "allowedVersions": "<1" - }, { "matchUpdateTypes": ["minor", "patch"], "matchCurrentVersion": "!/^0/",