mirror of
https://github.com/k8sgpt-ai/k8sgpt.git
synced 2026-01-20 08:53:45 +00:00
feat: support many:1 auth:provider mapping
This commit is contained in:
@@ -29,6 +29,7 @@ type AmazonBedRockClient struct {
|
||||
topP float32
|
||||
maxTokens int
|
||||
models []bedrock_support.BedrockModel
|
||||
configName string // Added to support multiple configurations
|
||||
}
|
||||
|
||||
// AmazonCompletion BedRock support region list US East (N. Virginia),US West (Oregon),Asia Pacific (Singapore),Asia Pacific (Tokyo),Europe (Frankfurt)
|
||||
@@ -353,10 +354,10 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error {
|
||||
|
||||
// Get the model input
|
||||
modelInput := config.GetModel()
|
||||
|
||||
|
||||
// Determine the appropriate region to use
|
||||
var region string
|
||||
|
||||
|
||||
// Check if the model input is actually an inference profile ARN
|
||||
if validateInferenceProfileArn(modelInput) {
|
||||
// Extract the region from the inference profile ARN
|
||||
@@ -370,11 +371,11 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error {
|
||||
// Use the provided region or default
|
||||
region = GetRegionOrDefault(config.GetProviderRegion())
|
||||
}
|
||||
|
||||
|
||||
// Only create AWS clients if they haven't been injected (for testing)
|
||||
if a.client == nil || a.mgmtClient == nil {
|
||||
// Create a new AWS config with the determined region
|
||||
cfg, err := awsconfig.LoadDefaultConfig(context.Background(),
|
||||
cfg, err := awsconfig.LoadDefaultConfig(context.Background(),
|
||||
awsconfig.WithRegion(region),
|
||||
)
|
||||
if err != nil {
|
||||
@@ -385,7 +386,7 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error {
|
||||
a.client = bedrockruntime.NewFromConfig(cfg)
|
||||
a.mgmtClient = bedrock.NewFromConfig(cfg)
|
||||
}
|
||||
|
||||
|
||||
// Handle model selection based on input type
|
||||
if validateInferenceProfileArn(modelInput) {
|
||||
// Get the inference profile details
|
||||
@@ -399,7 +400,7 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to extract model ID from inference profile: %v", err)
|
||||
}
|
||||
|
||||
|
||||
// Find the model configuration for the extracted model ID
|
||||
foundModel, err := a.getModelFromString(modelID)
|
||||
if err != nil {
|
||||
@@ -407,7 +408,7 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error {
|
||||
return fmt.Errorf("failed to find model configuration for %s: %v", modelID, err)
|
||||
}
|
||||
a.model = foundModel
|
||||
|
||||
|
||||
// Use the inference profile ARN as the model ID for API calls
|
||||
a.model.Config.ModelName = modelInput
|
||||
}
|
||||
@@ -420,11 +421,12 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error {
|
||||
a.model = foundModel
|
||||
a.model.Config.ModelName = foundModel.Config.ModelName
|
||||
}
|
||||
|
||||
|
||||
// Set common configuration parameters
|
||||
a.temperature = config.GetTemperature()
|
||||
a.topP = config.GetTopP()
|
||||
a.maxTokens = config.GetMaxTokens()
|
||||
a.configName = config.GetConfigName() // Store the config name
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -438,20 +440,20 @@ func (a *AmazonBedRockClient) getInferenceProfile(ctx context.Context, inference
|
||||
if len(parts) != 2 {
|
||||
return nil, fmt.Errorf("invalid inference profile ARN format: %s", inferenceProfileARN)
|
||||
}
|
||||
|
||||
|
||||
profileID := parts[1]
|
||||
|
||||
|
||||
// Create the input for the GetInferenceProfile API call
|
||||
input := &bedrock.GetInferenceProfileInput{
|
||||
InferenceProfileIdentifier: aws.String(profileID),
|
||||
}
|
||||
|
||||
|
||||
// Call the GetInferenceProfile API
|
||||
output, err := a.mgmtClient.GetInferenceProfile(ctx, input)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get inference profile: %w", err)
|
||||
}
|
||||
|
||||
|
||||
return output, nil
|
||||
}
|
||||
|
||||
@@ -460,25 +462,25 @@ func (a *AmazonBedRockClient) extractModelFromInferenceProfile(profile *bedrock.
|
||||
if profile == nil || len(profile.Models) == 0 {
|
||||
return "", fmt.Errorf("inference profile does not contain any models")
|
||||
}
|
||||
|
||||
|
||||
// Check if the first model has a non-nil ModelArn
|
||||
if profile.Models[0].ModelArn == nil {
|
||||
return "", fmt.Errorf("model information is missing in inference profile")
|
||||
}
|
||||
|
||||
|
||||
// Get the first model ARN from the profile
|
||||
modelARN := aws.ToString(profile.Models[0].ModelArn)
|
||||
if modelARN == "" {
|
||||
return "", fmt.Errorf("model ARN is empty in inference profile")
|
||||
}
|
||||
|
||||
|
||||
// Extract the model ID from the ARN
|
||||
// ARN format: arn:aws:bedrock:region::foundation-model/model-id
|
||||
parts := strings.Split(modelARN, "/")
|
||||
if len(parts) != 2 {
|
||||
return "", fmt.Errorf("invalid model ARN format: %s", modelARN)
|
||||
}
|
||||
|
||||
|
||||
modelID := parts[1]
|
||||
return modelID, nil
|
||||
}
|
||||
@@ -494,7 +496,7 @@ func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
|
||||
// Build the parameters for the model invocation
|
||||
params := &bedrockruntime.InvokeModelInput{
|
||||
Body: body,
|
||||
@@ -502,7 +504,7 @@ func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string)
|
||||
ContentType: aws.String("application/json"),
|
||||
Accept: aws.String("application/json"),
|
||||
}
|
||||
|
||||
|
||||
// Invoke the model
|
||||
resp, err := a.client.InvokeModel(ctx, params)
|
||||
if err != nil {
|
||||
|
||||
138
pkg/ai/iai.go
138
pkg/ai/iai.go
@@ -71,22 +71,14 @@ type nopCloser struct{}
|
||||
|
||||
func (nopCloser) Close() {}
|
||||
|
||||
// IAIConfig represents the configuration for an AI provider
|
||||
type IAIConfig interface {
|
||||
GetPassword() string
|
||||
GetModel() string
|
||||
GetBaseURL() string
|
||||
GetProxyEndpoint() string
|
||||
GetEndpointName() string
|
||||
GetEngine() string
|
||||
GetTemperature() float32
|
||||
GetProviderRegion() string
|
||||
GetTemperature() float32
|
||||
GetTopP() float32
|
||||
GetTopK() int32
|
||||
GetMaxTokens() int
|
||||
GetProviderId() string
|
||||
GetCompartmentId() string
|
||||
GetOrganizationId() string
|
||||
GetCustomHeaders() []http.Header
|
||||
GetConfigName() string // Added to support multiple configurations
|
||||
}
|
||||
|
||||
func NewClient(provider string) IAI {
|
||||
@@ -104,24 +96,95 @@ type AIConfiguration struct {
|
||||
DefaultProvider string `mapstructure:"defaultprovider"`
|
||||
}
|
||||
|
||||
// AIProvider represents a provider configuration
|
||||
type AIProvider struct {
|
||||
Name string `mapstructure:"name"`
|
||||
Model string `mapstructure:"model"`
|
||||
Password string `mapstructure:"password" yaml:"password,omitempty"`
|
||||
BaseURL string `mapstructure:"baseurl" yaml:"baseurl,omitempty"`
|
||||
ProxyEndpoint string `mapstructure:"proxyEndpoint" yaml:"proxyEndpoint,omitempty"`
|
||||
ProxyPort string `mapstructure:"proxyPort" yaml:"proxyPort,omitempty"`
|
||||
EndpointName string `mapstructure:"endpointname" yaml:"endpointname,omitempty"`
|
||||
Engine string `mapstructure:"engine" yaml:"engine,omitempty"`
|
||||
Temperature float32 `mapstructure:"temperature" yaml:"temperature,omitempty"`
|
||||
ProviderRegion string `mapstructure:"providerregion" yaml:"providerregion,omitempty"`
|
||||
ProviderId string `mapstructure:"providerid" yaml:"providerid,omitempty"`
|
||||
CompartmentId string `mapstructure:"compartmentid" yaml:"compartmentid,omitempty"`
|
||||
TopP float32 `mapstructure:"topp" yaml:"topp,omitempty"`
|
||||
TopK int32 `mapstructure:"topk" yaml:"topk,omitempty"`
|
||||
MaxTokens int `mapstructure:"maxtokens" yaml:"maxtokens,omitempty"`
|
||||
OrganizationId string `mapstructure:"organizationid" yaml:"organizationid,omitempty"`
|
||||
CustomHeaders []http.Header `mapstructure:"customHeaders"`
|
||||
Name string `mapstructure:"name" json:"name"`
|
||||
Model string `mapstructure:"model" json:"model,omitempty"`
|
||||
Password string `mapstructure:"password" yaml:"password,omitempty" json:"password,omitempty"`
|
||||
BaseURL string `mapstructure:"baseurl" yaml:"baseurl,omitempty" json:"baseurl,omitempty"`
|
||||
ProxyEndpoint string `mapstructure:"proxyEndpoint" yaml:"proxyEndpoint,omitempty" json:"proxyEndpoint,omitempty"`
|
||||
ProxyPort string `mapstructure:"proxyPort" yaml:"proxyPort,omitempty" json:"proxyPort,omitempty"`
|
||||
EndpointName string `mapstructure:"endpointname" yaml:"endpointname,omitempty" json:"endpointname,omitempty"`
|
||||
Engine string `mapstructure:"engine" yaml:"engine,omitempty" json:"engine,omitempty"`
|
||||
Temperature float32 `mapstructure:"temperature" yaml:"temperature,omitempty" json:"temperature,omitempty"`
|
||||
ProviderRegion string `mapstructure:"providerregion" yaml:"providerregion,omitempty" json:"providerregion,omitempty"`
|
||||
ProviderId string `mapstructure:"providerid" yaml:"providerid,omitempty" json:"providerid,omitempty"`
|
||||
CompartmentId string `mapstructure:"compartmentid" yaml:"compartmentid,omitempty" json:"compartmentid,omitempty"`
|
||||
TopP float32 `mapstructure:"topp" yaml:"topp,omitempty" json:"topp,omitempty"`
|
||||
TopK int32 `mapstructure:"topk" yaml:"topk,omitempty" json:"topk,omitempty"`
|
||||
MaxTokens int `mapstructure:"maxtokens" yaml:"maxtokens,omitempty" json:"maxtokens,omitempty"`
|
||||
OrganizationId string `mapstructure:"organizationid" yaml:"organizationid,omitempty" json:"organizationid,omitempty"`
|
||||
CustomHeaders []http.Header `mapstructure:"customHeaders" json:"customHeaders,omitempty"`
|
||||
Configs []AIProviderConfig `mapstructure:"configs" json:"configs,omitempty"`
|
||||
DefaultConfig int `mapstructure:"defaultConfig" json:"defaultConfig,omitempty"`
|
||||
}
|
||||
|
||||
// AIProviderConfig represents a single configuration for a provider
|
||||
type AIProviderConfig struct {
|
||||
Model string `mapstructure:"model" json:"model"`
|
||||
ProviderRegion string `mapstructure:"providerRegion" json:"providerRegion"`
|
||||
Temperature float32 `mapstructure:"temperature" json:"temperature"`
|
||||
TopP float32 `mapstructure:"topP" json:"topP"`
|
||||
MaxTokens int `mapstructure:"maxTokens" json:"maxTokens"`
|
||||
ConfigName string `mapstructure:"configName" json:"configName"`
|
||||
Password string `mapstructure:"password" yaml:"password,omitempty" json:"password,omitempty"`
|
||||
BaseURL string `mapstructure:"baseurl" yaml:"baseurl,omitempty" json:"baseurl,omitempty"`
|
||||
ProxyEndpoint string `mapstructure:"proxyEndpoint" yaml:"proxyEndpoint,omitempty" json:"proxyEndpoint,omitempty"`
|
||||
EndpointName string `mapstructure:"endpointname" yaml:"endpointname,omitempty" json:"endpointname,omitempty"`
|
||||
Engine string `mapstructure:"engine" yaml:"engine,omitempty" json:"engine,omitempty"`
|
||||
ProviderId string `mapstructure:"providerid" yaml:"providerid,omitempty" json:"providerid,omitempty"`
|
||||
CompartmentId string `mapstructure:"compartmentid" yaml:"compartmentid,omitempty" json:"compartmentid,omitempty"`
|
||||
TopK int32 `mapstructure:"topk" yaml:"topk,omitempty" json:"topk,omitempty"`
|
||||
OrganizationId string `mapstructure:"organizationid" yaml:"organizationid,omitempty" json:"organizationid,omitempty"`
|
||||
CustomHeaders []http.Header `mapstructure:"customHeaders" json:"customHeaders,omitempty"`
|
||||
}
|
||||
|
||||
// GetConfigName returns the configuration name
|
||||
func (p *AIProvider) GetConfigName() string {
|
||||
if len(p.Configs) > 0 && p.DefaultConfig >= 0 && p.DefaultConfig < len(p.Configs) {
|
||||
return p.Configs[p.DefaultConfig].ConfigName
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetModel returns the model name
|
||||
func (p *AIProvider) GetModel() string {
|
||||
if len(p.Configs) > 0 && p.DefaultConfig >= 0 && p.DefaultConfig < len(p.Configs) {
|
||||
return p.Configs[p.DefaultConfig].Model
|
||||
}
|
||||
return p.Model
|
||||
}
|
||||
|
||||
// GetProviderRegion returns the provider region
|
||||
func (p *AIProvider) GetProviderRegion() string {
|
||||
if len(p.Configs) > 0 && p.DefaultConfig >= 0 && p.DefaultConfig < len(p.Configs) {
|
||||
return p.Configs[p.DefaultConfig].ProviderRegion
|
||||
}
|
||||
return p.ProviderRegion
|
||||
}
|
||||
|
||||
// GetTemperature returns the temperature
|
||||
func (p *AIProvider) GetTemperature() float32 {
|
||||
if len(p.Configs) > 0 && p.DefaultConfig >= 0 && p.DefaultConfig < len(p.Configs) {
|
||||
return p.Configs[p.DefaultConfig].Temperature
|
||||
}
|
||||
return p.Temperature
|
||||
}
|
||||
|
||||
// GetTopP returns the top P value
|
||||
func (p *AIProvider) GetTopP() float32 {
|
||||
if len(p.Configs) > 0 && p.DefaultConfig >= 0 && p.DefaultConfig < len(p.Configs) {
|
||||
return p.Configs[p.DefaultConfig].TopP
|
||||
}
|
||||
return p.TopP
|
||||
}
|
||||
|
||||
// GetMaxTokens returns the maximum number of tokens
|
||||
func (p *AIProvider) GetMaxTokens() int {
|
||||
if len(p.Configs) > 0 && p.DefaultConfig >= 0 && p.DefaultConfig < len(p.Configs) {
|
||||
return p.Configs[p.DefaultConfig].MaxTokens
|
||||
}
|
||||
return p.MaxTokens
|
||||
}
|
||||
|
||||
func (p *AIProvider) GetBaseURL() string {
|
||||
@@ -136,36 +199,17 @@ func (p *AIProvider) GetEndpointName() string {
|
||||
return p.EndpointName
|
||||
}
|
||||
|
||||
func (p *AIProvider) GetTopP() float32 {
|
||||
return p.TopP
|
||||
}
|
||||
|
||||
func (p *AIProvider) GetTopK() int32 {
|
||||
return p.TopK
|
||||
}
|
||||
|
||||
func (p *AIProvider) GetMaxTokens() int {
|
||||
return p.MaxTokens
|
||||
}
|
||||
|
||||
func (p *AIProvider) GetPassword() string {
|
||||
return p.Password
|
||||
}
|
||||
|
||||
func (p *AIProvider) GetModel() string {
|
||||
return p.Model
|
||||
}
|
||||
|
||||
func (p *AIProvider) GetEngine() string {
|
||||
return p.Engine
|
||||
}
|
||||
func (p *AIProvider) GetTemperature() float32 {
|
||||
return p.Temperature
|
||||
}
|
||||
|
||||
func (p *AIProvider) GetProviderRegion() string {
|
||||
return p.ProviderRegion
|
||||
}
|
||||
|
||||
func (p *AIProvider) GetProviderId() string {
|
||||
return p.ProviderId
|
||||
|
||||
Reference in New Issue
Block a user