feat: support many:1 auth:provider mapping

This commit is contained in:
AlexsJones
2025-05-06 11:51:13 +01:00
parent 6a81d2c140
commit a1a405a380
2 changed files with 111 additions and 65 deletions

View File

@@ -29,6 +29,7 @@ type AmazonBedRockClient struct {
topP float32
maxTokens int
models []bedrock_support.BedrockModel
configName string // Added to support multiple configurations
}
// AmazonCompletion BedRock support region list US East (N. Virginia),US West (Oregon),Asia Pacific (Singapore),Asia Pacific (Tokyo),Europe (Frankfurt)
@@ -353,10 +354,10 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error {
// Get the model input
modelInput := config.GetModel()
// Determine the appropriate region to use
var region string
// Check if the model input is actually an inference profile ARN
if validateInferenceProfileArn(modelInput) {
// Extract the region from the inference profile ARN
@@ -370,11 +371,11 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error {
// Use the provided region or default
region = GetRegionOrDefault(config.GetProviderRegion())
}
// Only create AWS clients if they haven't been injected (for testing)
if a.client == nil || a.mgmtClient == nil {
// Create a new AWS config with the determined region
cfg, err := awsconfig.LoadDefaultConfig(context.Background(),
cfg, err := awsconfig.LoadDefaultConfig(context.Background(),
awsconfig.WithRegion(region),
)
if err != nil {
@@ -385,7 +386,7 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error {
a.client = bedrockruntime.NewFromConfig(cfg)
a.mgmtClient = bedrock.NewFromConfig(cfg)
}
// Handle model selection based on input type
if validateInferenceProfileArn(modelInput) {
// Get the inference profile details
@@ -399,7 +400,7 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error {
if err != nil {
return fmt.Errorf("failed to extract model ID from inference profile: %v", err)
}
// Find the model configuration for the extracted model ID
foundModel, err := a.getModelFromString(modelID)
if err != nil {
@@ -407,7 +408,7 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error {
return fmt.Errorf("failed to find model configuration for %s: %v", modelID, err)
}
a.model = foundModel
// Use the inference profile ARN as the model ID for API calls
a.model.Config.ModelName = modelInput
}
@@ -420,11 +421,12 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error {
a.model = foundModel
a.model.Config.ModelName = foundModel.Config.ModelName
}
// Set common configuration parameters
a.temperature = config.GetTemperature()
a.topP = config.GetTopP()
a.maxTokens = config.GetMaxTokens()
a.configName = config.GetConfigName() // Store the config name
return nil
}
@@ -438,20 +440,20 @@ func (a *AmazonBedRockClient) getInferenceProfile(ctx context.Context, inference
if len(parts) != 2 {
return nil, fmt.Errorf("invalid inference profile ARN format: %s", inferenceProfileARN)
}
profileID := parts[1]
// Create the input for the GetInferenceProfile API call
input := &bedrock.GetInferenceProfileInput{
InferenceProfileIdentifier: aws.String(profileID),
}
// Call the GetInferenceProfile API
output, err := a.mgmtClient.GetInferenceProfile(ctx, input)
if err != nil {
return nil, fmt.Errorf("failed to get inference profile: %w", err)
}
return output, nil
}
@@ -460,25 +462,25 @@ func (a *AmazonBedRockClient) extractModelFromInferenceProfile(profile *bedrock.
if profile == nil || len(profile.Models) == 0 {
return "", fmt.Errorf("inference profile does not contain any models")
}
// Check if the first model has a non-nil ModelArn
if profile.Models[0].ModelArn == nil {
return "", fmt.Errorf("model information is missing in inference profile")
}
// Get the first model ARN from the profile
modelARN := aws.ToString(profile.Models[0].ModelArn)
if modelARN == "" {
return "", fmt.Errorf("model ARN is empty in inference profile")
}
// Extract the model ID from the ARN
// ARN format: arn:aws:bedrock:region::foundation-model/model-id
parts := strings.Split(modelARN, "/")
if len(parts) != 2 {
return "", fmt.Errorf("invalid model ARN format: %s", modelARN)
}
modelID := parts[1]
return modelID, nil
}
@@ -494,7 +496,7 @@ func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string)
if err != nil {
return "", err
}
// Build the parameters for the model invocation
params := &bedrockruntime.InvokeModelInput{
Body: body,
@@ -502,7 +504,7 @@ func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string)
ContentType: aws.String("application/json"),
Accept: aws.String("application/json"),
}
// Invoke the model
resp, err := a.client.InvokeModel(ctx, params)
if err != nil {

View File

@@ -71,22 +71,14 @@ type nopCloser struct{}
func (nopCloser) Close() {}
// IAIConfig represents the configuration for an AI provider
type IAIConfig interface {
GetPassword() string
GetModel() string
GetBaseURL() string
GetProxyEndpoint() string
GetEndpointName() string
GetEngine() string
GetTemperature() float32
GetProviderRegion() string
GetTemperature() float32
GetTopP() float32
GetTopK() int32
GetMaxTokens() int
GetProviderId() string
GetCompartmentId() string
GetOrganizationId() string
GetCustomHeaders() []http.Header
GetConfigName() string // Added to support multiple configurations
}
func NewClient(provider string) IAI {
@@ -104,24 +96,95 @@ type AIConfiguration struct {
DefaultProvider string `mapstructure:"defaultprovider"`
}
// AIProvider represents a provider configuration
type AIProvider struct {
Name string `mapstructure:"name"`
Model string `mapstructure:"model"`
Password string `mapstructure:"password" yaml:"password,omitempty"`
BaseURL string `mapstructure:"baseurl" yaml:"baseurl,omitempty"`
ProxyEndpoint string `mapstructure:"proxyEndpoint" yaml:"proxyEndpoint,omitempty"`
ProxyPort string `mapstructure:"proxyPort" yaml:"proxyPort,omitempty"`
EndpointName string `mapstructure:"endpointname" yaml:"endpointname,omitempty"`
Engine string `mapstructure:"engine" yaml:"engine,omitempty"`
Temperature float32 `mapstructure:"temperature" yaml:"temperature,omitempty"`
ProviderRegion string `mapstructure:"providerregion" yaml:"providerregion,omitempty"`
ProviderId string `mapstructure:"providerid" yaml:"providerid,omitempty"`
CompartmentId string `mapstructure:"compartmentid" yaml:"compartmentid,omitempty"`
TopP float32 `mapstructure:"topp" yaml:"topp,omitempty"`
TopK int32 `mapstructure:"topk" yaml:"topk,omitempty"`
MaxTokens int `mapstructure:"maxtokens" yaml:"maxtokens,omitempty"`
OrganizationId string `mapstructure:"organizationid" yaml:"organizationid,omitempty"`
CustomHeaders []http.Header `mapstructure:"customHeaders"`
Name string `mapstructure:"name" json:"name"`
Model string `mapstructure:"model" json:"model,omitempty"`
Password string `mapstructure:"password" yaml:"password,omitempty" json:"password,omitempty"`
BaseURL string `mapstructure:"baseurl" yaml:"baseurl,omitempty" json:"baseurl,omitempty"`
ProxyEndpoint string `mapstructure:"proxyEndpoint" yaml:"proxyEndpoint,omitempty" json:"proxyEndpoint,omitempty"`
ProxyPort string `mapstructure:"proxyPort" yaml:"proxyPort,omitempty" json:"proxyPort,omitempty"`
EndpointName string `mapstructure:"endpointname" yaml:"endpointname,omitempty" json:"endpointname,omitempty"`
Engine string `mapstructure:"engine" yaml:"engine,omitempty" json:"engine,omitempty"`
Temperature float32 `mapstructure:"temperature" yaml:"temperature,omitempty" json:"temperature,omitempty"`
ProviderRegion string `mapstructure:"providerregion" yaml:"providerregion,omitempty" json:"providerregion,omitempty"`
ProviderId string `mapstructure:"providerid" yaml:"providerid,omitempty" json:"providerid,omitempty"`
CompartmentId string `mapstructure:"compartmentid" yaml:"compartmentid,omitempty" json:"compartmentid,omitempty"`
TopP float32 `mapstructure:"topp" yaml:"topp,omitempty" json:"topp,omitempty"`
TopK int32 `mapstructure:"topk" yaml:"topk,omitempty" json:"topk,omitempty"`
MaxTokens int `mapstructure:"maxtokens" yaml:"maxtokens,omitempty" json:"maxtokens,omitempty"`
OrganizationId string `mapstructure:"organizationid" yaml:"organizationid,omitempty" json:"organizationid,omitempty"`
CustomHeaders []http.Header `mapstructure:"customHeaders" json:"customHeaders,omitempty"`
Configs []AIProviderConfig `mapstructure:"configs" json:"configs,omitempty"`
DefaultConfig int `mapstructure:"defaultConfig" json:"defaultConfig,omitempty"`
}
// AIProviderConfig represents a single configuration for a provider
type AIProviderConfig struct {
Model string `mapstructure:"model" json:"model"`
ProviderRegion string `mapstructure:"providerRegion" json:"providerRegion"`
Temperature float32 `mapstructure:"temperature" json:"temperature"`
TopP float32 `mapstructure:"topP" json:"topP"`
MaxTokens int `mapstructure:"maxTokens" json:"maxTokens"`
ConfigName string `mapstructure:"configName" json:"configName"`
Password string `mapstructure:"password" yaml:"password,omitempty" json:"password,omitempty"`
BaseURL string `mapstructure:"baseurl" yaml:"baseurl,omitempty" json:"baseurl,omitempty"`
ProxyEndpoint string `mapstructure:"proxyEndpoint" yaml:"proxyEndpoint,omitempty" json:"proxyEndpoint,omitempty"`
EndpointName string `mapstructure:"endpointname" yaml:"endpointname,omitempty" json:"endpointname,omitempty"`
Engine string `mapstructure:"engine" yaml:"engine,omitempty" json:"engine,omitempty"`
ProviderId string `mapstructure:"providerid" yaml:"providerid,omitempty" json:"providerid,omitempty"`
CompartmentId string `mapstructure:"compartmentid" yaml:"compartmentid,omitempty" json:"compartmentid,omitempty"`
TopK int32 `mapstructure:"topk" yaml:"topk,omitempty" json:"topk,omitempty"`
OrganizationId string `mapstructure:"organizationid" yaml:"organizationid,omitempty" json:"organizationid,omitempty"`
CustomHeaders []http.Header `mapstructure:"customHeaders" json:"customHeaders,omitempty"`
}
// GetConfigName returns the configuration name
func (p *AIProvider) GetConfigName() string {
if len(p.Configs) > 0 && p.DefaultConfig >= 0 && p.DefaultConfig < len(p.Configs) {
return p.Configs[p.DefaultConfig].ConfigName
}
return ""
}
// GetModel returns the model name
func (p *AIProvider) GetModel() string {
if len(p.Configs) > 0 && p.DefaultConfig >= 0 && p.DefaultConfig < len(p.Configs) {
return p.Configs[p.DefaultConfig].Model
}
return p.Model
}
// GetProviderRegion returns the provider region
func (p *AIProvider) GetProviderRegion() string {
if len(p.Configs) > 0 && p.DefaultConfig >= 0 && p.DefaultConfig < len(p.Configs) {
return p.Configs[p.DefaultConfig].ProviderRegion
}
return p.ProviderRegion
}
// GetTemperature returns the temperature
func (p *AIProvider) GetTemperature() float32 {
if len(p.Configs) > 0 && p.DefaultConfig >= 0 && p.DefaultConfig < len(p.Configs) {
return p.Configs[p.DefaultConfig].Temperature
}
return p.Temperature
}
// GetTopP returns the top P value
func (p *AIProvider) GetTopP() float32 {
if len(p.Configs) > 0 && p.DefaultConfig >= 0 && p.DefaultConfig < len(p.Configs) {
return p.Configs[p.DefaultConfig].TopP
}
return p.TopP
}
// GetMaxTokens returns the maximum number of tokens
func (p *AIProvider) GetMaxTokens() int {
if len(p.Configs) > 0 && p.DefaultConfig >= 0 && p.DefaultConfig < len(p.Configs) {
return p.Configs[p.DefaultConfig].MaxTokens
}
return p.MaxTokens
}
func (p *AIProvider) GetBaseURL() string {
@@ -136,36 +199,17 @@ func (p *AIProvider) GetEndpointName() string {
return p.EndpointName
}
func (p *AIProvider) GetTopP() float32 {
return p.TopP
}
func (p *AIProvider) GetTopK() int32 {
return p.TopK
}
func (p *AIProvider) GetMaxTokens() int {
return p.MaxTokens
}
func (p *AIProvider) GetPassword() string {
return p.Password
}
func (p *AIProvider) GetModel() string {
return p.Model
}
func (p *AIProvider) GetEngine() string {
return p.Engine
}
func (p *AIProvider) GetTemperature() float32 {
return p.Temperature
}
func (p *AIProvider) GetProviderRegion() string {
return p.ProviderRegion
}
func (p *AIProvider) GetProviderId() string {
return p.ProviderId