mirror of
https://github.com/k8sgpt-ai/k8sgpt.git
synced 2025-04-27 11:11:31 +00:00
feat: call bedrock with inference profile (#1449)
* call bedrock with inference profile Signed-off-by: Tony Chen <tony_chen@discovery.com> * add validation and test Signed-off-by: Tony Chen <tony_chen@discovery.com> * update test Signed-off-by: Tony Chen <tony_chen@discovery.com> --------- Signed-off-by: Tony Chen <tony_chen@discovery.com>
This commit is contained in:
parent
766b51cd3e
commit
91d423b147
@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/aws/aws-sdk-go/service/bedrockruntime/bedrockruntimeiface"
|
||||
@ -293,6 +294,11 @@ func (a *AmazonBedRockClient) getModelFromString(model string) (*bedrock_support
|
||||
strings.Contains(modelConfigNameLower, modelLower) || strings.Contains(modelLower, modelConfigNameLower) {
|
||||
// Create a copy to avoid returning a pointer to a loop variable
|
||||
modelCopy := a.models[i]
|
||||
// for partial match, set the model name to the input string if it is a valid ARN
|
||||
if validateModelArn(modelLower) {
|
||||
modelCopy.Config.ModelName = modelLower
|
||||
}
|
||||
|
||||
return &modelCopy, nil
|
||||
}
|
||||
}
|
||||
@ -300,6 +306,11 @@ func (a *AmazonBedRockClient) getModelFromString(model string) (*bedrock_support
|
||||
return nil, fmt.Errorf("model '%s' not found in supported models", model)
|
||||
}
|
||||
|
||||
func validateModelArn(model string) bool {
|
||||
var re = regexp.MustCompile(`(?m)^arn:(?P<Partition>[^:\n]*):bedrock:(?P<Region>[^:\n]*):(?P<AccountID>[^:\n]*):(?P<Ignore>(?P<ResourceType>[^:\/\n]*)[:\/])?(?P<Resource>.*)$`)
|
||||
return re.MatchString(model)
|
||||
}
|
||||
|
||||
// Configure configures the AmazonBedRockClient with the provided configuration.
|
||||
func (a *AmazonBedRockClient) Configure(config IAIConfig) error {
|
||||
// Initialize models if not already initialized
|
||||
|
@ -41,7 +41,15 @@ func TestBedrockModelConfig(t *testing.T) {
|
||||
assert.Equal(t, foundModel.Config.MaxTokens, 100)
|
||||
assert.Equal(t, foundModel.Config.Temperature, float32(0.5))
|
||||
assert.Equal(t, foundModel.Config.TopP, float32(0.9))
|
||||
assert.Equal(t, foundModel.Config.ModelName, "anthropic.claude-3-5-sonnet-20240620-v1:0")
|
||||
assert.Equal(t, foundModel.Config.ModelName, "arn:aws:bedrock:us-east-1:*:inference-policy/anthropic.claude-3-5-sonnet-20240620-v1:0")
|
||||
}
|
||||
|
||||
func TestBedrockInvalidModel(t *testing.T) {
|
||||
client := &AmazonBedRockClient{models: testModels}
|
||||
|
||||
foundModel, err := client.getModelFromString("arn:aws:s3:us-east-1:*:inference-policy/anthropic.claude-3-5-sonnet-20240620-v1:0")
|
||||
assert.Nil(t, err, "Error should be nil")
|
||||
assert.Equal(t, foundModel.Config.MaxTokens, 100)
|
||||
}
|
||||
|
||||
func TestGetModelFromString(t *testing.T) {
|
||||
|
Loading…
Reference in New Issue
Block a user