mirror of
https://github.com/k8sgpt-ai/k8sgpt.git
synced 2025-08-17 15:26:37 +00:00
feat: added Google GenAI client; simplified IAI/clients API surface. (#829)
* refactor: Simplified IAI; made caching and processing consisent. Signed-off-by: bwplotka <bwplotka@gmail.com> * feat: Added Google AI API e.g. for Gemini models. Signed-off-by: bwplotka <bwplotka@gmail.com> --------- Signed-off-by: bwplotka <bwplotka@gmail.com> Co-authored-by: Alex Jones <alexsimonjones@gmail.com> Co-authored-by: Thomas Schuetz <38893055+thschue@users.noreply.github.com>
This commit is contained in:
parent
e78ff05419
commit
e7d41496dd
@ -59,6 +59,7 @@ var AnalyzeCmd = &cobra.Command{
|
|||||||
color.Red("Error: %v", err)
|
color.Red("Error: %v", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
defer config.Close()
|
||||||
|
|
||||||
config.RunAnalysis()
|
config.RunAnalysis()
|
||||||
|
|
||||||
|
@ -149,7 +149,7 @@ func init() {
|
|||||||
// add flag for url
|
// add flag for url
|
||||||
addCmd.Flags().StringVarP(&baseURL, "baseurl", "u", "", "URL AI provider, (e.g `http://localhost:8080/v1`)")
|
addCmd.Flags().StringVarP(&baseURL, "baseurl", "u", "", "URL AI provider, (e.g `http://localhost:8080/v1`)")
|
||||||
// add flag for endpointName
|
// add flag for endpointName
|
||||||
addCmd.Flags().StringVarP(&endpointName, "endpointname", "n", "", "Endpoint Name, (e.g `endpoint-xxxxxxxxxxxx`)")
|
addCmd.Flags().StringVarP(&endpointName, "endpointname", "n", "", "Endpoint Name, e.g. `endpoint-xxxxxxxxxxxx` (only for amazonbedrock, amazonsagemaker backends)")
|
||||||
// add flag for topP
|
// add flag for topP
|
||||||
addCmd.Flags().Float32VarP(&topP, "topp", "c", 0.5, "Probability Cutoff: Set a threshold (0.0-1.0) to limit word choices. Higher values add randomness, lower values increase predictability.")
|
addCmd.Flags().Float32VarP(&topP, "topp", "c", 0.5, "Probability Cutoff: Set a threshold (0.0-1.0) to limit word choices. Higher values add randomness, lower values increase predictability.")
|
||||||
// max tokens
|
// max tokens
|
||||||
@ -157,7 +157,7 @@ func init() {
|
|||||||
// add flag for temperature
|
// add flag for temperature
|
||||||
addCmd.Flags().Float32VarP(&temperature, "temperature", "t", 0.7, "The sampling temperature, value ranges between 0 ( output be more deterministic) and 1 (more random)")
|
addCmd.Flags().Float32VarP(&temperature, "temperature", "t", 0.7, "The sampling temperature, value ranges between 0 ( output be more deterministic) and 1 (more random)")
|
||||||
// add flag for azure open ai engine/deployment name
|
// add flag for azure open ai engine/deployment name
|
||||||
addCmd.Flags().StringVarP(&engine, "engine", "e", "", "Azure AI deployment name")
|
addCmd.Flags().StringVarP(&engine, "engine", "e", "", "Azure AI deployment name (only for azureopenai backend)")
|
||||||
//add flag for amazonbedrock region name
|
//add flag for amazonbedrock region name
|
||||||
addCmd.Flags().StringVarP(&providerRegion, "providerRegion", "r", "", "Provider Region name")
|
addCmd.Flags().StringVarP(&providerRegion, "providerRegion", "r", "", "Provider Region name (only for amazonbedrock backend)")
|
||||||
}
|
}
|
||||||
|
3
go.mod
3
go.mod
@ -31,6 +31,7 @@ require (
|
|||||||
github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.2.1
|
github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.2.1
|
||||||
github.com/aws/aws-sdk-go v1.49.15
|
github.com/aws/aws-sdk-go v1.49.15
|
||||||
github.com/cohere-ai/cohere-go v0.2.0
|
github.com/cohere-ai/cohere-go v0.2.0
|
||||||
|
github.com/google/generative-ai-go v0.5.0
|
||||||
github.com/olekukonko/tablewriter v0.0.5
|
github.com/olekukonko/tablewriter v0.0.5
|
||||||
google.golang.org/api v0.155.0
|
google.golang.org/api v0.155.0
|
||||||
sigs.k8s.io/controller-runtime v0.16.3
|
sigs.k8s.io/controller-runtime v0.16.3
|
||||||
@ -39,9 +40,11 @@ require (
|
|||||||
|
|
||||||
require (
|
require (
|
||||||
cloud.google.com/go v0.110.10 // indirect
|
cloud.google.com/go v0.110.10 // indirect
|
||||||
|
cloud.google.com/go/ai v0.3.0 // indirect
|
||||||
cloud.google.com/go/compute v1.23.3 // indirect
|
cloud.google.com/go/compute v1.23.3 // indirect
|
||||||
cloud.google.com/go/compute/metadata v0.2.3 // indirect
|
cloud.google.com/go/compute/metadata v0.2.3 // indirect
|
||||||
cloud.google.com/go/iam v1.1.5 // indirect
|
cloud.google.com/go/iam v1.1.5 // indirect
|
||||||
|
cloud.google.com/go/longrunning v0.5.4 // indirect
|
||||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.9.1 // indirect
|
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.9.1 // indirect
|
||||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.5.1 // indirect
|
github.com/Azure/azure-sdk-for-go/sdk/internal v1.5.1 // indirect
|
||||||
github.com/AzureAD/microsoft-authentication-library-for-go v1.2.0 // indirect
|
github.com/AzureAD/microsoft-authentication-library-for-go v1.2.0 // indirect
|
||||||
|
6
go.sum
6
go.sum
@ -49,6 +49,8 @@ cloud.google.com/go/accesscontextmanager v1.3.0/go.mod h1:TgCBehyr5gNMz7ZaH9xubp
|
|||||||
cloud.google.com/go/accesscontextmanager v1.4.0/go.mod h1:/Kjh7BBu/Gh83sv+K60vN9QE5NJcd80sU33vIe2IFPE=
|
cloud.google.com/go/accesscontextmanager v1.4.0/go.mod h1:/Kjh7BBu/Gh83sv+K60vN9QE5NJcd80sU33vIe2IFPE=
|
||||||
cloud.google.com/go/accesscontextmanager v1.6.0/go.mod h1:8XCvZWfYw3K/ji0iVnp+6pu7huxoQTLmxAbVjbloTtM=
|
cloud.google.com/go/accesscontextmanager v1.6.0/go.mod h1:8XCvZWfYw3K/ji0iVnp+6pu7huxoQTLmxAbVjbloTtM=
|
||||||
cloud.google.com/go/accesscontextmanager v1.7.0/go.mod h1:CEGLewx8dwa33aDAZQujl7Dx+uYhS0eay198wB/VumQ=
|
cloud.google.com/go/accesscontextmanager v1.7.0/go.mod h1:CEGLewx8dwa33aDAZQujl7Dx+uYhS0eay198wB/VumQ=
|
||||||
|
cloud.google.com/go/ai v0.3.0 h1:M617N0brv+XFch2KToZUhv6ggzgFZMUnmDkNQjW2pYg=
|
||||||
|
cloud.google.com/go/ai v0.3.0/go.mod h1:dTuQIBA8Kljuas5z1WNot1QZOl476A9TsFqEi6pzJlI=
|
||||||
cloud.google.com/go/aiplatform v1.22.0/go.mod h1:ig5Nct50bZlzV6NvKaTwmplLLddFx0YReh9WfTO5jKw=
|
cloud.google.com/go/aiplatform v1.22.0/go.mod h1:ig5Nct50bZlzV6NvKaTwmplLLddFx0YReh9WfTO5jKw=
|
||||||
cloud.google.com/go/aiplatform v1.24.0/go.mod h1:67UUvRBKG6GTayHKV8DBv2RtR1t93YRu5B1P3x99mYY=
|
cloud.google.com/go/aiplatform v1.24.0/go.mod h1:67UUvRBKG6GTayHKV8DBv2RtR1t93YRu5B1P3x99mYY=
|
||||||
cloud.google.com/go/aiplatform v1.27.0/go.mod h1:Bvxqtl40l0WImSb04d0hXFU7gDOiq9jQmorivIiWcKg=
|
cloud.google.com/go/aiplatform v1.27.0/go.mod h1:Bvxqtl40l0WImSb04d0hXFU7gDOiq9jQmorivIiWcKg=
|
||||||
@ -351,6 +353,8 @@ cloud.google.com/go/logging v1.7.0/go.mod h1:3xjP2CjkM3ZkO73aj4ASA5wRPGGCRrPIAeN
|
|||||||
cloud.google.com/go/longrunning v0.1.1/go.mod h1:UUFxuDWkv22EuY93jjmDMFT5GPQKeFVJBIF6QlTqdsE=
|
cloud.google.com/go/longrunning v0.1.1/go.mod h1:UUFxuDWkv22EuY93jjmDMFT5GPQKeFVJBIF6QlTqdsE=
|
||||||
cloud.google.com/go/longrunning v0.3.0/go.mod h1:qth9Y41RRSUE69rDcOn6DdK3HfQfsUI0YSmW3iIlLJc=
|
cloud.google.com/go/longrunning v0.3.0/go.mod h1:qth9Y41RRSUE69rDcOn6DdK3HfQfsUI0YSmW3iIlLJc=
|
||||||
cloud.google.com/go/longrunning v0.4.1/go.mod h1:4iWDqhBZ70CvZ6BfETbvam3T8FMvLK+eFj0E6AaRQTo=
|
cloud.google.com/go/longrunning v0.4.1/go.mod h1:4iWDqhBZ70CvZ6BfETbvam3T8FMvLK+eFj0E6AaRQTo=
|
||||||
|
cloud.google.com/go/longrunning v0.5.4 h1:w8xEcbZodnA2BbW6sVirkkoC+1gP8wS57EUUgGS0GVg=
|
||||||
|
cloud.google.com/go/longrunning v0.5.4/go.mod h1:zqNVncI0BOP8ST6XQD1+VcvuShMmq7+xFSzOL++V0dI=
|
||||||
cloud.google.com/go/managedidentities v1.3.0/go.mod h1:UzlW3cBOiPrzucO5qWkNkh0w33KFtBJU281hacNvsdE=
|
cloud.google.com/go/managedidentities v1.3.0/go.mod h1:UzlW3cBOiPrzucO5qWkNkh0w33KFtBJU281hacNvsdE=
|
||||||
cloud.google.com/go/managedidentities v1.4.0/go.mod h1:NWSBYbEMgqmbZsLIyKvxrYbtqOsxY1ZrGM+9RgDqInM=
|
cloud.google.com/go/managedidentities v1.4.0/go.mod h1:NWSBYbEMgqmbZsLIyKvxrYbtqOsxY1ZrGM+9RgDqInM=
|
||||||
cloud.google.com/go/managedidentities v1.5.0/go.mod h1:+dWcZ0JlUmpuxpIDfyP5pP5y0bLdRwOS4Lp7gMni/LA=
|
cloud.google.com/go/managedidentities v1.5.0/go.mod h1:+dWcZ0JlUmpuxpIDfyP5pP5y0bLdRwOS4Lp7gMni/LA=
|
||||||
@ -935,6 +939,8 @@ github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ
|
|||||||
github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU=
|
github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU=
|
||||||
github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
|
github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
|
||||||
github.com/google/flatbuffers v2.0.8+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8=
|
github.com/google/flatbuffers v2.0.8+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8=
|
||||||
|
github.com/google/generative-ai-go v0.5.0 h1:PfzPuSGdsmcSyPG7RIoijcKWZ7/x2kvgyNryvmXMUmA=
|
||||||
|
github.com/google/generative-ai-go v0.5.0/go.mod h1:8fXQk4w+eyTzFokGGJrBFL0/xwXqm3QNhTqOWyX11zs=
|
||||||
github.com/google/gnostic v0.7.0 h1:d7EpuFp8vVdML+y0JJJYiKeOLjKTdH/GvVkLOBWqJpw=
|
github.com/google/gnostic v0.7.0 h1:d7EpuFp8vVdML+y0JJJYiKeOLjKTdH/GvVkLOBWqJpw=
|
||||||
github.com/google/gnostic v0.7.0/go.mod h1:IAcUyMl6vtC95f60EZ8oXyqTsOersP6HbwjeG7EyDPM=
|
github.com/google/gnostic v0.7.0/go.mod h1:IAcUyMl6vtC95f60EZ8oXyqTsOersP6HbwjeG7EyDPM=
|
||||||
github.com/google/gnostic-models v0.6.9-0.20230804172637-c7be7c783f49 h1:0VpGH+cDhbDtdcweoyCVsF3fhN8kejK6rFe/2FFX2nU=
|
github.com/google/gnostic-models v0.6.9-0.20230804172637-c7be7c783f49 h1:0VpGH+cDhbDtdcweoyCVsF3fhN8kejK6rFe/2FFX2nU=
|
||||||
|
@ -2,15 +2,8 @@ package ai
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/fatih/color"
|
|
||||||
|
|
||||||
"github.com/k8sgpt-ai/k8sgpt/pkg/cache"
|
|
||||||
"github.com/k8sgpt-ai/k8sgpt/pkg/util"
|
|
||||||
|
|
||||||
"github.com/aws/aws-sdk-go/aws"
|
"github.com/aws/aws-sdk-go/aws"
|
||||||
"github.com/aws/aws-sdk-go/aws/session"
|
"github.com/aws/aws-sdk-go/aws/session"
|
||||||
@ -19,8 +12,9 @@ import (
|
|||||||
|
|
||||||
// AmazonBedRockClient represents the client for interacting with the Amazon Bedrock service.
|
// AmazonBedRockClient represents the client for interacting with the Amazon Bedrock service.
|
||||||
type AmazonBedRockClient struct {
|
type AmazonBedRockClient struct {
|
||||||
|
nopCloser
|
||||||
|
|
||||||
client *bedrockruntime.BedrockRuntime
|
client *bedrockruntime.BedrockRuntime
|
||||||
language string
|
|
||||||
model string
|
model string
|
||||||
temperature float32
|
temperature float32
|
||||||
}
|
}
|
||||||
@ -91,8 +85,8 @@ func GetRegionOrDefault(region string) string {
|
|||||||
return BEDROCK_DEFAULT_REGION
|
return BEDROCK_DEFAULT_REGION
|
||||||
}
|
}
|
||||||
|
|
||||||
// Configure configures the AmazonBedRockClient with the provided configuration and language.
|
// Configure configures the AmazonBedRockClient with the provided configuration.
|
||||||
func (a *AmazonBedRockClient) Configure(config IAIConfig, language string) error {
|
func (a *AmazonBedRockClient) Configure(config IAIConfig) error {
|
||||||
|
|
||||||
// Create a new AWS session
|
// Create a new AWS session
|
||||||
providerRegion := GetRegionOrDefault(config.GetProviderRegion())
|
providerRegion := GetRegionOrDefault(config.GetProviderRegion())
|
||||||
@ -107,7 +101,6 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig, language string) error
|
|||||||
|
|
||||||
// Create a new BedrockRuntime client
|
// Create a new BedrockRuntime client
|
||||||
a.client = bedrockruntime.New(sess)
|
a.client = bedrockruntime.New(sess)
|
||||||
a.language = language
|
|
||||||
a.model = GetModelOrDefault(config.GetModel())
|
a.model = GetModelOrDefault(config.GetModel())
|
||||||
a.temperature = config.GetTemperature()
|
a.temperature = config.GetTemperature()
|
||||||
|
|
||||||
@ -115,7 +108,7 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig, language string) error
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetCompletion sends a request to the model for generating completion based on the provided prompt.
|
// GetCompletion sends a request to the model for generating completion based on the provided prompt.
|
||||||
func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string, promptTmpl string) (string, error) {
|
func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string) (string, error) {
|
||||||
|
|
||||||
// Prepare the input data for the model invocation
|
// Prepare the input data for the model invocation
|
||||||
request := map[string]interface{}{
|
request := map[string]interface{}{
|
||||||
@ -152,44 +145,6 @@ func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string,
|
|||||||
return output.Completion, nil
|
return output.Completion, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse generates a completion for the provided prompt using the Amazon Bedrock model.
|
|
||||||
func (a *AmazonBedRockClient) Parse(ctx context.Context, prompt []string, cache cache.ICache, promptTmpl string) (string, error) {
|
|
||||||
inputKey := strings.Join(prompt, " ")
|
|
||||||
// Check for cached data
|
|
||||||
cacheKey := util.GetCacheKey(a.GetName(), a.language, inputKey)
|
|
||||||
|
|
||||||
if !cache.IsCacheDisabled() && cache.Exists(cacheKey) {
|
|
||||||
response, err := cache.Load(cacheKey)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
if response != "" {
|
|
||||||
output, err := base64.StdEncoding.DecodeString(response)
|
|
||||||
if err != nil {
|
|
||||||
color.Red("error decoding cached data: %v", err)
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
return string(output), nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
response, err := a.GetCompletion(ctx, inputKey, promptTmpl)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = cache.Store(cacheKey, base64.StdEncoding.EncodeToString([]byte(response)))
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
color.Red("error storing value to cache: %v", err)
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return response, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetName returns the name of the AmazonBedRockClient.
|
// GetName returns the name of the AmazonBedRockClient.
|
||||||
func (a *AmazonBedRockClient) GetName() string {
|
func (a *AmazonBedRockClient) GetName() string {
|
||||||
return "amazonbedrock"
|
return "amazonbedrock"
|
||||||
|
@ -15,15 +15,8 @@ package ai
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"github.com/fatih/color"
|
|
||||||
"github.com/k8sgpt-ai/k8sgpt/pkg/cache"
|
|
||||||
"github.com/k8sgpt-ai/k8sgpt/pkg/util"
|
|
||||||
|
|
||||||
"github.com/aws/aws-sdk-go/aws"
|
"github.com/aws/aws-sdk-go/aws"
|
||||||
"github.com/aws/aws-sdk-go/aws/session"
|
"github.com/aws/aws-sdk-go/aws/session"
|
||||||
@ -31,8 +24,9 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type SageMakerAIClient struct {
|
type SageMakerAIClient struct {
|
||||||
|
nopCloser
|
||||||
|
|
||||||
client *sagemakerruntime.SageMakerRuntime
|
client *sagemakerruntime.SageMakerRuntime
|
||||||
language string
|
|
||||||
model string
|
model string
|
||||||
temperature float32
|
temperature float32
|
||||||
endpoint string
|
endpoint string
|
||||||
@ -63,7 +57,7 @@ type Parameters struct {
|
|||||||
Temperature float64 `json:"temperature"`
|
Temperature float64 `json:"temperature"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *SageMakerAIClient) Configure(config IAIConfig, language string) error {
|
func (c *SageMakerAIClient) Configure(config IAIConfig) error {
|
||||||
|
|
||||||
// Create a new AWS session
|
// Create a new AWS session
|
||||||
sess := session.Must(session.NewSessionWithOptions(session.Options{
|
sess := session.Must(session.NewSessionWithOptions(session.Options{
|
||||||
@ -71,7 +65,6 @@ func (c *SageMakerAIClient) Configure(config IAIConfig, language string) error {
|
|||||||
SharedConfigState: session.SharedConfigEnable,
|
SharedConfigState: session.SharedConfigEnable,
|
||||||
}))
|
}))
|
||||||
|
|
||||||
c.language = language
|
|
||||||
// Create a new SageMaker runtime client
|
// Create a new SageMaker runtime client
|
||||||
c.client = sagemakerruntime.New(sess)
|
c.client = sagemakerruntime.New(sess)
|
||||||
c.model = config.GetModel()
|
c.model = config.GetModel()
|
||||||
@ -82,18 +75,13 @@ func (c *SageMakerAIClient) Configure(config IAIConfig, language string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *SageMakerAIClient) GetCompletion(ctx context.Context, prompt string, promptTmpl string) (string, error) {
|
func (c *SageMakerAIClient) GetCompletion(_ context.Context, prompt string) (string, error) {
|
||||||
// Create a completion request
|
// Create a completion request
|
||||||
|
|
||||||
if len(promptTmpl) == 0 {
|
|
||||||
promptTmpl = PromptMap["default"]
|
|
||||||
}
|
|
||||||
|
|
||||||
request := Request{
|
request := Request{
|
||||||
Inputs: [][]Message{
|
Inputs: [][]Message{
|
||||||
{
|
{
|
||||||
{Role: "system", Content: "DEFAULT_PROMPT"},
|
{Role: "system", Content: "DEFAULT_PROMPT"},
|
||||||
{Role: "user", Content: fmt.Sprintf(promptTmpl, c.language, prompt)},
|
{Role: "user", Content: prompt},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
@ -142,29 +130,6 @@ func (c *SageMakerAIClient) GetCompletion(ctx context.Context, prompt string, pr
|
|||||||
return content, nil
|
return content, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *SageMakerAIClient) Parse(ctx context.Context, prompt []string, cache cache.ICache, promptTmpl string) (string, error) {
|
func (c *SageMakerAIClient) GetName() string {
|
||||||
// parse the text with the AI backend
|
|
||||||
inputKey := strings.Join(prompt, " ")
|
|
||||||
// Check for cached data
|
|
||||||
sEnc := base64.StdEncoding.EncodeToString([]byte(inputKey))
|
|
||||||
cacheKey := util.GetCacheKey(a.GetName(), a.language, sEnc)
|
|
||||||
|
|
||||||
response, err := a.GetCompletion(ctx, inputKey, promptTmpl)
|
|
||||||
if err != nil {
|
|
||||||
color.Red("error getting completion: %v", err)
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = cache.Store(cacheKey, base64.StdEncoding.EncodeToString([]byte(response)))
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
color.Red("error storing value to cache: %v", err)
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
return response, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *SageMakerAIClient) GetName() string {
|
|
||||||
return "amazonsagemaker"
|
return "amazonsagemaker"
|
||||||
}
|
}
|
||||||
|
@ -2,27 +2,20 @@ package ai
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/k8sgpt-ai/k8sgpt/pkg/cache"
|
|
||||||
"github.com/k8sgpt-ai/k8sgpt/pkg/util"
|
|
||||||
|
|
||||||
"github.com/fatih/color"
|
|
||||||
|
|
||||||
"github.com/sashabaranov/go-openai"
|
"github.com/sashabaranov/go-openai"
|
||||||
)
|
)
|
||||||
|
|
||||||
type AzureAIClient struct {
|
type AzureAIClient struct {
|
||||||
|
nopCloser
|
||||||
|
|
||||||
client *openai.Client
|
client *openai.Client
|
||||||
language string
|
|
||||||
model string
|
model string
|
||||||
temperature float32
|
temperature float32
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *AzureAIClient) Configure(config IAIConfig, lang string) error {
|
func (c *AzureAIClient) Configure(config IAIConfig) error {
|
||||||
token := config.GetPassword()
|
token := config.GetPassword()
|
||||||
baseURL := config.GetBaseURL()
|
baseURL := config.GetBaseURL()
|
||||||
engine := config.GetEngine()
|
engine := config.GetEngine()
|
||||||
@ -40,21 +33,20 @@ func (c *AzureAIClient) Configure(config IAIConfig, lang string) error {
|
|||||||
if client == nil {
|
if client == nil {
|
||||||
return errors.New("error creating Azure OpenAI client")
|
return errors.New("error creating Azure OpenAI client")
|
||||||
}
|
}
|
||||||
c.language = lang
|
|
||||||
c.client = client
|
c.client = client
|
||||||
c.model = config.GetModel()
|
c.model = config.GetModel()
|
||||||
c.temperature = config.GetTemperature()
|
c.temperature = config.GetTemperature()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *AzureAIClient) GetCompletion(ctx context.Context, prompt string, promptTmpl string) (string, error) {
|
func (c *AzureAIClient) GetCompletion(ctx context.Context, prompt string) (string, error) {
|
||||||
// Create a completion request
|
// Create a completion request
|
||||||
resp, err := c.client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{
|
resp, err := c.client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{
|
||||||
Model: c.model,
|
Model: c.model,
|
||||||
Messages: []openai.ChatCompletionMessage{
|
Messages: []openai.ChatCompletionMessage{
|
||||||
{
|
{
|
||||||
Role: openai.ChatMessageRoleUser,
|
Role: openai.ChatMessageRoleUser,
|
||||||
Content: fmt.Sprintf(default_prompt, c.language, prompt),
|
Content: prompt,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Temperature: c.temperature,
|
Temperature: c.temperature,
|
||||||
@ -65,42 +57,6 @@ func (c *AzureAIClient) GetCompletion(ctx context.Context, prompt string, prompt
|
|||||||
return resp.Choices[0].Message.Content, nil
|
return resp.Choices[0].Message.Content, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *AzureAIClient) Parse(ctx context.Context, prompt []string, cache cache.ICache, promptTmpl string) (string, error) {
|
func (c *AzureAIClient) GetName() string {
|
||||||
inputKey := strings.Join(prompt, " ")
|
|
||||||
// Check for cached data
|
|
||||||
cacheKey := util.GetCacheKey(a.GetName(), a.language, inputKey)
|
|
||||||
|
|
||||||
if !cache.IsCacheDisabled() && cache.Exists(cacheKey) {
|
|
||||||
response, err := cache.Load(cacheKey)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
if response != "" {
|
|
||||||
output, err := base64.StdEncoding.DecodeString(response)
|
|
||||||
if err != nil {
|
|
||||||
color.Red("error decoding cached data: %v", err)
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
return string(output), nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
response, err := a.GetCompletion(ctx, inputKey, promptTmpl)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = cache.Store(cacheKey, base64.StdEncoding.EncodeToString([]byte(response)))
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
color.Red("error storing value to cache: %v", err)
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return response, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *AzureAIClient) GetName() string {
|
|
||||||
return "azureopenai"
|
return "azureopenai"
|
||||||
}
|
}
|
||||||
|
@ -15,26 +15,20 @@ package ai
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/cohere-ai/cohere-go"
|
"github.com/cohere-ai/cohere-go"
|
||||||
"github.com/fatih/color"
|
|
||||||
|
|
||||||
"github.com/k8sgpt-ai/k8sgpt/pkg/cache"
|
|
||||||
"github.com/k8sgpt-ai/k8sgpt/pkg/util"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type CohereClient struct {
|
type CohereClient struct {
|
||||||
|
nopCloser
|
||||||
|
|
||||||
client *cohere.Client
|
client *cohere.Client
|
||||||
language string
|
|
||||||
model string
|
model string
|
||||||
temperature float32
|
temperature float32
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *CohereClient) Configure(config IAIConfig, language string) error {
|
func (c *CohereClient) Configure(config IAIConfig) error {
|
||||||
token := config.GetPassword()
|
token := config.GetPassword()
|
||||||
|
|
||||||
client, err := cohere.CreateClient(token)
|
client, err := cohere.CreateClient(token)
|
||||||
@ -50,21 +44,17 @@ func (c *CohereClient) Configure(config IAIConfig, language string) error {
|
|||||||
if client == nil {
|
if client == nil {
|
||||||
return errors.New("error creating Cohere client")
|
return errors.New("error creating Cohere client")
|
||||||
}
|
}
|
||||||
c.language = language
|
|
||||||
c.client = client
|
c.client = client
|
||||||
c.model = config.GetModel()
|
c.model = config.GetModel()
|
||||||
c.temperature = config.GetTemperature()
|
c.temperature = config.GetTemperature()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *CohereClient) GetCompletion(ctx context.Context, prompt, promptTmpl string) (string, error) {
|
func (c *CohereClient) GetCompletion(_ context.Context, prompt string) (string, error) {
|
||||||
// Create a completion request
|
// Create a completion request
|
||||||
if len(promptTmpl) == 0 {
|
|
||||||
promptTmpl = PromptMap["default"]
|
|
||||||
}
|
|
||||||
resp, err := c.client.Generate(cohere.GenerateOptions{
|
resp, err := c.client.Generate(cohere.GenerateOptions{
|
||||||
Model: c.model,
|
Model: c.model,
|
||||||
Prompt: fmt.Sprintf(strings.TrimSpace(promptTmpl), c.language, prompt),
|
Prompt: prompt,
|
||||||
MaxTokens: cohere.Uint(2048),
|
MaxTokens: cohere.Uint(2048),
|
||||||
Temperature: cohere.Float64(float64(c.temperature)),
|
Temperature: cohere.Float64(float64(c.temperature)),
|
||||||
K: cohere.Int(0),
|
K: cohere.Int(0),
|
||||||
@ -77,42 +67,6 @@ func (c *CohereClient) GetCompletion(ctx context.Context, prompt, promptTmpl str
|
|||||||
return resp.Generations[0].Text, nil
|
return resp.Generations[0].Text, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *CohereClient) Parse(ctx context.Context, prompt []string, cache cache.ICache, promptTmpl string) (string, error) {
|
func (c *CohereClient) GetName() string {
|
||||||
inputKey := strings.Join(prompt, " ")
|
|
||||||
// Check for cached data
|
|
||||||
cacheKey := util.GetCacheKey(a.GetName(), a.language, inputKey)
|
|
||||||
|
|
||||||
if !cache.IsCacheDisabled() && cache.Exists(cacheKey) {
|
|
||||||
response, err := cache.Load(cacheKey)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
if response != "" {
|
|
||||||
output, err := base64.StdEncoding.DecodeString(response)
|
|
||||||
if err != nil {
|
|
||||||
color.Red("error decoding cached data: %v", err)
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
return string(output), nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
response, err := a.GetCompletion(ctx, inputKey, promptTmpl)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = cache.Store(cacheKey, base64.StdEncoding.EncodeToString([]byte(response)))
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
color.Red("error storing value to cache: %v", err)
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return response, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *CohereClient) GetName() string {
|
|
||||||
return "cohere"
|
return "cohere"
|
||||||
}
|
}
|
||||||
|
119
pkg/ai/googlegenai.go
Normal file
119
pkg/ai/googlegenai.go
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
/*
|
||||||
|
Copyright 2023 The K8sGPT Authors.
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/fatih/color"
|
||||||
|
"github.com/google/generative-ai-go/genai"
|
||||||
|
"google.golang.org/api/option"
|
||||||
|
)
|
||||||
|
|
||||||
|
const googleAIClientName = "google"
|
||||||
|
|
||||||
|
type GoogleGenAIClient struct {
|
||||||
|
client *genai.Client
|
||||||
|
|
||||||
|
model string
|
||||||
|
temperature float32
|
||||||
|
topP float32
|
||||||
|
maxTokens int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *GoogleGenAIClient) Configure(config IAIConfig) error {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Access your API key as an environment variable (see "Set up your API key" above)
|
||||||
|
token := config.GetPassword()
|
||||||
|
authOption := option.WithAPIKey(token)
|
||||||
|
if token[0] == '{' {
|
||||||
|
authOption = option.WithCredentialsJSON([]byte(token))
|
||||||
|
}
|
||||||
|
|
||||||
|
client, err := genai.NewClient(ctx, authOption)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("creating genai Google SDK client: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.client = client
|
||||||
|
c.model = config.GetModel()
|
||||||
|
c.temperature = config.GetTemperature()
|
||||||
|
c.topP = config.GetTopP()
|
||||||
|
c.maxTokens = config.GetMaxTokens()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *GoogleGenAIClient) GetCompletion(ctx context.Context, prompt string) (string, error) {
|
||||||
|
// Available models are at https://ai.google.dev/models e.g.gemini-pro.
|
||||||
|
model := c.client.GenerativeModel(c.model)
|
||||||
|
model.SetTemperature(c.temperature)
|
||||||
|
model.SetTopP(c.topP)
|
||||||
|
model.SetMaxOutputTokens(int32(c.maxTokens))
|
||||||
|
|
||||||
|
// Google AI SDK is capable of different inputs than just text, for now set explicit text prompt type.
|
||||||
|
// Similarly, we could stream the response. For now k8sgpt does not support streaming.
|
||||||
|
resp, err := model.GenerateContent(ctx, genai.Text(prompt))
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(resp.Candidates) == 0 {
|
||||||
|
if resp.PromptFeedback.BlockReason == genai.BlockReasonSafety {
|
||||||
|
for _, r := range resp.PromptFeedback.SafetyRatings {
|
||||||
|
if !r.Blocked {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("complection blocked due to %v with probability %v", r.Category.String(), r.Probability.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", errors.New("no complection returned; unknown reason")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Format output.
|
||||||
|
// TODO(bwplotka): Provider richer output in certain cases e.g. suddenly finished
|
||||||
|
// completion based on finish reasons or safety rankings.
|
||||||
|
got := resp.Candidates[0]
|
||||||
|
var output string
|
||||||
|
for _, part := range got.Content.Parts {
|
||||||
|
switch o := part.(type) {
|
||||||
|
case genai.Text:
|
||||||
|
output += string(o)
|
||||||
|
output += "\n"
|
||||||
|
default:
|
||||||
|
color.Yellow("found unsupported AI response part of type %T; ignoring", part)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if got.CitationMetadata != nil && len(got.CitationMetadata.CitationSources) > 0 {
|
||||||
|
output += "Citations:\n"
|
||||||
|
for _, source := range got.CitationMetadata.CitationSources {
|
||||||
|
// TODO(bwplotka): Give details around what exactly words could be attributed to the citation.
|
||||||
|
output += fmt.Sprintf("* %s, %s\n", *source.URI, source.License)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return output, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *GoogleGenAIClient) GetName() string {
|
||||||
|
return googleAIClientName
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *GoogleGenAIClient) Close() {
|
||||||
|
if err := c.client.Close(); err != nil {
|
||||||
|
color.Red("googleai client close error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
@ -15,8 +15,6 @@ package ai
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
"github.com/k8sgpt-ai/k8sgpt/pkg/cache"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -28,25 +26,38 @@ var (
|
|||||||
&CohereClient{},
|
&CohereClient{},
|
||||||
&AmazonBedRockClient{},
|
&AmazonBedRockClient{},
|
||||||
&SageMakerAIClient{},
|
&SageMakerAIClient{},
|
||||||
|
&GoogleGenAIClient{},
|
||||||
}
|
}
|
||||||
Backends = []string{
|
Backends = []string{
|
||||||
"openai",
|
"openai",
|
||||||
"localai",
|
"localai",
|
||||||
"azureopenai",
|
"azureopenai",
|
||||||
"noopai",
|
|
||||||
"cohere",
|
"cohere",
|
||||||
"amazonbedrock",
|
"amazonbedrock",
|
||||||
"amazonsagemaker",
|
"amazonsagemaker",
|
||||||
|
googleAIClientName,
|
||||||
|
"noopai",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// IAI is an interface all clients (representing backends) share.
|
||||||
type IAI interface {
|
type IAI interface {
|
||||||
Configure(config IAIConfig, language string) error
|
// Configure sets up client for given configuration. This is expected to be
|
||||||
GetCompletion(ctx context.Context, prompt string, promptTmpl string) (string, error)
|
// executed once per client life-time (e.g. analysis CLI command invocation).
|
||||||
Parse(ctx context.Context, prompt []string, cache cache.ICache, promptTmpl string) (string, error)
|
Configure(config IAIConfig) error
|
||||||
|
// GetCompletion generates text based on prompt.
|
||||||
|
GetCompletion(ctx context.Context, prompt string) (string, error)
|
||||||
|
// GetName returns name of the backend/client.
|
||||||
GetName() string
|
GetName() string
|
||||||
|
// Close cleans all the resources. No other methods should be used on the
|
||||||
|
// objects after this method is invoked.
|
||||||
|
Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type nopCloser struct{}
|
||||||
|
|
||||||
|
func (nopCloser) Close() {}
|
||||||
|
|
||||||
type IAIConfig interface {
|
type IAIConfig interface {
|
||||||
GetPassword() string
|
GetPassword() string
|
||||||
GetModel() string
|
GetModel() string
|
||||||
|
@ -15,58 +15,21 @@ package ai
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/fatih/color"
|
|
||||||
"github.com/k8sgpt-ai/k8sgpt/pkg/cache"
|
|
||||||
"github.com/k8sgpt-ai/k8sgpt/pkg/util"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type NoOpAIClient struct {
|
type NoOpAIClient struct {
|
||||||
client string
|
nopCloser
|
||||||
language string
|
|
||||||
model string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *NoOpAIClient) Configure(config IAIConfig, language string) error {
|
func (c *NoOpAIClient) Configure(_ IAIConfig) error {
|
||||||
token := config.GetPassword()
|
|
||||||
c.language = language
|
|
||||||
c.client = fmt.Sprintf("I am a noop client with the token %s ", token)
|
|
||||||
c.model = config.GetModel()
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *NoOpAIClient) GetCompletion(ctx context.Context, prompt string, promptTmpl string) (string, error) {
|
func (c *NoOpAIClient) GetCompletion(_ context.Context, prompt string) (string, error) {
|
||||||
// Create a completion request
|
|
||||||
response := "I am a noop response to the prompt " + prompt
|
response := "I am a noop response to the prompt " + prompt
|
||||||
return response, nil
|
return response, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *NoOpAIClient) Parse(ctx context.Context, prompt []string, cache cache.ICache, promptTmpl string) (string, error) {
|
func (c *NoOpAIClient) GetName() string {
|
||||||
// parse the text with the AI backend
|
|
||||||
inputKey := strings.Join(prompt, " ")
|
|
||||||
// Check for cached data
|
|
||||||
sEnc := base64.StdEncoding.EncodeToString([]byte(inputKey))
|
|
||||||
cacheKey := util.GetCacheKey(a.GetName(), a.language, sEnc)
|
|
||||||
|
|
||||||
response, err := a.GetCompletion(ctx, inputKey, promptTmpl)
|
|
||||||
if err != nil {
|
|
||||||
color.Red("error getting completion: %v", err)
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = cache.Store(cacheKey, base64.StdEncoding.EncodeToString([]byte(response)))
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
color.Red("error storing value to cache: %v", err)
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return response, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *NoOpAIClient) GetName() string {
|
|
||||||
return "noopai"
|
return "noopai"
|
||||||
}
|
}
|
||||||
|
@ -15,22 +15,15 @@ package ai
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/k8sgpt-ai/k8sgpt/pkg/cache"
|
|
||||||
"github.com/k8sgpt-ai/k8sgpt/pkg/util"
|
|
||||||
|
|
||||||
"github.com/sashabaranov/go-openai"
|
"github.com/sashabaranov/go-openai"
|
||||||
|
|
||||||
"github.com/fatih/color"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type OpenAIClient struct {
|
type OpenAIClient struct {
|
||||||
|
nopCloser
|
||||||
|
|
||||||
client *openai.Client
|
client *openai.Client
|
||||||
language string
|
|
||||||
model string
|
model string
|
||||||
temperature float32
|
temperature float32
|
||||||
}
|
}
|
||||||
@ -43,7 +36,7 @@ const (
|
|||||||
topP = 1.0
|
topP = 1.0
|
||||||
)
|
)
|
||||||
|
|
||||||
func (c *OpenAIClient) Configure(config IAIConfig, language string) error {
|
func (c *OpenAIClient) Configure(config IAIConfig) error {
|
||||||
token := config.GetPassword()
|
token := config.GetPassword()
|
||||||
defaultConfig := openai.DefaultConfig(token)
|
defaultConfig := openai.DefaultConfig(token)
|
||||||
|
|
||||||
@ -56,24 +49,20 @@ func (c *OpenAIClient) Configure(config IAIConfig, language string) error {
|
|||||||
if client == nil {
|
if client == nil {
|
||||||
return errors.New("error creating OpenAI client")
|
return errors.New("error creating OpenAI client")
|
||||||
}
|
}
|
||||||
c.language = language
|
|
||||||
c.client = client
|
c.client = client
|
||||||
c.model = config.GetModel()
|
c.model = config.GetModel()
|
||||||
c.temperature = config.GetTemperature()
|
c.temperature = config.GetTemperature()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *OpenAIClient) GetCompletion(ctx context.Context, prompt string, promptTmpl string) (string, error) {
|
func (c *OpenAIClient) GetCompletion(ctx context.Context, prompt string) (string, error) {
|
||||||
// Create a completion request
|
// Create a completion request
|
||||||
if len(promptTmpl) == 0 {
|
|
||||||
promptTmpl = PromptMap["default"]
|
|
||||||
}
|
|
||||||
resp, err := c.client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{
|
resp, err := c.client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{
|
||||||
Model: c.model,
|
Model: c.model,
|
||||||
Messages: []openai.ChatCompletionMessage{
|
Messages: []openai.ChatCompletionMessage{
|
||||||
{
|
{
|
||||||
Role: "user",
|
Role: "user",
|
||||||
Content: fmt.Sprintf(promptTmpl, c.language, prompt),
|
Content: prompt,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Temperature: c.temperature,
|
Temperature: c.temperature,
|
||||||
@ -88,42 +77,6 @@ func (c *OpenAIClient) GetCompletion(ctx context.Context, prompt string, promptT
|
|||||||
return resp.Choices[0].Message.Content, nil
|
return resp.Choices[0].Message.Content, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *OpenAIClient) Parse(ctx context.Context, prompt []string, cache cache.ICache, promptTmpl string) (string, error) {
|
func (c *OpenAIClient) GetName() string {
|
||||||
inputKey := strings.Join(prompt, " ")
|
|
||||||
// Check for cached data
|
|
||||||
cacheKey := util.GetCacheKey(a.GetName(), a.language, inputKey)
|
|
||||||
|
|
||||||
if !cache.IsCacheDisabled() && cache.Exists(cacheKey) {
|
|
||||||
response, err := cache.Load(cacheKey)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
if response != "" {
|
|
||||||
output, err := base64.StdEncoding.DecodeString(response)
|
|
||||||
if err != nil {
|
|
||||||
color.Red("error decoding cached data: %v", err)
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
return string(output), nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
response, err := a.GetCompletion(ctx, inputKey, promptTmpl)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = cache.Store(cacheKey, base64.StdEncoding.EncodeToString([]byte(response)))
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
color.Red("error storing value to cache: %v", err)
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return response, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *OpenAIClient) GetName() string {
|
|
||||||
return "openai"
|
return "openai"
|
||||||
}
|
}
|
||||||
|
@ -15,12 +15,14 @@ package analysis
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"github.com/fatih/color"
|
||||||
openapi_v2 "github.com/google/gnostic/openapiv2"
|
openapi_v2 "github.com/google/gnostic/openapiv2"
|
||||||
"github.com/k8sgpt-ai/k8sgpt/pkg/ai"
|
"github.com/k8sgpt-ai/k8sgpt/pkg/ai"
|
||||||
"github.com/k8sgpt-ai/k8sgpt/pkg/analyzer"
|
"github.com/k8sgpt-ai/k8sgpt/pkg/analyzer"
|
||||||
@ -36,6 +38,7 @@ type Analysis struct {
|
|||||||
Context context.Context
|
Context context.Context
|
||||||
Filters []string
|
Filters []string
|
||||||
Client *kubernetes.Client
|
Client *kubernetes.Client
|
||||||
|
Language string
|
||||||
AIClient ai.IAI
|
AIClient ai.IAI
|
||||||
Results []common.Result
|
Results []common.Result
|
||||||
Errors []string
|
Errors []string
|
||||||
@ -95,6 +98,7 @@ func NewAnalysis(
|
|||||||
Context: context.Background(),
|
Context: context.Background(),
|
||||||
Filters: filters,
|
Filters: filters,
|
||||||
Client: client,
|
Client: client,
|
||||||
|
Language: language,
|
||||||
Namespace: namespace,
|
Namespace: namespace,
|
||||||
Cache: cache,
|
Cache: cache,
|
||||||
Explain: explain,
|
Explain: explain,
|
||||||
@ -134,7 +138,7 @@ func NewAnalysis(
|
|||||||
}
|
}
|
||||||
|
|
||||||
aiClient := ai.NewClient(aiProvider.Name)
|
aiClient := ai.NewClient(aiProvider.Name)
|
||||||
if err := aiClient.Configure(&aiProvider, language); err != nil {
|
if err := aiClient.Configure(&aiProvider); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
a.AIClient = aiClient
|
a.AIClient = aiClient
|
||||||
@ -269,14 +273,14 @@ func (a *Analysis) GetAIResults(output string, anonymize bool) error {
|
|||||||
}
|
}
|
||||||
texts = append(texts, failure.Text)
|
texts = append(texts, failure.Text)
|
||||||
}
|
}
|
||||||
// If the resource `Kind` comes from a "integration plugin", maybe a customized prompt template will be involved.
|
|
||||||
var promptTemplate string
|
promptTemplate := ai.PromptMap["default"]
|
||||||
|
// If the resource `Kind` comes from an "integration plugin",
|
||||||
|
// maybe a customized prompt template will be involved.
|
||||||
if prompt, ok := ai.PromptMap[analysis.Kind]; ok {
|
if prompt, ok := ai.PromptMap[analysis.Kind]; ok {
|
||||||
promptTemplate = prompt
|
promptTemplate = prompt
|
||||||
} else {
|
|
||||||
promptTemplate = ai.PromptMap["default"]
|
|
||||||
}
|
}
|
||||||
parsedText, err := a.AIClient.Parse(a.Context, texts, a.Cache, promptTemplate)
|
result, err := a.getAIResultForSanitizedFailures(texts, promptTemplate)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// FIXME: can we avoid checking if output is json multiple times?
|
// FIXME: can we avoid checking if output is json multiple times?
|
||||||
// maybe implement the progress bar better?
|
// maybe implement the progress bar better?
|
||||||
@ -284,23 +288,22 @@ func (a *Analysis) GetAIResults(output string, anonymize bool) error {
|
|||||||
_ = bar.Exit()
|
_ = bar.Exit()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for exhaustion
|
// Check for exhaustion.
|
||||||
if strings.Contains(err.Error(), "status code: 429") {
|
if strings.Contains(err.Error(), "status code: 429") {
|
||||||
return fmt.Errorf("exhausted API quota for AI provider %s: %v", a.AIClient.GetName(), err)
|
return fmt.Errorf("exhausted API quota for AI provider %s: %v", a.AIClient.GetName(), err)
|
||||||
} else {
|
|
||||||
return fmt.Errorf("failed while calling AI provider %s: %v", a.AIClient.GetName(), err)
|
|
||||||
}
|
}
|
||||||
|
return fmt.Errorf("failed while calling AI provider %s: %v", a.AIClient.GetName(), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if anonymize {
|
if anonymize {
|
||||||
for _, failure := range analysis.Error {
|
for _, failure := range analysis.Error {
|
||||||
for _, s := range failure.Sensitive {
|
for _, s := range failure.Sensitive {
|
||||||
parsedText = strings.ReplaceAll(parsedText, s.Masked, s.Unmasked)
|
result = strings.ReplaceAll(result, s.Masked, s.Unmasked)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
analysis.Details = parsedText
|
analysis.Details = result
|
||||||
if output != "json" {
|
if output != "json" {
|
||||||
_ = bar.Add(1)
|
_ = bar.Add(1)
|
||||||
}
|
}
|
||||||
@ -308,3 +311,44 @@ func (a *Analysis) GetAIResults(output string, anonymize bool) error {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Analysis) getAIResultForSanitizedFailures(texts []string, promptTmpl string) (string, error) {
|
||||||
|
inputKey := strings.Join(texts, " ")
|
||||||
|
// Check for cached data.
|
||||||
|
// TODO(bwplotka): This might depend on model too (or even other client configuration pieces), fix it in later PRs.
|
||||||
|
cacheKey := util.GetCacheKey(a.AIClient.GetName(), a.Language, inputKey)
|
||||||
|
|
||||||
|
if !a.Cache.IsCacheDisabled() && a.Cache.Exists(cacheKey) {
|
||||||
|
response, err := a.Cache.Load(cacheKey)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if response != "" {
|
||||||
|
output, err := base64.StdEncoding.DecodeString(response)
|
||||||
|
if err == nil {
|
||||||
|
return string(output), nil
|
||||||
|
}
|
||||||
|
color.Red("error decoding cached data; ignoring cache item: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process template.
|
||||||
|
prompt := fmt.Sprintf(strings.TrimSpace(promptTmpl), a.Language, inputKey)
|
||||||
|
response, err := a.AIClient.GetCompletion(a.Context, prompt)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = a.Cache.Store(cacheKey, base64.StdEncoding.EncodeToString([]byte(response))); err != nil {
|
||||||
|
color.Red("error storing value to cache; value won't be cached: %v", err)
|
||||||
|
}
|
||||||
|
return response, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Analysis) Close() {
|
||||||
|
if a.AIClient == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
a.AIClient.Close()
|
||||||
|
}
|
||||||
|
@ -35,10 +35,11 @@ func (h *handler) Analyze(ctx context.Context, i *schemav1.AnalyzeRequest) (
|
|||||||
false, // Kubernetes Doc disabled in server mode
|
false, // Kubernetes Doc disabled in server mode
|
||||||
)
|
)
|
||||||
config.Context = ctx // Replace context for correct timeouts.
|
config.Context = ctx // Replace context for correct timeouts.
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &schemav1.AnalyzeResponse{}, err
|
return &schemav1.AnalyzeResponse{}, err
|
||||||
}
|
}
|
||||||
|
defer config.Close()
|
||||||
|
|
||||||
config.RunAnalysis()
|
config.RunAnalysis()
|
||||||
|
|
||||||
if i.Explain {
|
if i.Explain {
|
||||||
|
Loading…
Reference in New Issue
Block a user