feat: call bedrock with inference profile (#1449)

* call bedrock with inference profile

Signed-off-by: Tony Chen <tony_chen@discovery.com>

* add validation and test

Signed-off-by: Tony Chen <tony_chen@discovery.com>

* update test

Signed-off-by: Tony Chen <tony_chen@discovery.com>

---------

Signed-off-by: Tony Chen <tony_chen@discovery.com>
This commit is contained in:
ju187 2025-04-10 23:11:38 -07:00 committed by GitHub
parent 766b51cd3e
commit 91d423b147
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 20 additions and 1 deletions

View File

@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"os"
"regexp"
"strings"
"github.com/aws/aws-sdk-go/service/bedrockruntime/bedrockruntimeiface"
@ -293,6 +294,11 @@ func (a *AmazonBedRockClient) getModelFromString(model string) (*bedrock_support
strings.Contains(modelConfigNameLower, modelLower) || strings.Contains(modelLower, modelConfigNameLower) {
// Create a copy to avoid returning a pointer to a loop variable
modelCopy := a.models[i]
// for partial match, set the model name to the input string if it is a valid ARN
if validateModelArn(modelLower) {
modelCopy.Config.ModelName = modelLower
}
return &modelCopy, nil
}
}
@ -300,6 +306,11 @@ func (a *AmazonBedRockClient) getModelFromString(model string) (*bedrock_support
return nil, fmt.Errorf("model '%s' not found in supported models", model)
}
func validateModelArn(model string) bool {
var re = regexp.MustCompile(`(?m)^arn:(?P<Partition>[^:\n]*):bedrock:(?P<Region>[^:\n]*):(?P<AccountID>[^:\n]*):(?P<Ignore>(?P<ResourceType>[^:\/\n]*)[:\/])?(?P<Resource>.*)$`)
return re.MatchString(model)
}
// Configure configures the AmazonBedRockClient with the provided configuration.
func (a *AmazonBedRockClient) Configure(config IAIConfig) error {
// Initialize models if not already initialized

View File

@ -41,7 +41,15 @@ func TestBedrockModelConfig(t *testing.T) {
assert.Equal(t, foundModel.Config.MaxTokens, 100)
assert.Equal(t, foundModel.Config.Temperature, float32(0.5))
assert.Equal(t, foundModel.Config.TopP, float32(0.9))
assert.Equal(t, foundModel.Config.ModelName, "anthropic.claude-3-5-sonnet-20240620-v1:0")
assert.Equal(t, foundModel.Config.ModelName, "arn:aws:bedrock:us-east-1:*:inference-policy/anthropic.claude-3-5-sonnet-20240620-v1:0")
}
func TestBedrockInvalidModel(t *testing.T) {
client := &AmazonBedRockClient{models: testModels}
foundModel, err := client.getModelFromString("arn:aws:s3:us-east-1:*:inference-policy/anthropic.claude-3-5-sonnet-20240620-v1:0")
assert.Nil(t, err, "Error should be nil")
assert.Equal(t, foundModel.Config.MaxTokens, 100)
}
func TestGetModelFromString(t *testing.T) {