Compare commits

...

2 Commits

Author SHA1 Message Date
AlexsJones
a98de9a821 feat: support many-to-one auth provider mapping
This commit enhances the AI provider configuration system to support multiple
configurations per provider while maintaining backward compatibility. Key changes:

- Add GetConfigName() to IAIConfig interface to support named configurations
- Update AIProvider struct to handle multiple configurations via Configs array
- Implement configuration fallback logic in AIProvider methods
- Add support for default configuration selection
- Update mock configuration in tests to implement new interface methods

The changes allow providers to have multiple named configurations while
preserving existing functionality for single-configuration setups. This
enables more flexible provider configuration management and better
integration with various AI backends.

Breaking Changes: None
Backward Compatible: Yes

Signed-off-by: AlexsJones <alexsimonjones@gmail.com>
2025-05-06 11:59:09 +01:00
AlexsJones
a1a405a380 feat: support many:1 auth:provider mapping 2025-05-06 11:51:13 +01:00
3 changed files with 120 additions and 60 deletions

View File

@@ -29,6 +29,7 @@ type AmazonBedRockClient struct {
topP float32
maxTokens int
models []bedrock_support.BedrockModel
configName string // Added to support multiple configurations
}
// AmazonCompletion BedRock support region list US East (N. Virginia),US West (Oregon),Asia Pacific (Singapore),Asia Pacific (Tokyo),Europe (Frankfurt)
@@ -353,10 +354,10 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error {
// Get the model input
modelInput := config.GetModel()
// Determine the appropriate region to use
var region string
// Check if the model input is actually an inference profile ARN
if validateInferenceProfileArn(modelInput) {
// Extract the region from the inference profile ARN
@@ -370,11 +371,11 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error {
// Use the provided region or default
region = GetRegionOrDefault(config.GetProviderRegion())
}
// Only create AWS clients if they haven't been injected (for testing)
if a.client == nil || a.mgmtClient == nil {
// Create a new AWS config with the determined region
cfg, err := awsconfig.LoadDefaultConfig(context.Background(),
cfg, err := awsconfig.LoadDefaultConfig(context.Background(),
awsconfig.WithRegion(region),
)
if err != nil {
@@ -385,7 +386,7 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error {
a.client = bedrockruntime.NewFromConfig(cfg)
a.mgmtClient = bedrock.NewFromConfig(cfg)
}
// Handle model selection based on input type
if validateInferenceProfileArn(modelInput) {
// Get the inference profile details
@@ -399,7 +400,7 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error {
if err != nil {
return fmt.Errorf("failed to extract model ID from inference profile: %v", err)
}
// Find the model configuration for the extracted model ID
foundModel, err := a.getModelFromString(modelID)
if err != nil {
@@ -407,7 +408,7 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error {
return fmt.Errorf("failed to find model configuration for %s: %v", modelID, err)
}
a.model = foundModel
// Use the inference profile ARN as the model ID for API calls
a.model.Config.ModelName = modelInput
}
@@ -420,11 +421,12 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error {
a.model = foundModel
a.model.Config.ModelName = foundModel.Config.ModelName
}
// Set common configuration parameters
a.temperature = config.GetTemperature()
a.topP = config.GetTopP()
a.maxTokens = config.GetMaxTokens()
a.configName = config.GetConfigName() // Store the config name
return nil
}
@@ -438,20 +440,20 @@ func (a *AmazonBedRockClient) getInferenceProfile(ctx context.Context, inference
if len(parts) != 2 {
return nil, fmt.Errorf("invalid inference profile ARN format: %s", inferenceProfileARN)
}
profileID := parts[1]
// Create the input for the GetInferenceProfile API call
input := &bedrock.GetInferenceProfileInput{
InferenceProfileIdentifier: aws.String(profileID),
}
// Call the GetInferenceProfile API
output, err := a.mgmtClient.GetInferenceProfile(ctx, input)
if err != nil {
return nil, fmt.Errorf("failed to get inference profile: %w", err)
}
return output, nil
}
@@ -460,25 +462,25 @@ func (a *AmazonBedRockClient) extractModelFromInferenceProfile(profile *bedrock.
if profile == nil || len(profile.Models) == 0 {
return "", fmt.Errorf("inference profile does not contain any models")
}
// Check if the first model has a non-nil ModelArn
if profile.Models[0].ModelArn == nil {
return "", fmt.Errorf("model information is missing in inference profile")
}
// Get the first model ARN from the profile
modelARN := aws.ToString(profile.Models[0].ModelArn)
if modelARN == "" {
return "", fmt.Errorf("model ARN is empty in inference profile")
}
// Extract the model ID from the ARN
// ARN format: arn:aws:bedrock:region::foundation-model/model-id
parts := strings.Split(modelARN, "/")
if len(parts) != 2 {
return "", fmt.Errorf("invalid model ARN format: %s", modelARN)
}
modelID := parts[1]
return modelID, nil
}
@@ -494,7 +496,7 @@ func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string)
if err != nil {
return "", err
}
// Build the parameters for the model invocation
params := &bedrockruntime.InvokeModelInput{
Body: body,
@@ -502,7 +504,7 @@ func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string)
ContentType: aws.String("application/json"),
Accept: aws.String("application/json"),
}
// Invoke the model
resp, err := a.client.InvokeModel(ctx, params)
if err != nil {

View File

@@ -71,22 +71,24 @@ type nopCloser struct{}
func (nopCloser) Close() {}
// IAIConfig represents the configuration for an AI provider
type IAIConfig interface {
GetPassword() string
GetModel() string
GetProviderRegion() string
GetTemperature() float32
GetTopP() float32
GetMaxTokens() int
GetConfigName() string
GetPassword() string
GetBaseURL() string
GetProxyEndpoint() string
GetEndpointName() string
GetEngine() string
GetTemperature() float32
GetProviderRegion() string
GetTopP() float32
GetTopK() int32
GetMaxTokens() int
GetProviderId() string
GetCompartmentId() string
GetOrganizationId() string
GetCustomHeaders() []http.Header
GetTopK() int32
}
func NewClient(provider string) IAI {
@@ -104,24 +106,95 @@ type AIConfiguration struct {
DefaultProvider string `mapstructure:"defaultprovider"`
}
// AIProvider represents a provider configuration
type AIProvider struct {
Name string `mapstructure:"name"`
Model string `mapstructure:"model"`
Password string `mapstructure:"password" yaml:"password,omitempty"`
BaseURL string `mapstructure:"baseurl" yaml:"baseurl,omitempty"`
ProxyEndpoint string `mapstructure:"proxyEndpoint" yaml:"proxyEndpoint,omitempty"`
ProxyPort string `mapstructure:"proxyPort" yaml:"proxyPort,omitempty"`
EndpointName string `mapstructure:"endpointname" yaml:"endpointname,omitempty"`
Engine string `mapstructure:"engine" yaml:"engine,omitempty"`
Temperature float32 `mapstructure:"temperature" yaml:"temperature,omitempty"`
ProviderRegion string `mapstructure:"providerregion" yaml:"providerregion,omitempty"`
ProviderId string `mapstructure:"providerid" yaml:"providerid,omitempty"`
CompartmentId string `mapstructure:"compartmentid" yaml:"compartmentid,omitempty"`
TopP float32 `mapstructure:"topp" yaml:"topp,omitempty"`
TopK int32 `mapstructure:"topk" yaml:"topk,omitempty"`
MaxTokens int `mapstructure:"maxtokens" yaml:"maxtokens,omitempty"`
OrganizationId string `mapstructure:"organizationid" yaml:"organizationid,omitempty"`
CustomHeaders []http.Header `mapstructure:"customHeaders"`
Name string `mapstructure:"name" json:"name"`
Model string `mapstructure:"model" json:"model,omitempty"`
Password string `mapstructure:"password" yaml:"password,omitempty" json:"password,omitempty"`
BaseURL string `mapstructure:"baseurl" yaml:"baseurl,omitempty" json:"baseurl,omitempty"`
ProxyEndpoint string `mapstructure:"proxyEndpoint" yaml:"proxyEndpoint,omitempty" json:"proxyEndpoint,omitempty"`
ProxyPort string `mapstructure:"proxyPort" yaml:"proxyPort,omitempty" json:"proxyPort,omitempty"`
EndpointName string `mapstructure:"endpointname" yaml:"endpointname,omitempty" json:"endpointname,omitempty"`
Engine string `mapstructure:"engine" yaml:"engine,omitempty" json:"engine,omitempty"`
Temperature float32 `mapstructure:"temperature" yaml:"temperature,omitempty" json:"temperature,omitempty"`
ProviderRegion string `mapstructure:"providerregion" yaml:"providerregion,omitempty" json:"providerregion,omitempty"`
ProviderId string `mapstructure:"providerid" yaml:"providerid,omitempty" json:"providerid,omitempty"`
CompartmentId string `mapstructure:"compartmentid" yaml:"compartmentid,omitempty" json:"compartmentid,omitempty"`
TopP float32 `mapstructure:"topp" yaml:"topp,omitempty" json:"topp,omitempty"`
TopK int32 `mapstructure:"topk" yaml:"topk,omitempty" json:"topk,omitempty"`
MaxTokens int `mapstructure:"maxtokens" yaml:"maxtokens,omitempty" json:"maxtokens,omitempty"`
OrganizationId string `mapstructure:"organizationid" yaml:"organizationid,omitempty" json:"organizationid,omitempty"`
CustomHeaders []http.Header `mapstructure:"customHeaders" json:"customHeaders,omitempty"`
Configs []AIProviderConfig `mapstructure:"configs" json:"configs,omitempty"`
DefaultConfig int `mapstructure:"defaultConfig" json:"defaultConfig,omitempty"`
}
// AIProviderConfig represents a single configuration for a provider
type AIProviderConfig struct {
Model string `mapstructure:"model" json:"model"`
ProviderRegion string `mapstructure:"providerRegion" json:"providerRegion"`
Temperature float32 `mapstructure:"temperature" json:"temperature"`
TopP float32 `mapstructure:"topP" json:"topP"`
MaxTokens int `mapstructure:"maxTokens" json:"maxTokens"`
ConfigName string `mapstructure:"configName" json:"configName"`
Password string `mapstructure:"password" yaml:"password,omitempty" json:"password,omitempty"`
BaseURL string `mapstructure:"baseurl" yaml:"baseurl,omitempty" json:"baseurl,omitempty"`
ProxyEndpoint string `mapstructure:"proxyEndpoint" yaml:"proxyEndpoint,omitempty" json:"proxyEndpoint,omitempty"`
EndpointName string `mapstructure:"endpointname" yaml:"endpointname,omitempty" json:"endpointname,omitempty"`
Engine string `mapstructure:"engine" yaml:"engine,omitempty" json:"engine,omitempty"`
ProviderId string `mapstructure:"providerid" yaml:"providerid,omitempty" json:"providerid,omitempty"`
CompartmentId string `mapstructure:"compartmentid" yaml:"compartmentid,omitempty" json:"compartmentid,omitempty"`
TopK int32 `mapstructure:"topk" yaml:"topk,omitempty" json:"topk,omitempty"`
OrganizationId string `mapstructure:"organizationid" yaml:"organizationid,omitempty" json:"organizationid,omitempty"`
CustomHeaders []http.Header `mapstructure:"customHeaders" json:"customHeaders,omitempty"`
}
// GetConfigName returns the configuration name
func (p *AIProvider) GetConfigName() string {
if len(p.Configs) > 0 && p.DefaultConfig >= 0 && p.DefaultConfig < len(p.Configs) {
return p.Configs[p.DefaultConfig].ConfigName
}
return ""
}
// GetModel returns the model name
func (p *AIProvider) GetModel() string {
if len(p.Configs) > 0 && p.DefaultConfig >= 0 && p.DefaultConfig < len(p.Configs) {
return p.Configs[p.DefaultConfig].Model
}
return p.Model
}
// GetProviderRegion returns the provider region
func (p *AIProvider) GetProviderRegion() string {
if len(p.Configs) > 0 && p.DefaultConfig >= 0 && p.DefaultConfig < len(p.Configs) {
return p.Configs[p.DefaultConfig].ProviderRegion
}
return p.ProviderRegion
}
// GetTemperature returns the temperature
func (p *AIProvider) GetTemperature() float32 {
if len(p.Configs) > 0 && p.DefaultConfig >= 0 && p.DefaultConfig < len(p.Configs) {
return p.Configs[p.DefaultConfig].Temperature
}
return p.Temperature
}
// GetTopP returns the top P value
func (p *AIProvider) GetTopP() float32 {
if len(p.Configs) > 0 && p.DefaultConfig >= 0 && p.DefaultConfig < len(p.Configs) {
return p.Configs[p.DefaultConfig].TopP
}
return p.TopP
}
// GetMaxTokens returns the maximum number of tokens
func (p *AIProvider) GetMaxTokens() int {
if len(p.Configs) > 0 && p.DefaultConfig >= 0 && p.DefaultConfig < len(p.Configs) {
return p.Configs[p.DefaultConfig].MaxTokens
}
return p.MaxTokens
}
func (p *AIProvider) GetBaseURL() string {
@@ -136,36 +209,17 @@ func (p *AIProvider) GetEndpointName() string {
return p.EndpointName
}
func (p *AIProvider) GetTopP() float32 {
return p.TopP
}
func (p *AIProvider) GetTopK() int32 {
return p.TopK
}
func (p *AIProvider) GetMaxTokens() int {
return p.MaxTokens
}
func (p *AIProvider) GetPassword() string {
return p.Password
}
func (p *AIProvider) GetModel() string {
return p.Model
}
func (p *AIProvider) GetEngine() string {
return p.Engine
}
func (p *AIProvider) GetTemperature() float32 {
return p.Temperature
}
func (p *AIProvider) GetProviderRegion() string {
return p.ProviderRegion
}
func (p *AIProvider) GetProviderId() string {
return p.ProviderId

View File

@@ -76,6 +76,10 @@ func (m *mockConfig) GetProviderRegion() string {
return ""
}
func (m *mockConfig) GetConfigName() string {
return ""
}
func TestOpenAIClient_CustomHeaders(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "Value1", r.Header.Get("X-Custom-Header-1"))