mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 15:43:54 +00:00
chatperplexity stream-citations in additional kwargs (#29273)
chatperplexity stream-citations in additional kwargs --------- Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
parent
8fad9214c7
commit
e9abe583b2
@ -223,15 +223,21 @@ class ChatPerplexity(BaseChatModel):
|
|||||||
stream_resp = self.client.chat.completions.create(
|
stream_resp = self.client.chat.completions.create(
|
||||||
messages=message_dicts, stream=True, **params
|
messages=message_dicts, stream=True, **params
|
||||||
)
|
)
|
||||||
|
first_chunk = True
|
||||||
for chunk in stream_resp:
|
for chunk in stream_resp:
|
||||||
if not isinstance(chunk, dict):
|
if not isinstance(chunk, dict):
|
||||||
chunk = chunk.dict()
|
chunk = chunk.dict()
|
||||||
if len(chunk["choices"]) == 0:
|
if len(chunk["choices"]) == 0:
|
||||||
continue
|
continue
|
||||||
choice = chunk["choices"][0]
|
choice = chunk["choices"][0]
|
||||||
|
citations = chunk.get("citations", [])
|
||||||
|
|
||||||
chunk = self._convert_delta_to_message_chunk(
|
chunk = self._convert_delta_to_message_chunk(
|
||||||
choice["delta"], default_chunk_class
|
choice["delta"], default_chunk_class
|
||||||
)
|
)
|
||||||
|
if first_chunk:
|
||||||
|
chunk.additional_kwargs |= {"citations": citations}
|
||||||
|
first_chunk = False
|
||||||
finish_reason = choice.get("finish_reason")
|
finish_reason = choice.get("finish_reason")
|
||||||
generation_info = (
|
generation_info = (
|
||||||
dict(finish_reason=finish_reason) if finish_reason is not None else None
|
dict(finish_reason=finish_reason) if finish_reason is not None else None
|
||||||
|
@ -1,8 +1,12 @@
|
|||||||
"""Test Perplexity Chat API wrapper."""
|
"""Test Perplexity Chat API wrapper."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from langchain_core.messages import AIMessageChunk, BaseMessageChunk
|
||||||
|
from pytest_mock import MockerFixture
|
||||||
|
|
||||||
from langchain_community.chat_models import ChatPerplexity
|
from langchain_community.chat_models import ChatPerplexity
|
||||||
|
|
||||||
@ -40,3 +44,58 @@ def test_perplexity_initialization() -> None:
|
|||||||
]:
|
]:
|
||||||
assert model.request_timeout == 1
|
assert model.request_timeout == 1
|
||||||
assert model.pplx_api_key == "test"
|
assert model.pplx_api_key == "test"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("openai")
|
||||||
|
def test_perplexity_stream_includes_citations(mocker: MockerFixture) -> None:
|
||||||
|
"""Test that the stream method includes citations in the additional_kwargs."""
|
||||||
|
llm = ChatPerplexity(
|
||||||
|
model="test",
|
||||||
|
timeout=30,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
mock_chunk_0 = {
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"delta": {
|
||||||
|
"content": "Hello ",
|
||||||
|
},
|
||||||
|
"finish_reason": None,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"citations": ["example.com", "example2.com"],
|
||||||
|
}
|
||||||
|
mock_chunk_1 = {
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"delta": {
|
||||||
|
"content": "Perplexity",
|
||||||
|
},
|
||||||
|
"finish_reason": None,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"citations": ["example.com", "example2.com"],
|
||||||
|
}
|
||||||
|
mock_chunks: List[Dict[str, Any]] = [mock_chunk_0, mock_chunk_1]
|
||||||
|
mock_stream = MagicMock()
|
||||||
|
mock_stream.__iter__.return_value = mock_chunks
|
||||||
|
patcher = mocker.patch.object(
|
||||||
|
llm.client.chat.completions, "create", return_value=mock_stream
|
||||||
|
)
|
||||||
|
stream = llm.stream("Hello langchain")
|
||||||
|
full: Optional[BaseMessageChunk] = None
|
||||||
|
for i, chunk in enumerate(stream):
|
||||||
|
full = chunk if full is None else full + chunk
|
||||||
|
assert chunk.content == mock_chunks[i]["choices"][0]["delta"]["content"]
|
||||||
|
if i == 0:
|
||||||
|
assert chunk.additional_kwargs["citations"] == [
|
||||||
|
"example.com",
|
||||||
|
"example2.com",
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
assert "citations" not in chunk.additional_kwargs
|
||||||
|
assert isinstance(full, AIMessageChunk)
|
||||||
|
assert full.content == "Hello Perplexity"
|
||||||
|
assert full.additional_kwargs == {"citations": ["example.com", "example2.com"]}
|
||||||
|
|
||||||
|
patcher.assert_called_once()
|
||||||
|
Loading…
Reference in New Issue
Block a user