fix(deepseek): Tool Choice to required for Azure Deployment in case specific function dict is given (#34848)

This commit is contained in:
Mohammad Mohtashim
2026-02-20 22:50:25 +05:00
committed by GitHub
parent 0081deae96
commit 03826061be
2 changed files with 127 additions and 0 deletions

View File

@@ -194,6 +194,11 @@ class ChatDeepSeek(BaseChatOpenAI):
model_config = ConfigDict(populate_by_name=True)
@property
def _is_azure_endpoint(self) -> bool:
"""Check if the configured endpoint is an Azure deployment."""
return "azure.com" in (self.api_base or "").lower()
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
@@ -276,6 +281,17 @@ class ChatDeepSeek(BaseChatOpenAI):
if isinstance(block, dict) and block.get("type") == "text"
]
message["content"] = "".join(text_parts) if text_parts else ""
# Azure-hosted DeepSeek does not support the dict/object form of
# tool_choice (e.g. {"type": "function", "function": {"name": "..."}}).
# It only accepts string values: "none", "auto", or "required".
# Convert the unsupported dict form to "required", which is the closest
# string equivalent — it forces the model to call a tool without
# constraining which one. In the common with_structured_output() case
# only a single tool is bound, so the behavior is effectively identical.
if self._is_azure_endpoint and isinstance(payload.get("tool_choice"), dict):
payload["tool_choice"] = "required"
return payload
def _create_chat_result(

View File

@@ -309,6 +309,117 @@ class TestChatDeepSeekStrictMode:
assert structured_model is not None
class TestChatDeepSeekAzureToolChoice:
"""Tests for Azure-hosted DeepSeek tool_choice compatibility.
Azure-hosted DeepSeek does not support the dict/object form of tool_choice
(e.g. {"type": "function", "function": {"name": "..."}}) and returns a 422
error. Only string values ("none", "auto", "required") are accepted.
The fix converts the unsupported dict form to "required" at the payload
level in _get_request_payload, which is the last stop before the API call.
String values are preserved as-is.
"""
def _get_azure_model(
self,
endpoint: str = "https://my-resource.openai.azure.com/",
) -> ChatDeepSeek:
"""Create a ChatDeepSeek instance pointed at an Azure endpoint."""
return ChatDeepSeek(
model="deepseek-chat",
api_key=SecretStr("test_key"),
api_base=endpoint,
)
def test_is_azure_endpoint_detection(self) -> None:
"""Test that _is_azure_endpoint correctly identifies Azure URLs."""
azure_endpoints = [
"https://my-resource.openai.azure.com/",
"https://my-resource.openai.azure.com/openai/deployments/deepseek",
"https://RESOURCE.OPENAI.AZURE.COM/", # case insensitivity
"https://test.services.ai.azure.com/",
]
for endpoint in azure_endpoints:
llm = self._get_azure_model(endpoint)
assert llm._is_azure_endpoint, f"Expected Azure for {endpoint}"
non_azure_endpoints = [
DEFAULT_API_BASE,
"https://api.openai.com/v1",
"https://custom-endpoint.com/api",
]
for endpoint in non_azure_endpoints:
llm = ChatDeepSeek(
model="deepseek-chat",
api_key=SecretStr("test_key"),
api_base=endpoint,
)
assert not llm._is_azure_endpoint, f"Expected non-Azure for {endpoint}"
def test_payload_converts_dict_tool_choice_on_azure(self) -> None:
"""Test that dict-form tool_choice is converted to 'required' in payload."""
llm = self._get_azure_model()
# Simulate with_structured_output flow: bind_tools converts a tool name
# string into the dict form {"type": "function", "function": {"name": ...}}
bound = llm.bind_tools([SampleTool], tool_choice="SampleTool")
messages = [("user", "test")]
bound_kwargs = bound.kwargs # type: ignore[attr-defined]
# At bind_tools level, the parent converts the tool name to dict form
assert isinstance(bound_kwargs.get("tool_choice"), dict)
# But _get_request_payload should convert it to "required"
request_payload = llm._get_request_payload(messages, **bound_kwargs)
assert request_payload.get("tool_choice") == "required"
def test_payload_preserves_string_tool_choice_on_azure(self) -> None:
"""Test that valid string tool_choice values are NOT overridden on Azure."""
llm = self._get_azure_model()
messages = [("user", "test")]
for choice in ("auto", "none", "required"):
bound = llm.bind_tools([SampleTool], tool_choice=choice)
request_payload = llm._get_request_payload(
messages,
**bound.kwargs, # type: ignore[attr-defined]
)
assert request_payload.get("tool_choice") == choice, (
f"Expected '{choice}' to be preserved, got "
f"{request_payload.get('tool_choice')!r}"
)
def test_payload_preserves_dict_tool_choice_on_non_azure(self) -> None:
"""Test that dict-form tool_choice is NOT converted on non-Azure endpoints."""
llm = ChatDeepSeek(
model="deepseek-chat",
api_key=SecretStr("test_key"),
)
bound = llm.bind_tools([SampleTool], tool_choice="SampleTool")
messages = [("user", "test")]
request_payload = llm._get_request_payload(
messages,
**bound.kwargs, # type: ignore[attr-defined]
)
# On non-Azure, the dict form should be preserved
assert isinstance(request_payload.get("tool_choice"), dict)
def test_with_structured_output_on_azure(self) -> None:
"""Test that with_structured_output works on Azure (the original bug)."""
llm = self._get_azure_model()
# with_structured_output internally calls bind_tools with the schema
# name as tool_choice, which gets converted to the dict form.
structured = llm.with_structured_output(SampleTool)
assert structured is not None
def test_bind_tools_azure_with_strict_mode(self) -> None:
"""Test Azure endpoint with strict mode enabled."""
llm = self._get_azure_model()
bound_model = llm.bind_tools([SampleTool], strict=True)
assert bound_model is not None
def test_profile() -> None:
"""Test that model profile is loaded correctly."""
model = ChatDeepSeek(model="deepseek-reasoner", api_key=SecretStr("test_key"))