diff --git a/libs/partners/mistralai/langchain_mistralai/chat_models.py b/libs/partners/mistralai/langchain_mistralai/chat_models.py index e920fc6784c..615616d3c9b 100644 --- a/libs/partners/mistralai/langchain_mistralai/chat_models.py +++ b/libs/partners/mistralai/langchain_mistralai/chat_models.py @@ -546,6 +546,14 @@ class ChatMistralAI(BaseChatModel): max_tokens: int | None = None + stop: list[str] | None = None + """Default stop sequences. + + Generation stops when any of these strings is produced; the stop sequence itself + is not included in the output. Can be overridden per call via the `stop` argument. + Mistral accepts up to 4 stop sequences. + """ + top_p: float = 1 """Decode using nucleus sampling: consider the smallest set of tokens whose probability sum is at least `top_p`. Must be in the closed interval @@ -599,7 +607,7 @@ class ChatMistralAI(BaseChatModel): ) if ls_max_tokens := params.get("max_tokens", self.max_tokens): ls_params["ls_max_tokens"] = ls_max_tokens - if ls_stop := stop or params.get("stop", None): + if ls_stop := stop or self.stop or params.get("stop", None): ls_params["ls_stop"] = ls_stop return ls_params @@ -749,12 +757,9 @@ class ChatMistralAI(BaseChatModel): self, messages: list[BaseMessage], stop: list[str] | None ) -> tuple[list[dict], dict[str, Any]]: params = self._client_params - if stop is not None or "stop" in params: - if "stop" in params: - params.pop("stop") - logger.warning( - "Parameter `stop` not yet supported (https://docs.mistral.ai/api)" - ) + stop = stop if stop is not None else self.stop + if stop: + params["stop"] = stop message_dicts = [_convert_message_to_mistral_chat_message(m) for m in messages] return message_dicts, params diff --git a/libs/partners/mistralai/tests/integration_tests/test_chat_models.py b/libs/partners/mistralai/tests/integration_tests/test_chat_models.py index a1d42a47ca2..4ca23b5aeb4 100644 --- a/libs/partners/mistralai/tests/integration_tests/test_chat_models.py +++ b/libs/partners/mistralai/tests/integration_tests/test_chat_models.py @@ -199,3 +199,24 @@ def test_reasoning_v1() -> None: next_message = {"role": "user", "content": "What is my name?"} _ = model.invoke([input_message, full, next_message]) + + +def test_stop_sequence() -> None: + """Mistral honors `stop`: generation halts and the sequence is excluded.""" + model = ChatMistralAI(model="ministral-8b-latest", rate_limiter=rate_limiter) # type: ignore[call-arg] + prompt = "Count from 1 to 10, separated by spaces. Reply with only the numbers." + + # Without a stop sequence the full count is produced. + baseline = model.invoke(prompt) + assert isinstance(baseline.text, str) + assert "5" in baseline.text + + # With stop=["5"], generation halts before "5" and the sequence is excluded. + stopped = model.invoke(prompt, stop=["5"]) + assert "5" not in stopped.text + + # An instance-level `stop` is honored identically. + stopped_instance = ChatMistralAI( # type: ignore[call-arg] + model="ministral-8b-latest", stop=["5"], rate_limiter=rate_limiter + ).invoke(prompt) + assert "5" not in stopped_instance.text diff --git a/libs/partners/mistralai/tests/unit_tests/__snapshots__/test_standard.ambr b/libs/partners/mistralai/tests/unit_tests/__snapshots__/test_standard.ambr index 66b802e97e9..d45470f440e 100644 --- a/libs/partners/mistralai/tests/unit_tests/__snapshots__/test_standard.ambr +++ b/libs/partners/mistralai/tests/unit_tests/__snapshots__/test_standard.ambr @@ -8,7 +8,7 @@ 'ChatMistralAI', ]), 'kwargs': dict({ - 'endpoint': 'boo', + 'endpoint': 'https://api.mistral.ai/v1', 'max_concurrent_requests': 64, 'max_retries': 2, 'max_tokens': 100, @@ -20,10 +20,8 @@ 'type': 'secret', }), 'model': 'mistral-small', - 'model_kwargs': dict({ - 'stop': list([ - ]), - }), + 'stop': list([ + ]), 'temperature': 0.0, 'timeout': 60, 'top_p': 1, diff --git a/libs/partners/mistralai/tests/unit_tests/test_chat_models.py b/libs/partners/mistralai/tests/unit_tests/test_chat_models.py index 38de8cdb074..2ebc4f6b2ba 100644 --- a/libs/partners/mistralai/tests/unit_tests/test_chat_models.py +++ b/libs/partners/mistralai/tests/unit_tests/test_chat_models.py @@ -90,12 +90,12 @@ def test_mistralai_initialization_baseurl( ("MISTRAL_BASE_URL"), ], ) -def test_mistralai_initialization_baseurl_env(env_var_name: str) -> None: +def test_mistralai_initialization_baseurl_env( + env_var_name: str, monkeypatch: pytest.MonkeyPatch +) -> None: """Test ChatMistralAI initialization.""" # Verify that ChatMistralAI can be initialized using env variable - import os - - os.environ[env_var_name] = "boo" + monkeypatch.setenv(env_var_name, "boo") model = ChatMistralAI(model="test") # type: ignore[call-arg] assert model.endpoint == "boo" @@ -513,12 +513,14 @@ def test_tool_id_conversion() -> None: def test_extra_kwargs() -> None: # Check that foo is saved in extra_kwargs. - llm = ChatMistralAI(model="my-model", foo=3, max_tokens=10) # type: ignore[call-arg] + with pytest.warns(UserWarning, match="foo is not default parameter"): + llm = ChatMistralAI(model="my-model", foo=3, max_tokens=10) # type: ignore[call-arg] assert llm.max_tokens == 10 assert llm.model_kwargs == {"foo": 3} # Test that if extra_kwargs are provided, they are added to it. - llm = ChatMistralAI(model="my-model", foo=3, model_kwargs={"bar": 2}) # type: ignore[call-arg] + with pytest.warns(UserWarning, match="foo is not default parameter"): + llm = ChatMistralAI(model="my-model", foo=3, model_kwargs={"bar": 2}) # type: ignore[call-arg] assert llm.model_kwargs == {"foo": 3, "bar": 2} # Test that if provided twice it errors @@ -526,6 +528,52 @@ def test_extra_kwargs() -> None: ChatMistralAI(model="my-model", foo=3, model_kwargs={"foo": 2}) # type: ignore[call-arg] +def test_stop_stored_as_field() -> None: + """`stop` is a first-class field, not routed into `model_kwargs`.""" + llm = ChatMistralAI(model="my-model", stop=["END"]) # type: ignore[call-arg] + assert llm.stop == ["END"] + assert "stop" not in llm.model_kwargs + + +def test_create_message_dicts_sends_instance_stop() -> None: + """Instance-level `stop` is forwarded to the request params.""" + llm = ChatMistralAI(model="my-model", stop=["END"]) # type: ignore[call-arg] + _, params = llm._create_message_dicts([HumanMessage("hi")], None) + assert params["stop"] == ["END"] + + +def test_create_message_dicts_per_call_stop_overrides_instance() -> None: + """A per-call `stop` (including an empty list) overrides the instance value.""" + llm = ChatMistralAI(model="my-model", stop=["END"]) # type: ignore[call-arg] + # A non-empty per-call value wins over the instance default. + _, params = llm._create_message_dicts([HumanMessage("hi")], ["STOP"]) + assert params["stop"] == ["STOP"] + + # An explicit empty list overrides the instance default and is treated as + # "no stop sequences", so it is omitted from the request rather than sent + # as an empty array (which the API would reject). + _, params = llm._create_message_dicts([HumanMessage("hi")], []) + assert "stop" not in params + + +def test_create_message_dicts_omits_stop_when_unset() -> None: + """No `stop` field and no per-call value means `stop` is not sent.""" + llm = ChatMistralAI(model="my-model") # type: ignore[call-arg] + _, params = llm._create_message_dicts([HumanMessage("hi")], None) + assert "stop" not in params + + +def test_get_ls_params_stop_precedence() -> None: + """`_get_ls_params` records instance `stop` and lets a per-call value win.""" + llm = ChatMistralAI(model="my-model", stop=["END"]) # type: ignore[call-arg] + assert llm._get_ls_params().get("ls_stop") == ["END"] + assert llm._get_ls_params(stop=["STOP"]).get("ls_stop") == ["STOP"] + + # Without an instance default and no per-call value, `ls_stop` is omitted. + llm_no_stop = ChatMistralAI(model="my-model") # type: ignore[call-arg] + assert "ls_stop" not in llm_no_stop._get_ls_params() + + def test_retry_with_failure_then_success() -> None: """Test retry mechanism works correctly when fiest request fails, second succeed.""" # Create a real ChatMistralAI instance