This commit is contained in:
Chester Curme 2025-05-08 11:05:09 -04:00
parent fc797fadf4
commit 18c1d8a50c

View File

@ -8,14 +8,16 @@ from langchain_anthropic import ChatAnthropic, ChatAnthropicMessages
MODEL_NAME = "claude-3-5-haiku-latest" MODEL_NAME = "claude-3-5-haiku-latest"
async def test_astream() -> None:
def test_stream() -> None:
"""Test streaming tokens from Anthropic.""" """Test streaming tokens from Anthropic."""
llm = ChatAnthropicMessages(model_name=MODEL_NAME) # type: ignore[call-arg, call-arg] llm = ChatAnthropicMessages(model_name=MODEL_NAME) # type: ignore[call-arg, call-arg]
full: Optional[BaseMessageChunk] = None full: Optional[BaseMessageChunk] = None
chunks_with_input_token_counts = 0 chunks_with_input_token_counts = 0
chunks_with_output_token_counts = 0 chunks_with_output_token_counts = 0
async for token in llm.astream("I'm Pickle Rick"): chunks_with_model_name = 0
for token in llm.stream("I'm Pickle Rick"):
assert isinstance(token.content, str) assert isinstance(token.content, str)
full = token if full is None else full + token full = token if full is None else full + token
assert isinstance(token, AIMessageChunk) assert isinstance(token, AIMessageChunk)
@ -24,12 +26,14 @@ async def test_astream() -> None:
chunks_with_input_token_counts += 1 chunks_with_input_token_counts += 1
if token.usage_metadata.get("output_tokens"): if token.usage_metadata.get("output_tokens"):
chunks_with_output_token_counts += 1 chunks_with_output_token_counts += 1
chunks_with_model_name += int("model_name" in token.response_metadata)
if chunks_with_input_token_counts != 1 or chunks_with_output_token_counts != 1: if chunks_with_input_token_counts != 1 or chunks_with_output_token_counts != 1:
raise AssertionError( raise AssertionError(
"Expected exactly one chunk with input or output token counts. " "Expected exactly one chunk with input or output token counts. "
"AIMessageChunk aggregation adds counts. Check that " "AIMessageChunk aggregation adds counts. Check that "
"this is behaving properly." "this is behaving properly."
) )
assert chunks_with_model_name == 1
# check token usage is populated # check token usage is populated
assert isinstance(full, AIMessageChunk) assert isinstance(full, AIMessageChunk)
assert full.usage_metadata is not None assert full.usage_metadata is not None
@ -42,27 +46,7 @@ async def test_astream() -> None:
) )
assert "stop_reason" in full.response_metadata assert "stop_reason" in full.response_metadata
assert "stop_sequence" in full.response_metadata assert "stop_sequence" in full.response_metadata
assert "model_name" in full.response_metadata
# Check expected raw API output
async_client = llm._async_client
params: dict = {
"model": MODEL_NAME,
"max_tokens": 1024,
"messages": [{"role": "user", "content": "hi"}],
"temperature": 0.0,
}
stream = await async_client.messages.create(**params, stream=True)
async for event in stream:
if event.type == "message_start":
assert event.message.usage.input_tokens > 1
# Note: this single output token included in message start event
# does not appear to contribute to overall output token counts. It
# is excluded from the total token count.
assert event.message.usage.output_tokens == 1
elif event.type == "message_delta":
assert event.usage.output_tokens > 1
else:
pass
async def test_stream_usage() -> None: async def test_stream_usage() -> None: