feat: support openai organization Id (#1133)

* feat: add organization flag

Signed-off-by: JuHyung-Son <sonju0427@gmail.com>

* feat: add orgId on openai backend

Signed-off-by: JuHyung-Son <sonju0427@gmail.com>

---------

Signed-off-by: JuHyung-Son <sonju0427@gmail.com>
Co-authored-by: Alex Jones <alexsimonjones@gmail.com>
This commit is contained in:
JuHyung Son
2024-06-14 16:39:56 +09:00
committed by GitHub
parent c834c09996
commit 4867d39c66
6 changed files with 72 additions and 44 deletions

View File

@@ -131,6 +131,7 @@ var addCmd = &cobra.Command{
TopP: topP, TopP: topP,
TopK: topK, TopK: topK,
MaxTokens: maxTokens, MaxTokens: maxTokens,
OrganizationId: organizationId,
} }
if providerIndex == -1 { if providerIndex == -1 {
@@ -176,4 +177,6 @@ func init() {
addCmd.Flags().StringVarP(&providerId, "providerId", "i", "", "Provider specific ID for e.g. project (only for googlevertexai backend)") addCmd.Flags().StringVarP(&providerId, "providerId", "i", "", "Provider specific ID for e.g. project (only for googlevertexai backend)")
//add flag for OCI Compartment ID //add flag for OCI Compartment ID
addCmd.Flags().StringVarP(&compartmentId, "compartmentId", "k", "", "Compartment ID for generative AI model (only for oci backend)") addCmd.Flags().StringVarP(&compartmentId, "compartmentId", "k", "", "Compartment ID for generative AI model (only for oci backend)")
// add flag for openai organization
addCmd.Flags().StringVarP(&organizationId, "organizationId", "o", "", "OpenAI or AzureOpenAI Organization ID (only for openai and azureopenai backend)")
} }

View File

@@ -32,6 +32,7 @@ var (
topP float32 topP float32
topK int32 topK int32
maxTokens int maxTokens int
organizationId string
) )
var configAI ai.AIConfiguration var configAI ai.AIConfiguration

View File

@@ -26,13 +26,20 @@ var updateCmd = &cobra.Command{
Use: "update", Use: "update",
Short: "Update a backend provider", Short: "Update a backend provider",
Long: "The command to update an AI backend provider", Long: "The command to update an AI backend provider",
Args: cobra.ExactArgs(1), // Args: cobra.ExactArgs(1),
PreRun: func(cmd *cobra.Command, args []string) { PreRun: func(cmd *cobra.Command, args []string) {
backend, _ := cmd.Flags().GetString("backend") backend, _ := cmd.Flags().GetString("backend")
if strings.ToLower(backend) == "azureopenai" { if strings.ToLower(backend) == "azureopenai" {
_ = cmd.MarkFlagRequired("engine") _ = cmd.MarkFlagRequired("engine")
_ = cmd.MarkFlagRequired("baseurl") _ = cmd.MarkFlagRequired("baseurl")
} }
organizationId, _ := cmd.Flags().GetString("organizationId")
if strings.ToLower(backend) != "azureopenai" && strings.ToLower(backend) != "openai" {
if organizationId != "" {
color.Red("Error: organizationId must be empty for backends other than azureopenai or openai.")
os.Exit(1)
}
}
}, },
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
@@ -43,50 +50,47 @@ var updateCmd = &cobra.Command{
os.Exit(1) os.Exit(1)
} }
inputBackends := strings.Split(args[0], ",") backend, _ := cmd.Flags().GetString("backend")
if len(inputBackends) == 0 {
color.Red("Error: backend must be set.")
os.Exit(1)
}
if temperature > 1.0 || temperature < 0.0 { if temperature > 1.0 || temperature < 0.0 {
color.Red("Error: temperature ranges from 0 to 1.") color.Red("Error: temperature ranges from 0 to 1.")
os.Exit(1) os.Exit(1)
} }
for _, b := range inputBackends { foundBackend := false
foundBackend := false for i, provider := range configAI.Providers {
for i, provider := range configAI.Providers { if backend == provider.Name {
if b == provider.Name { foundBackend = true
foundBackend = true if backend != "" {
if backend != "" { configAI.Providers[i].Name = backend
configAI.Providers[i].Name = backend color.Blue("Backend name updated successfully")
color.Blue("Backend name updated successfully")
}
if model != "" {
configAI.Providers[i].Model = model
color.Blue("Model updated successfully")
}
if password != "" {
configAI.Providers[i].Password = password
color.Blue("Password updated successfully")
}
if baseURL != "" {
configAI.Providers[i].BaseURL = baseURL
color.Blue("Base URL updated successfully")
}
if engine != "" {
configAI.Providers[i].Engine = engine
}
configAI.Providers[i].Temperature = temperature
color.Green("%s updated in the AI backend provider list", b)
} }
if model != "" {
configAI.Providers[i].Model = model
color.Blue("Model updated successfully")
}
if password != "" {
configAI.Providers[i].Password = password
color.Blue("Password updated successfully")
}
if baseURL != "" {
configAI.Providers[i].BaseURL = baseURL
color.Blue("Base URL updated successfully")
}
if engine != "" {
configAI.Providers[i].Engine = engine
}
if organizationId != "" {
configAI.Providers[i].OrganizationId = organizationId
color.Blue("Organization Id updated successfully")
}
configAI.Providers[i].Temperature = temperature
color.Green("%s updated in the AI backend provider list", backend)
} }
if !foundBackend { }
color.Red("Error: %s does not exist in configuration file. Please use k8sgpt auth new.", args[0]) if !foundBackend {
os.Exit(1) color.Red("Error: %s does not exist in configuration file. Please use k8sgpt auth new.", args[0])
} os.Exit(1)
} }
viper.Set("ai", configAI) viper.Set("ai", configAI)
@@ -110,4 +114,6 @@ func init() {
updateCmd.Flags().Float32VarP(&temperature, "temperature", "t", 0.7, "The sampling temperature, value ranges between 0 ( output be more deterministic) and 1 (more random)") updateCmd.Flags().Float32VarP(&temperature, "temperature", "t", 0.7, "The sampling temperature, value ranges between 0 ( output be more deterministic) and 1 (more random)")
// update flag for azure open ai engine/deployment name // update flag for azure open ai engine/deployment name
updateCmd.Flags().StringVarP(&engine, "engine", "e", "", "Update Azure AI deployment name") updateCmd.Flags().StringVarP(&engine, "engine", "e", "", "Update Azure AI deployment name")
// update flag for organizationId
updateCmd.Flags().StringVarP(&organizationId, "organizationId", "o", "", "Update OpenAI or Azure organization Id")
} }

