diff --git a/libs/partners/xai/langchain_xai/chat_models.py b/libs/partners/xai/langchain_xai/chat_models.py index aa78391768f..7192dad7b2d 100644 --- a/libs/partners/xai/langchain_xai/chat_models.py +++ b/libs/partners/xai/langchain_xai/chat_models.py @@ -41,6 +41,23 @@ def _get_default_model_profile(model_name: str) -> ModelProfile: return default.copy() +def _model_rejects_stop(model_name: str) -> bool: + """Whether an *unprofiled* xAI model rejects the `stop` parameter. + + Used only as a fallback when no model profile is available; profiled models + defer to their `reasoning_output` flag, which is authoritative. The flag and + the name cannot be reconciled by string matching alone: the API accepts + `stop` for `grok-4.20-0309-non-reasoning` but rejects it for + `grok-4-fast-non-reasoning`, despite both containing `non-reasoning`. + + Every current `grok-3`, `grok-4`, and `grok-code-fast` model that is not in + the generated profiles rejects `stop`, so the fallback drops it for those + families. Dropping a `stop` the API would reject degrades more gracefully + than letting the request fail outright. + """ + return model_name.startswith(("grok-3", "grok-4")) or "grok-code-fast" in model_name + + class ChatXAI(BaseChatOpenAI): # type: ignore[override] r"""ChatXAI chat model. @@ -426,6 +443,7 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override] """ openai_api_key: SecretStr | None = None + openai_api_base: str | None = None model_config = ConfigDict( @@ -546,6 +564,30 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override] def _resolve_model_profile(self) -> ModelProfile | None: return _get_default_model_profile(self.model_name) or None + def _get_request_payload( + self, + input_: LanguageModelInput, + *, + stop: list[str] | None = None, + **kwargs: Any, + ) -> dict: + payload = super()._get_request_payload(input_, stop=stop, **kwargs) + if payload.get("stop") is not None: + # xAI rejects `stop` for reasoning models. The model profile is + # authoritative when available (it correctly distinguishes models + # that only differ from each other by profile, not by name); the + # name-based check is a fallback for aliases absent from the + # generated profiles. + model_profile = self._resolve_model_profile() + rejects_stop = ( + bool(model_profile.get("reasoning_output")) + if model_profile is not None + else _model_rejects_stop(self.model_name) + ) + if rejects_stop: + payload.pop("stop", None) + return payload + def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGenerationChunk]: """Route to Chat Completions or Responses API.""" if self._use_responses_api({**kwargs, **self.model_kwargs}): diff --git a/libs/partners/xai/tests/unit_tests/test_chat_models.py b/libs/partners/xai/tests/unit_tests/test_chat_models.py index 0ecc4a88bbe..66c949c3f32 100644 --- a/libs/partners/xai/tests/unit_tests/test_chat_models.py +++ b/libs/partners/xai/tests/unit_tests/test_chat_models.py @@ -84,6 +84,52 @@ def test_chat_xai_api_base_from_env(monkeypatch: pytest.MonkeyPatch) -> None: assert llm.xai_api_base == "http://env.example.test/v1" +@pytest.mark.parametrize( + "model", + [ + # Profiled reasoning models (`reasoning_output=True`). + "grok-4.3", + "grok-4.20-0309-reasoning", + # Unprofiled families that the live API rejects `stop` on. `grok-4` + # base and `grok-4-fast-non-reasoning` lack the substring "reasoning" + # yet still reject `stop`; `grok-code-fast` is a separate family. + "grok-3", + "grok-3-mini", + "grok-4", + "grok-4-0709", + "grok-4-fast-reasoning", + "grok-4-fast-non-reasoning", + "grok-code-fast-1", + ], +) +def test_reasoning_model_payload_drops_stop(model: str) -> None: + llm = ChatXAI( + model=model, + api_key=SecretStr("test-api-key"), + stop_sequences=["END"], + ) + + payload = llm._get_request_payload("hello") + + assert "stop" not in payload + + +def test_non_reasoning_model_payload_keeps_stop() -> None: + # `grok-4.20-0309-non-reasoning` is profiled with `reasoning_output=False` + # and the live API accepts `stop` for it, even though its name contains + # "non-reasoning" like the unprofiled `grok-4-fast-non-reasoning` that does + # not. The profile must take precedence over the name-based fallback. + llm = ChatXAI( + model="grok-4.20-0309-non-reasoning", + api_key=SecretStr("test-api-key"), + stop_sequences=["END"], + ) + + payload = llm._get_request_payload("hello") + + assert payload["stop"] == ["END"] + + def test_function_dict_to_message_function_message() -> None: content = json.dumps({"result": "Example #1"}) name = "test_function"