From 9f6431924ff9d457e0b627f7ef03d0ce3f5d5d15 Mon Sep 17 00:00:00 2001 From: Mason Daugherty Date: Mon, 15 Sep 2025 14:09:20 -0400 Subject: [PATCH] feat(openai): add `max_tokens` to `AzureChatOpenAI` (#32959) Fixes #32949 This pattern is [present in `ChatOpenAI`](https://github.com/langchain-ai/langchain/blob/master/libs/partners/openai/langchain_openai/chat_models/base.py#L2821) but wasn't carried over to Azure. [CI](https://github.com/langchain-ai/langchain/actions/runs/17741751797/job/50417180998) --- .../langchain_openai/chat_models/azure.py | 12 +++++++ .../unit_tests/chat_models/test_azure.py | 34 +++++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/libs/partners/openai/langchain_openai/chat_models/azure.py b/libs/partners/openai/langchain_openai/chat_models/azure.py index af108e10c97..2a2f73709cf 100644 --- a/libs/partners/openai/langchain_openai/chat_models/azure.py +++ b/libs/partners/openai/langchain_openai/chat_models/azure.py @@ -578,6 +578,9 @@ class AzureChatOpenAI(BaseChatOpenAI): ``'parallel_tools_calls'`` will be disabled. """ + max_tokens: Optional[int] = Field(default=None, alias="max_completion_tokens") # type: ignore[assignment] + """Maximum number of tokens to generate.""" + @classmethod def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object.""" @@ -699,6 +702,15 @@ class AzureChatOpenAI(BaseChatOpenAI): "openai_api_version": self.openai_api_version, } + @property + def _default_params(self) -> dict[str, Any]: + """Get the default parameters for calling Azure OpenAI API.""" + params = super()._default_params + if "max_tokens" in params: + params["max_completion_tokens"] = params.pop("max_tokens") + + return params + def _get_ls_params( self, stop: Optional[list[str]] = None, **kwargs: Any ) -> LangSmithParams: diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_azure.py b/libs/partners/openai/tests/unit_tests/chat_models/test_azure.py index 1640c120544..a42cc69008e 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_azure.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_azure.py @@ -139,3 +139,37 @@ def test_chat_completions_api_uses_model_name() -> None: assert payload["model"] == "gpt-5" assert "messages" in payload # Chat Completions API uses 'messages' assert "input" not in payload + + +def test_max_completion_tokens_parameter() -> None: + """Test that max_completion_tokens can be used as a direct parameter.""" + llm = AzureChatOpenAI( + azure_deployment="gpt-5", + api_version="2024-12-01-preview", + azure_endpoint="my-base-url", + max_completion_tokens=1500, + ) + messages = [HumanMessage("Hello")] + payload = llm._get_request_payload(messages) + + # Should use max_completion_tokens instead of max_tokens + assert "max_completion_tokens" in payload + assert payload["max_completion_tokens"] == 1500 + assert "max_tokens" not in payload + + +def test_max_tokens_converted_to_max_completion_tokens() -> None: + """Test that max_tokens is converted to max_completion_tokens.""" + llm = AzureChatOpenAI( + azure_deployment="gpt-5", + api_version="2024-12-01-preview", + azure_endpoint="my-base-url", + max_tokens=1000, # type: ignore[call-arg] + ) + messages = [HumanMessage("Hello")] + payload = llm._get_request_payload(messages) + + # max_tokens should be converted to max_completion_tokens + assert "max_completion_tokens" in payload + assert payload["max_completion_tokens"] == 1000 + assert "max_tokens" not in payload