mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 18:50:33 +00:00
Fixes #37912 `ChatPerplexity._convert_message_to_dict` raises `TypeError` on `ToolMessage` and drops `AIMessage.tool_calls`, which breaks tool-message round-trips through `ChatPerplexity` — a client-side tool-calling loop, or a shared message history across providers via `RunnableWithFallbacks`. Repro: ```python from langchain_perplexity import ChatPerplexity from langchain_core.messages import ToolMessage ChatPerplexity(model="sonar")._convert_message_to_dict( ToolMessage(content="result", tool_call_id="call_1") ) # TypeError: Got unknown type content='result' tool_call_id='call_1' ``` An `AIMessage` carrying `tool_calls` also serializes to `{"role": "assistant", "content": ...}` with the `tool_calls` silently dropped. This brings the converter to parity with `langchain-openai`: serialize `tool_calls` / `invalid_tool_calls`, send `content` as `null` when tool_calls are present, and add a `tool`-role branch for `ToolMessage`. How I verified: added unit tests for the `ToolMessage` and `AIMessage.tool_calls` / `invalid_tool_calls` cases; the perplexity package unit tests, lint, and format all pass. Scope: translating these to the Responses (Agent) API's `function_call` / `function_call_output` input items is a separate follow-up; this PR is the Chat Completions serialization parity fix. --------- Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com> Co-authored-by: Mason Daugherty <mason@langchain.dev> Co-authored-by: Mason Daugherty <github@mdrxy.com>
346 lines
11 KiB
Python
346 lines
11 KiB
Python
import json
|
|
from typing import Any, cast
|
|
from unittest.mock import MagicMock
|
|
|
|
from langchain_core.messages import (
|
|
AIMessage,
|
|
AIMessageChunk,
|
|
BaseMessage,
|
|
ToolMessage,
|
|
)
|
|
from pytest_mock import MockerFixture
|
|
|
|
from langchain_perplexity import ChatPerplexity, MediaResponse, WebSearchOptions
|
|
from langchain_perplexity.chat_models import _create_usage_metadata
|
|
|
|
|
|
def test_perplexity_model_name_param() -> None:
|
|
llm = ChatPerplexity(model="foo")
|
|
assert llm.model == "foo"
|
|
|
|
|
|
def test_perplexity_model_kwargs() -> None:
|
|
llm = ChatPerplexity(model="test", model_kwargs={"foo": "bar"})
|
|
assert llm.model_kwargs == {"foo": "bar"}
|
|
|
|
|
|
def test_perplexity_initialization() -> None:
|
|
"""Test perplexity initialization."""
|
|
# Verify that chat perplexity can be initialized using a secret key provided
|
|
# as a parameter rather than an environment variable.
|
|
for model in [
|
|
ChatPerplexity(
|
|
model="test", timeout=1, api_key="test", temperature=0.7, verbose=True
|
|
),
|
|
ChatPerplexity(
|
|
model="test",
|
|
request_timeout=1,
|
|
pplx_api_key="test",
|
|
temperature=0.7,
|
|
verbose=True,
|
|
),
|
|
]:
|
|
assert model.request_timeout == 1
|
|
assert (
|
|
model.pplx_api_key is not None
|
|
and model.pplx_api_key.get_secret_value() == "test"
|
|
)
|
|
|
|
|
|
def test_perplexity_new_params() -> None:
|
|
"""Test new Perplexity-specific parameters."""
|
|
web_search_options = WebSearchOptions(search_type="pro", search_context_size="high")
|
|
media_response = MediaResponse(overrides={"return_videos": True})
|
|
|
|
llm = ChatPerplexity(
|
|
model="sonar-pro",
|
|
search_mode="academic",
|
|
web_search_options=web_search_options,
|
|
media_response=media_response,
|
|
return_images=True,
|
|
)
|
|
|
|
params = llm._default_params
|
|
assert params["search_mode"] == "academic"
|
|
assert params["web_search_options"] == {
|
|
"search_type": "pro",
|
|
"search_context_size": "high",
|
|
}
|
|
|
|
assert params["extra_body"]["media_response"] == {
|
|
"overrides": {"return_videos": True}
|
|
}
|
|
assert params["return_images"] is True
|
|
|
|
|
|
def test_perplexity_stream_includes_citations(mocker: MockerFixture) -> None:
|
|
"""Test that the stream method includes citations in the additional_kwargs."""
|
|
llm = ChatPerplexity(model="test", timeout=30, verbose=True)
|
|
mock_chunk_0 = {
|
|
"choices": [{"delta": {"content": "Hello "}, "finish_reason": None}],
|
|
"citations": ["example.com", "example2.com"],
|
|
}
|
|
mock_chunk_1 = {
|
|
"choices": [{"delta": {"content": "Perplexity"}, "finish_reason": None}],
|
|
"citations": ["example.com", "example2.com"],
|
|
}
|
|
mock_chunk_2 = {
|
|
"choices": [{"delta": {}, "finish_reason": "stop"}],
|
|
}
|
|
mock_chunks: list[dict[str, Any]] = [mock_chunk_0, mock_chunk_1, mock_chunk_2]
|
|
mock_stream = MagicMock()
|
|
mock_stream.__iter__.return_value = mock_chunks
|
|
patcher = mocker.patch.object(
|
|
llm.client.chat.completions, "create", return_value=mock_stream
|
|
)
|
|
stream = llm.stream("Hello langchain")
|
|
full: BaseMessage | None = None
|
|
chunks_list = list(stream)
|
|
# BaseChatModel.stream() adds an extra chunk after the final chunk from _stream
|
|
assert len(chunks_list) == 4
|
|
for i, chunk in enumerate(
|
|
chunks_list[:3]
|
|
): # Only check first 3 chunks against mock
|
|
full = chunk if full is None else cast(BaseMessage, full + chunk)
|
|
assert chunk.content == mock_chunks[i]["choices"][0]["delta"].get("content", "")
|
|
if i == 0:
|
|
assert chunk.additional_kwargs["citations"] == [
|
|
"example.com",
|
|
"example2.com",
|
|
]
|
|
else:
|
|
assert "citations" not in chunk.additional_kwargs
|
|
# Process the 4th chunk
|
|
assert full is not None
|
|
full = cast(BaseMessage, full + chunks_list[3])
|
|
assert isinstance(full, AIMessageChunk)
|
|
assert full.content == "Hello Perplexity"
|
|
assert full.additional_kwargs == {"citations": ["example.com", "example2.com"]}
|
|
|
|
patcher.assert_called_once()
|
|
|
|
|
|
def test_perplexity_stream_includes_videos_and_reasoning(mocker: MockerFixture) -> None:
|
|
"""Test that stream extracts videos and reasoning_steps."""
|
|
llm = ChatPerplexity(model="test", timeout=30, verbose=True)
|
|
|
|
mock_chunk_0 = {
|
|
"choices": [{"delta": {"content": "Thinking... "}, "finish_reason": None}],
|
|
"videos": [{"url": "http://video.com", "thumbnail_url": "http://thumb.com"}],
|
|
"reasoning_steps": [{"thought": "I should search", "type": "web_search"}],
|
|
}
|
|
mock_chunk_1 = {
|
|
"choices": [{"delta": {}, "finish_reason": "stop"}],
|
|
}
|
|
|
|
mock_chunks: list[dict[str, Any]] = [mock_chunk_0, mock_chunk_1]
|
|
mock_stream = MagicMock()
|
|
mock_stream.__iter__.return_value = mock_chunks
|
|
mocker.patch.object(llm.client.chat.completions, "create", return_value=mock_stream)
|
|
|
|
stream = list(llm.stream("test"))
|
|
first_chunk = stream[0]
|
|
|
|
assert "videos" in first_chunk.additional_kwargs
|
|
assert first_chunk.additional_kwargs["videos"][0]["url"] == "http://video.com"
|
|
assert "reasoning_steps" in first_chunk.additional_kwargs
|
|
assert (
|
|
first_chunk.additional_kwargs["reasoning_steps"][0]["thought"]
|
|
== "I should search"
|
|
)
|
|
|
|
|
|
def test_create_usage_metadata_basic() -> None:
|
|
"""Test _create_usage_metadata with basic token counts."""
|
|
token_usage = {
|
|
"prompt_tokens": 10,
|
|
"completion_tokens": 20,
|
|
"total_tokens": 30,
|
|
"reasoning_tokens": 0,
|
|
"citation_tokens": 0,
|
|
}
|
|
|
|
usage_metadata = _create_usage_metadata(token_usage)
|
|
|
|
assert usage_metadata["input_tokens"] == 10
|
|
assert usage_metadata["output_tokens"] == 20
|
|
assert usage_metadata["total_tokens"] == 30
|
|
assert usage_metadata["output_token_details"]["reasoning"] == 0
|
|
assert usage_metadata["output_token_details"]["citation_tokens"] == 0 # type: ignore[typeddict-item]
|
|
|
|
|
|
def test_perplexity_invoke_includes_num_search_queries(mocker: MockerFixture) -> None:
|
|
"""Test that invoke includes num_search_queries in response_metadata."""
|
|
llm = ChatPerplexity(model="test", timeout=30, verbose=True)
|
|
|
|
mock_usage = MagicMock()
|
|
mock_usage.model_dump.return_value = {
|
|
"prompt_tokens": 10,
|
|
"completion_tokens": 20,
|
|
"total_tokens": 30,
|
|
"num_search_queries": 3,
|
|
"search_context_size": "high",
|
|
}
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.choices = [
|
|
MagicMock(
|
|
message=MagicMock(
|
|
content="Test response",
|
|
tool_calls=None,
|
|
),
|
|
finish_reason="stop",
|
|
)
|
|
]
|
|
mock_response.model = "test-model"
|
|
mock_response.usage = mock_usage
|
|
# Mock optional fields as empty/None
|
|
mock_response.videos = None
|
|
mock_response.reasoning_steps = None
|
|
mock_response.citations = None
|
|
mock_response.search_results = None
|
|
mock_response.images = None
|
|
mock_response.related_questions = None
|
|
|
|
patcher = mocker.patch.object(
|
|
llm.client.chat.completions, "create", return_value=mock_response
|
|
)
|
|
|
|
result = llm.invoke("Test query")
|
|
|
|
assert result.response_metadata["num_search_queries"] == 3
|
|
assert result.response_metadata["search_context_size"] == "high"
|
|
assert result.response_metadata["model_name"] == "test-model"
|
|
patcher.assert_called_once()
|
|
|
|
|
|
def test_profile() -> None:
|
|
model = ChatPerplexity(model="sonar")
|
|
assert model.profile
|
|
|
|
|
|
def test_convert_tool_message_to_dict() -> None:
|
|
"""A ToolMessage serializes to a ``tool``-role dict so tool results can be
|
|
fed back to the model in a client-side tool-calling loop."""
|
|
llm = ChatPerplexity(model="test", api_key="test")
|
|
message = ToolMessage(content="result text", tool_call_id="call_123")
|
|
assert llm._convert_message_to_dict(message) == {
|
|
"role": "tool",
|
|
"content": "result text",
|
|
"tool_call_id": "call_123",
|
|
}
|
|
|
|
|
|
def test_convert_ai_message_with_tool_calls_to_dict() -> None:
|
|
"""``AIMessage.tool_calls`` are serialized rather than dropped."""
|
|
llm = ChatPerplexity(model="test", api_key="test")
|
|
message = AIMessage(
|
|
content="",
|
|
tool_calls=[
|
|
{
|
|
"id": "call_123",
|
|
"name": "search",
|
|
"args": {"query": "langchain"},
|
|
"type": "tool_call",
|
|
}
|
|
],
|
|
)
|
|
result = llm._convert_message_to_dict(message)
|
|
assert result["role"] == "assistant"
|
|
# Empty content alongside tool_calls must be sent as null, not "".
|
|
assert result["content"] is None
|
|
assert result["tool_calls"] == [
|
|
{
|
|
"id": "call_123",
|
|
"type": "function",
|
|
"function": {
|
|
"name": "search",
|
|
"arguments": json.dumps({"query": "langchain"}),
|
|
},
|
|
}
|
|
]
|
|
|
|
|
|
def test_convert_ai_message_with_invalid_tool_calls_to_dict() -> None:
|
|
"""Invalid tool calls are serialized with their raw (unparsed) argument string."""
|
|
llm = ChatPerplexity(model="test", api_key="test")
|
|
message = AIMessage(
|
|
content="",
|
|
invalid_tool_calls=[
|
|
{
|
|
"id": "call_bad",
|
|
"name": "search",
|
|
"args": "{not valid json",
|
|
"error": "could not parse args",
|
|
"type": "invalid_tool_call",
|
|
}
|
|
],
|
|
)
|
|
result = llm._convert_message_to_dict(message)
|
|
assert result["tool_calls"] == [
|
|
{
|
|
"id": "call_bad",
|
|
"type": "function",
|
|
"function": {"name": "search", "arguments": "{not valid json"},
|
|
}
|
|
]
|
|
|
|
|
|
def test_convert_ai_message_preserves_content_alongside_tool_calls() -> None:
|
|
"""Non-empty content is preserved (not nulled) when tool_calls are present."""
|
|
llm = ChatPerplexity(model="test", api_key="test")
|
|
message = AIMessage(
|
|
content="Let me look that up.",
|
|
tool_calls=[
|
|
{
|
|
"id": "call_123",
|
|
"name": "search",
|
|
"args": {"query": "weather"},
|
|
"type": "tool_call",
|
|
}
|
|
],
|
|
)
|
|
result = llm._convert_message_to_dict(message)
|
|
assert result["content"] == "Let me look that up."
|
|
|
|
|
|
def test_convert_ai_message_with_valid_and_invalid_tool_calls_to_dict() -> None:
|
|
"""Valid and invalid tool calls serialize together, valid ones first."""
|
|
llm = ChatPerplexity(model="test", api_key="test")
|
|
message = AIMessage(
|
|
content="",
|
|
tool_calls=[
|
|
{
|
|
"id": "call_ok",
|
|
"name": "search",
|
|
"args": {"query": "weather"},
|
|
"type": "tool_call",
|
|
}
|
|
],
|
|
invalid_tool_calls=[
|
|
{
|
|
"id": "call_bad",
|
|
"name": "search",
|
|
"args": "{not valid json",
|
|
"error": "could not parse args",
|
|
"type": "invalid_tool_call",
|
|
}
|
|
],
|
|
)
|
|
result = llm._convert_message_to_dict(message)
|
|
assert result["tool_calls"] == [
|
|
{
|
|
"id": "call_ok",
|
|
"type": "function",
|
|
"function": {
|
|
"name": "search",
|
|
"arguments": json.dumps({"query": "weather"}),
|
|
},
|
|
},
|
|
{
|
|
"id": "call_bad",
|
|
"type": "function",
|
|
"function": {"name": "search", "arguments": "{not valid json"},
|
|
},
|
|
]
|