chore: model name (#1535)

* feat: added cache purge

Signed-off-by: AlexsJones <alexsimonjones@gmail.com>

* feat: improved AWS creds errors

Signed-off-by: AlexsJones <alexsimonjones@gmail.com>

* chore: removed model name

Signed-off-by: AlexsJones <alexsimonjones@gmail.com>

* chore: updated tests

Signed-off-by: AlexsJones <alexsimonjones@gmail.com>

---------

Signed-off-by: AlexsJones <alexsimonjones@gmail.com>
This commit is contained in:
Alex Jones
2025-06-20 16:30:50 +01:00
committed by GitHub
parent 74fbde0053
commit 0f700f0cd3
4 changed files with 21 additions and 24 deletions

View File

@@ -380,6 +380,9 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error {
awsconfig.WithRegion(region), awsconfig.WithRegion(region),
) )
if err != nil { if err != nil {
if strings.Contains(err.Error(), "InvalidAccessKeyId") || strings.Contains(err.Error(), "SignatureDoesNotMatch") || strings.Contains(err.Error(), "NoCredentialProviders") {
return fmt.Errorf("AWS credentials are invalid or missing. Please check your AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables or AWS config. Details: %v", err)
}
return fmt.Errorf("failed to load AWS config for region %s: %w", region, err) return fmt.Errorf("failed to load AWS config for region %s: %w", region, err)
} }
@@ -523,6 +526,9 @@ func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string)
// Invoke the model // Invoke the model
resp, err := a.client.InvokeModel(ctx, params) resp, err := a.client.InvokeModel(ctx, params)
if err != nil { if err != nil {
if strings.Contains(err.Error(), "InvalidAccessKeyId") || strings.Contains(err.Error(), "SignatureDoesNotMatch") || strings.Contains(err.Error(), "NoCredentialProviders") {
return "", fmt.Errorf("AWS credentials are invalid or missing. Please check your AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables or AWS config. Details: %v", err)
}
return "", err return "", err
} }

View File

@@ -88,9 +88,8 @@ func IsModelSupported(modelName string, supportedModels []string) bool {
// Note: The caller should check model support before calling GetCompletion. // Note: The caller should check model support before calling GetCompletion.
func (a *AmazonCompletion) GetCompletion(ctx context.Context, prompt string, modelConfig BedrockModelConfig) ([]byte, error) { func (a *AmazonCompletion) GetCompletion(ctx context.Context, prompt string, modelConfig BedrockModelConfig) ([]byte, error) {
// Defensive: if the model is not supported, return an error if a == nil || modelConfig.ModelName == "" {
if a == nil || modelConfig.ModelName == "unsupported-model" { return nil, fmt.Errorf("no model name provided to Bedrock completion")
return nil, fmt.Errorf("model %s is not supported", modelConfig.ModelName)
} }
if strings.Contains(modelConfig.ModelName, "nova") { if strings.Contains(modelConfig.ModelName, "nova") {
return a.GetNovaCompletion(ctx, prompt, modelConfig) return a.GetNovaCompletion(ctx, prompt, modelConfig)
@@ -113,7 +112,6 @@ func (a *AmazonCompletion) GetDefaultCompletion(ctx context.Context, prompt stri
return []byte{}, err return []byte{}, err
} }
return body, nil return body, nil
} }
func (a *AmazonCompletion) GetNovaCompletion(ctx context.Context, prompt string, modelConfig BedrockModelConfig) ([]byte, error) { func (a *AmazonCompletion) GetNovaCompletion(ctx context.Context, prompt string, modelConfig BedrockModelConfig) ([]byte, error) {

View File

@@ -158,21 +158,6 @@ func TestAmazonCompletion_GetCompletion_Default(t *testing.T) {
assert.Equal(t, 0.7, textConfig["topP"]) assert.Equal(t, 0.7, textConfig["topP"])
} }
func TestAmazonCompletion_GetCompletion_UnsupportedModel(t *testing.T) {
completion := &AmazonCompletion{}
modelConfig := BedrockModelConfig{
MaxTokens: 200,
Temperature: 0.5,
TopP: 0.7,
ModelName: "unsupported-model",
}
prompt := "Test prompt"
_, err := completion.GetCompletion(context.Background(), prompt, modelConfig)
assert.Error(t, err)
assert.Contains(t, err.Error(), "model unsupported-model is not supported")
}
func TestAmazonCompletion_GetCompletion_Inference_Profile(t *testing.T) { func TestAmazonCompletion_GetCompletion_Inference_Profile(t *testing.T) {
completion := &AmazonCompletion{} completion := &AmazonCompletion{}
modelConfig := BedrockModelConfig{ modelConfig := BedrockModelConfig{

18
pkg/cache/s3_based.go vendored
View File

@@ -3,8 +3,9 @@ package cache
import ( import (
"bytes" "bytes"
"crypto/tls" "crypto/tls"
"log" "errors"
"net/http" "net/http"
"strings"
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/aws/session"
@@ -27,16 +28,19 @@ type S3CacheConfiguration struct {
func (s *S3Cache) Configure(cacheInfo CacheProvider) error { func (s *S3Cache) Configure(cacheInfo CacheProvider) error {
if cacheInfo.S3.BucketName == "" { if cacheInfo.S3.BucketName == "" {
log.Fatal("Bucket name not configured") return errors.New("bucket name not configured")
} }
s.bucketName = cacheInfo.S3.BucketName s.bucketName = cacheInfo.S3.BucketName
sess := session.Must(session.NewSessionWithOptions(session.Options{ sess, err := session.NewSessionWithOptions(session.Options{
SharedConfigState: session.SharedConfigEnable, SharedConfigState: session.SharedConfigEnable,
Config: aws.Config{ Config: aws.Config{
Region: aws.String(cacheInfo.S3.Region), Region: aws.String(cacheInfo.S3.Region),
}, },
})) })
if err != nil {
return errors.New("failed to create AWS session; please check your AWS credentials and configuration: " + err.Error())
}
if cacheInfo.S3.Endpoint != "" { if cacheInfo.S3.Endpoint != "" {
sess.Config.Endpoint = &cacheInfo.S3.Endpoint sess.Config.Endpoint = &cacheInfo.S3.Endpoint
sess.Config.S3ForcePathStyle = aws.Bool(true) sess.Config.S3ForcePathStyle = aws.Bool(true)
@@ -50,10 +54,14 @@ func (s *S3Cache) Configure(cacheInfo CacheProvider) error {
s3Client := s3.New(sess) s3Client := s3.New(sess)
// Check if the bucket exists, if not create it // Check if the bucket exists, if not create it
_, err := s3Client.HeadBucket(&s3.HeadBucketInput{ _, err = s3Client.HeadBucket(&s3.HeadBucketInput{
Bucket: aws.String(cacheInfo.S3.BucketName), Bucket: aws.String(cacheInfo.S3.BucketName),
}) })
if err != nil { if err != nil {
// Check for AWS credentials error
if strings.Contains(err.Error(), "InvalidAccessKeyId") || strings.Contains(err.Error(), "SignatureDoesNotMatch") || strings.Contains(err.Error(), "NoCredentialProviders") {
return errors.New("aws credentials are invalid or missing; please check your AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables or AWS config")
}
_, err = s3Client.CreateBucket(&s3.CreateBucketInput{ _, err = s3Client.CreateBucket(&s3.CreateBucketInput{
Bucket: aws.String(cacheInfo.S3.BucketName), Bucket: aws.String(cacheInfo.S3.BucketName),
}) })