Compare commits

...

1 Commits

Author SHA1 Message Date
Bagatur
cd0ce3f4c6 mistralai[patch]: allow manual client construction 2024-04-29 13:15:00 -04:00
2 changed files with 31 additions and 22 deletions

View File

@@ -307,8 +307,8 @@ def _convert_message_to_mistral_chat_message(
class ChatMistralAI(BaseChatModel):
"""A chat model that uses the MistralAI API."""
client: httpx.Client = Field(default=None) #: :meta private:
async_client: httpx.AsyncClient = Field(default=None) #: :meta private:
client: httpx.Client = Field(default=None)
async_client: httpx.AsyncClient = Field(default=None)
mistral_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
endpoint: str = "https://api.mistral.ai/v1"
max_retries: int = 5
@@ -407,26 +407,28 @@ class ChatMistralAI(BaseChatModel):
)
)
api_key_str = values["mistral_api_key"].get_secret_value()
# todo: handle retries
values["client"] = httpx.Client(
base_url=values["endpoint"],
headers={
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": f"Bearer {api_key_str}",
},
timeout=values["timeout"],
)
# todo: handle retries and max_concurrency
values["async_client"] = httpx.AsyncClient(
base_url=values["endpoint"],
headers={
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": f"Bearer {api_key_str}",
},
timeout=values["timeout"],
)
if not values["client"]:
# todo: handle retries
values["client"] = httpx.Client(
base_url=values["endpoint"],
headers={
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": f"Bearer {api_key_str}",
},
timeout=values["timeout"],
)
if not values["async_client"]:
# todo: handle retries and max_concurrency
values["async_client"] = httpx.AsyncClient(
base_url=values["endpoint"],
headers={
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": f"Bearer {api_key_str}",
},
timeout=values["timeout"],
)
if values["temperature"] is not None and not 0 <= values["temperature"] <= 1:
raise ValueError("temperature must be in the range [0.0, 1.0]")

View File

@@ -4,6 +4,7 @@ import os
from typing import Any, AsyncGenerator, Dict, Generator, List, cast
from unittest.mock import patch
import httpx
import pytest
from langchain_core.callbacks.base import BaseCallbackHandler
from langchain_core.messages import (
@@ -42,6 +43,12 @@ def test_mistralai_initialization() -> None:
assert cast(SecretStr, model.mistral_api_key).get_secret_value() == "test"
def test_mistralai_init_client() -> None:
client = httpx.Client(base_url="foobar")
llm = ChatMistralAI(client=client)
assert llm.client == client
@pytest.mark.parametrize(
("message", "expected"),
[