mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-19 17:36:00 +00:00
feat(openai): add max_tokens
to AzureChatOpenAI
(#32959)
Fixes #32949 This pattern is [present in `ChatOpenAI`](https://github.com/langchain-ai/langchain/blob/master/libs/partners/openai/langchain_openai/chat_models/base.py#L2821) but wasn't carried over to Azure. [CI](https://github.com/langchain-ai/langchain/actions/runs/17741751797/job/50417180998)
This commit is contained in:
@@ -578,6 +578,9 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
||||
``'parallel_tools_calls'`` will be disabled.
|
||||
"""
|
||||
|
||||
max_tokens: Optional[int] = Field(default=None, alias="max_completion_tokens") # type: ignore[assignment]
|
||||
"""Maximum number of tokens to generate."""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
@@ -699,6 +702,15 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
||||
"openai_api_version": self.openai_api_version,
|
||||
}
|
||||
|
||||
@property
|
||||
def _default_params(self) -> dict[str, Any]:
|
||||
"""Get the default parameters for calling Azure OpenAI API."""
|
||||
params = super()._default_params
|
||||
if "max_tokens" in params:
|
||||
params["max_completion_tokens"] = params.pop("max_tokens")
|
||||
|
||||
return params
|
||||
|
||||
def _get_ls_params(
|
||||
self, stop: Optional[list[str]] = None, **kwargs: Any
|
||||
) -> LangSmithParams:
|
||||
|
@@ -139,3 +139,37 @@ def test_chat_completions_api_uses_model_name() -> None:
|
||||
assert payload["model"] == "gpt-5"
|
||||
assert "messages" in payload # Chat Completions API uses 'messages'
|
||||
assert "input" not in payload
|
||||
|
||||
|
||||
def test_max_completion_tokens_parameter() -> None:
|
||||
"""Test that max_completion_tokens can be used as a direct parameter."""
|
||||
llm = AzureChatOpenAI(
|
||||
azure_deployment="gpt-5",
|
||||
api_version="2024-12-01-preview",
|
||||
azure_endpoint="my-base-url",
|
||||
max_completion_tokens=1500,
|
||||
)
|
||||
messages = [HumanMessage("Hello")]
|
||||
payload = llm._get_request_payload(messages)
|
||||
|
||||
# Should use max_completion_tokens instead of max_tokens
|
||||
assert "max_completion_tokens" in payload
|
||||
assert payload["max_completion_tokens"] == 1500
|
||||
assert "max_tokens" not in payload
|
||||
|
||||
|
||||
def test_max_tokens_converted_to_max_completion_tokens() -> None:
|
||||
"""Test that max_tokens is converted to max_completion_tokens."""
|
||||
llm = AzureChatOpenAI(
|
||||
azure_deployment="gpt-5",
|
||||
api_version="2024-12-01-preview",
|
||||
azure_endpoint="my-base-url",
|
||||
max_tokens=1000, # type: ignore[call-arg]
|
||||
)
|
||||
messages = [HumanMessage("Hello")]
|
||||
payload = llm._get_request_payload(messages)
|
||||
|
||||
# max_tokens should be converted to max_completion_tokens
|
||||
assert "max_completion_tokens" in payload
|
||||
assert payload["max_completion_tokens"] == 1000
|
||||
assert "max_tokens" not in payload
|
||||
|
Reference in New Issue
Block a user