mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
fix(xai): fix routing of chat completions vs. responses apis during streaming (#34868)
This commit is contained in:
@@ -14,6 +14,8 @@ from typing_extensions import Self
|
||||
from langchain_xai.data._profiles import _PROFILES
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
|
||||
from langchain_core.language_models import (
|
||||
ModelProfile,
|
||||
ModelProfileRegistry,
|
||||
@@ -475,7 +477,7 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
||||
def _get_ls_params(
|
||||
self,
|
||||
stop: list[str] | None = None,
|
||||
**kwargs: Any, # noqa: ANN401
|
||||
**kwargs: Any,
|
||||
) -> LangSmithParams:
|
||||
"""Get the parameters used to invoke the model."""
|
||||
params = super()._get_ls_params(stop=stop, **kwargs)
|
||||
@@ -552,6 +554,23 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
||||
|
||||
return params
|
||||
|
||||
def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGenerationChunk]:
|
||||
"""Route to Chat Completions or Responses API."""
|
||||
if self._use_responses_api({**kwargs, **self.model_kwargs}):
|
||||
return super()._stream_responses(*args, **kwargs)
|
||||
return super()._stream(*args, **kwargs)
|
||||
|
||||
async def _astream(
|
||||
self, *args: Any, **kwargs: Any
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
"""Route to Chat Completions or Responses API."""
|
||||
if self._use_responses_api({**kwargs, **self.model_kwargs}):
|
||||
async for chunk in super()._astream_responses(*args, **kwargs):
|
||||
yield chunk
|
||||
else:
|
||||
async for chunk in super()._astream(*args, **kwargs):
|
||||
yield chunk
|
||||
|
||||
def _create_chat_result(
|
||||
self,
|
||||
response: dict | openai.BaseModel,
|
||||
@@ -648,7 +667,7 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
||||
] = "function_calling",
|
||||
include_raw: bool = False,
|
||||
strict: bool | None = None,
|
||||
**kwargs: Any, # noqa: ANN401
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, _DictOrPydantic]:
|
||||
"""Model wrapper that returns outputs formatted to match the given schema.
|
||||
|
||||
|
||||
@@ -65,6 +65,7 @@ docstring-code-line-length = 100
|
||||
[tool.ruff.lint]
|
||||
select = ["ALL"]
|
||||
ignore = [
|
||||
"ANN401", # Allow annotating `Any`
|
||||
"COM812", # Messes with the formatter
|
||||
"ISC001", # Messes with the formatter
|
||||
"PERF203", # Rarely useful
|
||||
|
||||
Reference in New Issue
Block a user