mistralai[patch]: standardize model params (#20163)

Related to #20085
This commit is contained in:
Bagatur 2024-04-08 11:48:38 -05:00 committed by GitHub
parent 17182406f3
commit 3490d70238
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 32 additions and 12 deletions

View File

@ -48,7 +48,7 @@
"source": [ "source": [
"import getpass\n", "import getpass\n",
"\n", "\n",
"mistral_api_key = getpass.getpass()" "api_key = getpass.getpass()"
] ]
}, },
{ {
@ -81,8 +81,8 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"# If mistral_api_key is not passed, default behavior is to use the `MISTRAL_API_KEY` environment variable.\n", "# If api_key is not passed, default behavior is to use the `MISTRAL_API_KEY` environment variable.\n",
"chat = ChatMistralAI(mistral_api_key=mistral_api_key)" "chat = ChatMistralAI(api_key=api_key)"
] ]
}, },
{ {

View File

@ -45,7 +45,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"embedding = MistralAIEmbeddings(mistral_api_key=\"your-api-key\")" "embedding = MistralAIEmbeddings(api_key=\"your-api-key\")"
] ]
}, },
{ {

View File

@ -54,8 +54,8 @@
"<ChatModelTabs\n", "<ChatModelTabs\n",
" openaiParams={`model=\"gpt-3.5-turbo-0125\", api_key=\"...\"`}\n", " openaiParams={`model=\"gpt-3.5-turbo-0125\", api_key=\"...\"`}\n",
" anthropicParams={`model=\"claude-3-sonnet-20240229\", anthropic_api_key=\"...\"`}\n", " anthropicParams={`model=\"claude-3-sonnet-20240229\", anthropic_api_key=\"...\"`}\n",
" mistralParams={`model=\"mistral-large-latest\", api_key=\"...\"`}\n",
" fireworksParams={`model=\"accounts/fireworks/models/mixtral-8x7b-instruct\", api_key=\"...\"`}\n", " fireworksParams={`model=\"accounts/fireworks/models/mixtral-8x7b-instruct\", api_key=\"...\"`}\n",
" mistralParams={`model=\"mistral-large-latest\", mistral_api_key=\"...\"`}\n",
" googleParams={`model=\"gemini-pro\", google_api_key=\"...\"`}\n", " googleParams={`model=\"gemini-pro\", google_api_key=\"...\"`}\n",
" togetherParams={`, together_api_key=\"...\"`}\n", " togetherParams={`, together_api_key=\"...\"`}\n",
" customVarName=\"chat\"\n", " customVarName=\"chat\"\n",

View File

@ -186,7 +186,7 @@ class ChatMistralAI(BaseChatModel):
client: httpx.Client = Field(default=None) #: :meta private: client: httpx.Client = Field(default=None) #: :meta private:
async_client: httpx.AsyncClient = Field(default=None) #: :meta private: async_client: httpx.AsyncClient = Field(default=None) #: :meta private:
mistral_api_key: Optional[SecretStr] = None mistral_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
endpoint: str = "https://api.mistral.ai/v1" endpoint: str = "https://api.mistral.ai/v1"
max_retries: int = 5 max_retries: int = 5
timeout: int = 120 timeout: int = 120
@ -202,6 +202,12 @@ class ChatMistralAI(BaseChatModel):
safe_mode: bool = False safe_mode: bool = False
streaming: bool = False streaming: bool = False
class Config:
"""Configuration for this pydantic object."""
allow_population_by_field_name = True
arbitrary_types_allowed = True
@property @property
def _default_params(self) -> Dict[str, Any]: def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling the API.""" """Get the default parameters for calling the API."""

View File

@ -29,15 +29,16 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
.. code-block:: python .. code-block:: python
from langchain_mistralai import MistralAIEmbeddings from langchain_mistralai import MistralAIEmbeddings
mistral = MistralAIEmbeddings( mistral = MistralAIEmbeddings(
model="mistral-embed", model="mistral-embed",
mistral_api_key="my-api-key" api_key="my-api-key"
) )
""" """
client: httpx.Client = Field(default=None) #: :meta private: client: httpx.Client = Field(default=None) #: :meta private:
async_client: httpx.AsyncClient = Field(default=None) #: :meta private: async_client: httpx.AsyncClient = Field(default=None) #: :meta private:
mistral_api_key: Optional[SecretStr] = None mistral_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
endpoint: str = "https://api.mistral.ai/v1/" endpoint: str = "https://api.mistral.ai/v1/"
max_retries: int = 5 max_retries: int = 5
timeout: int = 120 timeout: int = 120
@ -49,6 +50,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
class Config: class Config:
extra = Extra.forbid extra = Extra.forbid
arbitrary_types_allowed = True arbitrary_types_allowed = True
allow_population_by_field_name = True
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:

View File

@ -1,7 +1,7 @@
"""Test MistralAI Chat API wrapper.""" """Test MistralAI Chat API wrapper."""
import os import os
from typing import Any, AsyncGenerator, Dict, Generator from typing import Any, AsyncGenerator, Dict, Generator, cast
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
@ -13,6 +13,7 @@ from langchain_core.messages import (
HumanMessage, HumanMessage,
SystemMessage, SystemMessage,
) )
from langchain_core.pydantic_v1 import SecretStr
from langchain_mistralai.chat_models import ( # type: ignore[import] from langchain_mistralai.chat_models import ( # type: ignore[import]
ChatMistralAI, ChatMistralAI,
@ -31,7 +32,11 @@ def test_mistralai_initialization() -> None:
"""Test ChatMistralAI initialization.""" """Test ChatMistralAI initialization."""
# Verify that ChatMistralAI can be initialized using a secret key provided # Verify that ChatMistralAI can be initialized using a secret key provided
# as a parameter rather than an environment variable. # as a parameter rather than an environment variable.
ChatMistralAI(model="test", mistral_api_key="test") for model in [
ChatMistralAI(model="test", mistral_api_key="test"),
ChatMistralAI(model="test", api_key="test"),
]:
assert cast(SecretStr, model.mistral_api_key).get_secret_value() == "test"
@pytest.mark.parametrize( @pytest.mark.parametrize(

View File

@ -1,4 +1,7 @@
import os import os
from typing import cast
from langchain_core.pydantic_v1 import SecretStr
from langchain_mistralai import MistralAIEmbeddings from langchain_mistralai import MistralAIEmbeddings
@ -6,5 +9,9 @@ os.environ["MISTRAL_API_KEY"] = "foo"
def test_mistral_init() -> None: def test_mistral_init() -> None:
embeddings = MistralAIEmbeddings() for model in [
assert embeddings.model == "mistral-embed" MistralAIEmbeddings(model="mistral-embed", mistral_api_key="test"),
MistralAIEmbeddings(model="mistral-embed", api_key="test"),
]:
assert model.model == "mistral-embed"
assert cast(SecretStr, model.mistral_api_key).get_secret_value() == "test"