mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
fix(xai): fix routing of chat completions vs. responses apis during streaming (#34868)
This commit is contained in:
@@ -14,6 +14,8 @@ from typing_extensions import Self
|
|||||||
from langchain_xai.data._profiles import _PROFILES
|
from langchain_xai.data._profiles import _PROFILES
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import AsyncIterator, Iterator
|
||||||
|
|
||||||
from langchain_core.language_models import (
|
from langchain_core.language_models import (
|
||||||
ModelProfile,
|
ModelProfile,
|
||||||
ModelProfileRegistry,
|
ModelProfileRegistry,
|
||||||
@@ -475,7 +477,7 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
|||||||
def _get_ls_params(
|
def _get_ls_params(
|
||||||
self,
|
self,
|
||||||
stop: list[str] | None = None,
|
stop: list[str] | None = None,
|
||||||
**kwargs: Any, # noqa: ANN401
|
**kwargs: Any,
|
||||||
) -> LangSmithParams:
|
) -> LangSmithParams:
|
||||||
"""Get the parameters used to invoke the model."""
|
"""Get the parameters used to invoke the model."""
|
||||||
params = super()._get_ls_params(stop=stop, **kwargs)
|
params = super()._get_ls_params(stop=stop, **kwargs)
|
||||||
@@ -552,6 +554,23 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
|||||||
|
|
||||||
return params
|
return params
|
||||||
|
|
||||||
|
def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGenerationChunk]:
|
||||||
|
"""Route to Chat Completions or Responses API."""
|
||||||
|
if self._use_responses_api({**kwargs, **self.model_kwargs}):
|
||||||
|
return super()._stream_responses(*args, **kwargs)
|
||||||
|
return super()._stream(*args, **kwargs)
|
||||||
|
|
||||||
|
async def _astream(
|
||||||
|
self, *args: Any, **kwargs: Any
|
||||||
|
) -> AsyncIterator[ChatGenerationChunk]:
|
||||||
|
"""Route to Chat Completions or Responses API."""
|
||||||
|
if self._use_responses_api({**kwargs, **self.model_kwargs}):
|
||||||
|
async for chunk in super()._astream_responses(*args, **kwargs):
|
||||||
|
yield chunk
|
||||||
|
else:
|
||||||
|
async for chunk in super()._astream(*args, **kwargs):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
def _create_chat_result(
|
def _create_chat_result(
|
||||||
self,
|
self,
|
||||||
response: dict | openai.BaseModel,
|
response: dict | openai.BaseModel,
|
||||||
@@ -648,7 +667,7 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
|||||||
] = "function_calling",
|
] = "function_calling",
|
||||||
include_raw: bool = False,
|
include_raw: bool = False,
|
||||||
strict: bool | None = None,
|
strict: bool | None = None,
|
||||||
**kwargs: Any, # noqa: ANN401
|
**kwargs: Any,
|
||||||
) -> Runnable[LanguageModelInput, _DictOrPydantic]:
|
) -> Runnable[LanguageModelInput, _DictOrPydantic]:
|
||||||
"""Model wrapper that returns outputs formatted to match the given schema.
|
"""Model wrapper that returns outputs formatted to match the given schema.
|
||||||
|
|
||||||
|
|||||||
@@ -65,6 +65,7 @@ docstring-code-line-length = 100
|
|||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
select = ["ALL"]
|
select = ["ALL"]
|
||||||
ignore = [
|
ignore = [
|
||||||
|
"ANN401", # Allow annotating `Any`
|
||||||
"COM812", # Messes with the formatter
|
"COM812", # Messes with the formatter
|
||||||
"ISC001", # Messes with the formatter
|
"ISC001", # Messes with the formatter
|
||||||
"PERF203", # Rarely useful
|
"PERF203", # Rarely useful
|
||||||
|
|||||||
Reference in New Issue
Block a user