mirror of
https://github.com/k8sgpt-ai/k8sgpt.git
synced 2025-09-25 15:00:34 +00:00
chore: model name (#1535)
* feat: added cache purge Signed-off-by: AlexsJones <alexsimonjones@gmail.com> * feat: improved AWS creds errors Signed-off-by: AlexsJones <alexsimonjones@gmail.com> * chore: removed model name Signed-off-by: AlexsJones <alexsimonjones@gmail.com> * chore: updated tests Signed-off-by: AlexsJones <alexsimonjones@gmail.com> --------- Signed-off-by: AlexsJones <alexsimonjones@gmail.com>
This commit is contained in:
@@ -380,6 +380,9 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error {
|
|||||||
awsconfig.WithRegion(region),
|
awsconfig.WithRegion(region),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if strings.Contains(err.Error(), "InvalidAccessKeyId") || strings.Contains(err.Error(), "SignatureDoesNotMatch") || strings.Contains(err.Error(), "NoCredentialProviders") {
|
||||||
|
return fmt.Errorf("AWS credentials are invalid or missing. Please check your AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables or AWS config. Details: %v", err)
|
||||||
|
}
|
||||||
return fmt.Errorf("failed to load AWS config for region %s: %w", region, err)
|
return fmt.Errorf("failed to load AWS config for region %s: %w", region, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -523,6 +526,9 @@ func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string)
|
|||||||
// Invoke the model
|
// Invoke the model
|
||||||
resp, err := a.client.InvokeModel(ctx, params)
|
resp, err := a.client.InvokeModel(ctx, params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if strings.Contains(err.Error(), "InvalidAccessKeyId") || strings.Contains(err.Error(), "SignatureDoesNotMatch") || strings.Contains(err.Error(), "NoCredentialProviders") {
|
||||||
|
return "", fmt.Errorf("AWS credentials are invalid or missing. Please check your AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables or AWS config. Details: %v", err)
|
||||||
|
}
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -88,9 +88,8 @@ func IsModelSupported(modelName string, supportedModels []string) bool {
|
|||||||
|
|
||||||
// Note: The caller should check model support before calling GetCompletion.
|
// Note: The caller should check model support before calling GetCompletion.
|
||||||
func (a *AmazonCompletion) GetCompletion(ctx context.Context, prompt string, modelConfig BedrockModelConfig) ([]byte, error) {
|
func (a *AmazonCompletion) GetCompletion(ctx context.Context, prompt string, modelConfig BedrockModelConfig) ([]byte, error) {
|
||||||
// Defensive: if the model is not supported, return an error
|
if a == nil || modelConfig.ModelName == "" {
|
||||||
if a == nil || modelConfig.ModelName == "unsupported-model" {
|
return nil, fmt.Errorf("no model name provided to Bedrock completion")
|
||||||
return nil, fmt.Errorf("model %s is not supported", modelConfig.ModelName)
|
|
||||||
}
|
}
|
||||||
if strings.Contains(modelConfig.ModelName, "nova") {
|
if strings.Contains(modelConfig.ModelName, "nova") {
|
||||||
return a.GetNovaCompletion(ctx, prompt, modelConfig)
|
return a.GetNovaCompletion(ctx, prompt, modelConfig)
|
||||||
@@ -113,7 +112,6 @@ func (a *AmazonCompletion) GetDefaultCompletion(ctx context.Context, prompt stri
|
|||||||
return []byte{}, err
|
return []byte{}, err
|
||||||
}
|
}
|
||||||
return body, nil
|
return body, nil
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *AmazonCompletion) GetNovaCompletion(ctx context.Context, prompt string, modelConfig BedrockModelConfig) ([]byte, error) {
|
func (a *AmazonCompletion) GetNovaCompletion(ctx context.Context, prompt string, modelConfig BedrockModelConfig) ([]byte, error) {
|
||||||
|
@@ -158,21 +158,6 @@ func TestAmazonCompletion_GetCompletion_Default(t *testing.T) {
|
|||||||
assert.Equal(t, 0.7, textConfig["topP"])
|
assert.Equal(t, 0.7, textConfig["topP"])
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAmazonCompletion_GetCompletion_UnsupportedModel(t *testing.T) {
|
|
||||||
completion := &AmazonCompletion{}
|
|
||||||
modelConfig := BedrockModelConfig{
|
|
||||||
MaxTokens: 200,
|
|
||||||
Temperature: 0.5,
|
|
||||||
TopP: 0.7,
|
|
||||||
ModelName: "unsupported-model",
|
|
||||||
}
|
|
||||||
prompt := "Test prompt"
|
|
||||||
|
|
||||||
_, err := completion.GetCompletion(context.Background(), prompt, modelConfig)
|
|
||||||
assert.Error(t, err)
|
|
||||||
assert.Contains(t, err.Error(), "model unsupported-model is not supported")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAmazonCompletion_GetCompletion_Inference_Profile(t *testing.T) {
|
func TestAmazonCompletion_GetCompletion_Inference_Profile(t *testing.T) {
|
||||||
completion := &AmazonCompletion{}
|
completion := &AmazonCompletion{}
|
||||||
modelConfig := BedrockModelConfig{
|
modelConfig := BedrockModelConfig{
|
||||||
|
18
pkg/cache/s3_based.go
vendored
18
pkg/cache/s3_based.go
vendored
@@ -3,8 +3,9 @@ package cache
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"log"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/aws/aws-sdk-go/aws"
|
"github.com/aws/aws-sdk-go/aws"
|
||||||
"github.com/aws/aws-sdk-go/aws/session"
|
"github.com/aws/aws-sdk-go/aws/session"
|
||||||
@@ -27,16 +28,19 @@ type S3CacheConfiguration struct {
|
|||||||
|
|
||||||
func (s *S3Cache) Configure(cacheInfo CacheProvider) error {
|
func (s *S3Cache) Configure(cacheInfo CacheProvider) error {
|
||||||
if cacheInfo.S3.BucketName == "" {
|
if cacheInfo.S3.BucketName == "" {
|
||||||
log.Fatal("Bucket name not configured")
|
return errors.New("bucket name not configured")
|
||||||
}
|
}
|
||||||
s.bucketName = cacheInfo.S3.BucketName
|
s.bucketName = cacheInfo.S3.BucketName
|
||||||
|
|
||||||
sess := session.Must(session.NewSessionWithOptions(session.Options{
|
sess, err := session.NewSessionWithOptions(session.Options{
|
||||||
SharedConfigState: session.SharedConfigEnable,
|
SharedConfigState: session.SharedConfigEnable,
|
||||||
Config: aws.Config{
|
Config: aws.Config{
|
||||||
Region: aws.String(cacheInfo.S3.Region),
|
Region: aws.String(cacheInfo.S3.Region),
|
||||||
},
|
},
|
||||||
}))
|
})
|
||||||
|
if err != nil {
|
||||||
|
return errors.New("failed to create AWS session; please check your AWS credentials and configuration: " + err.Error())
|
||||||
|
}
|
||||||
if cacheInfo.S3.Endpoint != "" {
|
if cacheInfo.S3.Endpoint != "" {
|
||||||
sess.Config.Endpoint = &cacheInfo.S3.Endpoint
|
sess.Config.Endpoint = &cacheInfo.S3.Endpoint
|
||||||
sess.Config.S3ForcePathStyle = aws.Bool(true)
|
sess.Config.S3ForcePathStyle = aws.Bool(true)
|
||||||
@@ -50,10 +54,14 @@ func (s *S3Cache) Configure(cacheInfo CacheProvider) error {
|
|||||||
s3Client := s3.New(sess)
|
s3Client := s3.New(sess)
|
||||||
|
|
||||||
// Check if the bucket exists, if not create it
|
// Check if the bucket exists, if not create it
|
||||||
_, err := s3Client.HeadBucket(&s3.HeadBucketInput{
|
_, err = s3Client.HeadBucket(&s3.HeadBucketInput{
|
||||||
Bucket: aws.String(cacheInfo.S3.BucketName),
|
Bucket: aws.String(cacheInfo.S3.BucketName),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// Check for AWS credentials error
|
||||||
|
if strings.Contains(err.Error(), "InvalidAccessKeyId") || strings.Contains(err.Error(), "SignatureDoesNotMatch") || strings.Contains(err.Error(), "NoCredentialProviders") {
|
||||||
|
return errors.New("aws credentials are invalid or missing; please check your AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables or AWS config")
|
||||||
|
}
|
||||||
_, err = s3Client.CreateBucket(&s3.CreateBucketInput{
|
_, err = s3Client.CreateBucket(&s3.CreateBucketInput{
|
||||||
Bucket: aws.String(cacheInfo.S3.BucketName),
|
Bucket: aws.String(cacheInfo.S3.BucketName),
|
||||||
})
|
})
|
||||||
|
Reference in New Issue
Block a user