langchain-mistralai - make base URL possible to set via env variable for ChatMistralAI (#25956)

Thank you for contributing to LangChain!


**Description:** 

Similar to other packages (`langchain_openai`, `langchain_anthropic`) it
would be beneficial if that `ChatMistralAI` model could fetch the API
base URL from the environment.

This PR allows this via the following order:
- provided value
- then whatever `MISTRAL_API_URL` is set to
- then whatever `MISTRAL_BASE_URL` is set to
- if `None`, then default is ` "https://api.mistral.com/v1"`


- [x] **Add tests and docs**:

Added unit tests, docs I feel are unnecessary, as this is just aligning
with other packages that do the same?


- [x] **Lint and test**: 

Additional guidelines:
- Make sure optional dependencies are imported within a function.
- Please do not add dependencies to pyproject.toml files (even optional
ones) unless they are required for unit tests.
- Most PRs should not touch more than one package.
- Changes should be backwards compatible.
- If you are adding something to community, do not re-import it in
langchain.

If no one reviews your PR within a few days, please @-mention one of
baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17.

---------

Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
Maximilian Schulz 2024-09-03 16:32:35 +02:00 committed by GitHub
parent c7154a4045
commit fdeaff4149
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 44 additions and 3 deletions

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import hashlib import hashlib
import json import json
import logging import logging
import os
import re import re
import uuid import uuid
from operator import itemgetter from operator import itemgetter
@ -364,7 +365,7 @@ class ChatMistralAI(BaseChatModel):
alias="api_key", alias="api_key",
default_factory=secret_from_env("MISTRAL_API_KEY", default=None), default_factory=secret_from_env("MISTRAL_API_KEY", default=None),
) )
endpoint: str = "https://api.mistral.ai/v1" endpoint: Optional[str] = Field(default=None, alias="base_url")
max_retries: int = 5 max_retries: int = 5
timeout: int = 120 timeout: int = 120
max_concurrent_requests: int = 64 max_concurrent_requests: int = 64
@ -472,10 +473,17 @@ class ChatMistralAI(BaseChatModel):
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate api key, python package exists, temperature, and top_p.""" """Validate api key, python package exists, temperature, and top_p."""
api_key_str = values["mistral_api_key"].get_secret_value() api_key_str = values["mistral_api_key"].get_secret_value()
# todo: handle retries # todo: handle retries
base_url_str = (
values.get("endpoint")
or os.environ.get("MISTRAL_BASE_URL")
or "https://api.mistral.ai/v1"
)
values["endpoint"] = base_url_str
if not values.get("client"): if not values.get("client"):
values["client"] = httpx.Client( values["client"] = httpx.Client(
base_url=values["endpoint"], base_url=base_url_str,
headers={ headers={
"Content-Type": "application/json", "Content-Type": "application/json",
"Accept": "application/json", "Accept": "application/json",
@ -486,7 +494,7 @@ class ChatMistralAI(BaseChatModel):
# todo: handle retries and max_concurrency # todo: handle retries and max_concurrency
if not values.get("async_client"): if not values.get("async_client"):
values["async_client"] = httpx.AsyncClient( values["async_client"] = httpx.AsyncClient(
base_url=values["endpoint"], base_url=base_url_str,
headers={ headers={
"Content-Type": "application/json", "Content-Type": "application/json",
"Accept": "application/json", "Accept": "application/json",

View File

@ -44,6 +44,39 @@ def test_mistralai_initialization() -> None:
assert cast(SecretStr, model.mistral_api_key).get_secret_value() == "test" assert cast(SecretStr, model.mistral_api_key).get_secret_value() == "test"
@pytest.mark.parametrize(
"model,expected_url",
[
(ChatMistralAI(model="test"), "https://api.mistral.ai/v1"), # type: ignore[call-arg, arg-type]
(ChatMistralAI(model="test", endpoint="baz"), "baz"), # type: ignore[call-arg, arg-type]
],
)
def test_mistralai_initialization_baseurl(
model: ChatMistralAI, expected_url: str
) -> None:
"""Test ChatMistralAI initialization."""
# Verify that ChatMistralAI can be initialized providing endpoint, but also
# with default
assert model.endpoint == expected_url
@pytest.mark.parametrize(
"env_var_name",
[
("MISTRAL_BASE_URL"),
],
)
def test_mistralai_initialization_baseurl_env(env_var_name: str) -> None:
"""Test ChatMistralAI initialization."""
# Verify that ChatMistralAI can be initialized using env variable
import os
os.environ[env_var_name] = "boo"
model = ChatMistralAI(model="test") # type: ignore[call-arg]
assert model.endpoint == "boo"
@pytest.mark.parametrize( @pytest.mark.parametrize(
("message", "expected"), ("message", "expected"),
[ [