anthropic[patch]: pass back in citations in multi-turn conversations (#31882)

Also adds VCR cassettes for some heavy tests.
This commit is contained in:
ccurme 2025-07-05 17:33:22 -04:00 committed by GitHub
parent 46fe09f013
commit 3f4b355eef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 39 additions and 14 deletions

View File

@ -396,7 +396,8 @@ def _format_messages(
{ {
k: v k: v
for k, v in block.items() for k, v in block.items()
if k in ("type", "text", "cache_control") if k
in ("type", "text", "cache_control", "citations")
} }
) )
elif block["type"] == "thinking": elif block["type"] == "thinking":

View File

@ -713,14 +713,24 @@ def test_citations() -> None:
assert any("citations" in block for block in full.content) assert any("citations" in block for block in full.content)
assert not any("citation" in block for block in full.content) assert not any("citation" in block for block in full.content)
# Test pass back in
next_message = {
"role": "user",
"content": "Can you comment on the citations you just made?",
}
_ = llm.invoke(messages + [full, next_message])
@pytest.mark.vcr
def test_thinking() -> None: def test_thinking() -> None:
llm = ChatAnthropic( llm = ChatAnthropic(
model="claude-3-7-sonnet-latest", model="claude-3-7-sonnet-latest",
max_tokens=5_000, max_tokens=5_000,
thinking={"type": "enabled", "budget_tokens": 2_000}, thinking={"type": "enabled", "budget_tokens": 2_000},
) )
response = llm.invoke("Hello")
input_message = {"role": "user", "content": "Hello"}
response = llm.invoke([input_message])
assert any("thinking" in block for block in response.content) assert any("thinking" in block for block in response.content)
for block in response.content: for block in response.content:
assert isinstance(block, dict) assert isinstance(block, dict)
@ -731,7 +741,7 @@ def test_thinking() -> None:
# Test streaming # Test streaming
full: Optional[BaseMessageChunk] = None full: Optional[BaseMessageChunk] = None
for chunk in llm.stream("Hello"): for chunk in llm.stream([input_message]):
if full is None: if full is None:
full = cast(BaseMessageChunk, chunk) full = cast(BaseMessageChunk, chunk)
else: else:
@ -746,8 +756,12 @@ def test_thinking() -> None:
assert block["thinking"] and isinstance(block["thinking"], str) assert block["thinking"] and isinstance(block["thinking"], str)
assert block["signature"] and isinstance(block["signature"], str) assert block["signature"] and isinstance(block["signature"], str)
# Test pass back in
next_message = {"role": "user", "content": "How are you?"}
_ = llm.invoke([input_message, full, next_message])
@pytest.mark.flaky(retries=3, delay=1)
@pytest.mark.vcr
def test_redacted_thinking() -> None: def test_redacted_thinking() -> None:
llm = ChatAnthropic( llm = ChatAnthropic(
model="claude-3-7-sonnet-latest", model="claude-3-7-sonnet-latest",
@ -755,8 +769,9 @@ def test_redacted_thinking() -> None:
thinking={"type": "enabled", "budget_tokens": 2_000}, thinking={"type": "enabled", "budget_tokens": 2_000},
) )
query = "ANTHROPIC_MAGIC_STRING_TRIGGER_REDACTED_THINKING_46C9A13E193C177646C7398A98432ECCCE4C1253D5E2D82641AC0E52CC2876CB" # noqa: E501 query = "ANTHROPIC_MAGIC_STRING_TRIGGER_REDACTED_THINKING_46C9A13E193C177646C7398A98432ECCCE4C1253D5E2D82641AC0E52CC2876CB" # noqa: E501
input_message = {"role": "user", "content": query}
response = llm.invoke(query) response = llm.invoke([input_message])
has_reasoning = False has_reasoning = False
for block in response.content: for block in response.content:
assert isinstance(block, dict) assert isinstance(block, dict)
@ -768,7 +783,7 @@ def test_redacted_thinking() -> None:
# Test streaming # Test streaming
full: Optional[BaseMessageChunk] = None full: Optional[BaseMessageChunk] = None
for chunk in llm.stream(query): for chunk in llm.stream([input_message]):
if full is None: if full is None:
full = cast(BaseMessageChunk, chunk) full = cast(BaseMessageChunk, chunk)
else: else:
@ -784,6 +799,10 @@ def test_redacted_thinking() -> None:
assert block["data"] and isinstance(block["data"], str) assert block["data"] and isinstance(block["data"], str)
assert stream_has_reasoning assert stream_has_reasoning
# Test pass back in
next_message = {"role": "user", "content": "What?"}
_ = llm.invoke([input_message, full, next_message])
def test_structured_output_thinking_enabled() -> None: def test_structured_output_thinking_enabled() -> None:
llm = ChatAnthropic( llm = ChatAnthropic(
@ -882,9 +901,8 @@ def test_image_tool_calling() -> None:
llm.bind_tools([color_picker]).invoke(messages) llm.bind_tools([color_picker]).invoke(messages)
# TODO: set up VCR @pytest.mark.vcr
def test_web_search() -> None: def test_web_search() -> None:
pytest.skip()
llm = ChatAnthropic(model="claude-3-5-sonnet-latest") llm = ChatAnthropic(model="claude-3-5-sonnet-latest")
tool = {"type": "web_search_20250305", "name": "web_search", "max_uses": 1} tool = {"type": "web_search_20250305", "name": "web_search", "max_uses": 1}
@ -900,7 +918,8 @@ def test_web_search() -> None:
], ],
} }
response = llm_with_tools.invoke([input_message]) response = llm_with_tools.invoke([input_message])
block_types = {block["type"] for block in response.content} assert all(isinstance(block, dict) for block in response.content)
block_types = {block["type"] for block in response.content} # type: ignore[index]
assert block_types == {"text", "server_tool_use", "web_search_tool_result"} assert block_types == {"text", "server_tool_use", "web_search_tool_result"}
# Test streaming # Test streaming
@ -923,11 +942,12 @@ def test_web_search() -> None:
) )
@pytest.mark.vcr
def test_code_execution() -> None: def test_code_execution() -> None:
pytest.skip()
llm = ChatAnthropic( llm = ChatAnthropic(
model="claude-sonnet-4-20250514", model="claude-sonnet-4-20250514",
betas=["code-execution-2025-05-22"], betas=["code-execution-2025-05-22"],
max_tokens=10_000,
) )
tool = {"type": "code_execution_20250522", "name": "code_execution"} tool = {"type": "code_execution_20250522", "name": "code_execution"}
@ -946,7 +966,8 @@ def test_code_execution() -> None:
], ],
} }
response = llm_with_tools.invoke([input_message]) response = llm_with_tools.invoke([input_message])
block_types = {block["type"] for block in response.content} assert all(isinstance(block, dict) for block in response.content)
block_types = {block["type"] for block in response.content} # type: ignore[index]
assert block_types == {"text", "server_tool_use", "code_execution_tool_result"} assert block_types == {"text", "server_tool_use", "code_execution_tool_result"}
# Test streaming # Test streaming
@ -969,8 +990,8 @@ def test_code_execution() -> None:
) )
@pytest.mark.vcr
def test_remote_mcp() -> None: def test_remote_mcp() -> None:
pytest.skip()
mcp_servers = [ mcp_servers = [
{ {
"type": "url", "type": "url",
@ -985,6 +1006,7 @@ def test_remote_mcp() -> None:
model="claude-sonnet-4-20250514", model="claude-sonnet-4-20250514",
betas=["mcp-client-2025-04-04"], betas=["mcp-client-2025-04-04"],
mcp_servers=mcp_servers, mcp_servers=mcp_servers,
max_tokens=10_000,
) )
input_message = { input_message = {
@ -1000,7 +1022,8 @@ def test_remote_mcp() -> None:
], ],
} }
response = llm.invoke([input_message]) response = llm.invoke([input_message])
block_types = {block["type"] for block in response.content} assert all(isinstance(block, dict) for block in response.content)
block_types = {block["type"] for block in response.content} # type: ignore[index]
assert block_types == {"text", "mcp_tool_use", "mcp_tool_result"} assert block_types == {"text", "mcp_tool_use", "mcp_tool_result"}
# Test streaming # Test streaming
@ -1010,7 +1033,8 @@ def test_remote_mcp() -> None:
full = chunk if full is None else full + chunk full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk) assert isinstance(full, AIMessageChunk)
assert isinstance(full.content, list) assert isinstance(full.content, list)
block_types = {block["type"] for block in full.content} assert all(isinstance(block, dict) for block in full.content)
block_types = {block["type"] for block in full.content} # type: ignore[index]
assert block_types == {"text", "mcp_tool_use", "mcp_tool_result"} assert block_types == {"text", "mcp_tool_use", "mcp_tool_result"}
# Test we can pass back in # Test we can pass back in