mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
fix(openai): infer azure chat profiles from model name (#36858)
This commit is contained in:
@@ -705,6 +705,10 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
||||
return self
|
||||
|
||||
def _resolve_model_profile(self) -> ModelProfile | None:
|
||||
if (self.model_name is not None) and (
|
||||
profile := _get_default_model_profile(self.model_name) or None
|
||||
):
|
||||
return profile
|
||||
if self.deployment_name is not None:
|
||||
return _get_default_model_profile(self.deployment_name) or None
|
||||
return None
|
||||
|
||||
@@ -16,6 +16,7 @@ def test_initialize_azure_openai() -> None:
|
||||
azure_deployment="35-turbo-dev",
|
||||
openai_api_version="2023-05-15",
|
||||
azure_endpoint="my-base-url",
|
||||
api_key=SecretStr("test"),
|
||||
)
|
||||
assert llm.deployment_name == "35-turbo-dev"
|
||||
assert llm.openai_api_version == "2023-05-15"
|
||||
@@ -45,6 +46,92 @@ def test_initialize_more() -> None:
|
||||
assert ls_params.get("ls_model_name") == "gpt-35-turbo-0125"
|
||||
|
||||
|
||||
def test_profile_resolves_from_model_name() -> None:
|
||||
llm = AzureChatOpenAI(
|
||||
model="gpt-4o",
|
||||
azure_endpoint="my-base-url",
|
||||
api_key=SecretStr("test"),
|
||||
api_version="2023-05-15",
|
||||
)
|
||||
|
||||
assert llm.profile
|
||||
assert llm.profile["name"] == "GPT-4o"
|
||||
assert llm.profile["max_input_tokens"] == 128_000
|
||||
|
||||
|
||||
def test_profile_resolves_from_model_name_with_custom_deployment_alias() -> None:
|
||||
llm = AzureChatOpenAI(
|
||||
model="gpt-4o",
|
||||
azure_deployment="35-turbo-dev",
|
||||
azure_endpoint="my-base-url",
|
||||
api_key=SecretStr("test"),
|
||||
api_version="2023-05-15",
|
||||
)
|
||||
|
||||
assert llm.profile
|
||||
assert llm.profile["name"] == "GPT-4o"
|
||||
|
||||
|
||||
def test_profile_prefers_model_name_over_known_deployment_name() -> None:
|
||||
llm = AzureChatOpenAI(
|
||||
model="gpt-4o",
|
||||
azure_deployment="gpt-4",
|
||||
azure_endpoint="my-base-url",
|
||||
api_key=SecretStr("test"),
|
||||
api_version="2023-05-15",
|
||||
)
|
||||
|
||||
assert llm.profile
|
||||
assert llm.profile["name"] == "GPT-4o"
|
||||
|
||||
|
||||
def test_profile_falls_back_to_deployment_name_with_unknown_model() -> None:
|
||||
llm = AzureChatOpenAI(
|
||||
model="unknown-model",
|
||||
azure_deployment="gpt-4o",
|
||||
azure_endpoint="my-base-url",
|
||||
api_key=SecretStr("test"),
|
||||
api_version="2023-05-15",
|
||||
)
|
||||
|
||||
assert llm.profile
|
||||
|
||||
|
||||
def test_profile_resolves_from_deployment_name_without_model() -> None:
|
||||
llm = AzureChatOpenAI(
|
||||
azure_deployment="gpt-4o",
|
||||
azure_endpoint="my-base-url",
|
||||
api_key=SecretStr("test"),
|
||||
api_version="2023-05-15",
|
||||
)
|
||||
|
||||
assert llm.profile
|
||||
assert llm.profile["name"] == "GPT-4o"
|
||||
|
||||
|
||||
def test_profile_respects_explicit_profile() -> None:
|
||||
llm = AzureChatOpenAI(
|
||||
model="gpt-4o",
|
||||
azure_endpoint="my-base-url",
|
||||
api_key=SecretStr("test"),
|
||||
api_version="2023-05-15",
|
||||
profile={"tool_calling": False},
|
||||
)
|
||||
|
||||
assert llm.profile == {"tool_calling": False}
|
||||
|
||||
|
||||
def test_profile_is_none_for_unknown_deployment_without_model() -> None:
|
||||
llm = AzureChatOpenAI(
|
||||
azure_deployment="unknown-deployment",
|
||||
azure_endpoint="my-base-url",
|
||||
api_key=SecretStr("test"),
|
||||
api_version="2023-05-15",
|
||||
)
|
||||
|
||||
assert llm.profile is None
|
||||
|
||||
|
||||
def test_initialize_azure_openai_with_openai_api_base_set() -> None:
|
||||
with mock.patch.dict(os.environ, {"OPENAI_API_BASE": "https://api.openai.com"}):
|
||||
llm = AzureChatOpenAI( # type: ignore[call-arg, call-arg]
|
||||
@@ -79,6 +166,7 @@ def test_structured_output_old_model() -> None:
|
||||
azure_deployment="35-turbo-dev",
|
||||
openai_api_version="2023-05-15",
|
||||
azure_endpoint="my-base-url",
|
||||
api_key=SecretStr("test"),
|
||||
).with_structured_output(Output)
|
||||
|
||||
# assert tool calling was used instead of json_schema
|
||||
@@ -91,6 +179,7 @@ def test_max_completion_tokens_in_payload() -> None:
|
||||
azure_deployment="o1-mini",
|
||||
api_version="2024-12-01-preview",
|
||||
azure_endpoint="my-base-url",
|
||||
api_key=SecretStr("test"),
|
||||
model_kwargs={"max_completion_tokens": 300},
|
||||
)
|
||||
messages = [HumanMessage("Hello")]
|
||||
@@ -148,6 +237,7 @@ def test_max_completion_tokens_parameter() -> None:
|
||||
azure_deployment="gpt-5",
|
||||
api_version="2024-12-01-preview",
|
||||
azure_endpoint="my-base-url",
|
||||
api_key=SecretStr("test"),
|
||||
max_completion_tokens=1500,
|
||||
)
|
||||
messages = [HumanMessage("Hello")]
|
||||
@@ -165,6 +255,7 @@ def test_max_tokens_converted_to_max_completion_tokens() -> None:
|
||||
azure_deployment="gpt-5",
|
||||
api_version="2024-12-01-preview",
|
||||
azure_endpoint="my-base-url",
|
||||
api_key=SecretStr("test"),
|
||||
max_tokens=1000, # type: ignore[call-arg]
|
||||
)
|
||||
messages = [HumanMessage("Hello")]
|
||||
|
||||
Reference in New Issue
Block a user