mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-08 00:28:47 +00:00
mistral[patch]: support model_kwargs (#29838)
- **Description:** Frequency_penalty added as a client parameter - **Issue:** #29803 --------- Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
parent
924d9b1b33
commit
8293142fa0
libs/partners/mistralai
@ -68,9 +68,10 @@ from langchain_core.output_parsers.openai_tools import (
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import secret_from_env
|
||||
from langchain_core.utils import get_pydantic_field_names, secret_from_env
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
from langchain_core.utils.utils import _build_model_kwargs
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
@ -392,12 +393,22 @@ class ChatMistralAI(BaseChatModel):
|
||||
random_seed: Optional[int] = None
|
||||
safe_mode: Optional[bool] = None
|
||||
streaming: bool = False
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any invocation parameters not explicitly specified."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Any:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
values = _build_model_kwargs(values, all_required_field_names)
|
||||
return values
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling the API."""
|
||||
@ -408,6 +419,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
"top_p": self.top_p,
|
||||
"random_seed": self.random_seed,
|
||||
"safe_prompt": self.safe_mode,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
filtered = {k: v for k, v in defaults.items() if v is not None}
|
||||
return filtered
|
||||
|
@ -20,6 +20,10 @@
|
||||
'type': 'secret',
|
||||
}),
|
||||
'model': 'mistral-small',
|
||||
'model_kwargs': dict({
|
||||
'stop': list([
|
||||
]),
|
||||
}),
|
||||
'temperature': 0.0,
|
||||
'timeout': 60,
|
||||
'top_p': 1,
|
||||
|
@ -255,3 +255,18 @@ def test_tool_id_conversion() -> None:
|
||||
for input_id, expected_output in result_map.items():
|
||||
assert _convert_tool_call_id_to_mistral_compatible(input_id) == expected_output
|
||||
assert _is_valid_mistral_tool_call_id(expected_output)
|
||||
|
||||
|
||||
def test_extra_kwargs() -> None:
|
||||
# Check that foo is saved in extra_kwargs.
|
||||
llm = ChatMistralAI(model="my-model", foo=3, max_tokens=10) # type: ignore[call-arg]
|
||||
assert llm.max_tokens == 10
|
||||
assert llm.model_kwargs == {"foo": 3}
|
||||
|
||||
# Test that if extra_kwargs are provided, they are added to it.
|
||||
llm = ChatMistralAI(model="my-model", foo=3, model_kwargs={"bar": 2}) # type: ignore[call-arg]
|
||||
assert llm.model_kwargs == {"foo": 3, "bar": 2}
|
||||
|
||||
# Test that if provided twice it errors
|
||||
with pytest.raises(ValueError):
|
||||
ChatMistralAI(model="my-model", foo=3, model_kwargs={"foo": 2}) # type: ignore[call-arg]
|
||||
|
Loading…
Reference in New Issue
Block a user