fix(openai): infer azure chat profiles from model name (#36858)

This commit is contained in:
Thomas
2026-04-19 11:06:26 -04:00
committed by GitHub
parent 02991cb4cf
commit 8fec4e7cee
2 changed files with 95 additions and 0 deletions

View File

@@ -705,6 +705,10 @@ class AzureChatOpenAI(BaseChatOpenAI):
return self return self
def _resolve_model_profile(self) -> ModelProfile | None: def _resolve_model_profile(self) -> ModelProfile | None:
if (self.model_name is not None) and (
profile := _get_default_model_profile(self.model_name) or None
):
return profile
if self.deployment_name is not None: if self.deployment_name is not None:
return _get_default_model_profile(self.deployment_name) or None return _get_default_model_profile(self.deployment_name) or None
return None return None

View File

@@ -16,6 +16,7 @@ def test_initialize_azure_openai() -> None:
azure_deployment="35-turbo-dev", azure_deployment="35-turbo-dev",
openai_api_version="2023-05-15", openai_api_version="2023-05-15",
azure_endpoint="my-base-url", azure_endpoint="my-base-url",
api_key=SecretStr("test"),
) )
assert llm.deployment_name == "35-turbo-dev" assert llm.deployment_name == "35-turbo-dev"
assert llm.openai_api_version == "2023-05-15" assert llm.openai_api_version == "2023-05-15"
@@ -45,6 +46,92 @@ def test_initialize_more() -> None:
assert ls_params.get("ls_model_name") == "gpt-35-turbo-0125" assert ls_params.get("ls_model_name") == "gpt-35-turbo-0125"
def test_profile_resolves_from_model_name() -> None:
llm = AzureChatOpenAI(
model="gpt-4o",
azure_endpoint="my-base-url",
api_key=SecretStr("test"),
api_version="2023-05-15",
)
assert llm.profile
assert llm.profile["name"] == "GPT-4o"
assert llm.profile["max_input_tokens"] == 128_000
def test_profile_resolves_from_model_name_with_custom_deployment_alias() -> None:
llm = AzureChatOpenAI(
model="gpt-4o",
azure_deployment="35-turbo-dev",
azure_endpoint="my-base-url",
api_key=SecretStr("test"),
api_version="2023-05-15",
)
assert llm.profile
assert llm.profile["name"] == "GPT-4o"
def test_profile_prefers_model_name_over_known_deployment_name() -> None:
llm = AzureChatOpenAI(
model="gpt-4o",
azure_deployment="gpt-4",
azure_endpoint="my-base-url",
api_key=SecretStr("test"),
api_version="2023-05-15",
)
assert llm.profile
assert llm.profile["name"] == "GPT-4o"
def test_profile_falls_back_to_deployment_name_with_unknown_model() -> None:
llm = AzureChatOpenAI(
model="unknown-model",
azure_deployment="gpt-4o",
azure_endpoint="my-base-url",
api_key=SecretStr("test"),
api_version="2023-05-15",
)
assert llm.profile
def test_profile_resolves_from_deployment_name_without_model() -> None:
llm = AzureChatOpenAI(
azure_deployment="gpt-4o",
azure_endpoint="my-base-url",
api_key=SecretStr("test"),
api_version="2023-05-15",
)
assert llm.profile
assert llm.profile["name"] == "GPT-4o"
def test_profile_respects_explicit_profile() -> None:
llm = AzureChatOpenAI(
model="gpt-4o",
azure_endpoint="my-base-url",
api_key=SecretStr("test"),
api_version="2023-05-15",
profile={"tool_calling": False},
)
assert llm.profile == {"tool_calling": False}
def test_profile_is_none_for_unknown_deployment_without_model() -> None:
llm = AzureChatOpenAI(
azure_deployment="unknown-deployment",
azure_endpoint="my-base-url",
api_key=SecretStr("test"),
api_version="2023-05-15",
)
assert llm.profile is None
def test_initialize_azure_openai_with_openai_api_base_set() -> None: def test_initialize_azure_openai_with_openai_api_base_set() -> None:
with mock.patch.dict(os.environ, {"OPENAI_API_BASE": "https://api.openai.com"}): with mock.patch.dict(os.environ, {"OPENAI_API_BASE": "https://api.openai.com"}):
llm = AzureChatOpenAI( # type: ignore[call-arg, call-arg] llm = AzureChatOpenAI( # type: ignore[call-arg, call-arg]
@@ -79,6 +166,7 @@ def test_structured_output_old_model() -> None:
azure_deployment="35-turbo-dev", azure_deployment="35-turbo-dev",
openai_api_version="2023-05-15", openai_api_version="2023-05-15",
azure_endpoint="my-base-url", azure_endpoint="my-base-url",
api_key=SecretStr("test"),
).with_structured_output(Output) ).with_structured_output(Output)
# assert tool calling was used instead of json_schema # assert tool calling was used instead of json_schema
@@ -91,6 +179,7 @@ def test_max_completion_tokens_in_payload() -> None:
azure_deployment="o1-mini", azure_deployment="o1-mini",
api_version="2024-12-01-preview", api_version="2024-12-01-preview",
azure_endpoint="my-base-url", azure_endpoint="my-base-url",
api_key=SecretStr("test"),
model_kwargs={"max_completion_tokens": 300}, model_kwargs={"max_completion_tokens": 300},
) )
messages = [HumanMessage("Hello")] messages = [HumanMessage("Hello")]
@@ -148,6 +237,7 @@ def test_max_completion_tokens_parameter() -> None:
azure_deployment="gpt-5", azure_deployment="gpt-5",
api_version="2024-12-01-preview", api_version="2024-12-01-preview",
azure_endpoint="my-base-url", azure_endpoint="my-base-url",
api_key=SecretStr("test"),
max_completion_tokens=1500, max_completion_tokens=1500,
) )
messages = [HumanMessage("Hello")] messages = [HumanMessage("Hello")]
@@ -165,6 +255,7 @@ def test_max_tokens_converted_to_max_completion_tokens() -> None:
azure_deployment="gpt-5", azure_deployment="gpt-5",
api_version="2024-12-01-preview", api_version="2024-12-01-preview",
azure_endpoint="my-base-url", azure_endpoint="my-base-url",
api_key=SecretStr("test"),
max_tokens=1000, # type: ignore[call-arg] max_tokens=1000, # type: ignore[call-arg]
) )
messages = [HumanMessage("Hello")] messages = [HumanMessage("Hello")]