mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-30 19:49:09 +00:00
partners[azure]: fix having openai_api_base set for other packages (#22068)
This fix is for #21726. When having other packages installed that require the `openai_api_base` environment variable, users are not able to instantiate the AzureChatModels or AzureEmbeddings. This PR adds a new value `ignore_openai_api_base` which is a bool. When set to True, it sets `openai_api_base` to `None` Two new tests were added for the `test_azure` and a new file `test_azure_embeddings` A different approach may be better for this. If you can think of better logic, let me know and I can adjust it. --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
b36e95caa9
commit
04bc5f1a95
@ -1,4 +1,5 @@
|
|||||||
"""Azure OpenAI chat wrapper."""
|
"""Azure OpenAI chat wrapper."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
@ -548,8 +549,10 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
|||||||
values["openai_api_key"] = (
|
values["openai_api_key"] = (
|
||||||
convert_to_secret_str(openai_api_key) if openai_api_key else None
|
convert_to_secret_str(openai_api_key) if openai_api_key else None
|
||||||
)
|
)
|
||||||
values["openai_api_base"] = values["openai_api_base"] or os.getenv(
|
values["openai_api_base"] = (
|
||||||
"OPENAI_API_BASE"
|
values["openai_api_base"]
|
||||||
|
if "openai_api_base" in values
|
||||||
|
else os.getenv("OPENAI_API_BASE")
|
||||||
)
|
)
|
||||||
values["openai_api_version"] = values["openai_api_version"] or os.getenv(
|
values["openai_api_version"] = values["openai_api_version"] or os.getenv(
|
||||||
"OPENAI_API_VERSION"
|
"OPENAI_API_VERSION"
|
||||||
@ -601,12 +604,16 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
|||||||
"api_version": values["openai_api_version"],
|
"api_version": values["openai_api_version"],
|
||||||
"azure_endpoint": values["azure_endpoint"],
|
"azure_endpoint": values["azure_endpoint"],
|
||||||
"azure_deployment": values["deployment_name"],
|
"azure_deployment": values["deployment_name"],
|
||||||
"api_key": values["openai_api_key"].get_secret_value()
|
"api_key": (
|
||||||
if values["openai_api_key"]
|
values["openai_api_key"].get_secret_value()
|
||||||
else None,
|
if values["openai_api_key"]
|
||||||
"azure_ad_token": values["azure_ad_token"].get_secret_value()
|
else None
|
||||||
if values["azure_ad_token"]
|
),
|
||||||
else None,
|
"azure_ad_token": (
|
||||||
|
values["azure_ad_token"].get_secret_value()
|
||||||
|
if values["azure_ad_token"]
|
||||||
|
else None
|
||||||
|
),
|
||||||
"azure_ad_token_provider": values["azure_ad_token_provider"],
|
"azure_ad_token_provider": values["azure_ad_token_provider"],
|
||||||
"organization": values["openai_organization"],
|
"organization": values["openai_organization"],
|
||||||
"base_url": values["openai_api_base"],
|
"base_url": values["openai_api_base"],
|
||||||
|
@ -75,8 +75,10 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
|
|||||||
values["openai_api_key"] = (
|
values["openai_api_key"] = (
|
||||||
convert_to_secret_str(openai_api_key) if openai_api_key else None
|
convert_to_secret_str(openai_api_key) if openai_api_key else None
|
||||||
)
|
)
|
||||||
values["openai_api_base"] = values["openai_api_base"] or os.getenv(
|
values["openai_api_base"] = (
|
||||||
"OPENAI_API_BASE"
|
values["openai_api_base"]
|
||||||
|
if "openai_api_base" in values
|
||||||
|
else os.getenv("OPENAI_API_BASE")
|
||||||
)
|
)
|
||||||
values["openai_api_version"] = values["openai_api_version"] or os.getenv(
|
values["openai_api_version"] = values["openai_api_version"] or os.getenv(
|
||||||
"OPENAI_API_VERSION", default="2023-05-15"
|
"OPENAI_API_VERSION", default="2023-05-15"
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
"""Test Azure OpenAI Chat API wrapper."""
|
"""Test Azure OpenAI Chat API wrapper."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
from langchain_openai import AzureChatOpenAI
|
from langchain_openai import AzureChatOpenAI
|
||||||
|
|
||||||
|
|
||||||
@ -32,3 +34,25 @@ def test_initialize_more() -> None:
|
|||||||
ls_params = llm._get_ls_params()
|
ls_params = llm._get_ls_params()
|
||||||
assert ls_params["ls_provider"] == "azure"
|
assert ls_params["ls_provider"] == "azure"
|
||||||
assert ls_params["ls_model_name"] == "35-turbo-dev"
|
assert ls_params["ls_model_name"] == "35-turbo-dev"
|
||||||
|
|
||||||
|
|
||||||
|
def test_initialize_azure_openai_with_openai_api_base_set() -> None:
|
||||||
|
os.environ["OPENAI_API_BASE"] = "https://api.openai.com"
|
||||||
|
llm = AzureChatOpenAI(
|
||||||
|
api_key="xyz",
|
||||||
|
azure_endpoint="my-base-url",
|
||||||
|
azure_deployment="35-turbo-dev",
|
||||||
|
openai_api_version="2023-05-15",
|
||||||
|
temperature=0,
|
||||||
|
openai_api_base=None,
|
||||||
|
)
|
||||||
|
assert llm.openai_api_key is not None
|
||||||
|
assert llm.openai_api_key.get_secret_value() == "xyz"
|
||||||
|
assert llm.azure_endpoint == "my-base-url"
|
||||||
|
assert llm.deployment_name == "35-turbo-dev"
|
||||||
|
assert llm.openai_api_version == "2023-05-15"
|
||||||
|
assert llm.temperature == 0
|
||||||
|
|
||||||
|
ls_params = llm._get_ls_params()
|
||||||
|
assert ls_params["ls_provider"] == "azure"
|
||||||
|
assert ls_params["ls_model_name"] == "35-turbo-dev"
|
||||||
|
@ -0,0 +1,27 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from langchain_openai import AzureOpenAIEmbeddings
|
||||||
|
|
||||||
|
|
||||||
|
def test_initialize_azure_openai() -> None:
|
||||||
|
embeddings = AzureOpenAIEmbeddings(
|
||||||
|
model="text-embedding-large",
|
||||||
|
api_key="xyz",
|
||||||
|
azure_endpoint="my-base-url",
|
||||||
|
azure_deployment="35-turbo-dev",
|
||||||
|
openai_api_version="2023-05-15",
|
||||||
|
)
|
||||||
|
assert embeddings.model == "text-embedding-large"
|
||||||
|
|
||||||
|
|
||||||
|
def test_intialize_azure_openai_with_base_set() -> None:
|
||||||
|
os.environ["OPENAI_API_BASE"] = "https://api.openai.com"
|
||||||
|
embeddings = AzureOpenAIEmbeddings(
|
||||||
|
model="text-embedding-large",
|
||||||
|
api_key="xyz",
|
||||||
|
azure_endpoint="my-base-url",
|
||||||
|
azure_deployment="35-turbo-dev",
|
||||||
|
openai_api_version="2023-05-15",
|
||||||
|
openai_api_base=None,
|
||||||
|
)
|
||||||
|
assert embeddings.model == "text-embedding-large"
|
Loading…
Reference in New Issue
Block a user