feat: using modelName will calling completion (#1469)

* using modelName will calling completion

Signed-off-by: Tony Chen <tony_chen@discovery.com>

* sign

Signed-off-by: Tony Chen <tony_chen@discovery.com>

---------

Signed-off-by: Tony Chen <tony_chen@discovery.com>
This commit is contained in:
ju187 2025-04-24 01:15:17 -07:00 committed by GitHub
parent 67f5855695
commit f603948935
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 41 additions and 2 deletions

View File

@ -337,7 +337,7 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error {
// Create a new BedrockRuntime client
a.client = bedrockruntime.New(sess)
a.model = foundModel
a.model.Config.ModelName = foundModel.Name
a.model.Config.ModelName = foundModel.Config.ModelName
a.temperature = config.GetTemperature()
a.topP = config.GetTopP()
a.maxTokens = config.GetMaxTokens()
@ -360,7 +360,7 @@ func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string)
// Build the parameters for the model invocation
params := &bedrockruntime.InvokeModelInput{
Body: body,
ModelId: aws.String(a.model.Name),
ModelId: aws.String(a.model.Config.ModelName),
ContentType: aws.String("application/json"),
Accept: aws.String("application/json"),
}

View File

@ -52,6 +52,31 @@ func TestBedrockInvalidModel(t *testing.T) {
assert.Equal(t, foundModel.Config.MaxTokens, 100)
}
func TestBedrockGetCompletionInferenceProfile(t *testing.T) {
modelName := "arn:aws:bedrock:us-east-1:*:inference-policy/anthropic.claude-3-5-sonnet-20240620-v1:0"
var inferenceModelModels = []bedrock_support.BedrockModel{
{
Name: "anthropic.claude-3-5-sonnet-20240620-v1:0",
Completion: &bedrock_support.CohereMessagesCompletion{},
Response: &bedrock_support.CohereMessagesResponse{},
Config: bedrock_support.BedrockModelConfig{
MaxTokens: 100,
Temperature: 0.5,
TopP: 0.9,
ModelName: modelName,
},
},
}
client := &AmazonBedRockClient{models: inferenceModelModels}
config := AIProvider{
Model: modelName,
}
err := client.Configure(&config)
assert.Nil(t, err, "Error should be nil")
assert.Equal(t, modelName, client.model.Config.ModelName, "Model name should match")
}
func TestGetModelFromString(t *testing.T) {
client := &AmazonBedRockClient{models: testModels}

View File

@ -173,6 +173,20 @@ func TestAmazonCompletion_GetCompletion_UnsupportedModel(t *testing.T) {
assert.Contains(t, err.Error(), "model unsupported-model is not supported")
}
func TestAmazonCompletion_GetCompletion_Inference_Profile(t *testing.T) {
completion := &AmazonCompletion{}
modelConfig := BedrockModelConfig{
MaxTokens: 200,
Temperature: 0.5,
TopP: 0.7,
ModelName: "arn:aws:bedrock:us-east-1:*:inference-policy/anthropic.claude-3-5-sonnet-20240620-v1:0",
}
prompt := "Test prompt"
_, err := completion.GetCompletion(context.Background(), prompt, modelConfig)
assert.NoError(t, err)
}
func Test_isModelSupported(t *testing.T) {
assert.True(t, isModelSupported("anthropic.claude-v2"))
assert.False(t, isModelSupported("unsupported-model"))