feat: support amazonbedrock converse api (#1627)

* feat: add amazon bedrock converse api support

Signed-off-by: CradleKing24 <44717227+CradleKing24@users.noreply.github.com>

* docs(amazonbedrockconverse): add backend amazonbedrockconverse details

Signed-off-by: CradleKing24 <44717227+CradleKing24@users.noreply.github.com>

* fix(amazonbedrockconverse): error statements and comment cleanup

Signed-off-by: CradleKing24 <44717227+CradleKing24@users.noreply.github.com>

* test(amazonbedrockconverse): add unit tests

Signed-off-by: CradleKing24 <44717227+CradleKing24@users.noreply.github.com>

* fix(amazonbedrockconverse): linting, test coverage, converse output review

Signed-off-by: CradleKing24 <44717227+CradleKing24@users.noreply.github.com>

---------

Signed-off-by: CradleKing24 <44717227+CradleKing24@users.noreply.github.com>
This commit is contained in:
CradleKing24
2026-03-24 08:40:53 -05:00
committed by GitHub
parent 2276b12b0f
commit fc6a83d063
8 changed files with 451 additions and 3 deletions

View File

@@ -0,0 +1,161 @@
package ai
import (
"context"
"errors"
"fmt"
"github.com/aws/aws-sdk-go-v2/aws"
awsconfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
"os"
"strings"
)
const amazonBedrockConverseClientName = "amazonbedrockconverse"
type bedrockConverseAPI interface {
Converse(ctx context.Context, input *bedrockruntime.ConverseInput, optFns ...func(*bedrockruntime.Options)) (*bedrockruntime.ConverseOutput, error)
}
type AmazonBedrockConverseClient struct {
nopCloser
client bedrockConverseAPI
model string
temperature float32
topP float32
maxTokens int
stopSequences []string
}
func getRegion(region string) string {
if os.Getenv("AWS_DEFAULT_REGION") != "" {
region = os.Getenv("AWS_DEFAULT_REGION")
}
// Return the supplied provider region if not overridden by environment variable
return region
}
func (a *AmazonBedrockConverseClient) getModelFromString(model string) (string, error) {
if model == "" {
return "", errors.New("model name cannot be empty")
}
model = strings.TrimSpace(model)
return model, nil
}
func processError(err error, modelId string) error {
errMsg := err.Error()
if strings.Contains(errMsg, "no such host") {
return fmt.Errorf(`the bedrock service is not available in the selected region.
please double-check the service availability for your region at
https://aws.amazon.com/about-aws/global-infrastructure/regional-product-services/`)
} else if strings.Contains(errMsg, "Could not resolve the foundation model") {
return fmt.Errorf(`could not resolve the foundation model from model identifier: \"%s\".
please verify that the requested model exists and is accessible
within the specified region`, modelId)
} else {
return fmt.Errorf("could not invoke model: \"%s\". here is why: %s", modelId, err)
}
}
func (a *AmazonBedrockConverseClient) Configure(config IAIConfig) error {
modelInput := config.GetModel()
var region = getRegion(config.GetProviderRegion())
// Only create AWS clients if they haven't been injected (for testing)
if a.client == nil {
cfg, err := awsconfig.LoadDefaultConfig(context.Background(),
awsconfig.WithRegion(region),
)
if err != nil {
if strings.Contains(err.Error(), "InvalidAccessKeyId") || strings.Contains(err.Error(), "SignatureDoesNotMatch") || strings.Contains(err.Error(), "NoCredentialProviders") {
return fmt.Errorf("aws credentials are invalid or missing. Please check your environment variables or aws config. details: %v", err)
}
return fmt.Errorf("failed to load aws config for region %s: %w", region, err)
}
a.client = bedrockruntime.NewFromConfig(cfg)
}
foundModel, err := a.getModelFromString(modelInput)
if err != nil {
return fmt.Errorf("failed to find model configuration for %s: %w", modelInput, err)
}
a.model = foundModel
// Set common configuration parameters
a.temperature = config.GetTemperature()
a.topP = config.GetTopP()
a.maxTokens = config.GetMaxTokens()
a.stopSequences = config.GetStopSequences()
return nil
}
func extractTextFromConverseOutput(output types.ConverseOutput, modelId string) (string, error) {
if output == nil {
return "", fmt.Errorf("empty response from model: %s", modelId)
}
msg, ok := output.(*types.ConverseOutputMemberMessage)
if !ok {
return "", fmt.Errorf("unexpected response type from model: %s", modelId)
}
if len(msg.Value.Content) == 0 {
return "", fmt.Errorf("no content returned from model: %s", modelId)
}
var builder strings.Builder
for _, block := range msg.Value.Content {
if textBlock, ok := block.(*types.ContentBlockMemberText); ok && textBlock != nil {
builder.WriteString(textBlock.Value)
}
}
if builder.Len() == 0 {
return "", fmt.Errorf("no text content returned from model: %s", modelId)
}
return builder.String(), nil
}
func (a *AmazonBedrockConverseClient) GetCompletion(ctx context.Context, prompt string) (string, error) {
var content = types.ContentBlockMemberText{
Value: prompt,
}
var message = types.Message{
Content: []types.ContentBlock{&content},
Role: "user",
}
var converseInput = bedrockruntime.ConverseInput{
ModelId: aws.String(a.model),
Messages: []types.Message{message},
InferenceConfig: &types.InferenceConfiguration{
Temperature: aws.Float32(a.temperature),
TopP: aws.Float32(a.topP),
MaxTokens: aws.Int32(int32(a.maxTokens)),
StopSequences: a.stopSequences,
},
}
response, err := a.client.Converse(ctx, &converseInput)
if err != nil {
return "", processError(err, a.model)
}
text, err := extractTextFromConverseOutput(response.Output, a.model)
if err != nil {
return "", err
}
return text, nil
}
func (a *AmazonBedrockConverseClient) GetName() string {
return amazonBedrockConverseClientName
}

View File

@@ -0,0 +1,250 @@
package ai
import (
"context"
"errors"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
"github.com/stretchr/testify/assert"
"testing"
)
// ---- Mock Wrapper ----
type mockConverseClient struct {
converseFunc func(ctx context.Context, input *bedrockruntime.ConverseInput) (*bedrockruntime.ConverseOutput, error)
}
func (m *mockConverseClient) Converse(ctx context.Context, input *bedrockruntime.ConverseInput, _ ...func(*bedrockruntime.Options)) (*bedrockruntime.ConverseOutput, error) {
return m.converseFunc(ctx, input)
}
// ---- Tests ----
func TestGetCompletion_Success(t *testing.T) {
mock := &mockConverseClient{
converseFunc: func(ctx context.Context, input *bedrockruntime.ConverseInput) (*bedrockruntime.ConverseOutput, error) {
return &bedrockruntime.ConverseOutput{
Output: &types.ConverseOutputMemberMessage{
Value: types.Message{
Content: []types.ContentBlock{
&types.ContentBlockMemberText{
Value: "mock response",
},
},
},
},
}, nil
},
}
client := &AmazonBedrockConverseClient{
client: mock,
model: "test-model",
}
result, err := client.GetCompletion(context.Background(), "hello")
assert.NoError(t, err)
assert.Equal(t, "mock response", result)
}
func TestGetCompletion_Error(t *testing.T) {
mock := &mockConverseClient{
converseFunc: func(ctx context.Context, input *bedrockruntime.ConverseInput) (*bedrockruntime.ConverseOutput, error) {
return nil, errors.New("some error")
},
}
client := &AmazonBedrockConverseClient{
client: mock,
model: "test-model",
}
_, err := client.GetCompletion(context.Background(), "hello")
assert.Error(t, err)
}
func TestConfigure_WithInjectedClient(t *testing.T) {
mock := &mockConverseClient{}
cfg := &AIProvider{
Model: "test-model",
ProviderRegion: "us-west-2",
Temperature: 0.5,
TopP: 0.9,
MaxTokens: 100,
StopSequences: []string{"stop"},
}
client := &AmazonBedrockConverseClient{
client: mock,
}
err := client.Configure(cfg)
assert.NoError(t, err)
assert.Equal(t, "test-model", client.model)
assert.Equal(t, float32(0.5), client.temperature)
assert.Equal(t, float32(0.9), client.topP)
assert.Equal(t, 100, client.maxTokens)
assert.Equal(t, []string{"stop"}, client.stopSequences)
}
func TestConfigure_InvalidModel(t *testing.T) {
mock := &mockConverseClient{}
cfg := &AIProvider{
Model: "",
}
client := &AmazonBedrockConverseClient{
client: mock,
}
err := client.Configure(cfg)
assert.Error(t, err)
assert.Contains(t, err.Error(), "model name cannot be empty")
}
func TestGetRegion(t *testing.T) {
t.Run("uses provided region when env not set", func(t *testing.T) {
t.Setenv("AWS_DEFAULT_REGION", "")
result := getRegion("us-west-2")
assert.Equal(t, "us-west-2", result)
})
t.Run("env overrides provided region", func(t *testing.T) {
t.Setenv("AWS_DEFAULT_REGION", "us-east-1")
result := getRegion("us-west-2")
assert.Equal(t, "us-east-1", result)
})
}
func TestProcessError(t *testing.T) {
tests := []struct {
name string
err error
modelId string
contains string
}{
{
name: "no such host",
err: errors.New("dial tcp: no such host"),
modelId: "test-model",
contains: "bedrock service is not available",
},
{
name: "model not found",
err: errors.New("Could not resolve the foundation model"),
modelId: "test-model",
contains: "could not resolve the foundation model",
},
{
name: "generic error",
err: errors.New("something else"),
modelId: "test-model",
contains: "could not invoke model",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := processError(tt.err, tt.modelId)
assert.Contains(t, result.Error(), tt.contains)
})
}
}
func TestExtractTextFromConverseOutput(t *testing.T) {
tests := []struct {
name string
output types.ConverseOutput
expectError bool
expected string
}{
{
name: "nil output",
output: nil,
expectError: true,
},
{
name: "empty content",
output: &types.ConverseOutputMemberMessage{
Value: types.Message{
Content: []types.ContentBlock{},
},
},
expectError: true,
},
{
name: "single text block",
output: &types.ConverseOutputMemberMessage{
Value: types.Message{
Content: []types.ContentBlock{
&types.ContentBlockMemberText{Value: "hello"},
},
},
},
expected: "hello",
},
{
name: "multiple text blocks",
output: &types.ConverseOutputMemberMessage{
Value: types.Message{
Content: []types.ContentBlock{
&types.ContentBlockMemberText{Value: "hello "},
&types.ContentBlockMemberText{Value: "world"},
},
},
},
expected: "hello world",
},
{
name: "mixed content blocks",
output: &types.ConverseOutputMemberMessage{
Value: types.Message{
Content: []types.ContentBlock{
&types.ContentBlockMemberText{Value: "hello"},
// simulate non-text block
&types.ContentBlockMemberImage{},
&types.ContentBlockMemberText{Value: " world"},
},
},
},
expected: "hello world",
},
{
name: "no text blocks",
output: &types.ConverseOutputMemberMessage{
Value: types.Message{
Content: []types.ContentBlock{
&types.ContentBlockMemberImage{},
},
},
},
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := extractTextFromConverseOutput(tt.output, "test-model")
if tt.expectError {
assert.Error(t, err)
return
}
assert.NoError(t, err)
assert.Equal(t, tt.expected, result)
})
}
}
func TestGetName(t *testing.T) {
client := &AmazonBedrockConverseClient{}
assert.Equal(t, "amazonbedrockconverse", client.GetName())
}

