mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
anthropic[patch]: support citations in streaming (#29591)
This commit is contained in:
@@ -718,7 +718,9 @@ class ChatAnthropic(BaseChatModel):
|
||||
kwargs["stream"] = True
|
||||
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
||||
stream = self._client.messages.create(**payload)
|
||||
coerce_content_to_string = not _tools_in_params(payload)
|
||||
coerce_content_to_string = not _tools_in_params(
|
||||
payload
|
||||
) and not _documents_in_params(payload)
|
||||
for event in stream:
|
||||
msg = _make_message_chunk_from_anthropic_event(
|
||||
event,
|
||||
@@ -745,7 +747,9 @@ class ChatAnthropic(BaseChatModel):
|
||||
kwargs["stream"] = True
|
||||
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
||||
stream = await self._async_client.messages.create(**payload)
|
||||
coerce_content_to_string = not _tools_in_params(payload)
|
||||
coerce_content_to_string = not _tools_in_params(
|
||||
payload
|
||||
) and not _documents_in_params(payload)
|
||||
async for event in stream:
|
||||
msg = _make_message_chunk_from_anthropic_event(
|
||||
event,
|
||||
@@ -761,6 +765,16 @@ class ChatAnthropic(BaseChatModel):
|
||||
def _format_output(self, data: Any, **kwargs: Any) -> ChatResult:
|
||||
data_dict = data.model_dump()
|
||||
content = data_dict["content"]
|
||||
|
||||
# Remove citations if they are None - introduced in anthropic sdk 0.45
|
||||
for block in content:
|
||||
if (
|
||||
isinstance(block, dict)
|
||||
and "citations" in block
|
||||
and block["citations"] is None
|
||||
):
|
||||
block.pop("citations")
|
||||
|
||||
llm_output = {
|
||||
k: v for k, v in data_dict.items() if k not in ("content", "role", "type")
|
||||
}
|
||||
@@ -1254,6 +1268,19 @@ def _tools_in_params(params: dict) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def _documents_in_params(params: dict) -> bool:
|
||||
for message in params.get("messages", []):
|
||||
if isinstance(message.get("content"), list):
|
||||
for block in message["content"]:
|
||||
if (
|
||||
isinstance(block, dict)
|
||||
and block.get("type") == "document"
|
||||
and block.get("citations", {}).get("enabled")
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class _AnthropicToolUse(TypedDict):
|
||||
type: Literal["tool_use"]
|
||||
name: str
|
||||
@@ -1299,31 +1326,37 @@ def _make_message_chunk_from_anthropic_event(
|
||||
elif (
|
||||
event.type == "content_block_start"
|
||||
and event.content_block is not None
|
||||
and event.content_block.type == "tool_use"
|
||||
and event.content_block.type in ("tool_use", "document")
|
||||
):
|
||||
if coerce_content_to_string:
|
||||
warnings.warn("Received unexpected tool content block.")
|
||||
content_block = event.content_block.model_dump()
|
||||
content_block["index"] = event.index
|
||||
tool_call_chunk = create_tool_call_chunk(
|
||||
index=event.index,
|
||||
id=event.content_block.id,
|
||||
name=event.content_block.name,
|
||||
args="",
|
||||
)
|
||||
if event.content_block.type == "tool_use":
|
||||
tool_call_chunk = create_tool_call_chunk(
|
||||
index=event.index,
|
||||
id=event.content_block.id,
|
||||
name=event.content_block.name,
|
||||
args="",
|
||||
)
|
||||
tool_call_chunks = [tool_call_chunk]
|
||||
else:
|
||||
tool_call_chunks = []
|
||||
message_chunk = AIMessageChunk(
|
||||
content=[content_block],
|
||||
tool_call_chunks=[tool_call_chunk], # type: ignore
|
||||
tool_call_chunks=tool_call_chunks, # type: ignore
|
||||
)
|
||||
elif event.type == "content_block_delta":
|
||||
if event.delta.type == "text_delta":
|
||||
if coerce_content_to_string:
|
||||
if event.delta.type in ("text_delta", "citations_delta"):
|
||||
if coerce_content_to_string and hasattr(event.delta, "text"):
|
||||
text = event.delta.text
|
||||
message_chunk = AIMessageChunk(content=text)
|
||||
else:
|
||||
content_block = event.delta.model_dump()
|
||||
content_block["index"] = event.index
|
||||
content_block["type"] = "text"
|
||||
if "citation" in content_block:
|
||||
content_block["citations"] = [content_block.pop("citation")]
|
||||
message_chunk = AIMessageChunk(content=[content_block])
|
||||
elif event.delta.type == "input_json_delta":
|
||||
content_block = event.delta.model_dump()
|
||||
|
||||
Reference in New Issue
Block a user