fix(xai): count reasoning tokens in output total (#34603)

This commit is contained in:
ccurme
2026-01-05 13:25:30 -05:00
committed by GitHub
parent 730a3676f8
commit 944b43dd25
2 changed files with 46 additions and 0 deletions

View File

@@ -575,6 +575,21 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
response.citations response.citations
) )
# Unlike OpenAI, xAI reports reasoning tokens < completion tokens. So we assume
# they are not counted in output tokens, and we add them here.
if (
(not self._use_responses_api({}))
and (usage_metadata := rtn.generations[0].message.usage_metadata) # type: ignore[attr-defined]
and (
reasoning_tokens := usage_metadata.get("output_token_details", {}).get(
"reasoning"
)
)
):
rtn.generations[0].message.usage_metadata["output_tokens"] += ( # type: ignore[attr-defined]
reasoning_tokens
)
return rtn return rtn
def _convert_chunk_to_generation_chunk( def _convert_chunk_to_generation_chunk(
@@ -609,6 +624,19 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
): ):
generation_chunk.message.additional_kwargs["citations"] = citations generation_chunk.message.additional_kwargs["citations"] = citations
# Unlike OpenAI, xAI reports reasoning tokens < completion tokens. So we assume
# they are not counted in output tokens, and we add them here.
if (
generation_chunk
and (not self._use_responses_api({}))
and (usage_metadata := generation_chunk.message.usage_metadata) # type: ignore[attr-defined]
and (
reasoning_tokens := usage_metadata.get("output_token_details", {}).get(
"reasoning"
)
)
):
generation_chunk.message.usage_metadata["output_tokens"] += reasoning_tokens # type: ignore[attr-defined]
return generation_chunk return generation_chunk
def with_structured_output( def with_structured_output(

View File

@@ -37,6 +37,15 @@ def test_reasoning(output_version: Literal["", "v1"]) -> None:
assert response.content assert response.content
assert response.additional_kwargs["reasoning_content"] assert response.additional_kwargs["reasoning_content"]
## Check output tokens
usage_metadata = response.usage_metadata
assert usage_metadata
reasoning_tokens = usage_metadata.get("output_token_details", {}).get("reasoning")
total_tokens = usage_metadata.get("output_tokens")
assert total_tokens
assert reasoning_tokens
assert total_tokens > reasoning_tokens
# Test streaming # Test streaming
full: BaseMessageChunk | None = None full: BaseMessageChunk | None = None
for chunk in chat_model.stream(input_message): for chunk in chat_model.stream(input_message):
@@ -44,6 +53,15 @@ def test_reasoning(output_version: Literal["", "v1"]) -> None:
assert isinstance(full, AIMessageChunk) assert isinstance(full, AIMessageChunk)
assert full.additional_kwargs["reasoning_content"] assert full.additional_kwargs["reasoning_content"]
## Check output tokens
usage_metadata = full.usage_metadata
assert usage_metadata
reasoning_tokens = usage_metadata.get("output_token_details", {}).get("reasoning")
total_tokens = usage_metadata.get("output_tokens")
assert total_tokens
assert reasoning_tokens
assert total_tokens > reasoning_tokens
# Check that we can access reasoning content blocks # Check that we can access reasoning content blocks
assert response.content_blocks assert response.content_blocks
reasoning_content = ( reasoning_content = (