feat: fix to broken inference (#1575)

Signed-off-by: Alex <alexsimonjones@gmail.com>
This commit is contained in:
Alex Jones
2025-09-03 20:08:44 +01:00
committed by GitHub
parent 8bbffed643
commit 291e42dc4b

View File

@@ -458,26 +458,25 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error {
// Get the inference profile details
profile, err := a.getInferenceProfile(context.Background(), modelInput)
if err != nil {
// Instead of using a fallback model, throw an error
return fmt.Errorf("failed to get inference profile: %v", err)
} else {
// Extract the model ID from the inference profile
modelID, err := a.extractModelFromInferenceProfile(profile)
if err != nil {
return fmt.Errorf("failed to extract model ID from inference profile: %v", err)
}
// Find the model configuration for the extracted model ID
foundModel, err := a.getModelFromString(modelID)
if err != nil {
// Instead of using a fallback model, throw an error
return fmt.Errorf("failed to find model configuration for %s: %v", modelID, err)
}
a.model = foundModel
// Use the inference profile ARN as the model ID for API calls
a.model.Config.ModelName = modelInput
}
// Extract the model ID from the inference profile
modelID, err := a.extractModelFromInferenceProfile(profile)
if err != nil {
return fmt.Errorf("failed to extract model ID from inference profile: %v", err)
}
// Find the model configuration for the extracted model ID
foundModel, err := a.getModelFromString(modelID)
if err != nil {
// Instead of failing, use a generic config for completion/response
// But still warn user
return fmt.Errorf("failed to find model configuration for %s: %v", modelID, err)
}
// Use the found model config for completion/response, but set ModelName to the profile ARN
a.model = foundModel
a.model.Config.ModelName = modelInput
// Mark that we're using an inference profile
// (could add a field if needed)
} else {
// Regular model ID provided
foundModel, err := a.getModelFromString(modelInput)
@@ -562,7 +561,8 @@ func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string)
supportedModels[i] = m.Name
}
if !bedrock_support.IsModelSupported(a.model.Config.ModelName, supportedModels) {
// Allow valid inference profile ARNs as supported models
if !bedrock_support.IsModelSupported(a.model.Config.ModelName, supportedModels) && !validateInferenceProfileArn(a.model.Config.ModelName) {
return "", fmt.Errorf("model '%s' is not supported.\nSupported models:\n%s", a.model.Config.ModelName, func() string {
s := ""
for _, m := range supportedModels {