mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 14:43:07 +00:00
fix(deepseek): Tool Choice to required for Azure Deployment in case specific function dict is given (#34848)
This commit is contained in:
committed by
GitHub
parent
0081deae96
commit
03826061be
@@ -194,6 +194,11 @@ class ChatDeepSeek(BaseChatOpenAI):
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
@property
|
||||
def _is_azure_endpoint(self) -> bool:
|
||||
"""Check if the configured endpoint is an Azure deployment."""
|
||||
return "azure.com" in (self.api_base or "").lower()
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
@@ -276,6 +281,17 @@ class ChatDeepSeek(BaseChatOpenAI):
|
||||
if isinstance(block, dict) and block.get("type") == "text"
|
||||
]
|
||||
message["content"] = "".join(text_parts) if text_parts else ""
|
||||
|
||||
# Azure-hosted DeepSeek does not support the dict/object form of
|
||||
# tool_choice (e.g. {"type": "function", "function": {"name": "..."}}).
|
||||
# It only accepts string values: "none", "auto", or "required".
|
||||
# Convert the unsupported dict form to "required", which is the closest
|
||||
# string equivalent — it forces the model to call a tool without
|
||||
# constraining which one. In the common with_structured_output() case
|
||||
# only a single tool is bound, so the behavior is effectively identical.
|
||||
if self._is_azure_endpoint and isinstance(payload.get("tool_choice"), dict):
|
||||
payload["tool_choice"] = "required"
|
||||
|
||||
return payload
|
||||
|
||||
def _create_chat_result(
|
||||
|
||||
@@ -309,6 +309,117 @@ class TestChatDeepSeekStrictMode:
|
||||
assert structured_model is not None
|
||||
|
||||
|
||||
class TestChatDeepSeekAzureToolChoice:
|
||||
"""Tests for Azure-hosted DeepSeek tool_choice compatibility.
|
||||
|
||||
Azure-hosted DeepSeek does not support the dict/object form of tool_choice
|
||||
(e.g. {"type": "function", "function": {"name": "..."}}) and returns a 422
|
||||
error. Only string values ("none", "auto", "required") are accepted.
|
||||
|
||||
The fix converts the unsupported dict form to "required" at the payload
|
||||
level in _get_request_payload, which is the last stop before the API call.
|
||||
String values are preserved as-is.
|
||||
"""
|
||||
|
||||
def _get_azure_model(
|
||||
self,
|
||||
endpoint: str = "https://my-resource.openai.azure.com/",
|
||||
) -> ChatDeepSeek:
|
||||
"""Create a ChatDeepSeek instance pointed at an Azure endpoint."""
|
||||
return ChatDeepSeek(
|
||||
model="deepseek-chat",
|
||||
api_key=SecretStr("test_key"),
|
||||
api_base=endpoint,
|
||||
)
|
||||
|
||||
def test_is_azure_endpoint_detection(self) -> None:
|
||||
"""Test that _is_azure_endpoint correctly identifies Azure URLs."""
|
||||
azure_endpoints = [
|
||||
"https://my-resource.openai.azure.com/",
|
||||
"https://my-resource.openai.azure.com/openai/deployments/deepseek",
|
||||
"https://RESOURCE.OPENAI.AZURE.COM/", # case insensitivity
|
||||
"https://test.services.ai.azure.com/",
|
||||
]
|
||||
for endpoint in azure_endpoints:
|
||||
llm = self._get_azure_model(endpoint)
|
||||
assert llm._is_azure_endpoint, f"Expected Azure for {endpoint}"
|
||||
|
||||
non_azure_endpoints = [
|
||||
DEFAULT_API_BASE,
|
||||
"https://api.openai.com/v1",
|
||||
"https://custom-endpoint.com/api",
|
||||
]
|
||||
for endpoint in non_azure_endpoints:
|
||||
llm = ChatDeepSeek(
|
||||
model="deepseek-chat",
|
||||
api_key=SecretStr("test_key"),
|
||||
api_base=endpoint,
|
||||
)
|
||||
assert not llm._is_azure_endpoint, f"Expected non-Azure for {endpoint}"
|
||||
|
||||
def test_payload_converts_dict_tool_choice_on_azure(self) -> None:
|
||||
"""Test that dict-form tool_choice is converted to 'required' in payload."""
|
||||
llm = self._get_azure_model()
|
||||
# Simulate with_structured_output flow: bind_tools converts a tool name
|
||||
# string into the dict form {"type": "function", "function": {"name": ...}}
|
||||
bound = llm.bind_tools([SampleTool], tool_choice="SampleTool")
|
||||
messages = [("user", "test")]
|
||||
bound_kwargs = bound.kwargs # type: ignore[attr-defined]
|
||||
|
||||
# At bind_tools level, the parent converts the tool name to dict form
|
||||
assert isinstance(bound_kwargs.get("tool_choice"), dict)
|
||||
|
||||
# But _get_request_payload should convert it to "required"
|
||||
request_payload = llm._get_request_payload(messages, **bound_kwargs)
|
||||
assert request_payload.get("tool_choice") == "required"
|
||||
|
||||
def test_payload_preserves_string_tool_choice_on_azure(self) -> None:
|
||||
"""Test that valid string tool_choice values are NOT overridden on Azure."""
|
||||
llm = self._get_azure_model()
|
||||
messages = [("user", "test")]
|
||||
|
||||
for choice in ("auto", "none", "required"):
|
||||
bound = llm.bind_tools([SampleTool], tool_choice=choice)
|
||||
request_payload = llm._get_request_payload(
|
||||
messages,
|
||||
**bound.kwargs, # type: ignore[attr-defined]
|
||||
)
|
||||
assert request_payload.get("tool_choice") == choice, (
|
||||
f"Expected '{choice}' to be preserved, got "
|
||||
f"{request_payload.get('tool_choice')!r}"
|
||||
)
|
||||
|
||||
def test_payload_preserves_dict_tool_choice_on_non_azure(self) -> None:
|
||||
"""Test that dict-form tool_choice is NOT converted on non-Azure endpoints."""
|
||||
llm = ChatDeepSeek(
|
||||
model="deepseek-chat",
|
||||
api_key=SecretStr("test_key"),
|
||||
)
|
||||
bound = llm.bind_tools([SampleTool], tool_choice="SampleTool")
|
||||
messages = [("user", "test")]
|
||||
request_payload = llm._get_request_payload(
|
||||
messages,
|
||||
**bound.kwargs, # type: ignore[attr-defined]
|
||||
)
|
||||
# On non-Azure, the dict form should be preserved
|
||||
assert isinstance(request_payload.get("tool_choice"), dict)
|
||||
|
||||
def test_with_structured_output_on_azure(self) -> None:
|
||||
"""Test that with_structured_output works on Azure (the original bug)."""
|
||||
llm = self._get_azure_model()
|
||||
|
||||
# with_structured_output internally calls bind_tools with the schema
|
||||
# name as tool_choice, which gets converted to the dict form.
|
||||
structured = llm.with_structured_output(SampleTool)
|
||||
assert structured is not None
|
||||
|
||||
def test_bind_tools_azure_with_strict_mode(self) -> None:
|
||||
"""Test Azure endpoint with strict mode enabled."""
|
||||
llm = self._get_azure_model()
|
||||
bound_model = llm.bind_tools([SampleTool], strict=True)
|
||||
assert bound_model is not None
|
||||
|
||||
|
||||
def test_profile() -> None:
|
||||
"""Test that model profile is loaded correctly."""
|
||||
model = ChatDeepSeek(model="deepseek-reasoner", api_key=SecretStr("test_key"))
|
||||
|
||||
Reference in New Issue
Block a user