diff --git a/libs/partners/mistralai/langchain_mistralai/chat_models.py b/libs/partners/mistralai/langchain_mistralai/chat_models.py index 4e304e29e0d..4ab31613ed8 100644 --- a/libs/partners/mistralai/langchain_mistralai/chat_models.py +++ b/libs/partners/mistralai/langchain_mistralai/chat_models.py @@ -68,9 +68,10 @@ from langchain_core.output_parsers.openai_tools import ( from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough from langchain_core.tools import BaseTool -from langchain_core.utils import secret_from_env +from langchain_core.utils import get_pydantic_field_names, secret_from_env from langchain_core.utils.function_calling import convert_to_openai_tool from langchain_core.utils.pydantic import is_basemodel_subclass +from langchain_core.utils.utils import _build_model_kwargs from pydantic import ( BaseModel, ConfigDict, @@ -392,12 +393,22 @@ class ChatMistralAI(BaseChatModel): random_seed: Optional[int] = None safe_mode: Optional[bool] = None streaming: bool = False + model_kwargs: Dict[str, Any] = Field(default_factory=dict) + """Holds any invocation parameters not explicitly specified.""" model_config = ConfigDict( populate_by_name=True, arbitrary_types_allowed=True, ) + @model_validator(mode="before") + @classmethod + def build_extra(cls, values: Dict[str, Any]) -> Any: + """Build extra kwargs from additional params that were passed in.""" + all_required_field_names = get_pydantic_field_names(cls) + values = _build_model_kwargs(values, all_required_field_names) + return values + @property def _default_params(self) -> Dict[str, Any]: """Get the default parameters for calling the API.""" @@ -408,6 +419,7 @@ class ChatMistralAI(BaseChatModel): "top_p": self.top_p, "random_seed": self.random_seed, "safe_prompt": self.safe_mode, + **self.model_kwargs, } filtered = {k: v for k, v in defaults.items() if v is not None} return filtered diff --git a/libs/partners/mistralai/tests/unit_tests/__snapshots__/test_standard.ambr b/libs/partners/mistralai/tests/unit_tests/__snapshots__/test_standard.ambr index f7986097c47..66b802e97e9 100644 --- a/libs/partners/mistralai/tests/unit_tests/__snapshots__/test_standard.ambr +++ b/libs/partners/mistralai/tests/unit_tests/__snapshots__/test_standard.ambr @@ -20,6 +20,10 @@ 'type': 'secret', }), 'model': 'mistral-small', + 'model_kwargs': dict({ + 'stop': list([ + ]), + }), 'temperature': 0.0, 'timeout': 60, 'top_p': 1, diff --git a/libs/partners/mistralai/tests/unit_tests/test_chat_models.py b/libs/partners/mistralai/tests/unit_tests/test_chat_models.py index 27bd73f5deb..4dc251832e7 100644 --- a/libs/partners/mistralai/tests/unit_tests/test_chat_models.py +++ b/libs/partners/mistralai/tests/unit_tests/test_chat_models.py @@ -255,3 +255,18 @@ def test_tool_id_conversion() -> None: for input_id, expected_output in result_map.items(): assert _convert_tool_call_id_to_mistral_compatible(input_id) == expected_output assert _is_valid_mistral_tool_call_id(expected_output) + + +def test_extra_kwargs() -> None: + # Check that foo is saved in extra_kwargs. + llm = ChatMistralAI(model="my-model", foo=3, max_tokens=10) # type: ignore[call-arg] + assert llm.max_tokens == 10 + assert llm.model_kwargs == {"foo": 3} + + # Test that if extra_kwargs are provided, they are added to it. + llm = ChatMistralAI(model="my-model", foo=3, model_kwargs={"bar": 2}) # type: ignore[call-arg] + assert llm.model_kwargs == {"foo": 3, "bar": 2} + + # Test that if provided twice it errors + with pytest.raises(ValueError): + ChatMistralAI(model="my-model", foo=3, model_kwargs={"foo": 2}) # type: ignore[call-arg]