mirror of
https://github.com/k8sgpt-ai/k8sgpt.git
synced 2025-05-07 07:36:46 +00:00
feat: using modelName will calling completion (#1469)
* using modelName will calling completion Signed-off-by: Tony Chen <tony_chen@discovery.com> * sign Signed-off-by: Tony Chen <tony_chen@discovery.com> --------- Signed-off-by: Tony Chen <tony_chen@discovery.com>
This commit is contained in:
parent
67f5855695
commit
f603948935
pkg/ai
@ -337,7 +337,7 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error {
|
|||||||
// Create a new BedrockRuntime client
|
// Create a new BedrockRuntime client
|
||||||
a.client = bedrockruntime.New(sess)
|
a.client = bedrockruntime.New(sess)
|
||||||
a.model = foundModel
|
a.model = foundModel
|
||||||
a.model.Config.ModelName = foundModel.Name
|
a.model.Config.ModelName = foundModel.Config.ModelName
|
||||||
a.temperature = config.GetTemperature()
|
a.temperature = config.GetTemperature()
|
||||||
a.topP = config.GetTopP()
|
a.topP = config.GetTopP()
|
||||||
a.maxTokens = config.GetMaxTokens()
|
a.maxTokens = config.GetMaxTokens()
|
||||||
@ -360,7 +360,7 @@ func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string)
|
|||||||
// Build the parameters for the model invocation
|
// Build the parameters for the model invocation
|
||||||
params := &bedrockruntime.InvokeModelInput{
|
params := &bedrockruntime.InvokeModelInput{
|
||||||
Body: body,
|
Body: body,
|
||||||
ModelId: aws.String(a.model.Name),
|
ModelId: aws.String(a.model.Config.ModelName),
|
||||||
ContentType: aws.String("application/json"),
|
ContentType: aws.String("application/json"),
|
||||||
Accept: aws.String("application/json"),
|
Accept: aws.String("application/json"),
|
||||||
}
|
}
|
||||||
|
@ -52,6 +52,31 @@ func TestBedrockInvalidModel(t *testing.T) {
|
|||||||
assert.Equal(t, foundModel.Config.MaxTokens, 100)
|
assert.Equal(t, foundModel.Config.MaxTokens, 100)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBedrockGetCompletionInferenceProfile(t *testing.T) {
|
||||||
|
modelName := "arn:aws:bedrock:us-east-1:*:inference-policy/anthropic.claude-3-5-sonnet-20240620-v1:0"
|
||||||
|
var inferenceModelModels = []bedrock_support.BedrockModel{
|
||||||
|
{
|
||||||
|
Name: "anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||||
|
Completion: &bedrock_support.CohereMessagesCompletion{},
|
||||||
|
Response: &bedrock_support.CohereMessagesResponse{},
|
||||||
|
Config: bedrock_support.BedrockModelConfig{
|
||||||
|
MaxTokens: 100,
|
||||||
|
Temperature: 0.5,
|
||||||
|
TopP: 0.9,
|
||||||
|
ModelName: modelName,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
client := &AmazonBedRockClient{models: inferenceModelModels}
|
||||||
|
|
||||||
|
config := AIProvider{
|
||||||
|
Model: modelName,
|
||||||
|
}
|
||||||
|
err := client.Configure(&config)
|
||||||
|
assert.Nil(t, err, "Error should be nil")
|
||||||
|
assert.Equal(t, modelName, client.model.Config.ModelName, "Model name should match")
|
||||||
|
}
|
||||||
|
|
||||||
func TestGetModelFromString(t *testing.T) {
|
func TestGetModelFromString(t *testing.T) {
|
||||||
client := &AmazonBedRockClient{models: testModels}
|
client := &AmazonBedRockClient{models: testModels}
|
||||||
|
|
||||||
|
@ -173,6 +173,20 @@ func TestAmazonCompletion_GetCompletion_UnsupportedModel(t *testing.T) {
|
|||||||
assert.Contains(t, err.Error(), "model unsupported-model is not supported")
|
assert.Contains(t, err.Error(), "model unsupported-model is not supported")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAmazonCompletion_GetCompletion_Inference_Profile(t *testing.T) {
|
||||||
|
completion := &AmazonCompletion{}
|
||||||
|
modelConfig := BedrockModelConfig{
|
||||||
|
MaxTokens: 200,
|
||||||
|
Temperature: 0.5,
|
||||||
|
TopP: 0.7,
|
||||||
|
ModelName: "arn:aws:bedrock:us-east-1:*:inference-policy/anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||||
|
}
|
||||||
|
prompt := "Test prompt"
|
||||||
|
|
||||||
|
_, err := completion.GetCompletion(context.Background(), prompt, modelConfig)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
func Test_isModelSupported(t *testing.T) {
|
func Test_isModelSupported(t *testing.T) {
|
||||||
assert.True(t, isModelSupported("anthropic.claude-v2"))
|
assert.True(t, isModelSupported("anthropic.claude-v2"))
|
||||||
assert.False(t, isModelSupported("unsupported-model"))
|
assert.False(t, isModelSupported("unsupported-model"))
|
||||||
|
Loading…
Reference in New Issue
Block a user