From 0f700f0cd39bf5881d6c05240b842f4df7a6c016 Mon Sep 17 00:00:00 2001 From: Alex Jones Date: Fri, 20 Jun 2025 16:30:50 +0100 Subject: [PATCH] chore: model name (#1535) * feat: added cache purge Signed-off-by: AlexsJones * feat: improved AWS creds errors Signed-off-by: AlexsJones * chore: removed model name Signed-off-by: AlexsJones * chore: updated tests Signed-off-by: AlexsJones --------- Signed-off-by: AlexsJones --- pkg/ai/amazonbedrock.go | 6 ++++++ pkg/ai/bedrock_support/completions.go | 6 ++---- pkg/ai/bedrock_support/completions_test.go | 15 --------------- pkg/cache/s3_based.go | 18 +++++++++++++----- 4 files changed, 21 insertions(+), 24 deletions(-) diff --git a/pkg/ai/amazonbedrock.go b/pkg/ai/amazonbedrock.go index 6b7a11c..a0cb729 100644 --- a/pkg/ai/amazonbedrock.go +++ b/pkg/ai/amazonbedrock.go @@ -380,6 +380,9 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error { awsconfig.WithRegion(region), ) if err != nil { + if strings.Contains(err.Error(), "InvalidAccessKeyId") || strings.Contains(err.Error(), "SignatureDoesNotMatch") || strings.Contains(err.Error(), "NoCredentialProviders") { + return fmt.Errorf("AWS credentials are invalid or missing. Please check your AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables or AWS config. Details: %v", err) + } return fmt.Errorf("failed to load AWS config for region %s: %w", region, err) } @@ -523,6 +526,9 @@ func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string) // Invoke the model resp, err := a.client.InvokeModel(ctx, params) if err != nil { + if strings.Contains(err.Error(), "InvalidAccessKeyId") || strings.Contains(err.Error(), "SignatureDoesNotMatch") || strings.Contains(err.Error(), "NoCredentialProviders") { + return "", fmt.Errorf("AWS credentials are invalid or missing. Please check your AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables or AWS config. Details: %v", err) + } return "", err } diff --git a/pkg/ai/bedrock_support/completions.go b/pkg/ai/bedrock_support/completions.go index 4692ffe..92c5079 100644 --- a/pkg/ai/bedrock_support/completions.go +++ b/pkg/ai/bedrock_support/completions.go @@ -88,9 +88,8 @@ func IsModelSupported(modelName string, supportedModels []string) bool { // Note: The caller should check model support before calling GetCompletion. func (a *AmazonCompletion) GetCompletion(ctx context.Context, prompt string, modelConfig BedrockModelConfig) ([]byte, error) { - // Defensive: if the model is not supported, return an error - if a == nil || modelConfig.ModelName == "unsupported-model" { - return nil, fmt.Errorf("model %s is not supported", modelConfig.ModelName) + if a == nil || modelConfig.ModelName == "" { + return nil, fmt.Errorf("no model name provided to Bedrock completion") } if strings.Contains(modelConfig.ModelName, "nova") { return a.GetNovaCompletion(ctx, prompt, modelConfig) @@ -113,7 +112,6 @@ func (a *AmazonCompletion) GetDefaultCompletion(ctx context.Context, prompt stri return []byte{}, err } return body, nil - } func (a *AmazonCompletion) GetNovaCompletion(ctx context.Context, prompt string, modelConfig BedrockModelConfig) ([]byte, error) { diff --git a/pkg/ai/bedrock_support/completions_test.go b/pkg/ai/bedrock_support/completions_test.go index 0b03469..165040d 100644 --- a/pkg/ai/bedrock_support/completions_test.go +++ b/pkg/ai/bedrock_support/completions_test.go @@ -158,21 +158,6 @@ func TestAmazonCompletion_GetCompletion_Default(t *testing.T) { assert.Equal(t, 0.7, textConfig["topP"]) } -func TestAmazonCompletion_GetCompletion_UnsupportedModel(t *testing.T) { - completion := &AmazonCompletion{} - modelConfig := BedrockModelConfig{ - MaxTokens: 200, - Temperature: 0.5, - TopP: 0.7, - ModelName: "unsupported-model", - } - prompt := "Test prompt" - - _, err := completion.GetCompletion(context.Background(), prompt, modelConfig) - assert.Error(t, err) - assert.Contains(t, err.Error(), "model unsupported-model is not supported") -} - func TestAmazonCompletion_GetCompletion_Inference_Profile(t *testing.T) { completion := &AmazonCompletion{} modelConfig := BedrockModelConfig{ diff --git a/pkg/cache/s3_based.go b/pkg/cache/s3_based.go index f3a8b56..3932389 100644 --- a/pkg/cache/s3_based.go +++ b/pkg/cache/s3_based.go @@ -3,8 +3,9 @@ package cache import ( "bytes" "crypto/tls" - "log" + "errors" "net/http" + "strings" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/session" @@ -27,16 +28,19 @@ type S3CacheConfiguration struct { func (s *S3Cache) Configure(cacheInfo CacheProvider) error { if cacheInfo.S3.BucketName == "" { - log.Fatal("Bucket name not configured") + return errors.New("bucket name not configured") } s.bucketName = cacheInfo.S3.BucketName - sess := session.Must(session.NewSessionWithOptions(session.Options{ + sess, err := session.NewSessionWithOptions(session.Options{ SharedConfigState: session.SharedConfigEnable, Config: aws.Config{ Region: aws.String(cacheInfo.S3.Region), }, - })) + }) + if err != nil { + return errors.New("failed to create AWS session; please check your AWS credentials and configuration: " + err.Error()) + } if cacheInfo.S3.Endpoint != "" { sess.Config.Endpoint = &cacheInfo.S3.Endpoint sess.Config.S3ForcePathStyle = aws.Bool(true) @@ -50,10 +54,14 @@ func (s *S3Cache) Configure(cacheInfo CacheProvider) error { s3Client := s3.New(sess) // Check if the bucket exists, if not create it - _, err := s3Client.HeadBucket(&s3.HeadBucketInput{ + _, err = s3Client.HeadBucket(&s3.HeadBucketInput{ Bucket: aws.String(cacheInfo.S3.BucketName), }) if err != nil { + // Check for AWS credentials error + if strings.Contains(err.Error(), "InvalidAccessKeyId") || strings.Contains(err.Error(), "SignatureDoesNotMatch") || strings.Contains(err.Error(), "NoCredentialProviders") { + return errors.New("aws credentials are invalid or missing; please check your AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables or AWS config") + } _, err = s3Client.CreateBucket(&s3.CreateBucketInput{ Bucket: aws.String(cacheInfo.S3.BucketName), })