feat: added Google GenAI client; simplified IAI/clients API surface. (#829)

* refactor: Simplified IAI; made caching and processing consisent.


Signed-off-by: bwplotka <bwplotka@gmail.com>

* feat: Added Google AI API e.g. for Gemini models.

Signed-off-by: bwplotka <bwplotka@gmail.com>

---------

Signed-off-by: bwplotka <bwplotka@gmail.com>
Co-authored-by: Alex Jones <alexsimonjones@gmail.com>
Co-authored-by: Thomas Schuetz <38893055+thschue@users.noreply.github.com>
This commit is contained in:
Bartlomiej Plotka
2024-01-05 06:53:36 +01:00
committed by GitHub
parent e78ff05419
commit e7d41496dd
14 changed files with 240 additions and 309 deletions

View File

@@ -2,15 +2,8 @@ package ai
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"strings"
"github.com/fatih/color"
"github.com/k8sgpt-ai/k8sgpt/pkg/cache"
"github.com/k8sgpt-ai/k8sgpt/pkg/util"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
@@ -19,8 +12,9 @@ import (
// AmazonBedRockClient represents the client for interacting with the Amazon Bedrock service.
type AmazonBedRockClient struct {
nopCloser
client *bedrockruntime.BedrockRuntime
language string
model string
temperature float32
}
@@ -91,8 +85,8 @@ func GetRegionOrDefault(region string) string {
return BEDROCK_DEFAULT_REGION
}
// Configure configures the AmazonBedRockClient with the provided configuration and language.
func (a *AmazonBedRockClient) Configure(config IAIConfig, language string) error {
// Configure configures the AmazonBedRockClient with the provided configuration.
func (a *AmazonBedRockClient) Configure(config IAIConfig) error {
// Create a new AWS session
providerRegion := GetRegionOrDefault(config.GetProviderRegion())
@@ -107,7 +101,6 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig, language string) error
// Create a new BedrockRuntime client
a.client = bedrockruntime.New(sess)
a.language = language
a.model = GetModelOrDefault(config.GetModel())
a.temperature = config.GetTemperature()
@@ -115,7 +108,7 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig, language string) error
}
// GetCompletion sends a request to the model for generating completion based on the provided prompt.
func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string, promptTmpl string) (string, error) {
func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string) (string, error) {
// Prepare the input data for the model invocation
request := map[string]interface{}{
@@ -152,44 +145,6 @@ func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string,
return output.Completion, nil
}
// Parse generates a completion for the provided prompt using the Amazon Bedrock model.
func (a *AmazonBedRockClient) Parse(ctx context.Context, prompt []string, cache cache.ICache, promptTmpl string) (string, error) {
inputKey := strings.Join(prompt, " ")
// Check for cached data
cacheKey := util.GetCacheKey(a.GetName(), a.language, inputKey)
if !cache.IsCacheDisabled() && cache.Exists(cacheKey) {
response, err := cache.Load(cacheKey)
if err != nil {
return "", err
}
if response != "" {
output, err := base64.StdEncoding.DecodeString(response)
if err != nil {
color.Red("error decoding cached data: %v", err)
return "", nil
}
return string(output), nil
}
}
response, err := a.GetCompletion(ctx, inputKey, promptTmpl)
if err != nil {
return "", err
}
err = cache.Store(cacheKey, base64.StdEncoding.EncodeToString([]byte(response)))
if err != nil {
color.Red("error storing value to cache: %v", err)
return "", nil
}
return response, nil
}
// GetName returns the name of the AmazonBedRockClient.
func (a *AmazonBedRockClient) GetName() string {
return "amazonbedrock"