mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-16 09:48:04 +00:00
parent
17182406f3
commit
3490d70238
@ -48,7 +48,7 @@
|
||||
"source": [
|
||||
"import getpass\n",
|
||||
"\n",
|
||||
"mistral_api_key = getpass.getpass()"
|
||||
"api_key = getpass.getpass()"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -81,8 +81,8 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# If mistral_api_key is not passed, default behavior is to use the `MISTRAL_API_KEY` environment variable.\n",
|
||||
"chat = ChatMistralAI(mistral_api_key=mistral_api_key)"
|
||||
"# If api_key is not passed, default behavior is to use the `MISTRAL_API_KEY` environment variable.\n",
|
||||
"chat = ChatMistralAI(api_key=api_key)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -45,7 +45,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"embedding = MistralAIEmbeddings(mistral_api_key=\"your-api-key\")"
|
||||
"embedding = MistralAIEmbeddings(api_key=\"your-api-key\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -54,8 +54,8 @@
|
||||
"<ChatModelTabs\n",
|
||||
" openaiParams={`model=\"gpt-3.5-turbo-0125\", api_key=\"...\"`}\n",
|
||||
" anthropicParams={`model=\"claude-3-sonnet-20240229\", anthropic_api_key=\"...\"`}\n",
|
||||
" mistralParams={`model=\"mistral-large-latest\", api_key=\"...\"`}\n",
|
||||
" fireworksParams={`model=\"accounts/fireworks/models/mixtral-8x7b-instruct\", api_key=\"...\"`}\n",
|
||||
" mistralParams={`model=\"mistral-large-latest\", mistral_api_key=\"...\"`}\n",
|
||||
" googleParams={`model=\"gemini-pro\", google_api_key=\"...\"`}\n",
|
||||
" togetherParams={`, together_api_key=\"...\"`}\n",
|
||||
" customVarName=\"chat\"\n",
|
||||
|
@ -186,7 +186,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
|
||||
client: httpx.Client = Field(default=None) #: :meta private:
|
||||
async_client: httpx.AsyncClient = Field(default=None) #: :meta private:
|
||||
mistral_api_key: Optional[SecretStr] = None
|
||||
mistral_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
|
||||
endpoint: str = "https://api.mistral.ai/v1"
|
||||
max_retries: int = 5
|
||||
timeout: int = 120
|
||||
@ -202,6 +202,12 @@ class ChatMistralAI(BaseChatModel):
|
||||
safe_mode: bool = False
|
||||
streaming: bool = False
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
allow_population_by_field_name = True
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling the API."""
|
||||
|
@ -29,15 +29,16 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_mistralai import MistralAIEmbeddings
|
||||
|
||||
mistral = MistralAIEmbeddings(
|
||||
model="mistral-embed",
|
||||
mistral_api_key="my-api-key"
|
||||
api_key="my-api-key"
|
||||
)
|
||||
"""
|
||||
|
||||
client: httpx.Client = Field(default=None) #: :meta private:
|
||||
async_client: httpx.AsyncClient = Field(default=None) #: :meta private:
|
||||
mistral_api_key: Optional[SecretStr] = None
|
||||
mistral_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
|
||||
endpoint: str = "https://api.mistral.ai/v1/"
|
||||
max_retries: int = 5
|
||||
timeout: int = 120
|
||||
@ -49,6 +50,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
||||
class Config:
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
allow_population_by_field_name = True
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""Test MistralAI Chat API wrapper."""
|
||||
|
||||
import os
|
||||
from typing import Any, AsyncGenerator, Dict, Generator
|
||||
from typing import Any, AsyncGenerator, Dict, Generator, cast
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@ -13,6 +13,7 @@ from langchain_core.messages import (
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import SecretStr
|
||||
|
||||
from langchain_mistralai.chat_models import ( # type: ignore[import]
|
||||
ChatMistralAI,
|
||||
@ -31,7 +32,11 @@ def test_mistralai_initialization() -> None:
|
||||
"""Test ChatMistralAI initialization."""
|
||||
# Verify that ChatMistralAI can be initialized using a secret key provided
|
||||
# as a parameter rather than an environment variable.
|
||||
ChatMistralAI(model="test", mistral_api_key="test")
|
||||
for model in [
|
||||
ChatMistralAI(model="test", mistral_api_key="test"),
|
||||
ChatMistralAI(model="test", api_key="test"),
|
||||
]:
|
||||
assert cast(SecretStr, model.mistral_api_key).get_secret_value() == "test"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -1,4 +1,7 @@
|
||||
import os
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.pydantic_v1 import SecretStr
|
||||
|
||||
from langchain_mistralai import MistralAIEmbeddings
|
||||
|
||||
@ -6,5 +9,9 @@ os.environ["MISTRAL_API_KEY"] = "foo"
|
||||
|
||||
|
||||
def test_mistral_init() -> None:
|
||||
embeddings = MistralAIEmbeddings()
|
||||
assert embeddings.model == "mistral-embed"
|
||||
for model in [
|
||||
MistralAIEmbeddings(model="mistral-embed", mistral_api_key="test"),
|
||||
MistralAIEmbeddings(model="mistral-embed", api_key="test"),
|
||||
]:
|
||||
assert model.model == "mistral-embed"
|
||||
assert cast(SecretStr, model.mistral_api_key).get_secret_value() == "test"
|
||||
|
Loading…
Reference in New Issue
Block a user