chatperplexity stream-citations in additional kwargs (#29273)

chatperplexity stream-citations in additional kwargs

---------

Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
ThomasSaulou 2025-01-18 23:31:10 +01:00 committed by GitHub
parent 8fad9214c7
commit e9abe583b2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 65 additions and 0 deletions

View File

@ -223,15 +223,21 @@ class ChatPerplexity(BaseChatModel):
stream_resp = self.client.chat.completions.create( stream_resp = self.client.chat.completions.create(
messages=message_dicts, stream=True, **params messages=message_dicts, stream=True, **params
) )
first_chunk = True
for chunk in stream_resp: for chunk in stream_resp:
if not isinstance(chunk, dict): if not isinstance(chunk, dict):
chunk = chunk.dict() chunk = chunk.dict()
if len(chunk["choices"]) == 0: if len(chunk["choices"]) == 0:
continue continue
choice = chunk["choices"][0] choice = chunk["choices"][0]
citations = chunk.get("citations", [])
chunk = self._convert_delta_to_message_chunk( chunk = self._convert_delta_to_message_chunk(
choice["delta"], default_chunk_class choice["delta"], default_chunk_class
) )
if first_chunk:
chunk.additional_kwargs |= {"citations": citations}
first_chunk = False
finish_reason = choice.get("finish_reason") finish_reason = choice.get("finish_reason")
generation_info = ( generation_info = (
dict(finish_reason=finish_reason) if finish_reason is not None else None dict(finish_reason=finish_reason) if finish_reason is not None else None

View File

@ -1,8 +1,12 @@
"""Test Perplexity Chat API wrapper.""" """Test Perplexity Chat API wrapper."""
import os import os
from typing import Any, Dict, List, Optional
from unittest.mock import MagicMock
import pytest import pytest
from langchain_core.messages import AIMessageChunk, BaseMessageChunk
from pytest_mock import MockerFixture
from langchain_community.chat_models import ChatPerplexity from langchain_community.chat_models import ChatPerplexity
@ -40,3 +44,58 @@ def test_perplexity_initialization() -> None:
]: ]:
assert model.request_timeout == 1 assert model.request_timeout == 1
assert model.pplx_api_key == "test" assert model.pplx_api_key == "test"
@pytest.mark.requires("openai")
def test_perplexity_stream_includes_citations(mocker: MockerFixture) -> None:
"""Test that the stream method includes citations in the additional_kwargs."""
llm = ChatPerplexity(
model="test",
timeout=30,
verbose=True,
)
mock_chunk_0 = {
"choices": [
{
"delta": {
"content": "Hello ",
},
"finish_reason": None,
}
],
"citations": ["example.com", "example2.com"],
}
mock_chunk_1 = {
"choices": [
{
"delta": {
"content": "Perplexity",
},
"finish_reason": None,
}
],
"citations": ["example.com", "example2.com"],
}
mock_chunks: List[Dict[str, Any]] = [mock_chunk_0, mock_chunk_1]
mock_stream = MagicMock()
mock_stream.__iter__.return_value = mock_chunks
patcher = mocker.patch.object(
llm.client.chat.completions, "create", return_value=mock_stream
)
stream = llm.stream("Hello langchain")
full: Optional[BaseMessageChunk] = None
for i, chunk in enumerate(stream):
full = chunk if full is None else full + chunk
assert chunk.content == mock_chunks[i]["choices"][0]["delta"]["content"]
if i == 0:
assert chunk.additional_kwargs["citations"] == [
"example.com",
"example2.com",
]
else:
assert "citations" not in chunk.additional_kwargs
assert isinstance(full, AIMessageChunk)
assert full.content == "Hello Perplexity"
assert full.additional_kwargs == {"citations": ["example.com", "example2.com"]}
patcher.assert_called_once()