diff --git a/libs/partners/xai/langchain_xai/chat_models.py b/libs/partners/xai/langchain_xai/chat_models.py index 00a99c4e6a5..0fe7df227d3 100644 --- a/libs/partners/xai/langchain_xai/chat_models.py +++ b/libs/partners/xai/langchain_xai/chat_models.py @@ -14,6 +14,8 @@ from typing_extensions import Self from langchain_xai.data._profiles import _PROFILES if TYPE_CHECKING: + from collections.abc import AsyncIterator, Iterator + from langchain_core.language_models import ( ModelProfile, ModelProfileRegistry, @@ -475,7 +477,7 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override] def _get_ls_params( self, stop: list[str] | None = None, - **kwargs: Any, # noqa: ANN401 + **kwargs: Any, ) -> LangSmithParams: """Get the parameters used to invoke the model.""" params = super()._get_ls_params(stop=stop, **kwargs) @@ -552,6 +554,23 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override] return params + def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGenerationChunk]: + """Route to Chat Completions or Responses API.""" + if self._use_responses_api({**kwargs, **self.model_kwargs}): + return super()._stream_responses(*args, **kwargs) + return super()._stream(*args, **kwargs) + + async def _astream( + self, *args: Any, **kwargs: Any + ) -> AsyncIterator[ChatGenerationChunk]: + """Route to Chat Completions or Responses API.""" + if self._use_responses_api({**kwargs, **self.model_kwargs}): + async for chunk in super()._astream_responses(*args, **kwargs): + yield chunk + else: + async for chunk in super()._astream(*args, **kwargs): + yield chunk + def _create_chat_result( self, response: dict | openai.BaseModel, @@ -648,7 +667,7 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override] ] = "function_calling", include_raw: bool = False, strict: bool | None = None, - **kwargs: Any, # noqa: ANN401 + **kwargs: Any, ) -> Runnable[LanguageModelInput, _DictOrPydantic]: """Model wrapper that returns outputs formatted to match the given schema. diff --git a/libs/partners/xai/pyproject.toml b/libs/partners/xai/pyproject.toml index 6d8e6bcbf23..e17c3dc7675 100644 --- a/libs/partners/xai/pyproject.toml +++ b/libs/partners/xai/pyproject.toml @@ -65,6 +65,7 @@ docstring-code-line-length = 100 [tool.ruff.lint] select = ["ALL"] ignore = [ + "ANN401", # Allow annotating `Any` "COM812", # Messes with the formatter "ISC001", # Messes with the formatter "PERF203", # Rarely useful