mirror of
https://github.com/k8sgpt-ai/k8sgpt.git
synced 2025-09-12 21:34:47 +00:00
feat: support openai organization Id (#1133)
* feat: add organization flag Signed-off-by: JuHyung-Son <sonju0427@gmail.com> * feat: add orgId on openai backend Signed-off-by: JuHyung-Son <sonju0427@gmail.com> --------- Signed-off-by: JuHyung-Son <sonju0427@gmail.com> Co-authored-by: Alex Jones <alexsimonjones@gmail.com>
This commit is contained in:
@@ -131,6 +131,7 @@ var addCmd = &cobra.Command{
|
||||
TopP: topP,
|
||||
TopK: topK,
|
||||
MaxTokens: maxTokens,
|
||||
OrganizationId: organizationId,
|
||||
}
|
||||
|
||||
if providerIndex == -1 {
|
||||
@@ -176,4 +177,6 @@ func init() {
|
||||
addCmd.Flags().StringVarP(&providerId, "providerId", "i", "", "Provider specific ID for e.g. project (only for googlevertexai backend)")
|
||||
//add flag for OCI Compartment ID
|
||||
addCmd.Flags().StringVarP(&compartmentId, "compartmentId", "k", "", "Compartment ID for generative AI model (only for oci backend)")
|
||||
// add flag for openai organization
|
||||
addCmd.Flags().StringVarP(&organizationId, "organizationId", "o", "", "OpenAI or AzureOpenAI Organization ID (only for openai and azureopenai backend)")
|
||||
}
|
||||
|
@@ -32,6 +32,7 @@ var (
|
||||
topP float32
|
||||
topK int32
|
||||
maxTokens int
|
||||
organizationId string
|
||||
)
|
||||
|
||||
var configAI ai.AIConfiguration
|
||||
|
@@ -26,13 +26,20 @@ var updateCmd = &cobra.Command{
|
||||
Use: "update",
|
||||
Short: "Update a backend provider",
|
||||
Long: "The command to update an AI backend provider",
|
||||
Args: cobra.ExactArgs(1),
|
||||
// Args: cobra.ExactArgs(1),
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
backend, _ := cmd.Flags().GetString("backend")
|
||||
if strings.ToLower(backend) == "azureopenai" {
|
||||
_ = cmd.MarkFlagRequired("engine")
|
||||
_ = cmd.MarkFlagRequired("baseurl")
|
||||
}
|
||||
organizationId, _ := cmd.Flags().GetString("organizationId")
|
||||
if strings.ToLower(backend) != "azureopenai" && strings.ToLower(backend) != "openai" {
|
||||
if organizationId != "" {
|
||||
color.Red("Error: organizationId must be empty for backends other than azureopenai or openai.")
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
|
||||
@@ -43,21 +50,16 @@ var updateCmd = &cobra.Command{
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
inputBackends := strings.Split(args[0], ",")
|
||||
backend, _ := cmd.Flags().GetString("backend")
|
||||
|
||||
if len(inputBackends) == 0 {
|
||||
color.Red("Error: backend must be set.")
|
||||
os.Exit(1)
|
||||
}
|
||||
if temperature > 1.0 || temperature < 0.0 {
|
||||
color.Red("Error: temperature ranges from 0 to 1.")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
for _, b := range inputBackends {
|
||||
foundBackend := false
|
||||
for i, provider := range configAI.Providers {
|
||||
if b == provider.Name {
|
||||
if backend == provider.Name {
|
||||
foundBackend = true
|
||||
if backend != "" {
|
||||
configAI.Providers[i].Name = backend
|
||||
@@ -78,8 +80,12 @@ var updateCmd = &cobra.Command{
|
||||
if engine != "" {
|
||||
configAI.Providers[i].Engine = engine
|
||||
}
|
||||
if organizationId != "" {
|
||||
configAI.Providers[i].OrganizationId = organizationId
|
||||
color.Blue("Organization Id updated successfully")
|
||||
}
|
||||
configAI.Providers[i].Temperature = temperature
|
||||
color.Green("%s updated in the AI backend provider list", b)
|
||||
color.Green("%s updated in the AI backend provider list", backend)
|
||||
}
|
||||
}
|
||||
if !foundBackend {
|
||||
@@ -87,8 +93,6 @@ var updateCmd = &cobra.Command{
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
viper.Set("ai", configAI)
|
||||
if err := viper.WriteConfig(); err != nil {
|
||||
color.Red("Error writing config file: %s", err.Error())
|
||||
@@ -110,4 +114,6 @@ func init() {
|
||||
updateCmd.Flags().Float32VarP(&temperature, "temperature", "t", 0.7, "The sampling temperature, value ranges between 0 ( output be more deterministic) and 1 (more random)")
|
||||
// update flag for azure open ai engine/deployment name
|
||||
updateCmd.Flags().StringVarP(&engine, "engine", "e", "", "Update Azure AI deployment name")
|
||||
// update flag for organizationId
|
||||
updateCmd.Flags().StringVarP(&organizationId, "organizationId", "o", "", "Update OpenAI or Azure organization Id")
|
||||
}
|
||||
|
@@ -17,6 +17,7 @@ type AzureAIClient struct {
|
||||
client *openai.Client
|
||||
model string
|
||||
temperature float32
|
||||
organizationId string
|
||||
}
|
||||
|
||||
func (c *AzureAIClient) Configure(config IAIConfig) error {
|
||||
@@ -25,6 +26,7 @@ func (c *AzureAIClient) Configure(config IAIConfig) error {
|
||||
engine := config.GetEngine()
|
||||
proxyEndpoint := config.GetProxyEndpoint()
|
||||
defaultConfig := openai.DefaultAzureConfig(token, baseURL)
|
||||
orgId := config.GetOrganizationId()
|
||||
|
||||
defaultConfig.AzureModelMapperFunc = func(model string) string {
|
||||
// If you use a deployment name different from the model name, you can customize the AzureModelMapperFunc function
|
||||
@@ -48,6 +50,10 @@ func (c *AzureAIClient) Configure(config IAIConfig) error {
|
||||
Transport: transport,
|
||||
}
|
||||
}
|
||||
if orgId != "" {
|
||||
defaultConfig.OrgID = orgId
|
||||
}
|
||||
|
||||
client := openai.NewClientWithConfig(defaultConfig)
|
||||
if client == nil {
|
||||
return errors.New("error creating Azure OpenAI client")
|
||||
|
@@ -78,6 +78,7 @@ type IAIConfig interface {
|
||||
GetMaxTokens() int
|
||||
GetProviderId() string
|
||||
GetCompartmentId() string
|
||||
GetOrganizationId() string
|
||||
}
|
||||
|
||||
func NewClient(provider string) IAI {
|
||||
@@ -111,6 +112,7 @@ type AIProvider struct {
|
||||
TopP float32 `mapstructure:"topp" yaml:"topp,omitempty"`
|
||||
TopK int32 `mapstructure:"topk" yaml:"topk,omitempty"`
|
||||
MaxTokens int `mapstructure:"maxtokens" yaml:"maxtokens,omitempty"`
|
||||
OrganizationId string `mapstructure:"organizationid" yaml:"organizationid,omitempty"`
|
||||
}
|
||||
|
||||
func (p *AIProvider) GetBaseURL() string {
|
||||
@@ -164,6 +166,10 @@ func (p *AIProvider) GetCompartmentId() string {
|
||||
return p.CompartmentId
|
||||
}
|
||||
|
||||
func (p *AIProvider) GetOrganizationId() string {
|
||||
return p.OrganizationId
|
||||
}
|
||||
|
||||
var passwordlessProviders = []string{"localai", "amazonsagemaker", "amazonbedrock", "googlevertexai", "oci"}
|
||||
|
||||
func NeedPassword(backend string) bool {
|
||||
|
@@ -31,6 +31,7 @@ type OpenAIClient struct {
|
||||
model string
|
||||
temperature float32
|
||||
topP float32
|
||||
organizationId string
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -43,6 +44,7 @@ const (
|
||||
func (c *OpenAIClient) Configure(config IAIConfig) error {
|
||||
token := config.GetPassword()
|
||||
defaultConfig := openai.DefaultConfig(token)
|
||||
orgId := config.GetOrganizationId()
|
||||
proxyEndpoint := config.GetProxyEndpoint()
|
||||
|
||||
baseURL := config.GetBaseURL()
|
||||
@@ -64,6 +66,10 @@ func (c *OpenAIClient) Configure(config IAIConfig) error {
|
||||
}
|
||||
}
|
||||
|
||||
if orgId != "" {
|
||||
defaultConfig.OrgID = orgId
|
||||
}
|
||||
|
||||
client := openai.NewClientWithConfig(defaultConfig)
|
||||
if client == nil {
|
||||
return errors.New("error creating OpenAI client")
|
||||
|
Reference in New Issue
Block a user