mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-27 08:58:48 +00:00
anthropic[patch]: support citations in streaming (#29591)
This commit is contained in:
parent
5ae4ed791d
commit
5cbe6aba8f
@ -718,7 +718,9 @@ class ChatAnthropic(BaseChatModel):
|
|||||||
kwargs["stream"] = True
|
kwargs["stream"] = True
|
||||||
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
||||||
stream = self._client.messages.create(**payload)
|
stream = self._client.messages.create(**payload)
|
||||||
coerce_content_to_string = not _tools_in_params(payload)
|
coerce_content_to_string = not _tools_in_params(
|
||||||
|
payload
|
||||||
|
) and not _documents_in_params(payload)
|
||||||
for event in stream:
|
for event in stream:
|
||||||
msg = _make_message_chunk_from_anthropic_event(
|
msg = _make_message_chunk_from_anthropic_event(
|
||||||
event,
|
event,
|
||||||
@ -745,7 +747,9 @@ class ChatAnthropic(BaseChatModel):
|
|||||||
kwargs["stream"] = True
|
kwargs["stream"] = True
|
||||||
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
||||||
stream = await self._async_client.messages.create(**payload)
|
stream = await self._async_client.messages.create(**payload)
|
||||||
coerce_content_to_string = not _tools_in_params(payload)
|
coerce_content_to_string = not _tools_in_params(
|
||||||
|
payload
|
||||||
|
) and not _documents_in_params(payload)
|
||||||
async for event in stream:
|
async for event in stream:
|
||||||
msg = _make_message_chunk_from_anthropic_event(
|
msg = _make_message_chunk_from_anthropic_event(
|
||||||
event,
|
event,
|
||||||
@ -761,6 +765,16 @@ class ChatAnthropic(BaseChatModel):
|
|||||||
def _format_output(self, data: Any, **kwargs: Any) -> ChatResult:
|
def _format_output(self, data: Any, **kwargs: Any) -> ChatResult:
|
||||||
data_dict = data.model_dump()
|
data_dict = data.model_dump()
|
||||||
content = data_dict["content"]
|
content = data_dict["content"]
|
||||||
|
|
||||||
|
# Remove citations if they are None - introduced in anthropic sdk 0.45
|
||||||
|
for block in content:
|
||||||
|
if (
|
||||||
|
isinstance(block, dict)
|
||||||
|
and "citations" in block
|
||||||
|
and block["citations"] is None
|
||||||
|
):
|
||||||
|
block.pop("citations")
|
||||||
|
|
||||||
llm_output = {
|
llm_output = {
|
||||||
k: v for k, v in data_dict.items() if k not in ("content", "role", "type")
|
k: v for k, v in data_dict.items() if k not in ("content", "role", "type")
|
||||||
}
|
}
|
||||||
@ -1254,6 +1268,19 @@ def _tools_in_params(params: dict) -> bool:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _documents_in_params(params: dict) -> bool:
|
||||||
|
for message in params.get("messages", []):
|
||||||
|
if isinstance(message.get("content"), list):
|
||||||
|
for block in message["content"]:
|
||||||
|
if (
|
||||||
|
isinstance(block, dict)
|
||||||
|
and block.get("type") == "document"
|
||||||
|
and block.get("citations", {}).get("enabled")
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class _AnthropicToolUse(TypedDict):
|
class _AnthropicToolUse(TypedDict):
|
||||||
type: Literal["tool_use"]
|
type: Literal["tool_use"]
|
||||||
name: str
|
name: str
|
||||||
@ -1299,31 +1326,37 @@ def _make_message_chunk_from_anthropic_event(
|
|||||||
elif (
|
elif (
|
||||||
event.type == "content_block_start"
|
event.type == "content_block_start"
|
||||||
and event.content_block is not None
|
and event.content_block is not None
|
||||||
and event.content_block.type == "tool_use"
|
and event.content_block.type in ("tool_use", "document")
|
||||||
):
|
):
|
||||||
if coerce_content_to_string:
|
if coerce_content_to_string:
|
||||||
warnings.warn("Received unexpected tool content block.")
|
warnings.warn("Received unexpected tool content block.")
|
||||||
content_block = event.content_block.model_dump()
|
content_block = event.content_block.model_dump()
|
||||||
content_block["index"] = event.index
|
content_block["index"] = event.index
|
||||||
tool_call_chunk = create_tool_call_chunk(
|
if event.content_block.type == "tool_use":
|
||||||
index=event.index,
|
tool_call_chunk = create_tool_call_chunk(
|
||||||
id=event.content_block.id,
|
index=event.index,
|
||||||
name=event.content_block.name,
|
id=event.content_block.id,
|
||||||
args="",
|
name=event.content_block.name,
|
||||||
)
|
args="",
|
||||||
|
)
|
||||||
|
tool_call_chunks = [tool_call_chunk]
|
||||||
|
else:
|
||||||
|
tool_call_chunks = []
|
||||||
message_chunk = AIMessageChunk(
|
message_chunk = AIMessageChunk(
|
||||||
content=[content_block],
|
content=[content_block],
|
||||||
tool_call_chunks=[tool_call_chunk], # type: ignore
|
tool_call_chunks=tool_call_chunks, # type: ignore
|
||||||
)
|
)
|
||||||
elif event.type == "content_block_delta":
|
elif event.type == "content_block_delta":
|
||||||
if event.delta.type == "text_delta":
|
if event.delta.type in ("text_delta", "citations_delta"):
|
||||||
if coerce_content_to_string:
|
if coerce_content_to_string and hasattr(event.delta, "text"):
|
||||||
text = event.delta.text
|
text = event.delta.text
|
||||||
message_chunk = AIMessageChunk(content=text)
|
message_chunk = AIMessageChunk(content=text)
|
||||||
else:
|
else:
|
||||||
content_block = event.delta.model_dump()
|
content_block = event.delta.model_dump()
|
||||||
content_block["index"] = event.index
|
content_block["index"] = event.index
|
||||||
content_block["type"] = "text"
|
content_block["type"] = "text"
|
||||||
|
if "citation" in content_block:
|
||||||
|
content_block["citations"] = [content_block.pop("citation")]
|
||||||
message_chunk = AIMessageChunk(content=[content_block])
|
message_chunk = AIMessageChunk(content=[content_block])
|
||||||
elif event.delta.type == "input_json_delta":
|
elif event.delta.type == "input_json_delta":
|
||||||
content_block = event.delta.model_dump()
|
content_block = event.delta.model_dump()
|
||||||
|
8
libs/partners/anthropic/poetry.lock
generated
8
libs/partners/anthropic/poetry.lock
generated
@ -1,4 +1,4 @@
|
|||||||
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
|
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "annotated-types"
|
name = "annotated-types"
|
||||||
@ -544,7 +544,7 @@ url = "../../core"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "langchain-tests"
|
name = "langchain-tests"
|
||||||
version = "0.3.9"
|
version = "0.3.10"
|
||||||
description = "Standard tests for LangChain implementations"
|
description = "Standard tests for LangChain implementations"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.9,<4.0"
|
python-versions = ">=3.9,<4.0"
|
||||||
@ -553,7 +553,7 @@ develop = true
|
|||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
httpx = ">=0.25.0,<1"
|
httpx = ">=0.25.0,<1"
|
||||||
langchain-core = "^0.3.31"
|
langchain-core = "^0.3.33"
|
||||||
numpy = [
|
numpy = [
|
||||||
{version = ">=1.24.0,<2.0.0", markers = "python_version < \"3.12\""},
|
{version = ">=1.24.0,<2.0.0", markers = "python_version < \"3.12\""},
|
||||||
{version = ">=1.26.2,<3", markers = "python_version >= \"3.12\""},
|
{version = ">=1.26.2,<3", markers = "python_version >= \"3.12\""},
|
||||||
@ -1558,4 +1558,4 @@ cffi = ["cffi (>=1.11)"]
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.9,<4.0"
|
python-versions = ">=3.9,<4.0"
|
||||||
content-hash = "311f60adde43535d7e5dc14d93167440a8fb838868d544ac1b89b0faac174efe"
|
content-hash = "e92c081312cc5199ac660851355ac9bc59f0333a4a2e61375b5ecfaffe4725b7"
|
||||||
|
@ -21,7 +21,7 @@ plugins = ['pydantic.mypy']
|
|||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = ">=3.9,<4.0"
|
python = ">=3.9,<4.0"
|
||||||
anthropic = ">=0.41.0,<1"
|
anthropic = ">=0.45.0,<1"
|
||||||
langchain-core = "^0.3.33"
|
langchain-core = "^0.3.33"
|
||||||
pydantic = "^2.7.4"
|
pydantic = "^2.7.4"
|
||||||
|
|
||||||
|
@ -649,3 +649,12 @@ def test_citations() -> None:
|
|||||||
assert isinstance(response, AIMessage)
|
assert isinstance(response, AIMessage)
|
||||||
assert isinstance(response.content, list)
|
assert isinstance(response.content, list)
|
||||||
assert any("citations" in block for block in response.content)
|
assert any("citations" in block for block in response.content)
|
||||||
|
|
||||||
|
# Test streaming
|
||||||
|
full: Optional[BaseMessageChunk] = None
|
||||||
|
for chunk in llm.stream(messages):
|
||||||
|
full = chunk if full is None else full + chunk
|
||||||
|
assert isinstance(full, AIMessageChunk)
|
||||||
|
assert isinstance(full.content, list)
|
||||||
|
assert any("citations" in block for block in full.content)
|
||||||
|
assert not any("citation" in block for block in full.content)
|
||||||
|
Loading…
Reference in New Issue
Block a user