feat(mistralai): support stop sequences (#38047)

`ChatMistralAI` now supports `stop` sequences.

Previously, a `stop` value passed to the model was silently discarded:
the code carried a stale "not yet supported" note, dropped the parameter
before the request, and logged a warning. Mistral's chat completions API
does accept `stop` (a string or list of strings, up to 4 sequences), so
anyone setting `stop` and expecting generation to halt was getting no
effect.

Now `stop` is a first-class parameter. It can be set on the constructor
(`ChatMistralAI(stop=[...])`) or per call (`model.invoke(prompt,
stop=[...])`) and is forwarded to the API. A per-call value overrides
the instance default, and an empty list is treated as "no stop
sequences" — omitted from the request rather than sent as an empty array
(which the API rejects).

Verified against the live Mistral API: with `stop=["5"]`, "Count from 1
to 10" returns `1 2 3 4 ` instead of the full sequence. The 422
`extra_forbidden` response the API returns for genuinely unknown fields
confirms `stop` is a real schema field, not silently ignored.

This PR also folds in some test hygiene: the base-URL env test uses
`monkeypatch.setenv` so `MISTRAL_BASE_URL=boo` no longer leaks into
later serialization tests, and `test_extra_kwargs` asserts the
intentional unknown-kwarg warning with `pytest.warns`.

## Review notes
- Behavior change worth a careful look: `stop` now reaches the API
instead of being dropped. This changes request payloads for anyone
previously passing `stop`. It is the intended fix, but flagging it
explicitly.
- Coverage: `test_stop_sequence` (integration) exercises the end-to-end
behavior; unit tests cover parameter wiring, per-call-vs-instance
precedence, and the empty-list case.
This commit is contained in:
Mason Daugherty
2026-06-10 20:42:16 -04:00
committed by GitHub
parent 21eeadf274
commit fcaa61636e
4 changed files with 90 additions and 18 deletions

View File

@@ -546,6 +546,14 @@ class ChatMistralAI(BaseChatModel):
max_tokens: int | None = None
stop: list[str] | None = None
"""Default stop sequences.
Generation stops when any of these strings is produced; the stop sequence itself
is not included in the output. Can be overridden per call via the `stop` argument.
Mistral accepts up to 4 stop sequences.
"""
top_p: float = 1
"""Decode using nucleus sampling: consider the smallest set of tokens whose
probability sum is at least `top_p`. Must be in the closed interval
@@ -599,7 +607,7 @@ class ChatMistralAI(BaseChatModel):
)
if ls_max_tokens := params.get("max_tokens", self.max_tokens):
ls_params["ls_max_tokens"] = ls_max_tokens
if ls_stop := stop or params.get("stop", None):
if ls_stop := stop or self.stop or params.get("stop", None):
ls_params["ls_stop"] = ls_stop
return ls_params
@@ -749,12 +757,9 @@ class ChatMistralAI(BaseChatModel):
self, messages: list[BaseMessage], stop: list[str] | None
) -> tuple[list[dict], dict[str, Any]]:
params = self._client_params
if stop is not None or "stop" in params:
if "stop" in params:
params.pop("stop")
logger.warning(
"Parameter `stop` not yet supported (https://docs.mistral.ai/api)"
)
stop = stop if stop is not None else self.stop
if stop:
params["stop"] = stop
message_dicts = [_convert_message_to_mistral_chat_message(m) for m in messages]
return message_dicts, params

View File

@@ -199,3 +199,24 @@ def test_reasoning_v1() -> None:
next_message = {"role": "user", "content": "What is my name?"}
_ = model.invoke([input_message, full, next_message])
def test_stop_sequence() -> None:
"""Mistral honors `stop`: generation halts and the sequence is excluded."""
model = ChatMistralAI(model="ministral-8b-latest", rate_limiter=rate_limiter) # type: ignore[call-arg]
prompt = "Count from 1 to 10, separated by spaces. Reply with only the numbers."
# Without a stop sequence the full count is produced.
baseline = model.invoke(prompt)
assert isinstance(baseline.text, str)
assert "5" in baseline.text
# With stop=["5"], generation halts before "5" and the sequence is excluded.
stopped = model.invoke(prompt, stop=["5"])
assert "5" not in stopped.text
# An instance-level `stop` is honored identically.
stopped_instance = ChatMistralAI( # type: ignore[call-arg]
model="ministral-8b-latest", stop=["5"], rate_limiter=rate_limiter
).invoke(prompt)
assert "5" not in stopped_instance.text

View File

@@ -8,7 +8,7 @@
'ChatMistralAI',
]),
'kwargs': dict({
'endpoint': 'boo',
'endpoint': 'https://api.mistral.ai/v1',
'max_concurrent_requests': 64,
'max_retries': 2,
'max_tokens': 100,
@@ -20,10 +20,8 @@
'type': 'secret',
}),
'model': 'mistral-small',
'model_kwargs': dict({
'stop': list([
]),
}),
'stop': list([
]),
'temperature': 0.0,
'timeout': 60,
'top_p': 1,

View File

@@ -90,12 +90,12 @@ def test_mistralai_initialization_baseurl(
("MISTRAL_BASE_URL"),
],
)
def test_mistralai_initialization_baseurl_env(env_var_name: str) -> None:
def test_mistralai_initialization_baseurl_env(
env_var_name: str, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Test ChatMistralAI initialization."""
# Verify that ChatMistralAI can be initialized using env variable
import os
os.environ[env_var_name] = "boo"
monkeypatch.setenv(env_var_name, "boo")
model = ChatMistralAI(model="test") # type: ignore[call-arg]
assert model.endpoint == "boo"
@@ -513,12 +513,14 @@ def test_tool_id_conversion() -> None:
def test_extra_kwargs() -> None:
# Check that foo is saved in extra_kwargs.
llm = ChatMistralAI(model="my-model", foo=3, max_tokens=10) # type: ignore[call-arg]
with pytest.warns(UserWarning, match="foo is not default parameter"):
llm = ChatMistralAI(model="my-model", foo=3, max_tokens=10) # type: ignore[call-arg]
assert llm.max_tokens == 10
assert llm.model_kwargs == {"foo": 3}
# Test that if extra_kwargs are provided, they are added to it.
llm = ChatMistralAI(model="my-model", foo=3, model_kwargs={"bar": 2}) # type: ignore[call-arg]
with pytest.warns(UserWarning, match="foo is not default parameter"):
llm = ChatMistralAI(model="my-model", foo=3, model_kwargs={"bar": 2}) # type: ignore[call-arg]
assert llm.model_kwargs == {"foo": 3, "bar": 2}
# Test that if provided twice it errors
@@ -526,6 +528,52 @@ def test_extra_kwargs() -> None:
ChatMistralAI(model="my-model", foo=3, model_kwargs={"foo": 2}) # type: ignore[call-arg]
def test_stop_stored_as_field() -> None:
"""`stop` is a first-class field, not routed into `model_kwargs`."""
llm = ChatMistralAI(model="my-model", stop=["END"]) # type: ignore[call-arg]
assert llm.stop == ["END"]
assert "stop" not in llm.model_kwargs
def test_create_message_dicts_sends_instance_stop() -> None:
"""Instance-level `stop` is forwarded to the request params."""
llm = ChatMistralAI(model="my-model", stop=["END"]) # type: ignore[call-arg]
_, params = llm._create_message_dicts([HumanMessage("hi")], None)
assert params["stop"] == ["END"]
def test_create_message_dicts_per_call_stop_overrides_instance() -> None:
"""A per-call `stop` (including an empty list) overrides the instance value."""
llm = ChatMistralAI(model="my-model", stop=["END"]) # type: ignore[call-arg]
# A non-empty per-call value wins over the instance default.
_, params = llm._create_message_dicts([HumanMessage("hi")], ["STOP"])
assert params["stop"] == ["STOP"]
# An explicit empty list overrides the instance default and is treated as
# "no stop sequences", so it is omitted from the request rather than sent
# as an empty array (which the API would reject).
_, params = llm._create_message_dicts([HumanMessage("hi")], [])
assert "stop" not in params
def test_create_message_dicts_omits_stop_when_unset() -> None:
"""No `stop` field and no per-call value means `stop` is not sent."""
llm = ChatMistralAI(model="my-model") # type: ignore[call-arg]
_, params = llm._create_message_dicts([HumanMessage("hi")], None)
assert "stop" not in params
def test_get_ls_params_stop_precedence() -> None:
"""`_get_ls_params` records instance `stop` and lets a per-call value win."""
llm = ChatMistralAI(model="my-model", stop=["END"]) # type: ignore[call-arg]
assert llm._get_ls_params().get("ls_stop") == ["END"]
assert llm._get_ls_params(stop=["STOP"]).get("ls_stop") == ["STOP"]
# Without an instance default and no per-call value, `ls_stop` is omitted.
llm_no_stop = ChatMistralAI(model="my-model") # type: ignore[call-arg]
assert "ls_stop" not in llm_no_stop._get_ls_params()
def test_retry_with_failure_then_success() -> None:
"""Test retry mechanism works correctly when fiest request fails, second succeed."""
# Create a real ChatMistralAI instance