chatperplexity stream-citations in additional kwargs (#29273)

chatperplexity stream-citations in additional kwargs

---------

Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
ThomasSaulou 2025-01-18 23:31:10 +01:00 committed by GitHub
parent 8fad9214c7
commit e9abe583b2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 65 additions and 0 deletions

View File

@ -223,15 +223,21 @@ class ChatPerplexity(BaseChatModel):
stream_resp = self.client.chat.completions.create(
messages=message_dicts, stream=True, **params
)
first_chunk = True
for chunk in stream_resp:
if not isinstance(chunk, dict):
chunk = chunk.dict()
if len(chunk["choices"]) == 0:
continue
choice = chunk["choices"][0]
citations = chunk.get("citations", [])
chunk = self._convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
if first_chunk:
chunk.additional_kwargs |= {"citations": citations}
first_chunk = False
finish_reason = choice.get("finish_reason")
generation_info = (
dict(finish_reason=finish_reason) if finish_reason is not None else None

View File

@ -1,8 +1,12 @@
"""Test Perplexity Chat API wrapper."""
import os
from typing import Any, Dict, List, Optional
from unittest.mock import MagicMock
import pytest
from langchain_core.messages import AIMessageChunk, BaseMessageChunk
from pytest_mock import MockerFixture
from langchain_community.chat_models import ChatPerplexity
@ -40,3 +44,58 @@ def test_perplexity_initialization() -> None:
]:
assert model.request_timeout == 1
assert model.pplx_api_key == "test"
@pytest.mark.requires("openai")
def test_perplexity_stream_includes_citations(mocker: MockerFixture) -> None:
"""Test that the stream method includes citations in the additional_kwargs."""
llm = ChatPerplexity(
model="test",
timeout=30,
verbose=True,
)
mock_chunk_0 = {
"choices": [
{
"delta": {
"content": "Hello ",
},
"finish_reason": None,
}
],
"citations": ["example.com", "example2.com"],
}
mock_chunk_1 = {
"choices": [
{
"delta": {
"content": "Perplexity",
},
"finish_reason": None,
}
],
"citations": ["example.com", "example2.com"],
}
mock_chunks: List[Dict[str, Any]] = [mock_chunk_0, mock_chunk_1]
mock_stream = MagicMock()
mock_stream.__iter__.return_value = mock_chunks
patcher = mocker.patch.object(
llm.client.chat.completions, "create", return_value=mock_stream
)
stream = llm.stream("Hello langchain")
full: Optional[BaseMessageChunk] = None
for i, chunk in enumerate(stream):
full = chunk if full is None else full + chunk
assert chunk.content == mock_chunks[i]["choices"][0]["delta"]["content"]
if i == 0:
assert chunk.additional_kwargs["citations"] == [
"example.com",
"example2.com",
]
else:
assert "citations" not in chunk.additional_kwargs
assert isinstance(full, AIMessageChunk)
assert full.content == "Hello Perplexity"
assert full.additional_kwargs == {"citations": ["example.com", "example2.com"]}
patcher.assert_called_once()