mirror of
https://github.com/k8sgpt-ai/k8sgpt.git
synced 2025-05-10 00:57:53 +00:00
feat: added Google GenAI client; simplified IAI/clients API surface. (#829)
* refactor: Simplified IAI; made caching and processing consisent. Signed-off-by: bwplotka <bwplotka@gmail.com> * feat: Added Google AI API e.g. for Gemini models. Signed-off-by: bwplotka <bwplotka@gmail.com> --------- Signed-off-by: bwplotka <bwplotka@gmail.com> Co-authored-by: Alex Jones <alexsimonjones@gmail.com> Co-authored-by: Thomas Schuetz <38893055+thschue@users.noreply.github.com>
This commit is contained in:
parent
e78ff05419
commit
e7d41496dd
@ -59,6 +59,7 @@ var AnalyzeCmd = &cobra.Command{
|
||||
color.Red("Error: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer config.Close()
|
||||
|
||||
config.RunAnalysis()
|
||||
|
||||
|
@ -149,7 +149,7 @@ func init() {
|
||||
// add flag for url
|
||||
addCmd.Flags().StringVarP(&baseURL, "baseurl", "u", "", "URL AI provider, (e.g `http://localhost:8080/v1`)")
|
||||
// add flag for endpointName
|
||||
addCmd.Flags().StringVarP(&endpointName, "endpointname", "n", "", "Endpoint Name, (e.g `endpoint-xxxxxxxxxxxx`)")
|
||||
addCmd.Flags().StringVarP(&endpointName, "endpointname", "n", "", "Endpoint Name, e.g. `endpoint-xxxxxxxxxxxx` (only for amazonbedrock, amazonsagemaker backends)")
|
||||
// add flag for topP
|
||||
addCmd.Flags().Float32VarP(&topP, "topp", "c", 0.5, "Probability Cutoff: Set a threshold (0.0-1.0) to limit word choices. Higher values add randomness, lower values increase predictability.")
|
||||
// max tokens
|
||||
@ -157,7 +157,7 @@ func init() {
|
||||
// add flag for temperature
|
||||
addCmd.Flags().Float32VarP(&temperature, "temperature", "t", 0.7, "The sampling temperature, value ranges between 0 ( output be more deterministic) and 1 (more random)")
|
||||
// add flag for azure open ai engine/deployment name
|
||||
addCmd.Flags().StringVarP(&engine, "engine", "e", "", "Azure AI deployment name")
|
||||
addCmd.Flags().StringVarP(&engine, "engine", "e", "", "Azure AI deployment name (only for azureopenai backend)")
|
||||
//add flag for amazonbedrock region name
|
||||
addCmd.Flags().StringVarP(&providerRegion, "providerRegion", "r", "", "Provider Region name")
|
||||
addCmd.Flags().StringVarP(&providerRegion, "providerRegion", "r", "", "Provider Region name (only for amazonbedrock backend)")
|
||||
}
|
||||
|
3
go.mod
3
go.mod
@ -31,6 +31,7 @@ require (
|
||||
github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.2.1
|
||||
github.com/aws/aws-sdk-go v1.49.15
|
||||
github.com/cohere-ai/cohere-go v0.2.0
|
||||
github.com/google/generative-ai-go v0.5.0
|
||||
github.com/olekukonko/tablewriter v0.0.5
|
||||
google.golang.org/api v0.155.0
|
||||
sigs.k8s.io/controller-runtime v0.16.3
|
||||
@ -39,9 +40,11 @@ require (
|
||||
|
||||
require (
|
||||
cloud.google.com/go v0.110.10 // indirect
|
||||
cloud.google.com/go/ai v0.3.0 // indirect
|
||||
cloud.google.com/go/compute v1.23.3 // indirect
|
||||
cloud.google.com/go/compute/metadata v0.2.3 // indirect
|
||||
cloud.google.com/go/iam v1.1.5 // indirect
|
||||
cloud.google.com/go/longrunning v0.5.4 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.9.1 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.5.1 // indirect
|
||||
github.com/AzureAD/microsoft-authentication-library-for-go v1.2.0 // indirect
|
||||
|
6
go.sum
6
go.sum
@ -49,6 +49,8 @@ cloud.google.com/go/accesscontextmanager v1.3.0/go.mod h1:TgCBehyr5gNMz7ZaH9xubp
|
||||
cloud.google.com/go/accesscontextmanager v1.4.0/go.mod h1:/Kjh7BBu/Gh83sv+K60vN9QE5NJcd80sU33vIe2IFPE=
|
||||
cloud.google.com/go/accesscontextmanager v1.6.0/go.mod h1:8XCvZWfYw3K/ji0iVnp+6pu7huxoQTLmxAbVjbloTtM=
|
||||
cloud.google.com/go/accesscontextmanager v1.7.0/go.mod h1:CEGLewx8dwa33aDAZQujl7Dx+uYhS0eay198wB/VumQ=
|
||||
cloud.google.com/go/ai v0.3.0 h1:M617N0brv+XFch2KToZUhv6ggzgFZMUnmDkNQjW2pYg=
|
||||
cloud.google.com/go/ai v0.3.0/go.mod h1:dTuQIBA8Kljuas5z1WNot1QZOl476A9TsFqEi6pzJlI=
|
||||
cloud.google.com/go/aiplatform v1.22.0/go.mod h1:ig5Nct50bZlzV6NvKaTwmplLLddFx0YReh9WfTO5jKw=
|
||||
cloud.google.com/go/aiplatform v1.24.0/go.mod h1:67UUvRBKG6GTayHKV8DBv2RtR1t93YRu5B1P3x99mYY=
|
||||
cloud.google.com/go/aiplatform v1.27.0/go.mod h1:Bvxqtl40l0WImSb04d0hXFU7gDOiq9jQmorivIiWcKg=
|
||||
@ -351,6 +353,8 @@ cloud.google.com/go/logging v1.7.0/go.mod h1:3xjP2CjkM3ZkO73aj4ASA5wRPGGCRrPIAeN
|
||||
cloud.google.com/go/longrunning v0.1.1/go.mod h1:UUFxuDWkv22EuY93jjmDMFT5GPQKeFVJBIF6QlTqdsE=
|
||||
cloud.google.com/go/longrunning v0.3.0/go.mod h1:qth9Y41RRSUE69rDcOn6DdK3HfQfsUI0YSmW3iIlLJc=
|
||||
cloud.google.com/go/longrunning v0.4.1/go.mod h1:4iWDqhBZ70CvZ6BfETbvam3T8FMvLK+eFj0E6AaRQTo=
|
||||
cloud.google.com/go/longrunning v0.5.4 h1:w8xEcbZodnA2BbW6sVirkkoC+1gP8wS57EUUgGS0GVg=
|
||||
cloud.google.com/go/longrunning v0.5.4/go.mod h1:zqNVncI0BOP8ST6XQD1+VcvuShMmq7+xFSzOL++V0dI=
|
||||
cloud.google.com/go/managedidentities v1.3.0/go.mod h1:UzlW3cBOiPrzucO5qWkNkh0w33KFtBJU281hacNvsdE=
|
||||
cloud.google.com/go/managedidentities v1.4.0/go.mod h1:NWSBYbEMgqmbZsLIyKvxrYbtqOsxY1ZrGM+9RgDqInM=
|
||||
cloud.google.com/go/managedidentities v1.5.0/go.mod h1:+dWcZ0JlUmpuxpIDfyP5pP5y0bLdRwOS4Lp7gMni/LA=
|
||||
@ -935,6 +939,8 @@ github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ
|
||||
github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU=
|
||||
github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
|
||||
github.com/google/flatbuffers v2.0.8+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8=
|
||||
github.com/google/generative-ai-go v0.5.0 h1:PfzPuSGdsmcSyPG7RIoijcKWZ7/x2kvgyNryvmXMUmA=
|
||||
github.com/google/generative-ai-go v0.5.0/go.mod h1:8fXQk4w+eyTzFokGGJrBFL0/xwXqm3QNhTqOWyX11zs=
|
||||
github.com/google/gnostic v0.7.0 h1:d7EpuFp8vVdML+y0JJJYiKeOLjKTdH/GvVkLOBWqJpw=
|
||||
github.com/google/gnostic v0.7.0/go.mod h1:IAcUyMl6vtC95f60EZ8oXyqTsOersP6HbwjeG7EyDPM=
|
||||
github.com/google/gnostic-models v0.6.9-0.20230804172637-c7be7c783f49 h1:0VpGH+cDhbDtdcweoyCVsF3fhN8kejK6rFe/2FFX2nU=
|
||||
|
@ -2,15 +2,8 @@ package ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/fatih/color"
|
||||
|
||||
"github.com/k8sgpt-ai/k8sgpt/pkg/cache"
|
||||
"github.com/k8sgpt-ai/k8sgpt/pkg/util"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/session"
|
||||
@ -19,8 +12,9 @@ import (
|
||||
|
||||
// AmazonBedRockClient represents the client for interacting with the Amazon Bedrock service.
|
||||
type AmazonBedRockClient struct {
|
||||
nopCloser
|
||||
|
||||
client *bedrockruntime.BedrockRuntime
|
||||
language string
|
||||
model string
|
||||
temperature float32
|
||||
}
|
||||
@ -91,8 +85,8 @@ func GetRegionOrDefault(region string) string {
|
||||
return BEDROCK_DEFAULT_REGION
|
||||
}
|
||||
|
||||
// Configure configures the AmazonBedRockClient with the provided configuration and language.
|
||||
func (a *AmazonBedRockClient) Configure(config IAIConfig, language string) error {
|
||||
// Configure configures the AmazonBedRockClient with the provided configuration.
|
||||
func (a *AmazonBedRockClient) Configure(config IAIConfig) error {
|
||||
|
||||
// Create a new AWS session
|
||||
providerRegion := GetRegionOrDefault(config.GetProviderRegion())
|
||||
@ -107,7 +101,6 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig, language string) error
|
||||
|
||||
// Create a new BedrockRuntime client
|
||||
a.client = bedrockruntime.New(sess)
|
||||
a.language = language
|
||||
a.model = GetModelOrDefault(config.GetModel())
|
||||
a.temperature = config.GetTemperature()
|
||||
|
||||
@ -115,7 +108,7 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig, language string) error
|
||||
}
|
||||
|
||||
// GetCompletion sends a request to the model for generating completion based on the provided prompt.
|
||||
func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string, promptTmpl string) (string, error) {
|
||||
func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string) (string, error) {
|
||||
|
||||
// Prepare the input data for the model invocation
|
||||
request := map[string]interface{}{
|
||||
@ -152,44 +145,6 @@ func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string,
|
||||
return output.Completion, nil
|
||||
}
|
||||
|
||||
// Parse generates a completion for the provided prompt using the Amazon Bedrock model.
|
||||
func (a *AmazonBedRockClient) Parse(ctx context.Context, prompt []string, cache cache.ICache, promptTmpl string) (string, error) {
|
||||
inputKey := strings.Join(prompt, " ")
|
||||
// Check for cached data
|
||||
cacheKey := util.GetCacheKey(a.GetName(), a.language, inputKey)
|
||||
|
||||
if !cache.IsCacheDisabled() && cache.Exists(cacheKey) {
|
||||
response, err := cache.Load(cacheKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if response != "" {
|
||||
output, err := base64.StdEncoding.DecodeString(response)
|
||||
if err != nil {
|
||||
color.Red("error decoding cached data: %v", err)
|
||||
return "", nil
|
||||
}
|
||||
return string(output), nil
|
||||
}
|
||||
}
|
||||
|
||||
response, err := a.GetCompletion(ctx, inputKey, promptTmpl)
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
err = cache.Store(cacheKey, base64.StdEncoding.EncodeToString([]byte(response)))
|
||||
|
||||
if err != nil {
|
||||
color.Red("error storing value to cache: %v", err)
|
||||
return "", nil
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// GetName returns the name of the AmazonBedRockClient.
|
||||
func (a *AmazonBedRockClient) GetName() string {
|
||||
return "amazonbedrock"
|
||||
|
@ -15,15 +15,8 @@ package ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"encoding/json"
|
||||
|
||||
"github.com/fatih/color"
|
||||
"github.com/k8sgpt-ai/k8sgpt/pkg/cache"
|
||||
"github.com/k8sgpt-ai/k8sgpt/pkg/util"
|
||||
"fmt"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/session"
|
||||
@ -31,8 +24,9 @@ import (
|
||||
)
|
||||
|
||||
type SageMakerAIClient struct {
|
||||
nopCloser
|
||||
|
||||
client *sagemakerruntime.SageMakerRuntime
|
||||
language string
|
||||
model string
|
||||
temperature float32
|
||||
endpoint string
|
||||
@ -63,7 +57,7 @@ type Parameters struct {
|
||||
Temperature float64 `json:"temperature"`
|
||||
}
|
||||
|
||||
func (c *SageMakerAIClient) Configure(config IAIConfig, language string) error {
|
||||
func (c *SageMakerAIClient) Configure(config IAIConfig) error {
|
||||
|
||||
// Create a new AWS session
|
||||
sess := session.Must(session.NewSessionWithOptions(session.Options{
|
||||
@ -71,7 +65,6 @@ func (c *SageMakerAIClient) Configure(config IAIConfig, language string) error {
|
||||
SharedConfigState: session.SharedConfigEnable,
|
||||
}))
|
||||
|
||||
c.language = language
|
||||
// Create a new SageMaker runtime client
|
||||
c.client = sagemakerruntime.New(sess)
|
||||
c.model = config.GetModel()
|
||||
@ -82,18 +75,13 @@ func (c *SageMakerAIClient) Configure(config IAIConfig, language string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *SageMakerAIClient) GetCompletion(ctx context.Context, prompt string, promptTmpl string) (string, error) {
|
||||
func (c *SageMakerAIClient) GetCompletion(_ context.Context, prompt string) (string, error) {
|
||||
// Create a completion request
|
||||
|
||||
if len(promptTmpl) == 0 {
|
||||
promptTmpl = PromptMap["default"]
|
||||
}
|
||||
|
||||
request := Request{
|
||||
Inputs: [][]Message{
|
||||
{
|
||||
{Role: "system", Content: "DEFAULT_PROMPT"},
|
||||
{Role: "user", Content: fmt.Sprintf(promptTmpl, c.language, prompt)},
|
||||
{Role: "user", Content: prompt},
|
||||
},
|
||||
},
|
||||
|
||||
@ -142,29 +130,6 @@ func (c *SageMakerAIClient) GetCompletion(ctx context.Context, prompt string, pr
|
||||
return content, nil
|
||||
}
|
||||
|
||||
func (a *SageMakerAIClient) Parse(ctx context.Context, prompt []string, cache cache.ICache, promptTmpl string) (string, error) {
|
||||
// parse the text with the AI backend
|
||||
inputKey := strings.Join(prompt, " ")
|
||||
// Check for cached data
|
||||
sEnc := base64.StdEncoding.EncodeToString([]byte(inputKey))
|
||||
cacheKey := util.GetCacheKey(a.GetName(), a.language, sEnc)
|
||||
|
||||
response, err := a.GetCompletion(ctx, inputKey, promptTmpl)
|
||||
if err != nil {
|
||||
color.Red("error getting completion: %v", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
err = cache.Store(cacheKey, base64.StdEncoding.EncodeToString([]byte(response)))
|
||||
|
||||
if err != nil {
|
||||
color.Red("error storing value to cache: %v", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (a *SageMakerAIClient) GetName() string {
|
||||
func (c *SageMakerAIClient) GetName() string {
|
||||
return "amazonsagemaker"
|
||||
}
|
||||
|
@ -2,27 +2,20 @@ package ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/k8sgpt-ai/k8sgpt/pkg/cache"
|
||||
"github.com/k8sgpt-ai/k8sgpt/pkg/util"
|
||||
|
||||
"github.com/fatih/color"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
type AzureAIClient struct {
|
||||
nopCloser
|
||||
|
||||
client *openai.Client
|
||||
language string
|
||||
model string
|
||||
temperature float32
|
||||
}
|
||||
|
||||
func (c *AzureAIClient) Configure(config IAIConfig, lang string) error {
|
||||
func (c *AzureAIClient) Configure(config IAIConfig) error {
|
||||
token := config.GetPassword()
|
||||
baseURL := config.GetBaseURL()
|
||||
engine := config.GetEngine()
|
||||
@ -40,21 +33,20 @@ func (c *AzureAIClient) Configure(config IAIConfig, lang string) error {
|
||||
if client == nil {
|
||||
return errors.New("error creating Azure OpenAI client")
|
||||
}
|
||||
c.language = lang
|
||||
c.client = client
|
||||
c.model = config.GetModel()
|
||||
c.temperature = config.GetTemperature()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *AzureAIClient) GetCompletion(ctx context.Context, prompt string, promptTmpl string) (string, error) {
|
||||
func (c *AzureAIClient) GetCompletion(ctx context.Context, prompt string) (string, error) {
|
||||
// Create a completion request
|
||||
resp, err := c.client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{
|
||||
Model: c.model,
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: fmt.Sprintf(default_prompt, c.language, prompt),
|
||||
Content: prompt,
|
||||
},
|
||||
},
|
||||
Temperature: c.temperature,
|
||||
@ -65,42 +57,6 @@ func (c *AzureAIClient) GetCompletion(ctx context.Context, prompt string, prompt
|
||||
return resp.Choices[0].Message.Content, nil
|
||||
}
|
||||
|
||||
func (a *AzureAIClient) Parse(ctx context.Context, prompt []string, cache cache.ICache, promptTmpl string) (string, error) {
|
||||
inputKey := strings.Join(prompt, " ")
|
||||
// Check for cached data
|
||||
cacheKey := util.GetCacheKey(a.GetName(), a.language, inputKey)
|
||||
|
||||
if !cache.IsCacheDisabled() && cache.Exists(cacheKey) {
|
||||
response, err := cache.Load(cacheKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if response != "" {
|
||||
output, err := base64.StdEncoding.DecodeString(response)
|
||||
if err != nil {
|
||||
color.Red("error decoding cached data: %v", err)
|
||||
return "", nil
|
||||
}
|
||||
return string(output), nil
|
||||
}
|
||||
}
|
||||
|
||||
response, err := a.GetCompletion(ctx, inputKey, promptTmpl)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
err = cache.Store(cacheKey, base64.StdEncoding.EncodeToString([]byte(response)))
|
||||
|
||||
if err != nil {
|
||||
color.Red("error storing value to cache: %v", err)
|
||||
return "", nil
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (a *AzureAIClient) GetName() string {
|
||||
func (c *AzureAIClient) GetName() string {
|
||||
return "azureopenai"
|
||||
}
|
||||
|
@ -15,26 +15,20 @@ package ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/cohere-ai/cohere-go"
|
||||
"github.com/fatih/color"
|
||||
|
||||
"github.com/k8sgpt-ai/k8sgpt/pkg/cache"
|
||||
"github.com/k8sgpt-ai/k8sgpt/pkg/util"
|
||||
)
|
||||
|
||||
type CohereClient struct {
|
||||
nopCloser
|
||||
|
||||
client *cohere.Client
|
||||
language string
|
||||
model string
|
||||
temperature float32
|
||||
}
|
||||
|
||||
func (c *CohereClient) Configure(config IAIConfig, language string) error {
|
||||
func (c *CohereClient) Configure(config IAIConfig) error {
|
||||
token := config.GetPassword()
|
||||
|
||||
client, err := cohere.CreateClient(token)
|
||||
@ -50,21 +44,17 @@ func (c *CohereClient) Configure(config IAIConfig, language string) error {
|
||||
if client == nil {
|
||||
return errors.New("error creating Cohere client")
|
||||
}
|
||||
c.language = language
|
||||
c.client = client
|
||||
c.model = config.GetModel()
|
||||
c.temperature = config.GetTemperature()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CohereClient) GetCompletion(ctx context.Context, prompt, promptTmpl string) (string, error) {
|
||||
func (c *CohereClient) GetCompletion(_ context.Context, prompt string) (string, error) {
|
||||
// Create a completion request
|
||||
if len(promptTmpl) == 0 {
|
||||
promptTmpl = PromptMap["default"]
|
||||
}
|
||||
resp, err := c.client.Generate(cohere.GenerateOptions{
|
||||
Model: c.model,
|
||||
Prompt: fmt.Sprintf(strings.TrimSpace(promptTmpl), c.language, prompt),
|
||||
Prompt: prompt,
|
||||
MaxTokens: cohere.Uint(2048),
|
||||
Temperature: cohere.Float64(float64(c.temperature)),
|
||||
K: cohere.Int(0),
|
||||
@ -77,42 +67,6 @@ func (c *CohereClient) GetCompletion(ctx context.Context, prompt, promptTmpl str
|
||||
return resp.Generations[0].Text, nil
|
||||
}
|
||||
|
||||
func (a *CohereClient) Parse(ctx context.Context, prompt []string, cache cache.ICache, promptTmpl string) (string, error) {
|
||||
inputKey := strings.Join(prompt, " ")
|
||||
// Check for cached data
|
||||
cacheKey := util.GetCacheKey(a.GetName(), a.language, inputKey)
|
||||
|
||||
if !cache.IsCacheDisabled() && cache.Exists(cacheKey) {
|
||||
response, err := cache.Load(cacheKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if response != "" {
|
||||
output, err := base64.StdEncoding.DecodeString(response)
|
||||
if err != nil {
|
||||
color.Red("error decoding cached data: %v", err)
|
||||
return "", nil
|
||||
}
|
||||
return string(output), nil
|
||||
}
|
||||
}
|
||||
|
||||
response, err := a.GetCompletion(ctx, inputKey, promptTmpl)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
err = cache.Store(cacheKey, base64.StdEncoding.EncodeToString([]byte(response)))
|
||||
|
||||
if err != nil {
|
||||
color.Red("error storing value to cache: %v", err)
|
||||
return "", nil
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (a *CohereClient) GetName() string {
|
||||
func (c *CohereClient) GetName() string {
|
||||
return "cohere"
|
||||
}
|
||||
|
119
pkg/ai/googlegenai.go
Normal file
119
pkg/ai/googlegenai.go
Normal file
@ -0,0 +1,119 @@
|
||||
/*
|
||||
Copyright 2023 The K8sGPT Authors.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/fatih/color"
|
||||
"github.com/google/generative-ai-go/genai"
|
||||
"google.golang.org/api/option"
|
||||
)
|
||||
|
||||
const googleAIClientName = "google"
|
||||
|
||||
type GoogleGenAIClient struct {
|
||||
client *genai.Client
|
||||
|
||||
model string
|
||||
temperature float32
|
||||
topP float32
|
||||
maxTokens int
|
||||
}
|
||||
|
||||
func (c *GoogleGenAIClient) Configure(config IAIConfig) error {
|
||||
ctx := context.Background()
|
||||
|
||||
// Access your API key as an environment variable (see "Set up your API key" above)
|
||||
token := config.GetPassword()
|
||||
authOption := option.WithAPIKey(token)
|
||||
if token[0] == '{' {
|
||||
authOption = option.WithCredentialsJSON([]byte(token))
|
||||
}
|
||||
|
||||
client, err := genai.NewClient(ctx, authOption)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating genai Google SDK client: %w", err)
|
||||
}
|
||||
|
||||
c.client = client
|
||||
c.model = config.GetModel()
|
||||
c.temperature = config.GetTemperature()
|
||||
c.topP = config.GetTopP()
|
||||
c.maxTokens = config.GetMaxTokens()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *GoogleGenAIClient) GetCompletion(ctx context.Context, prompt string) (string, error) {
|
||||
// Available models are at https://ai.google.dev/models e.g.gemini-pro.
|
||||
model := c.client.GenerativeModel(c.model)
|
||||
model.SetTemperature(c.temperature)
|
||||
model.SetTopP(c.topP)
|
||||
model.SetMaxOutputTokens(int32(c.maxTokens))
|
||||
|
||||
// Google AI SDK is capable of different inputs than just text, for now set explicit text prompt type.
|
||||
// Similarly, we could stream the response. For now k8sgpt does not support streaming.
|
||||
resp, err := model.GenerateContent(ctx, genai.Text(prompt))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if len(resp.Candidates) == 0 {
|
||||
if resp.PromptFeedback.BlockReason == genai.BlockReasonSafety {
|
||||
for _, r := range resp.PromptFeedback.SafetyRatings {
|
||||
if !r.Blocked {
|
||||
continue
|
||||
}
|
||||
return "", fmt.Errorf("complection blocked due to %v with probability %v", r.Category.String(), r.Probability.String())
|
||||
}
|
||||
}
|
||||
return "", errors.New("no complection returned; unknown reason")
|
||||
}
|
||||
|
||||
// Format output.
|
||||
// TODO(bwplotka): Provider richer output in certain cases e.g. suddenly finished
|
||||
// completion based on finish reasons or safety rankings.
|
||||
got := resp.Candidates[0]
|
||||
var output string
|
||||
for _, part := range got.Content.Parts {
|
||||
switch o := part.(type) {
|
||||
case genai.Text:
|
||||
output += string(o)
|
||||
output += "\n"
|
||||
default:
|
||||
color.Yellow("found unsupported AI response part of type %T; ignoring", part)
|
||||
}
|
||||
}
|
||||
|
||||
if got.CitationMetadata != nil && len(got.CitationMetadata.CitationSources) > 0 {
|
||||
output += "Citations:\n"
|
||||
for _, source := range got.CitationMetadata.CitationSources {
|
||||
// TODO(bwplotka): Give details around what exactly words could be attributed to the citation.
|
||||
output += fmt.Sprintf("* %s, %s\n", *source.URI, source.License)
|
||||
}
|
||||
}
|
||||
return output, nil
|
||||
}
|
||||
|
||||
func (c *GoogleGenAIClient) GetName() string {
|
||||
return googleAIClientName
|
||||
}
|
||||
|
||||
func (c *GoogleGenAIClient) Close() {
|
||||
if err := c.client.Close(); err != nil {
|
||||
color.Red("googleai client close error: %v", err)
|
||||
}
|
||||
}
|
@ -15,8 +15,6 @@ package ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/k8sgpt-ai/k8sgpt/pkg/cache"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -28,25 +26,38 @@ var (
|
||||
&CohereClient{},
|
||||
&AmazonBedRockClient{},
|
||||
&SageMakerAIClient{},
|
||||
&GoogleGenAIClient{},
|
||||
}
|
||||
Backends = []string{
|
||||
"openai",
|
||||
"localai",
|
||||
"azureopenai",
|
||||
"noopai",
|
||||
"cohere",
|
||||
"amazonbedrock",
|
||||
"amazonsagemaker",
|
||||
googleAIClientName,
|
||||
"noopai",
|
||||
}
|
||||
)
|
||||
|
||||
// IAI is an interface all clients (representing backends) share.
|
||||
type IAI interface {
|
||||
Configure(config IAIConfig, language string) error
|
||||
GetCompletion(ctx context.Context, prompt string, promptTmpl string) (string, error)
|
||||
Parse(ctx context.Context, prompt []string, cache cache.ICache, promptTmpl string) (string, error)
|
||||
// Configure sets up client for given configuration. This is expected to be
|
||||
// executed once per client life-time (e.g. analysis CLI command invocation).
|
||||
Configure(config IAIConfig) error
|
||||
// GetCompletion generates text based on prompt.
|
||||
GetCompletion(ctx context.Context, prompt string) (string, error)
|
||||
// GetName returns name of the backend/client.
|
||||
GetName() string
|
||||
// Close cleans all the resources. No other methods should be used on the
|
||||
// objects after this method is invoked.
|
||||
Close()
|
||||
}
|
||||
|
||||
type nopCloser struct{}
|
||||
|
||||
func (nopCloser) Close() {}
|
||||
|
||||
type IAIConfig interface {
|
||||
GetPassword() string
|
||||
GetModel() string
|
||||
|
@ -15,58 +15,21 @@ package ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/fatih/color"
|
||||
"github.com/k8sgpt-ai/k8sgpt/pkg/cache"
|
||||
"github.com/k8sgpt-ai/k8sgpt/pkg/util"
|
||||
)
|
||||
|
||||
type NoOpAIClient struct {
|
||||
client string
|
||||
language string
|
||||
model string
|
||||
nopCloser
|
||||
}
|
||||
|
||||
func (c *NoOpAIClient) Configure(config IAIConfig, language string) error {
|
||||
token := config.GetPassword()
|
||||
c.language = language
|
||||
c.client = fmt.Sprintf("I am a noop client with the token %s ", token)
|
||||
c.model = config.GetModel()
|
||||
func (c *NoOpAIClient) Configure(_ IAIConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *NoOpAIClient) GetCompletion(ctx context.Context, prompt string, promptTmpl string) (string, error) {
|
||||
// Create a completion request
|
||||
func (c *NoOpAIClient) GetCompletion(_ context.Context, prompt string) (string, error) {
|
||||
response := "I am a noop response to the prompt " + prompt
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (a *NoOpAIClient) Parse(ctx context.Context, prompt []string, cache cache.ICache, promptTmpl string) (string, error) {
|
||||
// parse the text with the AI backend
|
||||
inputKey := strings.Join(prompt, " ")
|
||||
// Check for cached data
|
||||
sEnc := base64.StdEncoding.EncodeToString([]byte(inputKey))
|
||||
cacheKey := util.GetCacheKey(a.GetName(), a.language, sEnc)
|
||||
|
||||
response, err := a.GetCompletion(ctx, inputKey, promptTmpl)
|
||||
if err != nil {
|
||||
color.Red("error getting completion: %v", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
err = cache.Store(cacheKey, base64.StdEncoding.EncodeToString([]byte(response)))
|
||||
|
||||
if err != nil {
|
||||
color.Red("error storing value to cache: %v", err)
|
||||
return "", nil
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (a *NoOpAIClient) GetName() string {
|
||||
func (c *NoOpAIClient) GetName() string {
|
||||
return "noopai"
|
||||
}
|
||||
|
@ -15,22 +15,15 @@ package ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/k8sgpt-ai/k8sgpt/pkg/cache"
|
||||
"github.com/k8sgpt-ai/k8sgpt/pkg/util"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
|
||||
"github.com/fatih/color"
|
||||
)
|
||||
|
||||
type OpenAIClient struct {
|
||||
nopCloser
|
||||
|
||||
client *openai.Client
|
||||
language string
|
||||
model string
|
||||
temperature float32
|
||||
}
|
||||
@ -43,7 +36,7 @@ const (
|
||||
topP = 1.0
|
||||
)
|
||||
|
||||
func (c *OpenAIClient) Configure(config IAIConfig, language string) error {
|
||||
func (c *OpenAIClient) Configure(config IAIConfig) error {
|
||||
token := config.GetPassword()
|
||||
defaultConfig := openai.DefaultConfig(token)
|
||||
|
||||
@ -56,24 +49,20 @@ func (c *OpenAIClient) Configure(config IAIConfig, language string) error {
|
||||
if client == nil {
|
||||
return errors.New("error creating OpenAI client")
|
||||
}
|
||||
c.language = language
|
||||
c.client = client
|
||||
c.model = config.GetModel()
|
||||
c.temperature = config.GetTemperature()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *OpenAIClient) GetCompletion(ctx context.Context, prompt string, promptTmpl string) (string, error) {
|
||||
func (c *OpenAIClient) GetCompletion(ctx context.Context, prompt string) (string, error) {
|
||||
// Create a completion request
|
||||
if len(promptTmpl) == 0 {
|
||||
promptTmpl = PromptMap["default"]
|
||||
}
|
||||
resp, err := c.client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{
|
||||
Model: c.model,
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: "user",
|
||||
Content: fmt.Sprintf(promptTmpl, c.language, prompt),
|
||||
Content: prompt,
|
||||
},
|
||||
},
|
||||
Temperature: c.temperature,
|
||||
@ -88,42 +77,6 @@ func (c *OpenAIClient) GetCompletion(ctx context.Context, prompt string, promptT
|
||||
return resp.Choices[0].Message.Content, nil
|
||||
}
|
||||
|
||||
func (a *OpenAIClient) Parse(ctx context.Context, prompt []string, cache cache.ICache, promptTmpl string) (string, error) {
|
||||
inputKey := strings.Join(prompt, " ")
|
||||
// Check for cached data
|
||||
cacheKey := util.GetCacheKey(a.GetName(), a.language, inputKey)
|
||||
|
||||
if !cache.IsCacheDisabled() && cache.Exists(cacheKey) {
|
||||
response, err := cache.Load(cacheKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if response != "" {
|
||||
output, err := base64.StdEncoding.DecodeString(response)
|
||||
if err != nil {
|
||||
color.Red("error decoding cached data: %v", err)
|
||||
return "", nil
|
||||
}
|
||||
return string(output), nil
|
||||
}
|
||||
}
|
||||
|
||||
response, err := a.GetCompletion(ctx, inputKey, promptTmpl)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
err = cache.Store(cacheKey, base64.StdEncoding.EncodeToString([]byte(response)))
|
||||
|
||||
if err != nil {
|
||||
color.Red("error storing value to cache: %v", err)
|
||||
return "", nil
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (a *OpenAIClient) GetName() string {
|
||||
func (c *OpenAIClient) GetName() string {
|
||||
return "openai"
|
||||
}
|
||||
|
@ -15,12 +15,14 @@ package analysis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/fatih/color"
|
||||
openapi_v2 "github.com/google/gnostic/openapiv2"
|
||||
"github.com/k8sgpt-ai/k8sgpt/pkg/ai"
|
||||
"github.com/k8sgpt-ai/k8sgpt/pkg/analyzer"
|
||||
@ -36,6 +38,7 @@ type Analysis struct {
|
||||
Context context.Context
|
||||
Filters []string
|
||||
Client *kubernetes.Client
|
||||
Language string
|
||||
AIClient ai.IAI
|
||||
Results []common.Result
|
||||
Errors []string
|
||||
@ -95,6 +98,7 @@ func NewAnalysis(
|
||||
Context: context.Background(),
|
||||
Filters: filters,
|
||||
Client: client,
|
||||
Language: language,
|
||||
Namespace: namespace,
|
||||
Cache: cache,
|
||||
Explain: explain,
|
||||
@ -134,7 +138,7 @@ func NewAnalysis(
|
||||
}
|
||||
|
||||
aiClient := ai.NewClient(aiProvider.Name)
|
||||
if err := aiClient.Configure(&aiProvider, language); err != nil {
|
||||
if err := aiClient.Configure(&aiProvider); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
a.AIClient = aiClient
|
||||
@ -269,14 +273,14 @@ func (a *Analysis) GetAIResults(output string, anonymize bool) error {
|
||||
}
|
||||
texts = append(texts, failure.Text)
|
||||
}
|
||||
// If the resource `Kind` comes from a "integration plugin", maybe a customized prompt template will be involved.
|
||||
var promptTemplate string
|
||||
|
||||
promptTemplate := ai.PromptMap["default"]
|
||||
// If the resource `Kind` comes from an "integration plugin",
|
||||
// maybe a customized prompt template will be involved.
|
||||
if prompt, ok := ai.PromptMap[analysis.Kind]; ok {
|
||||
promptTemplate = prompt
|
||||
} else {
|
||||
promptTemplate = ai.PromptMap["default"]
|
||||
}
|
||||
parsedText, err := a.AIClient.Parse(a.Context, texts, a.Cache, promptTemplate)
|
||||
result, err := a.getAIResultForSanitizedFailures(texts, promptTemplate)
|
||||
if err != nil {
|
||||
// FIXME: can we avoid checking if output is json multiple times?
|
||||
// maybe implement the progress bar better?
|
||||
@ -284,23 +288,22 @@ func (a *Analysis) GetAIResults(output string, anonymize bool) error {
|
||||
_ = bar.Exit()
|
||||
}
|
||||
|
||||
// Check for exhaustion
|
||||
// Check for exhaustion.
|
||||
if strings.Contains(err.Error(), "status code: 429") {
|
||||
return fmt.Errorf("exhausted API quota for AI provider %s: %v", a.AIClient.GetName(), err)
|
||||
} else {
|
||||
return fmt.Errorf("failed while calling AI provider %s: %v", a.AIClient.GetName(), err)
|
||||
}
|
||||
return fmt.Errorf("failed while calling AI provider %s: %v", a.AIClient.GetName(), err)
|
||||
}
|
||||
|
||||
if anonymize {
|
||||
for _, failure := range analysis.Error {
|
||||
for _, s := range failure.Sensitive {
|
||||
parsedText = strings.ReplaceAll(parsedText, s.Masked, s.Unmasked)
|
||||
result = strings.ReplaceAll(result, s.Masked, s.Unmasked)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
analysis.Details = parsedText
|
||||
analysis.Details = result
|
||||
if output != "json" {
|
||||
_ = bar.Add(1)
|
||||
}
|
||||
@ -308,3 +311,44 @@ func (a *Analysis) GetAIResults(output string, anonymize bool) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Analysis) getAIResultForSanitizedFailures(texts []string, promptTmpl string) (string, error) {
|
||||
inputKey := strings.Join(texts, " ")
|
||||
// Check for cached data.
|
||||
// TODO(bwplotka): This might depend on model too (or even other client configuration pieces), fix it in later PRs.
|
||||
cacheKey := util.GetCacheKey(a.AIClient.GetName(), a.Language, inputKey)
|
||||
|
||||
if !a.Cache.IsCacheDisabled() && a.Cache.Exists(cacheKey) {
|
||||
response, err := a.Cache.Load(cacheKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if response != "" {
|
||||
output, err := base64.StdEncoding.DecodeString(response)
|
||||
if err == nil {
|
||||
return string(output), nil
|
||||
}
|
||||
color.Red("error decoding cached data; ignoring cache item: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Process template.
|
||||
prompt := fmt.Sprintf(strings.TrimSpace(promptTmpl), a.Language, inputKey)
|
||||
response, err := a.AIClient.GetCompletion(a.Context, prompt)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err = a.Cache.Store(cacheKey, base64.StdEncoding.EncodeToString([]byte(response))); err != nil {
|
||||
color.Red("error storing value to cache; value won't be cached: %v", err)
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (a *Analysis) Close() {
|
||||
if a.AIClient == nil {
|
||||
return
|
||||
}
|
||||
a.AIClient.Close()
|
||||
}
|
||||
|
@ -35,10 +35,11 @@ func (h *handler) Analyze(ctx context.Context, i *schemav1.AnalyzeRequest) (
|
||||
false, // Kubernetes Doc disabled in server mode
|
||||
)
|
||||
config.Context = ctx // Replace context for correct timeouts.
|
||||
|
||||
if err != nil {
|
||||
return &schemav1.AnalyzeResponse{}, err
|
||||
}
|
||||
defer config.Close()
|
||||
|
||||
config.RunAnalysis()
|
||||
|
||||
if i.Explain {
|
||||
|
Loading…
Reference in New Issue
Block a user