View File

@@ -27,6 +27,7 @@ var (
&NoOpAIClient{},
&CohereClient{},
&AmazonBedRockClient{},
&AmazonBedrockConverseClient{},
&SageMakerAIClient{},
&GoogleGenAIClient{},
&HuggingfaceClient{},
@@ -43,6 +44,7 @@ var (
azureAIClientName,
cohereAIClientName,
amazonbedrockAIClientName,
amazonBedrockConverseClientName,
amazonsagemakerAIClientName,
googleAIClientName,
noopAIClientName,
@@ -85,6 +87,7 @@ type IAIConfig interface {
GetTopP() float32
GetTopK() int32
GetMaxTokens() int
GetStopSequences() []string
GetProviderId() string
GetCompartmentId() string
GetOrganizationId() string
@@ -122,6 +125,7 @@ type AIProvider struct {
TopP float32 `mapstructure:"topp" yaml:"topp,omitempty"`
TopK int32 `mapstructure:"topk" yaml:"topk,omitempty"`
MaxTokens int `mapstructure:"maxtokens" yaml:"maxtokens,omitempty"`
StopSequences []string `mapstructure:"stopsequences" yaml:"stopsequences,omitempty"`
OrganizationId string `mapstructure:"organizationid" yaml:"organizationid,omitempty"`
CustomHeaders []http.Header `mapstructure:"customHeaders"`
}
@@ -150,6 +154,10 @@ func (p *AIProvider) GetMaxTokens() int {
return p.MaxTokens
}
func (p *AIProvider) GetStopSequences() []string {
return p.StopSequences
}
func (p *AIProvider) GetPassword() string {
return p.Password
}
@@ -185,7 +193,7 @@ func (p *AIProvider) GetCustomHeaders() []http.Header {
return p.CustomHeaders
}
var passwordlessProviders = []string{"localai", "ollama", "amazonsagemaker", "amazonbedrock", "googlevertexai", "oci", "customrest"}
var passwordlessProviders = []string{"localai", "ollama", "amazonsagemaker", "amazonbedrock", "amazonbedrockconverse", "googlevertexai", "oci", "customrest"}
func NeedPassword(backend string) bool {
for _, b := range passwordlessProviders {

View File

@@ -61,6 +61,10 @@ func (m *mockConfig) GetMaxTokens() int {
return 0
}
func (m *mockConfig) GetStopSequences() []string {
return []string{"", "", "", ""}
}
func (m *mockConfig) GetEndpointName() string {
return ""
}