From 4f231b46eeb99b680964d5799f2384edfa2913e2 Mon Sep 17 00:00:00 2001 From: st01cs <42166106+st01cs@users.noreply.github.com> Date: Thu, 13 Apr 2023 23:35:36 +0800 Subject: [PATCH] Add openai.api_base to support openapi proxy (#2823) I need access openai api through a proxy, so to add openai.api_base to support this method. Co-authored-by: bijia --- langchain/llms/openai.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/langchain/llms/openai.py b/langchain/llms/openai.py index b6219f7eda6..a13beb7367a 100644 --- a/langchain/llms/openai.py +++ b/langchain/llms/openai.py @@ -151,6 +151,7 @@ class BaseOpenAI(BaseLLM): model_kwargs: Dict[str, Any] = Field(default_factory=dict) """Holds any model parameters valid for `create` call not explicitly specified.""" openai_api_key: Optional[str] = None + openai_api_base: Optional[str] = None openai_organization: Optional[str] = None batch_size: int = 20 """Batch size to use when passing multiple documents to generate.""" @@ -205,6 +206,12 @@ class BaseOpenAI(BaseLLM): openai_api_key = get_from_dict_or_env( values, "openai_api_key", "OPENAI_API_KEY" ) + openai_api_base = get_from_dict_or_env( + values, + "openai_api_base", + "OPENAI_API_BASE", + default="", + ) openai_organization = get_from_dict_or_env( values, "openai_organization", @@ -215,6 +222,10 @@ class BaseOpenAI(BaseLLM): import openai openai.api_key = openai_api_key + if openai_api_base: + print("USING API_BASE: ") + print(openai_api_base) + openai.api_base = openai_api_base if openai_organization: print("USING ORGANIZATION: ") print(openai_organization) @@ -567,6 +578,7 @@ class OpenAIChat(BaseLLM): model_kwargs: Dict[str, Any] = Field(default_factory=dict) """Holds any model parameters valid for `create` call not explicitly specified.""" openai_api_key: Optional[str] = None + openai_api_base: Optional[str] = None max_retries: int = 6 """Maximum number of retries to make when generating.""" prefix_messages: List = Field(default_factory=list) @@ -599,6 +611,12 @@ class OpenAIChat(BaseLLM): openai_api_key = get_from_dict_or_env( values, "openai_api_key", "OPENAI_API_KEY" ) + openai_api_base = get_from_dict_or_env( + values, + "openai_api_base", + "OPENAI_API_BASE", + default="", + ) openai_organization = get_from_dict_or_env( values, "openai_organization", "OPENAI_ORGANIZATION", default="" ) @@ -606,6 +624,8 @@ class OpenAIChat(BaseLLM): import openai openai.api_key = openai_api_key + if openai_api_base: + openai.api_base = openai_api_base if openai_organization: openai.organization = openai_organization except ImportError: