fix(xai): stream usage metadata by default (#34531)

This commit is contained in:
ccurme
2025-12-31 16:30:52 -05:00
committed by GitHub
parent 5517ef37fb
commit 0b91774263
5 changed files with 16 additions and 3 deletions

View File

@@ -526,6 +526,11 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
**client_params, **client_params,
**async_specific, **async_specific,
) )
# Enable streaming usage metadata by default
if self.stream_usage is not False:
self.stream_usage = True
return self return self
@model_validator(mode="after") @model_validator(mode="after")

View File

@@ -35,7 +35,6 @@ class TestXAIStandard(ChatModelIntegrationTests):
return { return {
"model": MODEL_NAME, "model": MODEL_NAME,
"rate_limiter": rate_limiter, "rate_limiter": rate_limiter,
"stream_usage": True,
} }
@pytest.mark.xfail( @pytest.mark.xfail(

View File

@@ -13,6 +13,7 @@
'request_timeout': 60.0, 'request_timeout': 60.0,
'stop': list([ 'stop': list([
]), ]),
'stream_usage': True,
'temperature': 0.0, 'temperature': 0.0,
'xai_api_base': 'https://api.x.ai/v1/', 'xai_api_base': 'https://api.x.ai/v1/',
'xai_api_key': dict({ 'xai_api_key': dict({

View File

@@ -134,3 +134,11 @@ def test_convert_dict_to_message_tool() -> None:
expected_output = ToolMessage(content="foo", tool_call_id="bar") expected_output = ToolMessage(content="foo", tool_call_id="bar")
assert result == expected_output assert result == expected_output
assert _convert_message_to_dict(expected_output) == message assert _convert_message_to_dict(expected_output) == message
def test_stream_usage_metadata() -> None:
model = ChatXAI(model=MODEL_NAME)
assert model.stream_usage is True
model = ChatXAI(model=MODEL_NAME, stream_usage=False)
assert model.stream_usage is False

View File

@@ -1,5 +1,5 @@
version = 1 version = 1
revision = 3 revision = 2
requires-python = ">=3.10.0, <4.0.0" requires-python = ">=3.10.0, <4.0.0"
resolution-markers = [ resolution-markers = [
"python_full_version >= '3.13' and platform_python_implementation == 'PyPy'", "python_full_version >= '3.13' and platform_python_implementation == 'PyPy'",
@@ -621,7 +621,7 @@ wheels = [
[[package]] [[package]]
name = "langchain-core" name = "langchain-core"
version = "1.2.4" version = "1.2.5"
source = { editable = "../../core" } source = { editable = "../../core" }
dependencies = [ dependencies = [
{ name = "jsonpatch" }, { name = "jsonpatch" },