View File

@@ -14,9 +14,10 @@ const azureAIClientName = "azureopenai"
type AzureAIClient struct { type AzureAIClient struct {
nopCloser nopCloser
client *openai.Client client *openai.Client
model string model string
temperature float32 temperature float32
organizationId string
} }
func (c *AzureAIClient) Configure(config IAIConfig) error { func (c *AzureAIClient) Configure(config IAIConfig) error {
@@ -25,6 +26,7 @@ func (c *AzureAIClient) Configure(config IAIConfig) error {
engine := config.GetEngine() engine := config.GetEngine()
proxyEndpoint := config.GetProxyEndpoint() proxyEndpoint := config.GetProxyEndpoint()
defaultConfig := openai.DefaultAzureConfig(token, baseURL) defaultConfig := openai.DefaultAzureConfig(token, baseURL)
orgId := config.GetOrganizationId()
defaultConfig.AzureModelMapperFunc = func(model string) string { defaultConfig.AzureModelMapperFunc = func(model string) string {
// If you use a deployment name different from the model name, you can customize the AzureModelMapperFunc function // If you use a deployment name different from the model name, you can customize the AzureModelMapperFunc function
@@ -48,6 +50,10 @@ func (c *AzureAIClient) Configure(config IAIConfig) error {
Transport: transport, Transport: transport,
} }
} }
if orgId != "" {
defaultConfig.OrgID = orgId
}
client := openai.NewClientWithConfig(defaultConfig) client := openai.NewClientWithConfig(defaultConfig)
if client == nil { if client == nil {
return errors.New("error creating Azure OpenAI client") return errors.New("error creating Azure OpenAI client")

View File

@@ -78,6 +78,7 @@ type IAIConfig interface {
GetMaxTokens() int GetMaxTokens() int
GetProviderId() string GetProviderId() string
GetCompartmentId() string GetCompartmentId() string
GetOrganizationId() string
} }
func NewClient(provider string) IAI { func NewClient(provider string) IAI {
@@ -111,6 +112,7 @@ type AIProvider struct {
TopP float32 `mapstructure:"topp" yaml:"topp,omitempty"` TopP float32 `mapstructure:"topp" yaml:"topp,omitempty"`
TopK int32 `mapstructure:"topk" yaml:"topk,omitempty"` TopK int32 `mapstructure:"topk" yaml:"topk,omitempty"`
MaxTokens int `mapstructure:"maxtokens" yaml:"maxtokens,omitempty"` MaxTokens int `mapstructure:"maxtokens" yaml:"maxtokens,omitempty"`
OrganizationId string `mapstructure:"organizationid" yaml:"organizationid,omitempty"`
} }
func (p *AIProvider) GetBaseURL() string { func (p *AIProvider) GetBaseURL() string {
@@ -164,6 +166,10 @@ func (p *AIProvider) GetCompartmentId() string {
return p.CompartmentId return p.CompartmentId
} }
func (p *AIProvider) GetOrganizationId() string {
return p.OrganizationId
}
var passwordlessProviders = []string{"localai", "amazonsagemaker", "amazonbedrock", "googlevertexai", "oci"} var passwordlessProviders = []string{"localai", "amazonsagemaker", "amazonbedrock", "googlevertexai", "oci"}
func NeedPassword(backend string) bool { func NeedPassword(backend string) bool {

View File

@@ -27,10 +27,11 @@ const openAIClientName = "openai"
type OpenAIClient struct { type OpenAIClient struct {
nopCloser nopCloser
client *openai.Client client *openai.Client
model string model string
temperature float32 temperature float32
topP float32 topP float32
organizationId string
} }
const ( const (
@@ -43,6 +44,7 @@ const (
func (c *OpenAIClient) Configure(config IAIConfig) error { func (c *OpenAIClient) Configure(config IAIConfig) error {
token := config.GetPassword() token := config.GetPassword()
defaultConfig := openai.DefaultConfig(token) defaultConfig := openai.DefaultConfig(token)
orgId := config.GetOrganizationId()
proxyEndpoint := config.GetProxyEndpoint() proxyEndpoint := config.GetProxyEndpoint()
baseURL := config.GetBaseURL() baseURL := config.GetBaseURL()
@@ -64,6 +66,10 @@ func (c *OpenAIClient) Configure(config IAIConfig) error {
} }
} }
if orgId != "" {
defaultConfig.OrgID = orgId
}
client := openai.NewClientWithConfig(defaultConfig) client := openai.NewClientWithConfig(defaultConfig)
if client == nil { if client == nil {
return errors.New("error creating OpenAI client") return errors.New("error creating OpenAI client")