Files
langchain/libs/partners/groq/tests/unit_tests/test_chat_models.py
Mason Daugherty 8c15649127 fix(openai,groq,openrouter): use is-not-None checks in usage metadata token extraction (#36500)
Python's `or` operator treats `0` as falsy, so
`token_usage.get("total_tokens") or fallback` silently replaces a
provider-reported `total_tokens=0` with the computed sum of input +
output tokens. Providers can legitimately report zero tokens (e.g.,
cached responses, empty completions).

The same pattern exists in the dual-key lookups for
`input_tokens`/`output_tokens` in Groq and OpenRouter. While current
APIs don't return both key formats simultaneously (making the `or`-chain
functionally correct today), the semantics are still wrong; `0` should
not fall through to a fallback.

## Changes

- Replace `x.get(key) or fallback` with explicit `is not None` checks in
`_create_usage_metadata` across `langchain-openai`, `langchain-groq`,
and `langchain-openrouter` for `input_tokens`, `output_tokens`, and
`total_tokens`
- Fix a concrete bug in the `total_tokens` path: a provider-reported `0`
was silently replaced by the computed sum
- Harden dual-key lookups in Groq and OpenRouter to correctly preserve
zero values from the preferred key, should both key formats ever coexist
- Update OpenAI's single-key extraction for consistency — the old `or 0`
pattern happened to produce correct results (`0 or 0 == 0`) but was
semantically wrong
2026-04-03 11:46:36 -04:00

1098 lines
35 KiB
Python

