mirror of
https://github.com/hwchase17/langchain.git
synced 2026-07-01 06:42:37 +00:00
feat(mistralai): support stop sequences (#38047)
`ChatMistralAI` now supports `stop` sequences. Previously, a `stop` value passed to the model was silently discarded: the code carried a stale "not yet supported" note, dropped the parameter before the request, and logged a warning. Mistral's chat completions API does accept `stop` (a string or list of strings, up to 4 sequences), so anyone setting `stop` and expecting generation to halt was getting no effect. Now `stop` is a first-class parameter. It can be set on the constructor (`ChatMistralAI(stop=[...])`) or per call (`model.invoke(prompt, stop=[...])`) and is forwarded to the API. A per-call value overrides the instance default, and an empty list is treated as "no stop sequences" — omitted from the request rather than sent as an empty array (which the API rejects). Verified against the live Mistral API: with `stop=["5"]`, "Count from 1 to 10" returns `1 2 3 4 ` instead of the full sequence. The 422 `extra_forbidden` response the API returns for genuinely unknown fields confirms `stop` is a real schema field, not silently ignored. This PR also folds in some test hygiene: the base-URL env test uses `monkeypatch.setenv` so `MISTRAL_BASE_URL=boo` no longer leaks into later serialization tests, and `test_extra_kwargs` asserts the intentional unknown-kwarg warning with `pytest.warns`. ## Review notes - Behavior change worth a careful look: `stop` now reaches the API instead of being dropped. This changes request payloads for anyone previously passing `stop`. It is the intended fix, but flagging it explicitly. - Coverage: `test_stop_sequence` (integration) exercises the end-to-end behavior; unit tests cover parameter wiring, per-call-vs-instance precedence, and the empty-list case.
This commit is contained in:
@@ -546,6 +546,14 @@ class ChatMistralAI(BaseChatModel):
|
||||
|
||||
max_tokens: int | None = None
|
||||
|
||||
stop: list[str] | None = None
|
||||
"""Default stop sequences.
|
||||
|
||||
Generation stops when any of these strings is produced; the stop sequence itself
|
||||
is not included in the output. Can be overridden per call via the `stop` argument.
|
||||
Mistral accepts up to 4 stop sequences.
|
||||
"""
|
||||
|
||||
top_p: float = 1
|
||||
"""Decode using nucleus sampling: consider the smallest set of tokens whose
|
||||
probability sum is at least `top_p`. Must be in the closed interval
|
||||
@@ -599,7 +607,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
)
|
||||
if ls_max_tokens := params.get("max_tokens", self.max_tokens):
|
||||
ls_params["ls_max_tokens"] = ls_max_tokens
|
||||
if ls_stop := stop or params.get("stop", None):
|
||||
if ls_stop := stop or self.stop or params.get("stop", None):
|
||||
ls_params["ls_stop"] = ls_stop
|
||||
return ls_params
|
||||
|
||||
@@ -749,12 +757,9 @@ class ChatMistralAI(BaseChatModel):
|
||||
self, messages: list[BaseMessage], stop: list[str] | None
|
||||
) -> tuple[list[dict], dict[str, Any]]:
|
||||
params = self._client_params
|
||||
if stop is not None or "stop" in params:
|
||||
if "stop" in params:
|
||||
params.pop("stop")
|
||||
logger.warning(
|
||||
"Parameter `stop` not yet supported (https://docs.mistral.ai/api)"
|
||||
)
|
||||
stop = stop if stop is not None else self.stop
|
||||
if stop:
|
||||
params["stop"] = stop
|
||||
message_dicts = [_convert_message_to_mistral_chat_message(m) for m in messages]
|
||||
return message_dicts, params
|
||||
|
||||
|
||||
@@ -199,3 +199,24 @@ def test_reasoning_v1() -> None:
|
||||
|
||||
next_message = {"role": "user", "content": "What is my name?"}
|
||||
_ = model.invoke([input_message, full, next_message])
|
||||
|
||||
|
||||
def test_stop_sequence() -> None:
|
||||
"""Mistral honors `stop`: generation halts and the sequence is excluded."""
|
||||
model = ChatMistralAI(model="ministral-8b-latest", rate_limiter=rate_limiter) # type: ignore[call-arg]
|
||||
prompt = "Count from 1 to 10, separated by spaces. Reply with only the numbers."
|
||||
|
||||
# Without a stop sequence the full count is produced.
|
||||
baseline = model.invoke(prompt)
|
||||
assert isinstance(baseline.text, str)
|
||||
assert "5" in baseline.text
|
||||
|
||||
# With stop=["5"], generation halts before "5" and the sequence is excluded.
|
||||
stopped = model.invoke(prompt, stop=["5"])
|
||||
assert "5" not in stopped.text
|
||||
|
||||
# An instance-level `stop` is honored identically.
|
||||
stopped_instance = ChatMistralAI( # type: ignore[call-arg]
|
||||
model="ministral-8b-latest", stop=["5"], rate_limiter=rate_limiter
|
||||
).invoke(prompt)
|
||||
assert "5" not in stopped_instance.text
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
'ChatMistralAI',
|
||||
]),
|
||||
'kwargs': dict({
|
||||
'endpoint': 'boo',
|
||||
'endpoint': 'https://api.mistral.ai/v1',
|
||||
'max_concurrent_requests': 64,
|
||||
'max_retries': 2,
|
||||
'max_tokens': 100,
|
||||
@@ -20,10 +20,8 @@
|
||||
'type': 'secret',
|
||||
}),
|
||||
'model': 'mistral-small',
|
||||
'model_kwargs': dict({
|
||||
'stop': list([
|
||||
]),
|
||||
}),
|
||||
'stop': list([
|
||||
]),
|
||||
'temperature': 0.0,
|
||||
'timeout': 60,
|
||||
'top_p': 1,
|
||||
|
||||
@@ -90,12 +90,12 @@ def test_mistralai_initialization_baseurl(
|
||||
("MISTRAL_BASE_URL"),
|
||||
],
|
||||
)
|
||||
def test_mistralai_initialization_baseurl_env(env_var_name: str) -> None:
|
||||
def test_mistralai_initialization_baseurl_env(
|
||||
env_var_name: str, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Test ChatMistralAI initialization."""
|
||||
# Verify that ChatMistralAI can be initialized using env variable
|
||||
import os
|
||||
|
||||
os.environ[env_var_name] = "boo"
|
||||
monkeypatch.setenv(env_var_name, "boo")
|
||||
model = ChatMistralAI(model="test") # type: ignore[call-arg]
|
||||
assert model.endpoint == "boo"
|
||||
|
||||
@@ -513,12 +513,14 @@ def test_tool_id_conversion() -> None:
|
||||
|
||||
def test_extra_kwargs() -> None:
|
||||
# Check that foo is saved in extra_kwargs.
|
||||
llm = ChatMistralAI(model="my-model", foo=3, max_tokens=10) # type: ignore[call-arg]
|
||||
with pytest.warns(UserWarning, match="foo is not default parameter"):
|
||||
llm = ChatMistralAI(model="my-model", foo=3, max_tokens=10) # type: ignore[call-arg]
|
||||
assert llm.max_tokens == 10
|
||||
assert llm.model_kwargs == {"foo": 3}
|
||||
|
||||
# Test that if extra_kwargs are provided, they are added to it.
|
||||
llm = ChatMistralAI(model="my-model", foo=3, model_kwargs={"bar": 2}) # type: ignore[call-arg]
|
||||
with pytest.warns(UserWarning, match="foo is not default parameter"):
|
||||
llm = ChatMistralAI(model="my-model", foo=3, model_kwargs={"bar": 2}) # type: ignore[call-arg]
|
||||
assert llm.model_kwargs == {"foo": 3, "bar": 2}
|
||||
|
||||
# Test that if provided twice it errors
|
||||
@@ -526,6 +528,52 @@ def test_extra_kwargs() -> None:
|
||||
ChatMistralAI(model="my-model", foo=3, model_kwargs={"foo": 2}) # type: ignore[call-arg]
|
||||
|
||||
|
||||
def test_stop_stored_as_field() -> None:
|
||||
"""`stop` is a first-class field, not routed into `model_kwargs`."""
|
||||
llm = ChatMistralAI(model="my-model", stop=["END"]) # type: ignore[call-arg]
|
||||
assert llm.stop == ["END"]
|
||||
assert "stop" not in llm.model_kwargs
|
||||
|
||||
|
||||
def test_create_message_dicts_sends_instance_stop() -> None:
|
||||
"""Instance-level `stop` is forwarded to the request params."""
|
||||
llm = ChatMistralAI(model="my-model", stop=["END"]) # type: ignore[call-arg]
|
||||
_, params = llm._create_message_dicts([HumanMessage("hi")], None)
|
||||
assert params["stop"] == ["END"]
|
||||
|
||||
|
||||
def test_create_message_dicts_per_call_stop_overrides_instance() -> None:
|
||||
"""A per-call `stop` (including an empty list) overrides the instance value."""
|
||||
llm = ChatMistralAI(model="my-model", stop=["END"]) # type: ignore[call-arg]
|
||||
# A non-empty per-call value wins over the instance default.
|
||||
_, params = llm._create_message_dicts([HumanMessage("hi")], ["STOP"])
|
||||
assert params["stop"] == ["STOP"]
|
||||
|
||||
# An explicit empty list overrides the instance default and is treated as
|
||||
# "no stop sequences", so it is omitted from the request rather than sent
|
||||
# as an empty array (which the API would reject).
|
||||
_, params = llm._create_message_dicts([HumanMessage("hi")], [])
|
||||
assert "stop" not in params
|
||||
|
||||
|
||||
def test_create_message_dicts_omits_stop_when_unset() -> None:
|
||||
"""No `stop` field and no per-call value means `stop` is not sent."""
|
||||
llm = ChatMistralAI(model="my-model") # type: ignore[call-arg]
|
||||
_, params = llm._create_message_dicts([HumanMessage("hi")], None)
|
||||
assert "stop" not in params
|
||||
|
||||
|
||||
def test_get_ls_params_stop_precedence() -> None:
|
||||
"""`_get_ls_params` records instance `stop` and lets a per-call value win."""
|
||||
llm = ChatMistralAI(model="my-model", stop=["END"]) # type: ignore[call-arg]
|
||||
assert llm._get_ls_params().get("ls_stop") == ["END"]
|
||||
assert llm._get_ls_params(stop=["STOP"]).get("ls_stop") == ["STOP"]
|
||||
|
||||
# Without an instance default and no per-call value, `ls_stop` is omitted.
|
||||
llm_no_stop = ChatMistralAI(model="my-model") # type: ignore[call-arg]
|
||||
assert "ls_stop" not in llm_no_stop._get_ls_params()
|
||||
|
||||
|
||||
def test_retry_with_failure_then_success() -> None:
|
||||
"""Test retry mechanism works correctly when fiest request fails, second succeed."""
|
||||
# Create a real ChatMistralAI instance
|
||||
|
||||
Reference in New Issue
Block a user