diff --git a/pkg/ai/amazonbedrock.go b/pkg/ai/amazonbedrock.go index 21c12666..f7685e1f 100644 --- a/pkg/ai/amazonbedrock.go +++ b/pkg/ai/amazonbedrock.go @@ -458,26 +458,25 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error { // Get the inference profile details profile, err := a.getInferenceProfile(context.Background(), modelInput) if err != nil { - // Instead of using a fallback model, throw an error return fmt.Errorf("failed to get inference profile: %v", err) - } else { - // Extract the model ID from the inference profile - modelID, err := a.extractModelFromInferenceProfile(profile) - if err != nil { - return fmt.Errorf("failed to extract model ID from inference profile: %v", err) - } - - // Find the model configuration for the extracted model ID - foundModel, err := a.getModelFromString(modelID) - if err != nil { - // Instead of using a fallback model, throw an error - return fmt.Errorf("failed to find model configuration for %s: %v", modelID, err) - } - a.model = foundModel - - // Use the inference profile ARN as the model ID for API calls - a.model.Config.ModelName = modelInput } + // Extract the model ID from the inference profile + modelID, err := a.extractModelFromInferenceProfile(profile) + if err != nil { + return fmt.Errorf("failed to extract model ID from inference profile: %v", err) + } + // Find the model configuration for the extracted model ID + foundModel, err := a.getModelFromString(modelID) + if err != nil { + // Instead of failing, use a generic config for completion/response + // But still warn user + return fmt.Errorf("failed to find model configuration for %s: %v", modelID, err) + } + // Use the found model config for completion/response, but set ModelName to the profile ARN + a.model = foundModel + a.model.Config.ModelName = modelInput + // Mark that we're using an inference profile + // (could add a field if needed) } else { // Regular model ID provided foundModel, err := a.getModelFromString(modelInput) @@ -562,7 +561,8 @@ func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string) supportedModels[i] = m.Name } - if !bedrock_support.IsModelSupported(a.model.Config.ModelName, supportedModels) { + // Allow valid inference profile ARNs as supported models + if !bedrock_support.IsModelSupported(a.model.Config.ModelName, supportedModels) && !validateInferenceProfileArn(a.model.Config.ModelName) { return "", fmt.Errorf("model '%s' is not supported.\nSupported models:\n%s", a.model.Config.ModelName, func() string { s := "" for _, m := range supportedModels {