anthropic[patch]: support citations in streaming (#29591)

This commit is contained in:
ccurme 2025-02-05 09:12:07 -05:00 committed by GitHub
parent 5ae4ed791d
commit 5cbe6aba8f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 59 additions and 17 deletions

View File

@ -718,7 +718,9 @@ class ChatAnthropic(BaseChatModel):
kwargs["stream"] = True
payload = self._get_request_payload(messages, stop=stop, **kwargs)
stream = self._client.messages.create(**payload)
coerce_content_to_string = not _tools_in_params(payload)
coerce_content_to_string = not _tools_in_params(
payload
) and not _documents_in_params(payload)
for event in stream:
msg = _make_message_chunk_from_anthropic_event(
event,
@ -745,7 +747,9 @@ class ChatAnthropic(BaseChatModel):
kwargs["stream"] = True
payload = self._get_request_payload(messages, stop=stop, **kwargs)
stream = await self._async_client.messages.create(**payload)
coerce_content_to_string = not _tools_in_params(payload)
coerce_content_to_string = not _tools_in_params(
payload
) and not _documents_in_params(payload)
async for event in stream:
msg = _make_message_chunk_from_anthropic_event(
event,
@ -761,6 +765,16 @@ class ChatAnthropic(BaseChatModel):
def _format_output(self, data: Any, **kwargs: Any) -> ChatResult:
data_dict = data.model_dump()
content = data_dict["content"]
# Remove citations if they are None - introduced in anthropic sdk 0.45
for block in content:
if (
isinstance(block, dict)
and "citations" in block
and block["citations"] is None
):
block.pop("citations")
llm_output = {
k: v for k, v in data_dict.items() if k not in ("content", "role", "type")
}
@ -1254,6 +1268,19 @@ def _tools_in_params(params: dict) -> bool:
)
def _documents_in_params(params: dict) -> bool:
for message in params.get("messages", []):
if isinstance(message.get("content"), list):
for block in message["content"]:
if (
isinstance(block, dict)
and block.get("type") == "document"
and block.get("citations", {}).get("enabled")
):
return True
return False
class _AnthropicToolUse(TypedDict):
type: Literal["tool_use"]
name: str
@ -1299,31 +1326,37 @@ def _make_message_chunk_from_anthropic_event(
elif (
event.type == "content_block_start"
and event.content_block is not None
and event.content_block.type == "tool_use"
and event.content_block.type in ("tool_use", "document")
):
if coerce_content_to_string:
warnings.warn("Received unexpected tool content block.")
content_block = event.content_block.model_dump()
content_block["index"] = event.index
if event.content_block.type == "tool_use":
tool_call_chunk = create_tool_call_chunk(
index=event.index,
id=event.content_block.id,
name=event.content_block.name,
args="",
)
tool_call_chunks = [tool_call_chunk]
else:
tool_call_chunks = []
message_chunk = AIMessageChunk(
content=[content_block],
tool_call_chunks=[tool_call_chunk], # type: ignore
tool_call_chunks=tool_call_chunks, # type: ignore
)
elif event.type == "content_block_delta":
if event.delta.type == "text_delta":
if coerce_content_to_string:
if event.delta.type in ("text_delta", "citations_delta"):
if coerce_content_to_string and hasattr(event.delta, "text"):
text = event.delta.text
message_chunk = AIMessageChunk(content=text)
else:
content_block = event.delta.model_dump()
content_block["index"] = event.index
content_block["type"] = "text"
if "citation" in content_block:
content_block["citations"] = [content_block.pop("citation")]
message_chunk = AIMessageChunk(content=[content_block])
elif event.delta.type == "input_json_delta":
content_block = event.delta.model_dump()

View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
[[package]]
name = "annotated-types"
@ -544,7 +544,7 @@ url = "../../core"
[[package]]
name = "langchain-tests"
version = "0.3.9"
version = "0.3.10"
description = "Standard tests for LangChain implementations"
optional = false
python-versions = ">=3.9,<4.0"
@ -553,7 +553,7 @@ develop = true
[package.dependencies]
httpx = ">=0.25.0,<1"
langchain-core = "^0.3.31"
langchain-core = "^0.3.33"
numpy = [
{version = ">=1.24.0,<2.0.0", markers = "python_version < \"3.12\""},
{version = ">=1.26.2,<3", markers = "python_version >= \"3.12\""},
@ -1558,4 +1558,4 @@ cffi = ["cffi (>=1.11)"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.9,<4.0"
content-hash = "311f60adde43535d7e5dc14d93167440a8fb838868d544ac1b89b0faac174efe"
content-hash = "e92c081312cc5199ac660851355ac9bc59f0333a4a2e61375b5ecfaffe4725b7"

View File

@ -21,7 +21,7 @@ plugins = ['pydantic.mypy']
[tool.poetry.dependencies]
python = ">=3.9,<4.0"
anthropic = ">=0.41.0,<1"
anthropic = ">=0.45.0,<1"
langchain-core = "^0.3.33"
pydantic = "^2.7.4"

View File

@ -649,3 +649,12 @@ def test_citations() -> None:
assert isinstance(response, AIMessage)
assert isinstance(response.content, list)
assert any("citations" in block for block in response.content)
# Test streaming
full: Optional[BaseMessageChunk] = None
for chunk in llm.stream(messages):
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
assert isinstance(full.content, list)
assert any("citations" in block for block in full.content)
assert not any("citation" in block for block in full.content)