mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-27 08:58:48 +00:00
anthropic[patch]: support citations in streaming (#29591)
This commit is contained in:
parent
5ae4ed791d
commit
5cbe6aba8f
@ -718,7 +718,9 @@ class ChatAnthropic(BaseChatModel):
|
||||
kwargs["stream"] = True
|
||||
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
||||
stream = self._client.messages.create(**payload)
|
||||
coerce_content_to_string = not _tools_in_params(payload)
|
||||
coerce_content_to_string = not _tools_in_params(
|
||||
payload
|
||||
) and not _documents_in_params(payload)
|
||||
for event in stream:
|
||||
msg = _make_message_chunk_from_anthropic_event(
|
||||
event,
|
||||
@ -745,7 +747,9 @@ class ChatAnthropic(BaseChatModel):
|
||||
kwargs["stream"] = True
|
||||
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
||||
stream = await self._async_client.messages.create(**payload)
|
||||
coerce_content_to_string = not _tools_in_params(payload)
|
||||
coerce_content_to_string = not _tools_in_params(
|
||||
payload
|
||||
) and not _documents_in_params(payload)
|
||||
async for event in stream:
|
||||
msg = _make_message_chunk_from_anthropic_event(
|
||||
event,
|
||||
@ -761,6 +765,16 @@ class ChatAnthropic(BaseChatModel):
|
||||
def _format_output(self, data: Any, **kwargs: Any) -> ChatResult:
|
||||
data_dict = data.model_dump()
|
||||
content = data_dict["content"]
|
||||
|
||||
# Remove citations if they are None - introduced in anthropic sdk 0.45
|
||||
for block in content:
|
||||
if (
|
||||
isinstance(block, dict)
|
||||
and "citations" in block
|
||||
and block["citations"] is None
|
||||
):
|
||||
block.pop("citations")
|
||||
|
||||
llm_output = {
|
||||
k: v for k, v in data_dict.items() if k not in ("content", "role", "type")
|
||||
}
|
||||
@ -1254,6 +1268,19 @@ def _tools_in_params(params: dict) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def _documents_in_params(params: dict) -> bool:
|
||||
for message in params.get("messages", []):
|
||||
if isinstance(message.get("content"), list):
|
||||
for block in message["content"]:
|
||||
if (
|
||||
isinstance(block, dict)
|
||||
and block.get("type") == "document"
|
||||
and block.get("citations", {}).get("enabled")
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class _AnthropicToolUse(TypedDict):
|
||||
type: Literal["tool_use"]
|
||||
name: str
|
||||
@ -1299,31 +1326,37 @@ def _make_message_chunk_from_anthropic_event(
|
||||
elif (
|
||||
event.type == "content_block_start"
|
||||
and event.content_block is not None
|
||||
and event.content_block.type == "tool_use"
|
||||
and event.content_block.type in ("tool_use", "document")
|
||||
):
|
||||
if coerce_content_to_string:
|
||||
warnings.warn("Received unexpected tool content block.")
|
||||
content_block = event.content_block.model_dump()
|
||||
content_block["index"] = event.index
|
||||
if event.content_block.type == "tool_use":
|
||||
tool_call_chunk = create_tool_call_chunk(
|
||||
index=event.index,
|
||||
id=event.content_block.id,
|
||||
name=event.content_block.name,
|
||||
args="",
|
||||
)
|
||||
tool_call_chunks = [tool_call_chunk]
|
||||
else:
|
||||
tool_call_chunks = []
|
||||
message_chunk = AIMessageChunk(
|
||||
content=[content_block],
|
||||
tool_call_chunks=[tool_call_chunk], # type: ignore
|
||||
tool_call_chunks=tool_call_chunks, # type: ignore
|
||||
)
|
||||
elif event.type == "content_block_delta":
|
||||
if event.delta.type == "text_delta":
|
||||
if coerce_content_to_string:
|
||||
if event.delta.type in ("text_delta", "citations_delta"):
|
||||
if coerce_content_to_string and hasattr(event.delta, "text"):
|
||||
text = event.delta.text
|
||||
message_chunk = AIMessageChunk(content=text)
|
||||
else:
|
||||
content_block = event.delta.model_dump()
|
||||
content_block["index"] = event.index
|
||||
content_block["type"] = "text"
|
||||
if "citation" in content_block:
|
||||
content_block["citations"] = [content_block.pop("citation")]
|
||||
message_chunk = AIMessageChunk(content=[content_block])
|
||||
elif event.delta.type == "input_json_delta":
|
||||
content_block = event.delta.model_dump()
|
||||
|
8
libs/partners/anthropic/poetry.lock
generated
8
libs/partners/anthropic/poetry.lock
generated
@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "annotated-types"
|
||||
@ -544,7 +544,7 @@ url = "../../core"
|
||||
|
||||
[[package]]
|
||||
name = "langchain-tests"
|
||||
version = "0.3.9"
|
||||
version = "0.3.10"
|
||||
description = "Standard tests for LangChain implementations"
|
||||
optional = false
|
||||
python-versions = ">=3.9,<4.0"
|
||||
@ -553,7 +553,7 @@ develop = true
|
||||
|
||||
[package.dependencies]
|
||||
httpx = ">=0.25.0,<1"
|
||||
langchain-core = "^0.3.31"
|
||||
langchain-core = "^0.3.33"
|
||||
numpy = [
|
||||
{version = ">=1.24.0,<2.0.0", markers = "python_version < \"3.12\""},
|
||||
{version = ">=1.26.2,<3", markers = "python_version >= \"3.12\""},
|
||||
@ -1558,4 +1558,4 @@ cffi = ["cffi (>=1.11)"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.9,<4.0"
|
||||
content-hash = "311f60adde43535d7e5dc14d93167440a8fb838868d544ac1b89b0faac174efe"
|
||||
content-hash = "e92c081312cc5199ac660851355ac9bc59f0333a4a2e61375b5ecfaffe4725b7"
|
||||
|
@ -21,7 +21,7 @@ plugins = ['pydantic.mypy']
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.9,<4.0"
|
||||
anthropic = ">=0.41.0,<1"
|
||||
anthropic = ">=0.45.0,<1"
|
||||
langchain-core = "^0.3.33"
|
||||
pydantic = "^2.7.4"
|
||||
|
||||
|
@ -649,3 +649,12 @@ def test_citations() -> None:
|
||||
assert isinstance(response, AIMessage)
|
||||
assert isinstance(response.content, list)
|
||||
assert any("citations" in block for block in response.content)
|
||||
|
||||
# Test streaming
|
||||
full: Optional[BaseMessageChunk] = None
|
||||
for chunk in llm.stream(messages):
|
||||
full = chunk if full is None else full + chunk
|
||||
assert isinstance(full, AIMessageChunk)
|
||||
assert isinstance(full.content, list)
|
||||
assert any("citations" in block for block in full.content)
|
||||
assert not any("citation" in block for block in full.content)
|
||||
|
Loading…
Reference in New Issue
Block a user