mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 18:50:33 +00:00
fix(fireworks): bind_tools(strict: bool) and reasoning_content (#34343)
Extract strict from kwargs and pass it to convert_to_openai_tool when converting tools. This ensures that when strict is provided, it's properly used during tool conversion and removed from kwargs before calling the parent bind method. Also extract reasoning_content from API responses and store it in additional_kwargs for AIMessage objects. Fixes https://github.com/langchain-ai/langchain/issues/34341 and https://github.com/langchain-ai/langchain/issues/34342 --------- Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
@@ -114,8 +114,12 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
|||||||
# Also Fireworks returns None for tool invocations
|
# Also Fireworks returns None for tool invocations
|
||||||
content = _dict.get("content", "") or ""
|
content = _dict.get("content", "") or ""
|
||||||
additional_kwargs: dict = {}
|
additional_kwargs: dict = {}
|
||||||
|
if reasoning_content := _dict.get("reasoning_content"):
|
||||||
|
additional_kwargs["reasoning_content"] = reasoning_content
|
||||||
|
|
||||||
if function_call := _dict.get("function_call"):
|
if function_call := _dict.get("function_call"):
|
||||||
additional_kwargs["function_call"] = dict(function_call)
|
additional_kwargs["function_call"] = dict(function_call)
|
||||||
|
|
||||||
tool_calls = []
|
tool_calls = []
|
||||||
invalid_tool_calls = []
|
invalid_tool_calls = []
|
||||||
if raw_tool_calls := _dict.get("tool_calls"):
|
if raw_tool_calls := _dict.get("tool_calls"):
|
||||||
@@ -678,7 +682,10 @@ class ChatFireworks(BaseChatModel):
|
|||||||
**kwargs: Any additional parameters to pass to
|
**kwargs: Any additional parameters to pass to
|
||||||
`langchain_fireworks.chat_models.ChatFireworks.bind`
|
`langchain_fireworks.chat_models.ChatFireworks.bind`
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
|
strict = kwargs.pop("strict", None)
|
||||||
|
formatted_tools = [
|
||||||
|
convert_to_openai_tool(tool, strict=strict) for tool in tools
|
||||||
|
]
|
||||||
if tool_choice is not None and tool_choice:
|
if tool_choice is not None and tool_choice:
|
||||||
if isinstance(tool_choice, str) and (
|
if isinstance(tool_choice, str) and (
|
||||||
tool_choice not in ("auto", "any", "none")
|
tool_choice not in ("auto", "any", "none")
|
||||||
|
|||||||
@@ -18,15 +18,30 @@ from langchain_fireworks import ChatFireworks
|
|||||||
_MODEL = "accounts/fireworks/models/gpt-oss-120b"
|
_MODEL = "accounts/fireworks/models/gpt-oss-120b"
|
||||||
|
|
||||||
|
|
||||||
def test_tool_choice_bool() -> None:
|
@pytest.mark.parametrize("strict", [None, True, False])
|
||||||
"""Test that tool choice is respected just passing in True."""
|
def test_tool_choice_bool(strict: bool | None) -> None: # noqa: FBT001
|
||||||
|
"""Test that tool choice is respected with different strict values."""
|
||||||
llm = ChatFireworks(model="fireworks/kimi-k2-instruct-0905")
|
llm = ChatFireworks(model="fireworks/kimi-k2-instruct-0905")
|
||||||
|
|
||||||
class MyTool(BaseModel):
|
class MyTool(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
age: int
|
age: int
|
||||||
|
|
||||||
with_tool = llm.bind_tools([MyTool], tool_choice=True)
|
kwargs = {"tool_choice": True}
|
||||||
|
if strict is not None:
|
||||||
|
kwargs["strict"] = strict
|
||||||
|
with_tool = llm.bind_tools([MyTool], **kwargs)
|
||||||
|
|
||||||
|
# Verify that strict is correctly set in the tool definition
|
||||||
|
assert hasattr(with_tool, "kwargs")
|
||||||
|
tools = with_tool.kwargs.get("tools", [])
|
||||||
|
assert len(tools) == 1
|
||||||
|
tool_def = tools[0]
|
||||||
|
assert "function" in tool_def
|
||||||
|
if strict is None:
|
||||||
|
assert "strict" not in tool_def["function"]
|
||||||
|
else:
|
||||||
|
assert tool_def["function"].get("strict") is strict
|
||||||
|
|
||||||
resp = with_tool.invoke("Who was the 27 year old named Erick?")
|
resp = with_tool.invoke("Who was the 27 year old named Erick?")
|
||||||
assert isinstance(resp, AIMessage)
|
assert isinstance(resp, AIMessage)
|
||||||
|
|||||||
38
libs/partners/fireworks/tests/unit_tests/test_chat_models.py
Normal file
38
libs/partners/fireworks/tests/unit_tests/test_chat_models.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
"""Unit tests for ChatFireworks."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
|
|
||||||
|
from langchain_fireworks.chat_models import _convert_dict_to_message
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_dict_to_message_with_reasoning_content() -> None:
|
||||||
|
"""Test that reasoning_content is correctly extracted from API response."""
|
||||||
|
response_dict = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "The answer is 42.",
|
||||||
|
"reasoning_content": "Let me think about this step by step...",
|
||||||
|
}
|
||||||
|
|
||||||
|
message = _convert_dict_to_message(response_dict)
|
||||||
|
|
||||||
|
assert isinstance(message, AIMessage)
|
||||||
|
assert message.content == "The answer is 42."
|
||||||
|
assert "reasoning_content" in message.additional_kwargs
|
||||||
|
expected_reasoning = "Let me think about this step by step..."
|
||||||
|
assert message.additional_kwargs["reasoning_content"] == expected_reasoning
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_dict_to_message_without_reasoning_content() -> None:
|
||||||
|
"""Test that messages without reasoning_content work correctly."""
|
||||||
|
response_dict = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "The answer is 42.",
|
||||||
|
}
|
||||||
|
|
||||||
|
message = _convert_dict_to_message(response_dict)
|
||||||
|
|
||||||
|
assert isinstance(message, AIMessage)
|
||||||
|
assert message.content == "The answer is 42."
|
||||||
|
assert "reasoning_content" not in message.additional_kwargs
|
||||||
Reference in New Issue
Block a user