mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 13:23:35 +00:00
langchain-mistralai
- make base URL possible to set via env variable for ChatMistralAI
(#25956)
Thank you for contributing to LangChain! **Description:** Similar to other packages (`langchain_openai`, `langchain_anthropic`) it would be beneficial if that `ChatMistralAI` model could fetch the API base URL from the environment. This PR allows this via the following order: - provided value - then whatever `MISTRAL_API_URL` is set to - then whatever `MISTRAL_BASE_URL` is set to - if `None`, then default is ` "https://api.mistral.com/v1"` - [x] **Add tests and docs**: Added unit tests, docs I feel are unnecessary, as this is just aligning with other packages that do the same? - [x] **Lint and test**: Additional guidelines: - Make sure optional dependencies are imported within a function. - Please do not add dependencies to pyproject.toml files (even optional ones) unless they are required for unit tests. - Most PRs should not touch more than one package. - Changes should be backwards compatible. - If you are adding something to community, do not re-import it in langchain. If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17. --------- Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
parent
c7154a4045
commit
fdeaff4149
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
@ -364,7 +365,7 @@ class ChatMistralAI(BaseChatModel):
|
|||||||
alias="api_key",
|
alias="api_key",
|
||||||
default_factory=secret_from_env("MISTRAL_API_KEY", default=None),
|
default_factory=secret_from_env("MISTRAL_API_KEY", default=None),
|
||||||
)
|
)
|
||||||
endpoint: str = "https://api.mistral.ai/v1"
|
endpoint: Optional[str] = Field(default=None, alias="base_url")
|
||||||
max_retries: int = 5
|
max_retries: int = 5
|
||||||
timeout: int = 120
|
timeout: int = 120
|
||||||
max_concurrent_requests: int = 64
|
max_concurrent_requests: int = 64
|
||||||
@ -472,10 +473,17 @@ class ChatMistralAI(BaseChatModel):
|
|||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
"""Validate api key, python package exists, temperature, and top_p."""
|
"""Validate api key, python package exists, temperature, and top_p."""
|
||||||
api_key_str = values["mistral_api_key"].get_secret_value()
|
api_key_str = values["mistral_api_key"].get_secret_value()
|
||||||
|
|
||||||
# todo: handle retries
|
# todo: handle retries
|
||||||
|
base_url_str = (
|
||||||
|
values.get("endpoint")
|
||||||
|
or os.environ.get("MISTRAL_BASE_URL")
|
||||||
|
or "https://api.mistral.ai/v1"
|
||||||
|
)
|
||||||
|
values["endpoint"] = base_url_str
|
||||||
if not values.get("client"):
|
if not values.get("client"):
|
||||||
values["client"] = httpx.Client(
|
values["client"] = httpx.Client(
|
||||||
base_url=values["endpoint"],
|
base_url=base_url_str,
|
||||||
headers={
|
headers={
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"Accept": "application/json",
|
"Accept": "application/json",
|
||||||
@ -486,7 +494,7 @@ class ChatMistralAI(BaseChatModel):
|
|||||||
# todo: handle retries and max_concurrency
|
# todo: handle retries and max_concurrency
|
||||||
if not values.get("async_client"):
|
if not values.get("async_client"):
|
||||||
values["async_client"] = httpx.AsyncClient(
|
values["async_client"] = httpx.AsyncClient(
|
||||||
base_url=values["endpoint"],
|
base_url=base_url_str,
|
||||||
headers={
|
headers={
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"Accept": "application/json",
|
"Accept": "application/json",
|
||||||
|
@ -44,6 +44,39 @@ def test_mistralai_initialization() -> None:
|
|||||||
assert cast(SecretStr, model.mistral_api_key).get_secret_value() == "test"
|
assert cast(SecretStr, model.mistral_api_key).get_secret_value() == "test"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model,expected_url",
|
||||||
|
[
|
||||||
|
(ChatMistralAI(model="test"), "https://api.mistral.ai/v1"), # type: ignore[call-arg, arg-type]
|
||||||
|
(ChatMistralAI(model="test", endpoint="baz"), "baz"), # type: ignore[call-arg, arg-type]
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_mistralai_initialization_baseurl(
|
||||||
|
model: ChatMistralAI, expected_url: str
|
||||||
|
) -> None:
|
||||||
|
"""Test ChatMistralAI initialization."""
|
||||||
|
# Verify that ChatMistralAI can be initialized providing endpoint, but also
|
||||||
|
# with default
|
||||||
|
|
||||||
|
assert model.endpoint == expected_url
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"env_var_name",
|
||||||
|
[
|
||||||
|
("MISTRAL_BASE_URL"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_mistralai_initialization_baseurl_env(env_var_name: str) -> None:
|
||||||
|
"""Test ChatMistralAI initialization."""
|
||||||
|
# Verify that ChatMistralAI can be initialized using env variable
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ[env_var_name] = "boo"
|
||||||
|
model = ChatMistralAI(model="test") # type: ignore[call-arg]
|
||||||
|
assert model.endpoint == "boo"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("message", "expected"),
|
("message", "expected"),
|
||||||
[
|
[
|
||||||
|
Loading…
Reference in New Issue
Block a user