mirror of
https://github.com/k8sgpt-ai/k8sgpt.git
synced 2025-05-12 18:14:22 +00:00
feat: oci genai (#1102)
Signed-off-by: Anders Swanson <anders.swanson@oracle.com> Co-authored-by: Alex Jones <alexsimonjones@gmail.com>
This commit is contained in:
parent
eda52312ae
commit
047afd46d6
@ -127,6 +127,7 @@ var addCmd = &cobra.Command{
|
||||
Temperature: temperature,
|
||||
ProviderRegion: providerRegion,
|
||||
ProviderId: providerId,
|
||||
CompartmentId: compartmentId,
|
||||
TopP: topP,
|
||||
TopK: topK,
|
||||
MaxTokens: maxTokens,
|
||||
@ -173,4 +174,6 @@ func init() {
|
||||
addCmd.Flags().StringVarP(&providerRegion, "providerRegion", "r", "", "Provider Region name (only for amazonbedrock, googlevertexai backend)")
|
||||
//add flag for vertexAI Project ID
|
||||
addCmd.Flags().StringVarP(&providerId, "providerId", "i", "", "Provider specific ID for e.g. project (only for googlevertexai backend)")
|
||||
//add flag for OCI Compartment ID
|
||||
addCmd.Flags().StringVarP(&compartmentId, "compartmentId", "k", "", "Compartment ID for generative AI model (only for oci backend)")
|
||||
}
|
||||
|
@ -28,6 +28,7 @@ var (
|
||||
temperature float32
|
||||
providerRegion string
|
||||
providerId string
|
||||
compartmentId string
|
||||
topP float32
|
||||
topK int32
|
||||
maxTokens int
|
||||
|
3
go.mod
3
go.mod
@ -38,6 +38,7 @@ require (
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.19.1
|
||||
github.com/hupe1980/go-huggingface v0.0.15
|
||||
github.com/olekukonko/tablewriter v0.0.5
|
||||
github.com/oracle/oci-go-sdk/v65 v65.65.1
|
||||
github.com/prometheus/prometheus v0.49.1
|
||||
github.com/pterm/pterm v0.12.79
|
||||
google.golang.org/api v0.172.0
|
||||
@ -70,6 +71,7 @@ require (
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/go-kit/log v0.2.1 // indirect
|
||||
github.com/go-logfmt/logfmt v0.6.0 // indirect
|
||||
github.com/gofrs/flock v0.8.1 // indirect
|
||||
github.com/golang-jwt/jwt/v5 v5.2.1 // indirect
|
||||
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
|
||||
github.com/google/gnostic-models v0.6.9-0.20230804172637-c7be7c783f49 // indirect
|
||||
@ -87,6 +89,7 @@ require (
|
||||
github.com/prometheus/common/sigv4 v0.1.0 // indirect
|
||||
github.com/sagikazarmark/locafero v0.4.0 // indirect
|
||||
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
|
||||
github.com/sony/gobreaker v0.5.0 // indirect
|
||||
github.com/sourcegraph/conc v0.3.0 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||
go.opencensus.io v0.24.0 // indirect
|
||||
|
6
go.sum
6
go.sum
@ -1545,6 +1545,8 @@ github.com/gobuffalo/packr/v2 v2.8.3/go.mod h1:0SahksCVcx4IMnigTjiFuyldmTrdTctXs
|
||||
github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y=
|
||||
github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8=
|
||||
github.com/goccy/go-json v0.9.11/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||
github.com/gofrs/flock v0.8.1 h1:+gYjHKf32LDeiEEFhQaotPbLuUXjY5ZqxKgXy7n59aw=
|
||||
github.com/gofrs/flock v0.8.1/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU=
|
||||
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
|
||||
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
|
||||
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
||||
@ -1934,6 +1936,8 @@ github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8
|
||||
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
||||
github.com/opencontainers/image-spec v1.1.0-rc5 h1:Ygwkfw9bpDvs+c9E34SdgGOj41dX/cbdlwvlWt0pnFI=
|
||||
github.com/opencontainers/image-spec v1.1.0-rc5/go.mod h1:X4pATf0uXsnn3g5aiGIsVnJBR4mxhKzfwmvK/B2NTm8=
|
||||
github.com/oracle/oci-go-sdk/v65 v65.65.1 h1:sv7uD844tJGa2Vc+2KaByoXQ0FllZDGV/2+9MdxN6nA=
|
||||
github.com/oracle/oci-go-sdk/v65 v65.65.1/go.mod h1:IBEV9l1qBzUpo7zgGaRUhbB05BVfcDGYRFBCPlTcPp0=
|
||||
github.com/ovh/go-ovh v1.4.3 h1:Gs3V823zwTFpzgGLZNI6ILS4rmxZgJwJCz54Er9LwD0=
|
||||
github.com/ovh/go-ovh v1.4.3/go.mod h1:AkPXVtgwB6xlKblMjRKJJmjRp+ogrE7fz2lVgcQY8SY=
|
||||
github.com/owenrumney/squealer v1.2.1 h1:4ryMMT59aaz8VMsqsD+FDkarADJz0F1dcq2fd0DRR+c=
|
||||
@ -2050,6 +2054,8 @@ github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ
|
||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||
github.com/skeema/knownhosts v1.2.1 h1:SHWdIUa82uGZz+F+47k8SY4QhhI291cXCpopT1lK2AQ=
|
||||
github.com/skeema/knownhosts v1.2.1/go.mod h1:xYbVRSPxqBZFrdmDyMmsOs+uX1UZC3nTN3ThzgDxUwo=
|
||||
github.com/sony/gobreaker v0.5.0 h1:dRCvqm0P490vZPmy7ppEk2qCnCieBooFJ+YoXGYB+yg=
|
||||
github.com/sony/gobreaker v0.5.0/go.mod h1:ZKptC7FHNvhBz7dN2LGjPVBz2sZJmc0/PkyDJOjmxWY=
|
||||
github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo=
|
||||
github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0=
|
||||
github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
|
||||
|
@ -29,6 +29,7 @@ var (
|
||||
&GoogleGenAIClient{},
|
||||
&HuggingfaceClient{},
|
||||
&GoogleVertexAIClient{},
|
||||
&OCIGenAIClient{},
|
||||
}
|
||||
Backends = []string{
|
||||
openAIClientName,
|
||||
@ -41,6 +42,7 @@ var (
|
||||
noopAIClientName,
|
||||
huggingfaceAIClientName,
|
||||
googleVertexAIClientName,
|
||||
ociClientName,
|
||||
}
|
||||
)
|
||||
|
||||
@ -75,6 +77,7 @@ type IAIConfig interface {
|
||||
GetTopK() int32
|
||||
GetMaxTokens() int
|
||||
GetProviderId() string
|
||||
GetCompartmentId() string
|
||||
}
|
||||
|
||||
func NewClient(provider string) IAI {
|
||||
@ -104,6 +107,7 @@ type AIProvider struct {
|
||||
Temperature float32 `mapstructure:"temperature" yaml:"temperature,omitempty"`
|
||||
ProviderRegion string `mapstructure:"providerregion" yaml:"providerregion,omitempty"`
|
||||
ProviderId string `mapstructure:"providerid" yaml:"providerid,omitempty"`
|
||||
CompartmentId string `mapstructure:"compartmentid" yaml:"compartmentid,omitempty"`
|
||||
TopP float32 `mapstructure:"topp" yaml:"topp,omitempty"`
|
||||
TopK int32 `mapstructure:"topk" yaml:"topk,omitempty"`
|
||||
MaxTokens int `mapstructure:"maxtokens" yaml:"maxtokens,omitempty"`
|
||||
@ -156,7 +160,11 @@ func (p *AIProvider) GetProviderId() string {
|
||||
return p.ProviderId
|
||||
}
|
||||
|
||||
var passwordlessProviders = []string{"localai", "amazonsagemaker", "amazonbedrock", "googlevertexai"}
|
||||
func (p *AIProvider) GetCompartmentId() string {
|
||||
return p.CompartmentId
|
||||
}
|
||||
|
||||
var passwordlessProviders = []string{"localai", "amazonsagemaker", "amazonbedrock", "googlevertexai", "oci"}
|
||||
|
||||
func NeedPassword(backend string) bool {
|
||||
for _, b := range passwordlessProviders {
|
||||
|
97
pkg/ai/ocigenai.go
Normal file
97
pkg/ai/ocigenai.go
Normal file
@ -0,0 +1,97 @@
|
||||
/*
|
||||
Copyright 2024 The K8sGPT Authors.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"github.com/oracle/oci-go-sdk/v65/common"
|
||||
"github.com/oracle/oci-go-sdk/v65/generativeaiinference"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const ociClientName = "oci"
|
||||
|
||||
type OCIGenAIClient struct {
|
||||
nopCloser
|
||||
|
||||
client *generativeaiinference.GenerativeAiInferenceClient
|
||||
model string
|
||||
compartmentId string
|
||||
temperature float32
|
||||
topP float32
|
||||
maxTokens int
|
||||
}
|
||||
|
||||
func (c *OCIGenAIClient) GetName() string {
|
||||
return ociClientName
|
||||
}
|
||||
|
||||
func (c *OCIGenAIClient) Configure(config IAIConfig) error {
|
||||
config.GetEndpointName()
|
||||
c.model = config.GetModel()
|
||||
c.temperature = config.GetTemperature()
|
||||
c.topP = config.GetTopP()
|
||||
c.maxTokens = config.GetMaxTokens()
|
||||
c.compartmentId = config.GetCompartmentId()
|
||||
provider := common.DefaultConfigProvider()
|
||||
client, err := generativeaiinference.NewGenerativeAiInferenceClientWithConfigurationProvider(provider)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.client = &client
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *OCIGenAIClient) GetCompletion(ctx context.Context, prompt string) (string, error) {
|
||||
generateTextRequest := c.newGenerateTextRequest(prompt)
|
||||
generateTextResponse, err := c.client.GenerateText(ctx, generateTextRequest)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return extractGeneratedText(generateTextResponse.InferenceResponse)
|
||||
}
|
||||
|
||||
func (c *OCIGenAIClient) newGenerateTextRequest(prompt string) generativeaiinference.GenerateTextRequest {
|
||||
temperatureF64 := float64(c.temperature)
|
||||
topPF64 := float64(c.topP)
|
||||
return generativeaiinference.GenerateTextRequest{
|
||||
GenerateTextDetails: generativeaiinference.GenerateTextDetails{
|
||||
CompartmentId: &c.compartmentId,
|
||||
ServingMode: generativeaiinference.OnDemandServingMode{
|
||||
ModelId: &c.model,
|
||||
},
|
||||
InferenceRequest: generativeaiinference.CohereLlmInferenceRequest{
|
||||
Prompt: &prompt,
|
||||
MaxTokens: &c.maxTokens,
|
||||
Temperature: &temperatureF64,
|
||||
TopP: &topPF64,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func extractGeneratedText(llmInferenceResponse generativeaiinference.LlmInferenceResponse) (string, error) {
|
||||
response, ok := llmInferenceResponse.(generativeaiinference.CohereLlmInferenceResponse)
|
||||
if !ok {
|
||||
return "", errors.New("failed to extract generated text from backed response")
|
||||
}
|
||||
sb := strings.Builder{}
|
||||
for _, text := range response.GeneratedTexts {
|
||||
if text.Text != nil {
|
||||
sb.WriteString(*text.Text)
|
||||
}
|
||||
}
|
||||
return sb.String(), nil
|
||||
}
|
Loading…
Reference in New Issue
Block a user