anthropic[patch]: support citations in streaming (#29591)

This commit is contained in:
ccurme
2025-02-05 09:12:07 -05:00
committed by GitHub
parent 5ae4ed791d
commit 5cbe6aba8f
4 changed files with 59 additions and 17 deletions

View File

@@ -718,7 +718,9 @@ class ChatAnthropic(BaseChatModel):
kwargs["stream"] = True
payload = self._get_request_payload(messages, stop=stop, **kwargs)
stream = self._client.messages.create(**payload)
coerce_content_to_string = not _tools_in_params(payload)
coerce_content_to_string = not _tools_in_params(
payload
) and not _documents_in_params(payload)
for event in stream:
msg = _make_message_chunk_from_anthropic_event(
event,
@@ -745,7 +747,9 @@ class ChatAnthropic(BaseChatModel):
kwargs["stream"] = True
payload = self._get_request_payload(messages, stop=stop, **kwargs)
stream = await self._async_client.messages.create(**payload)
coerce_content_to_string = not _tools_in_params(payload)
coerce_content_to_string = not _tools_in_params(
payload
) and not _documents_in_params(payload)
async for event in stream:
msg = _make_message_chunk_from_anthropic_event(
event,
@@ -761,6 +765,16 @@ class ChatAnthropic(BaseChatModel):
def _format_output(self, data: Any, **kwargs: Any) -> ChatResult:
data_dict = data.model_dump()
content = data_dict["content"]
# Remove citations if they are None - introduced in anthropic sdk 0.45
for block in content:
if (
isinstance(block, dict)
and "citations" in block
and block["citations"] is None
):
block.pop("citations")
llm_output = {
k: v for k, v in data_dict.items() if k not in ("content", "role", "type")
}
@@ -1254,6 +1268,19 @@ def _tools_in_params(params: dict) -> bool:
)
def _documents_in_params(params: dict) -> bool:
for message in params.get("messages", []):
if isinstance(message.get("content"), list):
for block in message["content"]:
if (
isinstance(block, dict)
and block.get("type") == "document"
and block.get("citations", {}).get("enabled")
):
return True
return False
class _AnthropicToolUse(TypedDict):
type: Literal["tool_use"]
name: str
@@ -1299,31 +1326,37 @@ def _make_message_chunk_from_anthropic_event(
elif (
event.type == "content_block_start"
and event.content_block is not None
and event.content_block.type == "tool_use"
and event.content_block.type in ("tool_use", "document")
):
if coerce_content_to_string:
warnings.warn("Received unexpected tool content block.")
content_block = event.content_block.model_dump()
content_block["index"] = event.index
tool_call_chunk = create_tool_call_chunk(
index=event.index,
id=event.content_block.id,
name=event.content_block.name,
args="",
)
if event.content_block.type == "tool_use":
tool_call_chunk = create_tool_call_chunk(
index=event.index,
id=event.content_block.id,
name=event.content_block.name,
args="",
)
tool_call_chunks = [tool_call_chunk]
else:
tool_call_chunks = []
message_chunk = AIMessageChunk(
content=[content_block],
tool_call_chunks=[tool_call_chunk], # type: ignore
tool_call_chunks=tool_call_chunks, # type: ignore
)
elif event.type == "content_block_delta":
if event.delta.type == "text_delta":
if coerce_content_to_string:
if event.delta.type in ("text_delta", "citations_delta"):
if coerce_content_to_string and hasattr(event.delta, "text"):
text = event.delta.text
message_chunk = AIMessageChunk(content=text)
else:
content_block = event.delta.model_dump()
content_block["index"] = event.index
content_block["type"] = "text"
if "citation" in content_block:
content_block["citations"] = [content_block.pop("citation")]
message_chunk = AIMessageChunk(content=[content_block])
elif event.delta.type == "input_json_delta":
content_block = event.delta.model_dump()