From a1a405a3801a60b657e6e3dca3687fb0545c61ff Mon Sep 17 00:00:00 2001 From: AlexsJones Date: Tue, 6 May 2025 11:51:13 +0100 Subject: [PATCH] feat: support many:1 auth:provider mapping --- pkg/ai/amazonbedrock.go | 38 +++++------ pkg/ai/iai.go | 138 ++++++++++++++++++++++++++-------------- 2 files changed, 111 insertions(+), 65 deletions(-) diff --git a/pkg/ai/amazonbedrock.go b/pkg/ai/amazonbedrock.go index 995de4b3..4823564a 100644 --- a/pkg/ai/amazonbedrock.go +++ b/pkg/ai/amazonbedrock.go @@ -29,6 +29,7 @@ type AmazonBedRockClient struct { topP float32 maxTokens int models []bedrock_support.BedrockModel + configName string // Added to support multiple configurations } // AmazonCompletion BedRock support region list US East (N. Virginia),US West (Oregon),Asia Pacific (Singapore),Asia Pacific (Tokyo),Europe (Frankfurt) @@ -353,10 +354,10 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error { // Get the model input modelInput := config.GetModel() - + // Determine the appropriate region to use var region string - + // Check if the model input is actually an inference profile ARN if validateInferenceProfileArn(modelInput) { // Extract the region from the inference profile ARN @@ -370,11 +371,11 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error { // Use the provided region or default region = GetRegionOrDefault(config.GetProviderRegion()) } - + // Only create AWS clients if they haven't been injected (for testing) if a.client == nil || a.mgmtClient == nil { // Create a new AWS config with the determined region - cfg, err := awsconfig.LoadDefaultConfig(context.Background(), + cfg, err := awsconfig.LoadDefaultConfig(context.Background(), awsconfig.WithRegion(region), ) if err != nil { @@ -385,7 +386,7 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error { a.client = bedrockruntime.NewFromConfig(cfg) a.mgmtClient = bedrock.NewFromConfig(cfg) } - + // Handle model selection based on input type if validateInferenceProfileArn(modelInput) { // Get the inference profile details @@ -399,7 +400,7 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error { if err != nil { return fmt.Errorf("failed to extract model ID from inference profile: %v", err) } - + // Find the model configuration for the extracted model ID foundModel, err := a.getModelFromString(modelID) if err != nil { @@ -407,7 +408,7 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error { return fmt.Errorf("failed to find model configuration for %s: %v", modelID, err) } a.model = foundModel - + // Use the inference profile ARN as the model ID for API calls a.model.Config.ModelName = modelInput } @@ -420,11 +421,12 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error { a.model = foundModel a.model.Config.ModelName = foundModel.Config.ModelName } - + // Set common configuration parameters a.temperature = config.GetTemperature() a.topP = config.GetTopP() a.maxTokens = config.GetMaxTokens() + a.configName = config.GetConfigName() // Store the config name return nil } @@ -438,20 +440,20 @@ func (a *AmazonBedRockClient) getInferenceProfile(ctx context.Context, inference if len(parts) != 2 { return nil, fmt.Errorf("invalid inference profile ARN format: %s", inferenceProfileARN) } - + profileID := parts[1] - + // Create the input for the GetInferenceProfile API call input := &bedrock.GetInferenceProfileInput{ InferenceProfileIdentifier: aws.String(profileID), } - + // Call the GetInferenceProfile API output, err := a.mgmtClient.GetInferenceProfile(ctx, input) if err != nil { return nil, fmt.Errorf("failed to get inference profile: %w", err) } - + return output, nil } @@ -460,25 +462,25 @@ func (a *AmazonBedRockClient) extractModelFromInferenceProfile(profile *bedrock. if profile == nil || len(profile.Models) == 0 { return "", fmt.Errorf("inference profile does not contain any models") } - + // Check if the first model has a non-nil ModelArn if profile.Models[0].ModelArn == nil { return "", fmt.Errorf("model information is missing in inference profile") } - + // Get the first model ARN from the profile modelARN := aws.ToString(profile.Models[0].ModelArn) if modelARN == "" { return "", fmt.Errorf("model ARN is empty in inference profile") } - + // Extract the model ID from the ARN // ARN format: arn:aws:bedrock:region::foundation-model/model-id parts := strings.Split(modelARN, "/") if len(parts) != 2 { return "", fmt.Errorf("invalid model ARN format: %s", modelARN) } - + modelID := parts[1] return modelID, nil } @@ -494,7 +496,7 @@ func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string) if err != nil { return "", err } - + // Build the parameters for the model invocation params := &bedrockruntime.InvokeModelInput{ Body: body, @@ -502,7 +504,7 @@ func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string) ContentType: aws.String("application/json"), Accept: aws.String("application/json"), } - + // Invoke the model resp, err := a.client.InvokeModel(ctx, params) if err != nil { diff --git a/pkg/ai/iai.go b/pkg/ai/iai.go index e1e7169c..15e1b65d 100644 --- a/pkg/ai/iai.go +++ b/pkg/ai/iai.go @@ -71,22 +71,14 @@ type nopCloser struct{} func (nopCloser) Close() {} +// IAIConfig represents the configuration for an AI provider type IAIConfig interface { - GetPassword() string GetModel() string - GetBaseURL() string - GetProxyEndpoint() string - GetEndpointName() string - GetEngine() string - GetTemperature() float32 GetProviderRegion() string + GetTemperature() float32 GetTopP() float32 - GetTopK() int32 GetMaxTokens() int - GetProviderId() string - GetCompartmentId() string - GetOrganizationId() string - GetCustomHeaders() []http.Header + GetConfigName() string // Added to support multiple configurations } func NewClient(provider string) IAI { @@ -104,24 +96,95 @@ type AIConfiguration struct { DefaultProvider string `mapstructure:"defaultprovider"` } +// AIProvider represents a provider configuration type AIProvider struct { - Name string `mapstructure:"name"` - Model string `mapstructure:"model"` - Password string `mapstructure:"password" yaml:"password,omitempty"` - BaseURL string `mapstructure:"baseurl" yaml:"baseurl,omitempty"` - ProxyEndpoint string `mapstructure:"proxyEndpoint" yaml:"proxyEndpoint,omitempty"` - ProxyPort string `mapstructure:"proxyPort" yaml:"proxyPort,omitempty"` - EndpointName string `mapstructure:"endpointname" yaml:"endpointname,omitempty"` - Engine string `mapstructure:"engine" yaml:"engine,omitempty"` - Temperature float32 `mapstructure:"temperature" yaml:"temperature,omitempty"` - ProviderRegion string `mapstructure:"providerregion" yaml:"providerregion,omitempty"` - ProviderId string `mapstructure:"providerid" yaml:"providerid,omitempty"` - CompartmentId string `mapstructure:"compartmentid" yaml:"compartmentid,omitempty"` - TopP float32 `mapstructure:"topp" yaml:"topp,omitempty"` - TopK int32 `mapstructure:"topk" yaml:"topk,omitempty"` - MaxTokens int `mapstructure:"maxtokens" yaml:"maxtokens,omitempty"` - OrganizationId string `mapstructure:"organizationid" yaml:"organizationid,omitempty"` - CustomHeaders []http.Header `mapstructure:"customHeaders"` + Name string `mapstructure:"name" json:"name"` + Model string `mapstructure:"model" json:"model,omitempty"` + Password string `mapstructure:"password" yaml:"password,omitempty" json:"password,omitempty"` + BaseURL string `mapstructure:"baseurl" yaml:"baseurl,omitempty" json:"baseurl,omitempty"` + ProxyEndpoint string `mapstructure:"proxyEndpoint" yaml:"proxyEndpoint,omitempty" json:"proxyEndpoint,omitempty"` + ProxyPort string `mapstructure:"proxyPort" yaml:"proxyPort,omitempty" json:"proxyPort,omitempty"` + EndpointName string `mapstructure:"endpointname" yaml:"endpointname,omitempty" json:"endpointname,omitempty"` + Engine string `mapstructure:"engine" yaml:"engine,omitempty" json:"engine,omitempty"` + Temperature float32 `mapstructure:"temperature" yaml:"temperature,omitempty" json:"temperature,omitempty"` + ProviderRegion string `mapstructure:"providerregion" yaml:"providerregion,omitempty" json:"providerregion,omitempty"` + ProviderId string `mapstructure:"providerid" yaml:"providerid,omitempty" json:"providerid,omitempty"` + CompartmentId string `mapstructure:"compartmentid" yaml:"compartmentid,omitempty" json:"compartmentid,omitempty"` + TopP float32 `mapstructure:"topp" yaml:"topp,omitempty" json:"topp,omitempty"` + TopK int32 `mapstructure:"topk" yaml:"topk,omitempty" json:"topk,omitempty"` + MaxTokens int `mapstructure:"maxtokens" yaml:"maxtokens,omitempty" json:"maxtokens,omitempty"` + OrganizationId string `mapstructure:"organizationid" yaml:"organizationid,omitempty" json:"organizationid,omitempty"` + CustomHeaders []http.Header `mapstructure:"customHeaders" json:"customHeaders,omitempty"` + Configs []AIProviderConfig `mapstructure:"configs" json:"configs,omitempty"` + DefaultConfig int `mapstructure:"defaultConfig" json:"defaultConfig,omitempty"` +} + +// AIProviderConfig represents a single configuration for a provider +type AIProviderConfig struct { + Model string `mapstructure:"model" json:"model"` + ProviderRegion string `mapstructure:"providerRegion" json:"providerRegion"` + Temperature float32 `mapstructure:"temperature" json:"temperature"` + TopP float32 `mapstructure:"topP" json:"topP"` + MaxTokens int `mapstructure:"maxTokens" json:"maxTokens"` + ConfigName string `mapstructure:"configName" json:"configName"` + Password string `mapstructure:"password" yaml:"password,omitempty" json:"password,omitempty"` + BaseURL string `mapstructure:"baseurl" yaml:"baseurl,omitempty" json:"baseurl,omitempty"` + ProxyEndpoint string `mapstructure:"proxyEndpoint" yaml:"proxyEndpoint,omitempty" json:"proxyEndpoint,omitempty"` + EndpointName string `mapstructure:"endpointname" yaml:"endpointname,omitempty" json:"endpointname,omitempty"` + Engine string `mapstructure:"engine" yaml:"engine,omitempty" json:"engine,omitempty"` + ProviderId string `mapstructure:"providerid" yaml:"providerid,omitempty" json:"providerid,omitempty"` + CompartmentId string `mapstructure:"compartmentid" yaml:"compartmentid,omitempty" json:"compartmentid,omitempty"` + TopK int32 `mapstructure:"topk" yaml:"topk,omitempty" json:"topk,omitempty"` + OrganizationId string `mapstructure:"organizationid" yaml:"organizationid,omitempty" json:"organizationid,omitempty"` + CustomHeaders []http.Header `mapstructure:"customHeaders" json:"customHeaders,omitempty"` +} + +// GetConfigName returns the configuration name +func (p *AIProvider) GetConfigName() string { + if len(p.Configs) > 0 && p.DefaultConfig >= 0 && p.DefaultConfig < len(p.Configs) { + return p.Configs[p.DefaultConfig].ConfigName + } + return "" +} + +// GetModel returns the model name +func (p *AIProvider) GetModel() string { + if len(p.Configs) > 0 && p.DefaultConfig >= 0 && p.DefaultConfig < len(p.Configs) { + return p.Configs[p.DefaultConfig].Model + } + return p.Model +} + +// GetProviderRegion returns the provider region +func (p *AIProvider) GetProviderRegion() string { + if len(p.Configs) > 0 && p.DefaultConfig >= 0 && p.DefaultConfig < len(p.Configs) { + return p.Configs[p.DefaultConfig].ProviderRegion + } + return p.ProviderRegion +} + +// GetTemperature returns the temperature +func (p *AIProvider) GetTemperature() float32 { + if len(p.Configs) > 0 && p.DefaultConfig >= 0 && p.DefaultConfig < len(p.Configs) { + return p.Configs[p.DefaultConfig].Temperature + } + return p.Temperature +} + +// GetTopP returns the top P value +func (p *AIProvider) GetTopP() float32 { + if len(p.Configs) > 0 && p.DefaultConfig >= 0 && p.DefaultConfig < len(p.Configs) { + return p.Configs[p.DefaultConfig].TopP + } + return p.TopP +} + +// GetMaxTokens returns the maximum number of tokens +func (p *AIProvider) GetMaxTokens() int { + if len(p.Configs) > 0 && p.DefaultConfig >= 0 && p.DefaultConfig < len(p.Configs) { + return p.Configs[p.DefaultConfig].MaxTokens + } + return p.MaxTokens } func (p *AIProvider) GetBaseURL() string { @@ -136,36 +199,17 @@ func (p *AIProvider) GetEndpointName() string { return p.EndpointName } -func (p *AIProvider) GetTopP() float32 { - return p.TopP -} - func (p *AIProvider) GetTopK() int32 { return p.TopK } -func (p *AIProvider) GetMaxTokens() int { - return p.MaxTokens -} - func (p *AIProvider) GetPassword() string { return p.Password } -func (p *AIProvider) GetModel() string { - return p.Model -} - func (p *AIProvider) GetEngine() string { return p.Engine } -func (p *AIProvider) GetTemperature() float32 { - return p.Temperature -} - -func (p *AIProvider) GetProviderRegion() string { - return p.ProviderRegion -} func (p *AIProvider) GetProviderId() string { return p.ProviderId