1
0
mirror of https://github.com/hwchase17/langchain.git synced 2025-05-08 00:28:47 +00:00

mistral[patch]: support model_kwargs ()

- **Description:** Frequency_penalty added as a client parameter
- **Issue:** 

---------

Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
Mohammad Mohtashim 2025-02-21 04:47:39 +05:00 committed by GitHub
parent 924d9b1b33
commit 8293142fa0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 32 additions and 1 deletions
libs/partners/mistralai
langchain_mistralai
tests/unit_tests

View File

@ -68,9 +68,10 @@ from langchain_core.output_parsers.openai_tools import (
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool
from langchain_core.utils import secret_from_env
from langchain_core.utils import get_pydantic_field_names, secret_from_env
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.pydantic import is_basemodel_subclass
from langchain_core.utils.utils import _build_model_kwargs
from pydantic import (
BaseModel,
ConfigDict,
@ -392,12 +393,22 @@ class ChatMistralAI(BaseChatModel):
random_seed: Optional[int] = None
safe_mode: Optional[bool] = None
streaming: bool = False
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any invocation parameters not explicitly specified."""
model_config = ConfigDict(
populate_by_name=True,
arbitrary_types_allowed=True,
)
@model_validator(mode="before")
@classmethod
def build_extra(cls, values: Dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
values = _build_model_kwargs(values, all_required_field_names)
return values
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling the API."""
@ -408,6 +419,7 @@ class ChatMistralAI(BaseChatModel):
"top_p": self.top_p,
"random_seed": self.random_seed,
"safe_prompt": self.safe_mode,
**self.model_kwargs,
}
filtered = {k: v for k, v in defaults.items() if v is not None}
return filtered

View File

@ -20,6 +20,10 @@
'type': 'secret',
}),
'model': 'mistral-small',
'model_kwargs': dict({
'stop': list([
]),
}),
'temperature': 0.0,
'timeout': 60,
'top_p': 1,

View File

@ -255,3 +255,18 @@ def test_tool_id_conversion() -> None:
for input_id, expected_output in result_map.items():
assert _convert_tool_call_id_to_mistral_compatible(input_id) == expected_output
assert _is_valid_mistral_tool_call_id(expected_output)
def test_extra_kwargs() -> None:
# Check that foo is saved in extra_kwargs.
llm = ChatMistralAI(model="my-model", foo=3, max_tokens=10) # type: ignore[call-arg]
assert llm.max_tokens == 10
assert llm.model_kwargs == {"foo": 3}
# Test that if extra_kwargs are provided, they are added to it.
llm = ChatMistralAI(model="my-model", foo=3, model_kwargs={"bar": 2}) # type: ignore[call-arg]
assert llm.model_kwargs == {"foo": 3, "bar": 2}
# Test that if provided twice it errors
with pytest.raises(ValueError):
ChatMistralAI(model="my-model", foo=3, model_kwargs={"foo": 2}) # type: ignore[call-arg]