From f603948935f1c4cb171378634714577205de7b08 Mon Sep 17 00:00:00 2001 From: ju187 Date: Thu, 24 Apr 2025 01:15:17 -0700 Subject: [PATCH] feat: using modelName will calling completion (#1469) * using modelName will calling completion Signed-off-by: Tony Chen * sign Signed-off-by: Tony Chen --------- Signed-off-by: Tony Chen --- pkg/ai/amazonbedrock.go | 4 ++-- pkg/ai/amazonbedrock_test.go | 25 ++++++++++++++++++++++ pkg/ai/bedrock_support/completions_test.go | 14 ++++++++++++ 3 files changed, 41 insertions(+), 2 deletions(-) diff --git a/pkg/ai/amazonbedrock.go b/pkg/ai/amazonbedrock.go index 81a070a..ecce40e 100644 --- a/pkg/ai/amazonbedrock.go +++ b/pkg/ai/amazonbedrock.go @@ -337,7 +337,7 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error { // Create a new BedrockRuntime client a.client = bedrockruntime.New(sess) a.model = foundModel - a.model.Config.ModelName = foundModel.Name + a.model.Config.ModelName = foundModel.Config.ModelName a.temperature = config.GetTemperature() a.topP = config.GetTopP() a.maxTokens = config.GetMaxTokens() @@ -360,7 +360,7 @@ func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string) // Build the parameters for the model invocation params := &bedrockruntime.InvokeModelInput{ Body: body, - ModelId: aws.String(a.model.Name), + ModelId: aws.String(a.model.Config.ModelName), ContentType: aws.String("application/json"), Accept: aws.String("application/json"), } diff --git a/pkg/ai/amazonbedrock_test.go b/pkg/ai/amazonbedrock_test.go index f933779..d29cbc6 100644 --- a/pkg/ai/amazonbedrock_test.go +++ b/pkg/ai/amazonbedrock_test.go @@ -52,6 +52,31 @@ func TestBedrockInvalidModel(t *testing.T) { assert.Equal(t, foundModel.Config.MaxTokens, 100) } +func TestBedrockGetCompletionInferenceProfile(t *testing.T) { + modelName := "arn:aws:bedrock:us-east-1:*:inference-policy/anthropic.claude-3-5-sonnet-20240620-v1:0" + var inferenceModelModels = []bedrock_support.BedrockModel{ + { + Name: "anthropic.claude-3-5-sonnet-20240620-v1:0", + Completion: &bedrock_support.CohereMessagesCompletion{}, + Response: &bedrock_support.CohereMessagesResponse{}, + Config: bedrock_support.BedrockModelConfig{ + MaxTokens: 100, + Temperature: 0.5, + TopP: 0.9, + ModelName: modelName, + }, + }, + } + client := &AmazonBedRockClient{models: inferenceModelModels} + + config := AIProvider{ + Model: modelName, + } + err := client.Configure(&config) + assert.Nil(t, err, "Error should be nil") + assert.Equal(t, modelName, client.model.Config.ModelName, "Model name should match") +} + func TestGetModelFromString(t *testing.T) { client := &AmazonBedRockClient{models: testModels} diff --git a/pkg/ai/bedrock_support/completions_test.go b/pkg/ai/bedrock_support/completions_test.go index d2a56eb..1b052fb 100644 --- a/pkg/ai/bedrock_support/completions_test.go +++ b/pkg/ai/bedrock_support/completions_test.go @@ -173,6 +173,20 @@ func TestAmazonCompletion_GetCompletion_UnsupportedModel(t *testing.T) { assert.Contains(t, err.Error(), "model unsupported-model is not supported") } +func TestAmazonCompletion_GetCompletion_Inference_Profile(t *testing.T) { + completion := &AmazonCompletion{} + modelConfig := BedrockModelConfig{ + MaxTokens: 200, + Temperature: 0.5, + TopP: 0.7, + ModelName: "arn:aws:bedrock:us-east-1:*:inference-policy/anthropic.claude-3-5-sonnet-20240620-v1:0", + } + prompt := "Test prompt" + + _, err := completion.GetCompletion(context.Background(), prompt, modelConfig) + assert.NoError(t, err) +} + func Test_isModelSupported(t *testing.T) { assert.True(t, isModelSupported("anthropic.claude-v2")) assert.False(t, isModelSupported("unsupported-model"))