anthropic[patch]: pass back in citations in multi-turn conversations (#31882)

Also adds VCR cassettes for some heavy tests.
This commit is contained in:
ccurme 2025-07-05 17:33:22 -04:00 committed by GitHub
parent 46fe09f013
commit 3f4b355eef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 39 additions and 14 deletions

View File

@ -396,7 +396,8 @@ def _format_messages(
{
k: v
for k, v in block.items()
if k in ("type", "text", "cache_control")
if k
in ("type", "text", "cache_control", "citations")
}
)
elif block["type"] == "thinking":

View File

@ -713,14 +713,24 @@ def test_citations() -> None:
assert any("citations" in block for block in full.content)
assert not any("citation" in block for block in full.content)
# Test pass back in
next_message = {
"role": "user",
"content": "Can you comment on the citations you just made?",
}
_ = llm.invoke(messages + [full, next_message])
@pytest.mark.vcr
def test_thinking() -> None:
llm = ChatAnthropic(
model="claude-3-7-sonnet-latest",
max_tokens=5_000,
thinking={"type": "enabled", "budget_tokens": 2_000},
)
response = llm.invoke("Hello")
input_message = {"role": "user", "content": "Hello"}
response = llm.invoke([input_message])
assert any("thinking" in block for block in response.content)
for block in response.content:
assert isinstance(block, dict)
@ -731,7 +741,7 @@ def test_thinking() -> None:
# Test streaming
full: Optional[BaseMessageChunk] = None
for chunk in llm.stream("Hello"):
for chunk in llm.stream([input_message]):
if full is None:
full = cast(BaseMessageChunk, chunk)
else:
@ -746,8 +756,12 @@ def test_thinking() -> None:
assert block["thinking"] and isinstance(block["thinking"], str)
assert block["signature"] and isinstance(block["signature"], str)
# Test pass back in
next_message = {"role": "user", "content": "How are you?"}
_ = llm.invoke([input_message, full, next_message])
@pytest.mark.flaky(retries=3, delay=1)
@pytest.mark.vcr
def test_redacted_thinking() -> None:
llm = ChatAnthropic(
model="claude-3-7-sonnet-latest",
@ -755,8 +769,9 @@ def test_redacted_thinking() -> None:
thinking={"type": "enabled", "budget_tokens": 2_000},
)
query = "ANTHROPIC_MAGIC_STRING_TRIGGER_REDACTED_THINKING_46C9A13E193C177646C7398A98432ECCCE4C1253D5E2D82641AC0E52CC2876CB" # noqa: E501
input_message = {"role": "user", "content": query}
response = llm.invoke(query)
response = llm.invoke([input_message])
has_reasoning = False
for block in response.content:
assert isinstance(block, dict)
@ -768,7 +783,7 @@ def test_redacted_thinking() -> None:
# Test streaming
full: Optional[BaseMessageChunk] = None
for chunk in llm.stream(query):
for chunk in llm.stream([input_message]):
if full is None:
full = cast(BaseMessageChunk, chunk)
else:
@ -784,6 +799,10 @@ def test_redacted_thinking() -> None:
assert block["data"] and isinstance(block["data"], str)
assert stream_has_reasoning
# Test pass back in
next_message = {"role": "user", "content": "What?"}
_ = llm.invoke([input_message, full, next_message])
def test_structured_output_thinking_enabled() -> None:
llm = ChatAnthropic(
@ -882,9 +901,8 @@ def test_image_tool_calling() -> None:
llm.bind_tools([color_picker]).invoke(messages)
# TODO: set up VCR
@pytest.mark.vcr
def test_web_search() -> None:
pytest.skip()
llm = ChatAnthropic(model="claude-3-5-sonnet-latest")
tool = {"type": "web_search_20250305", "name": "web_search", "max_uses": 1}
@ -900,7 +918,8 @@ def test_web_search() -> None:
],
}
response = llm_with_tools.invoke([input_message])
block_types = {block["type"] for block in response.content}
assert all(isinstance(block, dict) for block in response.content)
block_types = {block["type"] for block in response.content} # type: ignore[index]
assert block_types == {"text", "server_tool_use", "web_search_tool_result"}
# Test streaming
@ -923,11 +942,12 @@ def test_web_search() -> None:
)
@pytest.mark.vcr
def test_code_execution() -> None:
pytest.skip()
llm = ChatAnthropic(
model="claude-sonnet-4-20250514",
betas=["code-execution-2025-05-22"],
max_tokens=10_000,
)
tool = {"type": "code_execution_20250522", "name": "code_execution"}
@ -946,7 +966,8 @@ def test_code_execution() -> None:
],
}
response = llm_with_tools.invoke([input_message])
block_types = {block["type"] for block in response.content}
assert all(isinstance(block, dict) for block in response.content)
block_types = {block["type"] for block in response.content} # type: ignore[index]
assert block_types == {"text", "server_tool_use", "code_execution_tool_result"}
# Test streaming
@ -969,8 +990,8 @@ def test_code_execution() -> None:
)
@pytest.mark.vcr
def test_remote_mcp() -> None:
pytest.skip()
mcp_servers = [
{
"type": "url",
@ -985,6 +1006,7 @@ def test_remote_mcp() -> None:
model="claude-sonnet-4-20250514",
betas=["mcp-client-2025-04-04"],
mcp_servers=mcp_servers,
max_tokens=10_000,
)
input_message = {
@ -1000,7 +1022,8 @@ def test_remote_mcp() -> None:
],
}
response = llm.invoke([input_message])
block_types = {block["type"] for block in response.content}
assert all(isinstance(block, dict) for block in response.content)
block_types = {block["type"] for block in response.content} # type: ignore[index]
assert block_types == {"text", "mcp_tool_use", "mcp_tool_result"}
# Test streaming
@ -1010,7 +1033,8 @@ def test_remote_mcp() -> None:
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
assert isinstance(full.content, list)
block_types = {block["type"] for block in full.content}
assert all(isinstance(block, dict) for block in full.content)
block_types = {block["type"] for block in full.content} # type: ignore[index]
assert block_types == {"text", "mcp_tool_use", "mcp_tool_result"}
# Test we can pass back in