fix(xai): fix routing of chat completions vs. responses apis during streaming (#34868)

This commit is contained in:
Sholto Armstrong
2026-01-26 05:58:11 +02:00
committed by GitHub
parent f0ca2c4675
commit 666bb6fe53
2 changed files with 22 additions and 2 deletions

View File

@@ -14,6 +14,8 @@ from typing_extensions import Self
from langchain_xai.data._profiles import _PROFILES
if TYPE_CHECKING:
from collections.abc import AsyncIterator, Iterator
from langchain_core.language_models import (
ModelProfile,
ModelProfileRegistry,
@@ -475,7 +477,7 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
def _get_ls_params(
self,
stop: list[str] | None = None,
**kwargs: Any, # noqa: ANN401
**kwargs: Any,
) -> LangSmithParams:
"""Get the parameters used to invoke the model."""
params = super()._get_ls_params(stop=stop, **kwargs)
@@ -552,6 +554,23 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
return params
def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGenerationChunk]:
"""Route to Chat Completions or Responses API."""
if self._use_responses_api({**kwargs, **self.model_kwargs}):
return super()._stream_responses(*args, **kwargs)
return super()._stream(*args, **kwargs)
async def _astream(
self, *args: Any, **kwargs: Any
) -> AsyncIterator[ChatGenerationChunk]:
"""Route to Chat Completions or Responses API."""
if self._use_responses_api({**kwargs, **self.model_kwargs}):
async for chunk in super()._astream_responses(*args, **kwargs):
yield chunk
else:
async for chunk in super()._astream(*args, **kwargs):
yield chunk
def _create_chat_result(
self,
response: dict | openai.BaseModel,
@@ -648,7 +667,7 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
] = "function_calling",
include_raw: bool = False,
strict: bool | None = None,
**kwargs: Any, # noqa: ANN401
**kwargs: Any,
) -> Runnable[LanguageModelInput, _DictOrPydantic]:
"""Model wrapper that returns outputs formatted to match the given schema.

View File

@@ -65,6 +65,7 @@ docstring-code-line-length = 100
[tool.ruff.lint]
select = ["ALL"]
ignore = [
"ANN401", # Allow annotating `Any`
"COM812", # Messes with the formatter
"ISC001", # Messes with the formatter
"PERF203", # Rarely useful