feat: support openai organization Id (#1133)

* feat: add organization flag

Signed-off-by: JuHyung-Son <sonju0427@gmail.com>

* feat: add orgId on openai backend

Signed-off-by: JuHyung-Son <sonju0427@gmail.com>

---------

Signed-off-by: JuHyung-Son <sonju0427@gmail.com>
Co-authored-by: Alex Jones <alexsimonjones@gmail.com>
This commit is contained in:
JuHyung Son
2024-06-14 16:39:56 +09:00
committed by GitHub
parent c834c09996
commit 4867d39c66
6 changed files with 72 additions and 44 deletions

View File

@@ -131,6 +131,7 @@ var addCmd = &cobra.Command{
TopP: topP,
TopK: topK,
MaxTokens: maxTokens,
OrganizationId: organizationId,
}
if providerIndex == -1 {
@@ -176,4 +177,6 @@ func init() {
addCmd.Flags().StringVarP(&providerId, "providerId", "i", "", "Provider specific ID for e.g. project (only for googlevertexai backend)")
//add flag for OCI Compartment ID
addCmd.Flags().StringVarP(&compartmentId, "compartmentId", "k", "", "Compartment ID for generative AI model (only for oci backend)")
// add flag for openai organization
addCmd.Flags().StringVarP(&organizationId, "organizationId", "o", "", "OpenAI or AzureOpenAI Organization ID (only for openai and azureopenai backend)")
}

View File

@@ -32,6 +32,7 @@ var (
topP float32
topK int32
maxTokens int
organizationId string
)
var configAI ai.AIConfiguration

View File

@@ -26,13 +26,20 @@ var updateCmd = &cobra.Command{
Use: "update",
Short: "Update a backend provider",
Long: "The command to update an AI backend provider",
Args: cobra.ExactArgs(1),
// Args: cobra.ExactArgs(1),
PreRun: func(cmd *cobra.Command, args []string) {
backend, _ := cmd.Flags().GetString("backend")
if strings.ToLower(backend) == "azureopenai" {
_ = cmd.MarkFlagRequired("engine")
_ = cmd.MarkFlagRequired("baseurl")
}
organizationId, _ := cmd.Flags().GetString("organizationId")
if strings.ToLower(backend) != "azureopenai" && strings.ToLower(backend) != "openai" {
if organizationId != "" {
color.Red("Error: organizationId must be empty for backends other than azureopenai or openai.")
os.Exit(1)
}
}
},
Run: func(cmd *cobra.Command, args []string) {
@@ -43,21 +50,16 @@ var updateCmd = &cobra.Command{
os.Exit(1)
}
inputBackends := strings.Split(args[0], ",")
backend, _ := cmd.Flags().GetString("backend")
if len(inputBackends) == 0 {
color.Red("Error: backend must be set.")
os.Exit(1)
}
if temperature > 1.0 || temperature < 0.0 {
color.Red("Error: temperature ranges from 0 to 1.")
os.Exit(1)
}
for _, b := range inputBackends {
foundBackend := false
for i, provider := range configAI.Providers {
if b == provider.Name {
if backend == provider.Name {
foundBackend = true
if backend != "" {
configAI.Providers[i].Name = backend
@@ -78,8 +80,12 @@ var updateCmd = &cobra.Command{
if engine != "" {
configAI.Providers[i].Engine = engine
}
if organizationId != "" {
configAI.Providers[i].OrganizationId = organizationId
color.Blue("Organization Id updated successfully")
}
configAI.Providers[i].Temperature = temperature
color.Green("%s updated in the AI backend provider list", b)
color.Green("%s updated in the AI backend provider list", backend)
}
}
if !foundBackend {
@@ -87,8 +93,6 @@ var updateCmd = &cobra.Command{
os.Exit(1)
}
}
viper.Set("ai", configAI)
if err := viper.WriteConfig(); err != nil {
color.Red("Error writing config file: %s", err.Error())
@@ -110,4 +114,6 @@ func init() {
updateCmd.Flags().Float32VarP(&temperature, "temperature", "t", 0.7, "The sampling temperature, value ranges between 0 ( output be more deterministic) and 1 (more random)")
// update flag for azure open ai engine/deployment name
updateCmd.Flags().StringVarP(&engine, "engine", "e", "", "Update Azure AI deployment name")
// update flag for organizationId
updateCmd.Flags().StringVarP(&organizationId, "organizationId", "o", "", "Update OpenAI or Azure organization Id")
}

View File

@@ -17,6 +17,7 @@ type AzureAIClient struct {
client *openai.Client
model string
temperature float32
organizationId string
}
func (c *AzureAIClient) Configure(config IAIConfig) error {
@@ -25,6 +26,7 @@ func (c *AzureAIClient) Configure(config IAIConfig) error {
engine := config.GetEngine()
proxyEndpoint := config.GetProxyEndpoint()
defaultConfig := openai.DefaultAzureConfig(token, baseURL)
orgId := config.GetOrganizationId()
defaultConfig.AzureModelMapperFunc = func(model string) string {
// If you use a deployment name different from the model name, you can customize the AzureModelMapperFunc function
@@ -48,6 +50,10 @@ func (c *AzureAIClient) Configure(config IAIConfig) error {
Transport: transport,
}
}
if orgId != "" {
defaultConfig.OrgID = orgId
}
client := openai.NewClientWithConfig(defaultConfig)
if client == nil {
return errors.New("error creating Azure OpenAI client")

View File

@@ -78,6 +78,7 @@ type IAIConfig interface {
GetMaxTokens() int
GetProviderId() string
GetCompartmentId() string
GetOrganizationId() string
}
func NewClient(provider string) IAI {
@@ -111,6 +112,7 @@ type AIProvider struct {
TopP float32 `mapstructure:"topp" yaml:"topp,omitempty"`
TopK int32 `mapstructure:"topk" yaml:"topk,omitempty"`
MaxTokens int `mapstructure:"maxtokens" yaml:"maxtokens,omitempty"`
OrganizationId string `mapstructure:"organizationid" yaml:"organizationid,omitempty"`
}
func (p *AIProvider) GetBaseURL() string {
@@ -164,6 +166,10 @@ func (p *AIProvider) GetCompartmentId() string {
return p.CompartmentId
}
func (p *AIProvider) GetOrganizationId() string {
return p.OrganizationId
}
var passwordlessProviders = []string{"localai", "amazonsagemaker", "amazonbedrock", "googlevertexai", "oci"}
func NeedPassword(backend string) bool {

View File

@@ -31,6 +31,7 @@ type OpenAIClient struct {
model string
temperature float32
topP float32
organizationId string
}
const (
@@ -43,6 +44,7 @@ const (
func (c *OpenAIClient) Configure(config IAIConfig) error {
token := config.GetPassword()
defaultConfig := openai.DefaultConfig(token)
orgId := config.GetOrganizationId()
proxyEndpoint := config.GetProxyEndpoint()
baseURL := config.GetBaseURL()
@@ -64,6 +66,10 @@ func (c *OpenAIClient) Configure(config IAIConfig) error {
}
}
if orgId != "" {
defaultConfig.OrgID = orgId
}
client := openai.NewClientWithConfig(defaultConfig)
if client == nil {
return errors.New("error creating OpenAI client")