"""Test Groq Chat API wrapper."""
import json
import os
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import langchain_core.load as lc_load
import pytest
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
FunctionMessage,
HumanMessage,
InvalidToolCall,
SystemMessage,
ToolCall,
)
from langchain_core.runnables import RunnableBinding, RunnableSequence
from pydantic import BaseModel
from langchain_groq.chat_models import (
ChatGroq,
_convert_chunk_to_message_chunk,
_convert_dict_to_message,
_create_usage_metadata,
_format_message_content,
)
if "GROQ_API_KEY" not in os.environ:
os.environ["GROQ_API_KEY"] = "fake-key"
def test_groq_model_param() -> None:
llm = ChatGroq(model="foo") # type: ignore[call-arg]
assert llm.model_name == "foo"
assert llm.model == "foo"
llm = ChatGroq(model_name="foo") # type: ignore[call-arg]
assert llm.model_name == "foo"
assert llm.model == "foo"
def test_function_message_dict_to_function_message() -> None:
content = json.dumps({"result": "Example #1"})
name = "test_function"
result = _convert_dict_to_message(
{
"role": "function",
"name": name,
"content": content,
}
)
assert isinstance(result, FunctionMessage)
assert result.name == name
assert result.content == content
def test__convert_dict_to_message_human() -> None:
message = {"role": "user", "content": "foo"}
result = _convert_dict_to_message(message)
expected_output = HumanMessage(content="foo")
assert result == expected_output
def test__convert_dict_to_message_ai() -> None:
message = {"role": "assistant", "content": "foo"}
result = _convert_dict_to_message(message)
expected_output = AIMessage(
content="foo", response_metadata={"model_provider": "groq"}
)
assert result == expected_output
def test__convert_dict_to_message_tool_call() -> None:
raw_tool_call = {
"id": "call_wm0JY6CdwOMZ4eTxHWUThDNz",
"function": {
"arguments": '{"name":"Sally","hair_color":"green"}',
"name": "GenerateUsername",
},
"type": "function",
}
message = {"role": "assistant", "content": None, "tool_calls": [raw_tool_call]}
result = _convert_dict_to_message(message)
expected_output = AIMessage(
content="",
additional_kwargs={"tool_calls": [raw_tool_call]},
tool_calls=[
ToolCall(
name="GenerateUsername",
args={"name": "Sally", "hair_color": "green"},
id="call_wm0JY6CdwOMZ4eTxHWUThDNz",
type="tool_call",
)
],
response_metadata={"model_provider": "groq"},
)
assert result == expected_output
# Test malformed tool call
raw_tool_calls = [
{
"id": "call_wm0JY6CdwOMZ4eTxHWUThDNz",
"function": {
"arguments": "oops",
"name": "GenerateUsername",
},
"type": "function",
},
{
"id": "call_abc123",
"function": {
"arguments": '{"name":"Sally","hair_color":"green"}',
"name": "GenerateUsername",
},
"type": "function",
},
]
message = {"role": "assistant", "content": None, "tool_calls": raw_tool_calls}
result = _convert_dict_to_message(message)
expected_output = AIMessage(
content="",
additional_kwargs={"tool_calls": raw_tool_calls},
invalid_tool_calls=[
InvalidToolCall(
name="GenerateUsername",
args="oops",
id="call_wm0JY6CdwOMZ4eTxHWUThDNz",
error="Function GenerateUsername arguments:\n\noops\n\nare not valid JSON. Received JSONDecodeError Expecting value: line 1 column 1 (char 0)\nFor troubleshooting, visit: https://docs.langchain.com/oss/python/langchain/errors/OUTPUT_PARSING_FAILURE ", # noqa: E501
type="invalid_tool_call",
),
],
tool_calls=[
ToolCall(
name="GenerateUsername",
args={"name": "Sally", "hair_color": "green"},
id="call_abc123",
type="tool_call",
),
],
response_metadata={"model_provider": "groq"},
)
assert result == expected_output
def test__convert_dict_to_message_system() -> None:
message = {"role": "system", "content": "foo"}
result = _convert_dict_to_message(message)
expected_output = SystemMessage(content="foo")
assert result == expected_output
@pytest.fixture
def mock_completion() -> dict:
return {
"id": "chatcmpl-7fcZavknQda3SQ",
"object": "chat.completion",
"created": 1689989000,
"model": "test-model",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Bar Baz",
},
"finish_reason": "stop",
}
],
}
def test_groq_invoke(mock_completion: dict) -> None:
llm = ChatGroq(model="foo")
mock_client = MagicMock()
completed = False
def mock_create(*args: Any, **kwargs: Any) -> Any:
nonlocal completed
completed = True
return mock_completion
mock_client.create = mock_create
with patch.object(
llm,
"client",
mock_client,
):
res = llm.invoke("bar")
assert res.content == "Bar Baz"
assert type(res) is AIMessage
assert completed
async def test_groq_ainvoke(mock_completion: dict) -> None:
llm = ChatGroq(model="foo")
mock_client = AsyncMock()
completed = False
async def mock_create(*args: Any, **kwargs: Any) -> Any:
nonlocal completed
completed = True
return mock_completion
mock_client.create = mock_create
with patch.object(
llm,
"async_client",
mock_client,
):
res = await llm.ainvoke("bar")
assert res.content == "Bar Baz"
assert type(res) is AIMessage
assert completed
def test_chat_groq_extra_kwargs() -> None:
"""Test extra kwargs to chat groq."""
# Check that foo is saved in extra_kwargs.
with pytest.warns(UserWarning) as record:
llm = ChatGroq(model="foo", foo=3, max_tokens=10) # type: ignore[call-arg]
assert llm.max_tokens == 10
assert llm.model_kwargs == {"foo": 3}
assert len(record) == 1
assert type(record[0].message) is UserWarning
assert "foo is not default parameter" in record[0].message.args[0]
# Test that if extra_kwargs are provided, they are added to it.
with pytest.warns(UserWarning) as record:
llm = ChatGroq(model="foo", foo=3, model_kwargs={"bar": 2}) # type: ignore[call-arg]
assert llm.model_kwargs == {"foo": 3, "bar": 2}
assert len(record) == 1
assert type(record[0].message) is UserWarning
assert "foo is not default parameter" in record[0].message.args[0]
# Test that if provided twice it errors
with pytest.raises(ValueError):
ChatGroq(model="foo", foo=3, model_kwargs={"foo": 2}) # type: ignore[call-arg]
# Test that if explicit param is specified in kwargs it errors
with pytest.raises(ValueError):
ChatGroq(model="foo", model_kwargs={"temperature": 0.2})
# Test that "model" cannot be specified in kwargs
with pytest.raises(ValueError):
ChatGroq(model="foo", model_kwargs={"model": "test-model"})
def test_chat_groq_invalid_streaming_params() -> None:
"""Test that an error is raised if streaming is invoked with n>1."""
with pytest.raises(ValueError):
ChatGroq(
model="foo",
max_tokens=10,
streaming=True,
temperature=0,
n=5,
)
def test_with_structured_output_json_schema_strict() -> None:
class Response(BaseModel):
"""Response schema."""
foo: str
structured_model = ChatGroq(model="openai/gpt-oss-20b").with_structured_output(
Response, method="json_schema", strict=True
)
assert isinstance(structured_model, RunnableSequence)
first_step = structured_model.steps[0]
assert isinstance(first_step, RunnableBinding)
response_format = first_step.kwargs["response_format"]
assert response_format["type"] == "json_schema"
json_schema = response_format["json_schema"]
assert json_schema["strict"] is True
assert json_schema["name"] == "Response"
assert json_schema["schema"]["properties"]["foo"]["type"] == "string"
assert "foo" in json_schema["schema"]["required"]
assert json_schema["schema"]["additionalProperties"] is False
def test_with_structured_output_json_schema_strict_ignored_on_unsupported_model() -> (
None
):
class Response(BaseModel):
"""Response schema."""
foo: str
structured_model = ChatGroq(model="llama-3.1-8b-instant").with_structured_output(
Response, method="json_schema", strict=True
)
assert isinstance(structured_model, RunnableSequence)
first_step = structured_model.steps[0]
assert isinstance(first_step, RunnableBinding)
response_format = first_step.kwargs["response_format"]
assert response_format["type"] == "json_schema"
assert "strict" not in response_format["json_schema"]
def test_chat_groq_secret() -> None:
"""Test that secret is not printed."""
secret = "secretKey" # noqa: S105
not_secret = "safe" # noqa: S105
llm = ChatGroq(model="foo", api_key=secret, model_kwargs={"not_secret": not_secret}) # type: ignore[call-arg, arg-type]
stringified = str(llm)
assert not_secret in stringified
assert secret not in stringified
@pytest.mark.filterwarnings("ignore:The function `loads` is in beta")
def test_groq_serialization() -> None:
"""Test that ChatGroq can be successfully serialized and deserialized."""
api_key1 = "top secret"
api_key2 = "topest secret"
llm = ChatGroq(model="foo", api_key=api_key1, temperature=0.5) # type: ignore[call-arg, arg-type]
dump = lc_load.dumps(llm)
llm2 = lc_load.loads(
dump,
valid_namespaces=["langchain_groq"],
secrets_map={"GROQ_API_KEY": api_key2},
allowed_objects="all",
)
assert type(llm2) is ChatGroq
# Ensure api key wasn't dumped and instead was read from secret map.
assert llm.groq_api_key is not None
assert llm.groq_api_key.get_secret_value() not in dump
assert llm2.groq_api_key is not None
assert llm2.groq_api_key.get_secret_value() == api_key2
# Ensure a non-secret field was preserved
assert llm.temperature == llm2.temperature
# Ensure a None was preserved
assert llm.groq_api_base == llm2.groq_api_base
def test_create_usage_metadata_basic() -> None:
"""Test basic usage metadata creation without details."""
token_usage = {
"prompt_tokens": 100,
"completion_tokens": 50,
"total_tokens": 150,
}
result = _create_usage_metadata(token_usage)
assert isinstance(result, dict)
assert result["input_tokens"] == 100
assert result["output_tokens"] == 50
assert result["total_tokens"] == 150
assert "input_token_details" not in result
assert "output_token_details" not in result
def test_create_usage_metadata_responses_api_format() -> None:
"""Test usage metadata creation with new Responses API format."""
token_usage = {
"input_tokens": 1590,
"output_tokens": 77,
"total_tokens": 1667,
"input_tokens_details": {"cached_tokens": 1536},
"output_tokens_details": {"reasoning_tokens": 0},
}
result = _create_usage_metadata(token_usage)
assert isinstance(result, dict)
assert result["input_tokens"] == 1590
assert result["output_tokens"] == 77
assert result["total_tokens"] == 1667
assert result.get("input_token_details", {}).get("cache_read") == 1536
# reasoning_tokens is 0, so filtered out
assert "output_token_details" not in result
def test_create_usage_metadata_chat_completions_with_details() -> None:
"""Test usage metadata with hypothetical Chat Completions API format."""
token_usage = {
"prompt_tokens": 100,
"completion_tokens": 50,
"total_tokens": 150,
"prompt_tokens_details": {"cached_tokens": 80},
"completion_tokens_details": {"reasoning_tokens": 25},
}
result = _create_usage_metadata(token_usage)
assert isinstance(result, dict)
assert result["input_tokens"] == 100
assert result["output_tokens"] == 50
assert result["total_tokens"] == 150
assert result.get("input_token_details", {}).get("cache_read") == 80
assert result.get("output_token_details", {}).get("reasoning") == 25
def test_create_usage_metadata_with_cached_tokens() -> None:
"""Test usage metadata with prompt caching."""
token_usage = {
"prompt_tokens": 2006,
"completion_tokens": 300,
"total_tokens": 2306,
"input_tokens_details": {"cached_tokens": 1920},
}
result = _create_usage_metadata(token_usage)
assert isinstance(result, dict)
assert result["input_tokens"] == 2006
assert result["output_tokens"] == 300
assert result["total_tokens"] == 2306
assert "input_token_details" in result
assert isinstance(result["input_token_details"], dict)
assert result["input_token_details"]["cache_read"] == 1920
assert "output_token_details" not in result
def test_create_usage_metadata_with_all_details() -> None:
"""Test usage metadata with all available details."""
token_usage = {
"prompt_tokens": 2006,
"completion_tokens": 450,
"total_tokens": 2456,
"input_tokens_details": {"cached_tokens": 1920},
"output_tokens_details": {"reasoning_tokens": 200},
}
result = _create_usage_metadata(token_usage)
assert isinstance(result, dict)
assert result["input_tokens"] == 2006
assert result["output_tokens"] == 450
assert result["total_tokens"] == 2456
assert "input_token_details" in result
assert isinstance(result["input_token_details"], dict)
assert result["input_token_details"]["cache_read"] == 1920
assert "output_token_details" in result
assert isinstance(result["output_token_details"], dict)
assert result["output_token_details"]["reasoning"] == 200
def test_create_usage_metadata_missing_total_tokens() -> None:
"""Test that total_tokens is calculated when missing."""
token_usage = {
"prompt_tokens": 100,
"completion_tokens": 50,
}
result = _create_usage_metadata(token_usage)
assert result["input_tokens"] == 100
assert result["output_tokens"] == 50
assert result["total_tokens"] == 150
def test_create_usage_metadata_zero_total_tokens() -> None:
"""Test that explicit total_tokens=0 is preserved, not replaced by sum."""
token_usage = {
"prompt_tokens": 10,
"completion_tokens": 5,
"total_tokens": 0,
}
result = _create_usage_metadata(token_usage)
assert result["total_tokens"] == 0
def test_create_usage_metadata_zero_input_tokens_preferred_key() -> None:
"""Test that input_tokens=0 is not overridden by prompt_tokens fallback."""
token_usage = {
"input_tokens": 0,
"prompt_tokens": 50,
"completion_tokens": 5,
"total_tokens": 55,
}
result = _create_usage_metadata(token_usage)
assert result["input_tokens"] == 0
def test_create_usage_metadata_zero_output_tokens_preferred_key() -> None:
"""Test that output_tokens=0 is not overridden by completion_tokens fallback."""
token_usage = {
"input_tokens": 10,
"output_tokens": 0,
"completion_tokens": 50,
"total_tokens": 60,
}
result = _create_usage_metadata(token_usage)
assert result["output_tokens"] == 0
def test_create_usage_metadata_empty_details() -> None:
"""Test that empty detail dicts don't create token detail objects."""
token_usage = {
"prompt_tokens": 100,
"completion_tokens": 50,
"total_tokens": 150,
"input_tokens_details": {},
}
result = _create_usage_metadata(token_usage)
assert result["input_tokens"] == 100
assert result["output_tokens"] == 50
assert result["total_tokens"] == 150
assert "input_token_details" not in result
assert "output_token_details" not in result
def test_create_usage_metadata_zero_cached_tokens() -> None:
"""Test that zero cached tokens are not included (falsy)."""
token_usage = {
"prompt_tokens": 100,
"completion_tokens": 50,
"total_tokens": 150,
"input_tokens_details": {"cached_tokens": 0},
}
result = _create_usage_metadata(token_usage)
assert result["input_tokens"] == 100
assert result["output_tokens"] == 50
assert result["total_tokens"] == 150
assert "input_token_details" not in result
def test_create_usage_metadata_with_reasoning_tokens() -> None:
"""Test usage metadata with reasoning tokens."""
token_usage = {
"prompt_tokens": 100,
"completion_tokens": 450,
"total_tokens": 550,
"output_tokens_details": {"reasoning_tokens": 200},
}
result = _create_usage_metadata(token_usage)
assert isinstance(result, dict)
assert result["input_tokens"] == 100
assert result["output_tokens"] == 450
assert result["total_tokens"] == 550
assert "output_token_details" in result
assert isinstance(result["output_token_details"], dict)
assert result["output_token_details"]["reasoning"] == 200
assert "input_token_details" not in result
def test_create_usage_metadata_with_cached_and_reasoning_tokens() -> None:
"""Test usage metadata with both cached and reasoning tokens."""
token_usage = {
"prompt_tokens": 2006,
"completion_tokens": 450,
"total_tokens": 2456,
"input_tokens_details": {"cached_tokens": 1920},
"output_tokens_details": {"reasoning_tokens": 200},
}
result = _create_usage_metadata(token_usage)
assert isinstance(result, dict)
assert result["input_tokens"] == 2006
assert result["output_tokens"] == 450
assert result["total_tokens"] == 2456
assert "input_token_details" in result
assert isinstance(result["input_token_details"], dict)
assert result["input_token_details"]["cache_read"] == 1920
assert "output_token_details" in result
assert isinstance(result["output_token_details"], dict)
assert result["output_token_details"]["reasoning"] == 200
def test_create_usage_metadata_zero_reasoning_tokens() -> None:
"""Test that zero reasoning tokens are not included (falsy)."""
token_usage = {
"prompt_tokens": 100,
"completion_tokens": 50,
"total_tokens": 150,
"output_tokens_details": {"reasoning_tokens": 0},
}
result = _create_usage_metadata(token_usage)
assert result["input_tokens"] == 100
assert result["output_tokens"] == 50
assert result["total_tokens"] == 150
assert "output_token_details" not in result
def test_create_usage_metadata_empty_completion_details() -> None:
"""Test that empty output_tokens_details don't create output_token_details."""
token_usage = {
"prompt_tokens": 100,
"completion_tokens": 50,
"total_tokens": 150,
"output_tokens_details": {},
}
result = _create_usage_metadata(token_usage)
assert result["input_tokens"] == 100
assert result["output_tokens"] == 50
assert result["total_tokens"] == 150
assert "output_token_details" not in result
def test_chat_result_with_usage_metadata() -> None:
"""Test that _create_chat_result properly includes usage metadata."""
llm = ChatGroq(model="test-model")
mock_response = {
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1234567890,
"model": "test-model",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Test response",
},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 2006,
"completion_tokens": 300,
"total_tokens": 2306,
"input_tokens_details": {"cached_tokens": 1920},
},
}
result = llm._create_chat_result(mock_response, {})
assert len(result.generations) == 1
message = result.generations[0].message
assert isinstance(message, AIMessage)
assert message.content == "Test response"
assert message.usage_metadata is not None
assert isinstance(message.usage_metadata, dict)
assert message.usage_metadata["input_tokens"] == 2006
assert message.usage_metadata["output_tokens"] == 300
assert message.usage_metadata["total_tokens"] == 2306
assert "input_token_details" in message.usage_metadata
assert message.usage_metadata["input_token_details"]["cache_read"] == 1920
assert "output_token_details" not in message.usage_metadata
def test_chat_result_with_reasoning_tokens() -> None:
"""Test that _create_chat_result properly includes reasoning tokens."""
llm = ChatGroq(model="test-model")
mock_response = {
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1234567890,
"model": "test-model",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Test reasoning response",
},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 100,
"completion_tokens": 450,
"total_tokens": 550,
"output_tokens_details": {"reasoning_tokens": 200},
},
}
result = llm._create_chat_result(mock_response, {})
assert len(result.generations) == 1
message = result.generations[0].message
assert isinstance(message, AIMessage)
assert message.content == "Test reasoning response"
assert message.usage_metadata is not None
assert isinstance(message.usage_metadata, dict)
assert message.usage_metadata["input_tokens"] == 100
assert message.usage_metadata["output_tokens"] == 450
assert message.usage_metadata["total_tokens"] == 550
assert "output_token_details" in message.usage_metadata
assert message.usage_metadata["output_token_details"]["reasoning"] == 200
assert "input_token_details" not in message.usage_metadata
def test_chat_result_with_cached_and_reasoning_tokens() -> None:
"""Test that _create_chat_result includes both cached and reasoning tokens."""
llm = ChatGroq(model="test-model")
mock_response = {
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1234567890,
"model": "test-model",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Test response with both",
},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 2006,
"completion_tokens": 450,
"total_tokens": 2456,
"input_tokens_details": {"cached_tokens": 1920},
"output_tokens_details": {"reasoning_tokens": 200},
},
}
result = llm._create_chat_result(mock_response, {})
assert len(result.generations) == 1
message = result.generations[0].message
assert isinstance(message, AIMessage)
assert message.content == "Test response with both"
assert message.usage_metadata is not None
assert isinstance(message.usage_metadata, dict)
assert message.usage_metadata["input_tokens"] == 2006
assert message.usage_metadata["output_tokens"] == 450
assert message.usage_metadata["total_tokens"] == 2456
assert "input_token_details" in message.usage_metadata
assert message.usage_metadata["input_token_details"]["cache_read"] == 1920
assert "output_token_details" in message.usage_metadata
assert message.usage_metadata["output_token_details"]["reasoning"] == 200
def test_chat_result_backward_compatibility() -> None:
"""Test that responses without new fields still work."""
llm = ChatGroq(model="test-model")
mock_response = {
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1234567890,
"model": "test-model",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Test response",
},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 100,
"completion_tokens": 50,
"total_tokens": 150,
},
}
result = llm._create_chat_result(mock_response, {})
assert len(result.generations) == 1
message = result.generations[0].message
assert isinstance(message, AIMessage)
assert message.usage_metadata is not None
assert message.usage_metadata["input_tokens"] == 100
assert message.usage_metadata["output_tokens"] == 50
assert message.usage_metadata["total_tokens"] == 150
assert "input_token_details" not in message.usage_metadata
assert "output_token_details" not in message.usage_metadata
def test_streaming_with_usage_metadata() -> None:
"""Test that streaming properly includes usage metadata."""
chunk = {
"id": "chatcmpl-123",
"object": "chat.completion.chunk",
"created": 1234567890,
"model": "test-model",
"choices": [
{
"index": 0,
"delta": {
"role": "assistant",
"content": "Hello",
},
"finish_reason": None,
}
],
"x_groq": {
"usage": {
"prompt_tokens": 2006,
"completion_tokens": 300,
"total_tokens": 2306,
"input_tokens_details": {"cached_tokens": 1920},
}
},
}
result = _convert_chunk_to_message_chunk(chunk, AIMessageChunk)
assert isinstance(result, AIMessageChunk)
assert result.content == "Hello"
assert result.usage_metadata is not None
assert isinstance(result.usage_metadata, dict)
assert result.usage_metadata["input_tokens"] == 2006
assert result.usage_metadata["output_tokens"] == 300
assert result.usage_metadata["total_tokens"] == 2306
assert "input_token_details" in result.usage_metadata
assert result.usage_metadata["input_token_details"]["cache_read"] == 1920
assert "output_token_details" not in result.usage_metadata
def test_streaming_with_reasoning_tokens() -> None:
"""Test that streaming properly includes reasoning tokens in usage metadata."""
chunk = {
"id": "chatcmpl-123",
"object": "chat.completion.chunk",
"created": 1234567890,
"model": "test-model",
"choices": [
{
"index": 0,
"delta": {
"role": "assistant",
"content": "Hello",
},
"finish_reason": None,
}
],
"x_groq": {
"usage": {
"prompt_tokens": 100,
"completion_tokens": 450,
"total_tokens": 550,
"output_tokens_details": {"reasoning_tokens": 200},
}
},
}
result = _convert_chunk_to_message_chunk(chunk, AIMessageChunk)
assert isinstance(result, AIMessageChunk)
assert result.content == "Hello"
assert result.usage_metadata is not None
assert isinstance(result.usage_metadata, dict)
assert result.usage_metadata["input_tokens"] == 100
assert result.usage_metadata["output_tokens"] == 450
assert result.usage_metadata["total_tokens"] == 550
assert "output_token_details" in result.usage_metadata
assert result.usage_metadata["output_token_details"]["reasoning"] == 200
assert "input_token_details" not in result.usage_metadata
def test_streaming_with_cached_and_reasoning_tokens() -> None:
"""Test that streaming includes both cached and reasoning tokens."""
chunk = {
"id": "chatcmpl-123",
"object": "chat.completion.chunk",
"created": 1234567890,
"model": "test-model",
"choices": [
{
"index": 0,
"delta": {
"role": "assistant",
"content": "Hello",
},
"finish_reason": None,
}
],
"x_groq": {
"usage": {
"prompt_tokens": 2006,
"completion_tokens": 450,
"total_tokens": 2456,
"input_tokens_details": {"cached_tokens": 1920},
"output_tokens_details": {"reasoning_tokens": 200},
}
},
}
result = _convert_chunk_to_message_chunk(chunk, AIMessageChunk)
assert isinstance(result, AIMessageChunk)
assert result.content == "Hello"
assert result.usage_metadata is not None
assert isinstance(result.usage_metadata, dict)
assert result.usage_metadata["input_tokens"] == 2006
assert result.usage_metadata["output_tokens"] == 450
assert result.usage_metadata["total_tokens"] == 2456
assert "input_token_details" in result.usage_metadata
assert result.usage_metadata["input_token_details"]["cache_read"] == 1920
assert "output_token_details" in result.usage_metadata
assert result.usage_metadata["output_token_details"]["reasoning"] == 200
def test_streaming_without_usage_metadata() -> None:
"""Test that streaming works without usage metadata (backward compatibility)."""
chunk = {
"id": "chatcmpl-123",
"object": "chat.completion.chunk",
"created": 1234567890,
"model": "test-model",
"choices": [
{
"index": 0,
"delta": {
"role": "assistant",
"content": "Hello",
},
"finish_reason": None,
}
],
}
result = _convert_chunk_to_message_chunk(chunk, AIMessageChunk)
assert isinstance(result, AIMessageChunk)
assert result.content == "Hello"
assert result.usage_metadata is None
def test_combine_llm_outputs_with_token_details() -> None:
"""Test that _combine_llm_outputs properly combines nested token details."""
llm = ChatGroq(model="test-model")
llm_outputs: list[dict[str, Any] | None] = [
{
"token_usage": {
"prompt_tokens": 100,
"completion_tokens": 50,
"total_tokens": 150,
"input_tokens_details": {"cached_tokens": 80},
"output_tokens_details": {"reasoning_tokens": 20},
},
"model_name": "test-model",
"system_fingerprint": "fp_123",
},
{
"token_usage": {
"prompt_tokens": 200,
"completion_tokens": 100,
"total_tokens": 300,
"input_tokens_details": {"cached_tokens": 150},
"output_tokens_details": {"reasoning_tokens": 40},
},
"model_name": "test-model",
"system_fingerprint": "fp_123",
},
]
result = llm._combine_llm_outputs(llm_outputs)
assert result["token_usage"]["prompt_tokens"] == 300
assert result["token_usage"]["completion_tokens"] == 150
assert result["token_usage"]["total_tokens"] == 450
assert result["token_usage"]["input_tokens_details"]["cached_tokens"] == 230
assert result["token_usage"]["output_tokens_details"]["reasoning_tokens"] == 60
assert result["model_name"] == "test-model"
assert result["system_fingerprint"] == "fp_123"
def test_combine_llm_outputs_with_missing_details() -> None:
"""Test _combine_llm_outputs when some outputs have details and others don't."""
llm = ChatGroq(model="test-model")
llm_outputs: list[dict[str, Any] | None] = [
{
"token_usage": {
"prompt_tokens": 100,
"completion_tokens": 50,
"total_tokens": 150,
},
"model_name": "test-model",
},
{
"token_usage": {
"prompt_tokens": 200,
"completion_tokens": 100,
"total_tokens": 300,
"output_tokens_details": {"reasoning_tokens": 40},
},
"model_name": "test-model",
},
]
result = llm._combine_llm_outputs(llm_outputs)
assert result["token_usage"]["prompt_tokens"] == 300
assert result["token_usage"]["completion_tokens"] == 150
assert result["token_usage"]["total_tokens"] == 450
assert result["token_usage"]["output_tokens_details"]["reasoning_tokens"] == 40
assert "input_tokens_details" not in result["token_usage"]
def test_profile() -> None:
model = ChatGroq(model="openai/gpt-oss-20b")
assert model.profile
def test_format_message_content_string() -> None:
"""Test that string content is passed through unchanged."""
content = "hello"
assert content == _format_message_content(content)
def test_format_message_content_none() -> None:
"""Test that None content is passed through unchanged."""
content = None
assert content == _format_message_content(content)
def test_format_message_content_empty_list() -> None:
"""Test that empty list is passed through unchanged."""
content: list = []
assert content == _format_message_content(content)
def test_format_message_content_text_and_image_url() -> None:
"""Test that existing image_url format is passed through unchanged."""
content = [
{"type": "text", "text": "What is in this image?"},
{"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}},
]
assert content == _format_message_content(content)
def test_format_message_content_langchain_image_base64() -> None:
"""Test that LangChain image blocks with base64 are converted."""
content = {"type": "image", "base64": "<base64 data>", "mime_type": "image/png"}
expected = [
{
"type": "image_url",
"image_url": {"url": "data:image/png;base64,<base64 data>"},
}
]
assert expected == _format_message_content([content])
def test_format_message_content_langchain_image_url() -> None:
"""Test that LangChain image blocks with URL are converted."""
content = {"type": "image", "url": "https://example.com/image.jpg"}
expected = [
{"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}}
]
assert expected == _format_message_content([content])
def test_format_message_content_mixed() -> None:
"""Test that mixed content with text and image is handled correctly."""
content = [
{"type": "text", "text": "Describe this image"},
{"type": "image", "base64": "<data>", "mime_type": "image/png"},
]
expected = [
{"type": "text", "text": "Describe this image"},
{"type": "image_url", "image_url": {"url": "data:image/png;base64,<data>"}},
]
assert expected == _format_message_content(content)