mirror of
https://github.com/k8sgpt-ai/k8sgpt.git
synced 2025-09-01 07:19:19 +00:00
feat: call bedrock with inference profile (#1449)
* call bedrock with inference profile Signed-off-by: Tony Chen <tony_chen@discovery.com> * add validation and test Signed-off-by: Tony Chen <tony_chen@discovery.com> * update test Signed-off-by: Tony Chen <tony_chen@discovery.com> --------- Signed-off-by: Tony Chen <tony_chen@discovery.com>
This commit is contained in:
@@ -5,6 +5,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/aws/aws-sdk-go/service/bedrockruntime/bedrockruntimeiface"
|
"github.com/aws/aws-sdk-go/service/bedrockruntime/bedrockruntimeiface"
|
||||||
@@ -293,6 +294,11 @@ func (a *AmazonBedRockClient) getModelFromString(model string) (*bedrock_support
|
|||||||
strings.Contains(modelConfigNameLower, modelLower) || strings.Contains(modelLower, modelConfigNameLower) {
|
strings.Contains(modelConfigNameLower, modelLower) || strings.Contains(modelLower, modelConfigNameLower) {
|
||||||
// Create a copy to avoid returning a pointer to a loop variable
|
// Create a copy to avoid returning a pointer to a loop variable
|
||||||
modelCopy := a.models[i]
|
modelCopy := a.models[i]
|
||||||
|
// for partial match, set the model name to the input string if it is a valid ARN
|
||||||
|
if validateModelArn(modelLower) {
|
||||||
|
modelCopy.Config.ModelName = modelLower
|
||||||
|
}
|
||||||
|
|
||||||
return &modelCopy, nil
|
return &modelCopy, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -300,6 +306,11 @@ func (a *AmazonBedRockClient) getModelFromString(model string) (*bedrock_support
|
|||||||
return nil, fmt.Errorf("model '%s' not found in supported models", model)
|
return nil, fmt.Errorf("model '%s' not found in supported models", model)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func validateModelArn(model string) bool {
|
||||||
|
var re = regexp.MustCompile(`(?m)^arn:(?P<Partition>[^:\n]*):bedrock:(?P<Region>[^:\n]*):(?P<AccountID>[^:\n]*):(?P<Ignore>(?P<ResourceType>[^:\/\n]*)[:\/])?(?P<Resource>.*)$`)
|
||||||
|
return re.MatchString(model)
|
||||||
|
}
|
||||||
|
|
||||||
// Configure configures the AmazonBedRockClient with the provided configuration.
|
// Configure configures the AmazonBedRockClient with the provided configuration.
|
||||||
func (a *AmazonBedRockClient) Configure(config IAIConfig) error {
|
func (a *AmazonBedRockClient) Configure(config IAIConfig) error {
|
||||||
// Initialize models if not already initialized
|
// Initialize models if not already initialized
|
||||||
|
@@ -41,7 +41,15 @@ func TestBedrockModelConfig(t *testing.T) {
|
|||||||
assert.Equal(t, foundModel.Config.MaxTokens, 100)
|
assert.Equal(t, foundModel.Config.MaxTokens, 100)
|
||||||
assert.Equal(t, foundModel.Config.Temperature, float32(0.5))
|
assert.Equal(t, foundModel.Config.Temperature, float32(0.5))
|
||||||
assert.Equal(t, foundModel.Config.TopP, float32(0.9))
|
assert.Equal(t, foundModel.Config.TopP, float32(0.9))
|
||||||
assert.Equal(t, foundModel.Config.ModelName, "anthropic.claude-3-5-sonnet-20240620-v1:0")
|
assert.Equal(t, foundModel.Config.ModelName, "arn:aws:bedrock:us-east-1:*:inference-policy/anthropic.claude-3-5-sonnet-20240620-v1:0")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBedrockInvalidModel(t *testing.T) {
|
||||||
|
client := &AmazonBedRockClient{models: testModels}
|
||||||
|
|
||||||
|
foundModel, err := client.getModelFromString("arn:aws:s3:us-east-1:*:inference-policy/anthropic.claude-3-5-sonnet-20240620-v1:0")
|
||||||
|
assert.Nil(t, err, "Error should be nil")
|
||||||
|
assert.Equal(t, foundModel.Config.MaxTokens, 100)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetModelFromString(t *testing.T) {
|
func TestGetModelFromString(t *testing.T) {
|
||||||
|
Reference in New Issue
Block a user