From fdeaff4149ed2e4bf034ac18bfff49f613a05e6e Mon Sep 17 00:00:00 2001 From: Maximilian Schulz <83698606+maxschulz-COL@users.noreply.github.com> Date: Tue, 3 Sep 2024 16:32:35 +0200 Subject: [PATCH] `langchain-mistralai` - make base URL possible to set via env variable for `ChatMistralAI` (#25956) Thank you for contributing to LangChain! **Description:** Similar to other packages (`langchain_openai`, `langchain_anthropic`) it would be beneficial if that `ChatMistralAI` model could fetch the API base URL from the environment. This PR allows this via the following order: - provided value - then whatever `MISTRAL_API_URL` is set to - then whatever `MISTRAL_BASE_URL` is set to - if `None`, then default is ` "https://api.mistral.com/v1"` - [x] **Add tests and docs**: Added unit tests, docs I feel are unnecessary, as this is just aligning with other packages that do the same? - [x] **Lint and test**: Additional guidelines: - Make sure optional dependencies are imported within a function. - Please do not add dependencies to pyproject.toml files (even optional ones) unless they are required for unit tests. - Most PRs should not touch more than one package. - Changes should be backwards compatible. - If you are adding something to community, do not re-import it in langchain. If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17. --------- Co-authored-by: Chester Curme --- .../langchain_mistralai/chat_models.py | 14 ++++++-- .../tests/unit_tests/test_chat_models.py | 33 +++++++++++++++++++ 2 files changed, 44 insertions(+), 3 deletions(-) diff --git a/libs/partners/mistralai/langchain_mistralai/chat_models.py b/libs/partners/mistralai/langchain_mistralai/chat_models.py index 804bb64f340..06ee230d1f6 100644 --- a/libs/partners/mistralai/langchain_mistralai/chat_models.py +++ b/libs/partners/mistralai/langchain_mistralai/chat_models.py @@ -3,6 +3,7 @@ from __future__ import annotations import hashlib import json import logging +import os import re import uuid from operator import itemgetter @@ -364,7 +365,7 @@ class ChatMistralAI(BaseChatModel): alias="api_key", default_factory=secret_from_env("MISTRAL_API_KEY", default=None), ) - endpoint: str = "https://api.mistral.ai/v1" + endpoint: Optional[str] = Field(default=None, alias="base_url") max_retries: int = 5 timeout: int = 120 max_concurrent_requests: int = 64 @@ -472,10 +473,17 @@ class ChatMistralAI(BaseChatModel): def validate_environment(cls, values: Dict) -> Dict: """Validate api key, python package exists, temperature, and top_p.""" api_key_str = values["mistral_api_key"].get_secret_value() + # todo: handle retries + base_url_str = ( + values.get("endpoint") + or os.environ.get("MISTRAL_BASE_URL") + or "https://api.mistral.ai/v1" + ) + values["endpoint"] = base_url_str if not values.get("client"): values["client"] = httpx.Client( - base_url=values["endpoint"], + base_url=base_url_str, headers={ "Content-Type": "application/json", "Accept": "application/json", @@ -486,7 +494,7 @@ class ChatMistralAI(BaseChatModel): # todo: handle retries and max_concurrency if not values.get("async_client"): values["async_client"] = httpx.AsyncClient( - base_url=values["endpoint"], + base_url=base_url_str, headers={ "Content-Type": "application/json", "Accept": "application/json", diff --git a/libs/partners/mistralai/tests/unit_tests/test_chat_models.py b/libs/partners/mistralai/tests/unit_tests/test_chat_models.py index 35b21af8a15..62b86333887 100644 --- a/libs/partners/mistralai/tests/unit_tests/test_chat_models.py +++ b/libs/partners/mistralai/tests/unit_tests/test_chat_models.py @@ -44,6 +44,39 @@ def test_mistralai_initialization() -> None: assert cast(SecretStr, model.mistral_api_key).get_secret_value() == "test" +@pytest.mark.parametrize( + "model,expected_url", + [ + (ChatMistralAI(model="test"), "https://api.mistral.ai/v1"), # type: ignore[call-arg, arg-type] + (ChatMistralAI(model="test", endpoint="baz"), "baz"), # type: ignore[call-arg, arg-type] + ], +) +def test_mistralai_initialization_baseurl( + model: ChatMistralAI, expected_url: str +) -> None: + """Test ChatMistralAI initialization.""" + # Verify that ChatMistralAI can be initialized providing endpoint, but also + # with default + + assert model.endpoint == expected_url + + +@pytest.mark.parametrize( + "env_var_name", + [ + ("MISTRAL_BASE_URL"), + ], +) +def test_mistralai_initialization_baseurl_env(env_var_name: str) -> None: + """Test ChatMistralAI initialization.""" + # Verify that ChatMistralAI can be initialized using env variable + import os + + os.environ[env_var_name] = "boo" + model = ChatMistralAI(model="test") # type: ignore[call-arg] + assert model.endpoint == "boo" + + @pytest.mark.parametrize( ("message", "expected"), [