mistralai: Add langchain-mistralai partner package (#14783)

Co-authored-by: Chad Phillips <chad@apartmentlines.com>
This commit is contained in:
Bagatur 2023-12-19 10:34:19 -05:00 committed by GitHub
parent 44cb899a93
commit a5be9f9475
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 2131 additions and 0 deletions

View File

@ -41,6 +41,7 @@ jobs:
shell: bash
env:
GOOGLE_API_KEY: ${{ secrets.GOOGLE_API_KEY }}
MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }}
run: |
make integration_tests

View File

@ -153,6 +153,7 @@ jobs:
if: ${{ startsWith(inputs.working-directory, 'libs/partners/') }}
env:
GOOGLE_API_KEY: ${{ secrets.GOOGLE_API_KEY }}
MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }}
run: make integration_tests
working-directory: ${{ inputs.working-directory }}

View File

@ -0,0 +1,152 @@
{
"cells": [
{
"cell_type": "raw",
"id": "53fbf15f",
"metadata": {},
"source": [
"---\n",
"sidebar_label: MistralAI\n",
"---"
]
},
{
"cell_type": "markdown",
"id": "bf733a38-db84-4363-89e2-de6735c37230",
"metadata": {},
"source": [
"# ChatMistralAI\n",
"\n",
"This notebook covers how to get started with MistralAI chat models, via their [API](https://docs.mistral.ai/api/).\n",
"\n",
"A valid [API key](https://console.mistral.ai/users/api-keys/) is needed to communicate with the API."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "d4a7c55d-b235-4ca4-a579-c90cc9570da9",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from langchain_core.messages import HumanMessage\n",
"from langchain_mistralai.chat_models import ChatMistralAI"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "70cf04e8-423a-4ff6-8b09-f11fb711c817",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import os\n",
"\n",
"mistral_api_key = os.environ.get(\"MISTRAL_API_KEY\")\n",
"# If mistral_api_key is not passed, default behavior is to use the `MISTRAL_API_KEY` environment variable.\n",
"chat = ChatMistralAI(mistral_api_key=mistral_api_key)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "8199ef8f-eb8b-4253-9ea0-6c24a013ca4c",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content=\"Hello! I'm here to assist you. How can I help you today? If you have any questions or need information on a particular topic, feel free to ask. I'm ready to provide accurate and helpful answers to the best of my ability.\")"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"messages = [HumanMessage(content=\"say a brief hello\")]\n",
"chat.invoke(messages)"
]
},
{
"cell_type": "markdown",
"id": "c361ab1e-8c0c-4206-9e3c-9d1424a12b9c",
"metadata": {},
"source": [
"## `ChatMistralAI` also supports async and streaming functionality:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "c5fac0e9-05a4-4fc1-a3b3-e5bbb24b971b",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content=\"Hello! I'm glad you're here. If you have any questions or need assistance with something related to programming or software development, feel free to ask. I'll do my best to help you out. Have a great day!\")"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"await chat.ainvoke(messages)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "025be980-e50d-4a68-93dc-c9c7b500ce34",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Hello! I'm happy to assist you. Is there a specific question or topic you would like to discuss? I can provide information and answer questions on a wide variety of subjects."
]
}
],
"source": [
"for chunk in chat.stream(messages):\n",
" print(chunk.content, end=\"\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

1
libs/partners/mistralai/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
__pycache__

View File

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2023 LangChain, Inc.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -0,0 +1,62 @@
.PHONY: all format lint test tests integration_tests docker_tests help extended_tests
# Default target executed when no arguments are given to make.
all: help
# Define a variable for the test file path.
TEST_FILE ?= tests/unit_tests/
INTEGRATION_TEST_FILE ?= tests/integration_tests/
integration_test integration_tests: TEST_FILE=$(INTEGRATION_TEST_FILE)
test tests:
poetry run pytest $(TEST_FILE)
integration_test integration_tests:
poetry run pytest $(TEST_FILE)
######################
# LINTING AND FORMATTING
######################
# Define a variable for Python and notebook files.
PYTHON_FILES=.
MYPY_CACHE=.mypy_cache
lint format: PYTHON_FILES=.
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/partners/mistral --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
lint_package: PYTHON_FILES=langchain_mistralai
lint_tests: PYTHON_FILES=tests
lint_tests: MYPY_CACHE=.mypy_cache_test
lint lint_diff lint_package lint_tests:
poetry run ruff .
poetry run ruff format $(PYTHON_FILES) --diff
poetry run ruff --select I $(PYTHON_FILES)
mkdir $(MYPY_CACHE); poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
format format_diff:
poetry run ruff format $(PYTHON_FILES)
poetry run ruff --select I --fix $(PYTHON_FILES)
spell_check:
poetry run codespell --toml pyproject.toml
spell_fix:
poetry run codespell --toml pyproject.toml -w
check_imports: $(shell find langchain_mistralai -name '*.py')
poetry run python ./scripts/check_imports.py $^
######################
# HELP
######################
help:
@echo '----'
@echo 'check_imports - check imports'
@echo 'format - run code formatters'
@echo 'lint - run linters'
@echo 'test - run unit tests'
@echo 'tests - run unit tests'
@echo 'test TEST_FILE=<test_file> - run all tests in file'

View File

@ -0,0 +1 @@
# langchain-mistralai

View File

@ -0,0 +1,3 @@
from langchain_mistralai.chat_models import ChatMistralAI
__all__ = ["ChatMistralAI"]

View File

@ -0,0 +1,390 @@
from __future__ import annotations
import importlib.util
import logging
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Optional,
Tuple,
Type,
Union,
)
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import (
BaseChatModel,
agenerate_from_stream,
generate_from_stream,
)
from langchain_core.language_models.llms import create_base_retry_decorator
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessage,
ChatMessageChunk,
HumanMessage,
HumanMessageChunk,
SystemMessage,
SystemMessageChunk,
)
from langchain_core.outputs import (
ChatGeneration,
ChatGenerationChunk,
ChatResult,
)
from langchain_core.pydantic_v1 import root_validator
from langchain_core.utils import get_from_dict_or_env
# TODO: Remove 'type: ignore' once mistralai has stubs or py.typed marker.
from mistralai.async_client import MistralAsyncClient # type: ignore[import]
from mistralai.client import MistralClient # type: ignore[import]
from mistralai.constants import ( # type: ignore[import]
ENDPOINT as DEFAULT_MISTRAL_ENDPOINT,
)
from mistralai.exceptions import ( # type: ignore[import]
MistralAPIException,
MistralConnectionException,
MistralException,
)
from mistralai.models.chat_completion import ( # type: ignore[import]
ChatCompletionResponse as MistralChatCompletionResponse,
)
from mistralai.models.chat_completion import ( # type: ignore[import]
ChatMessage as MistralChatMessage,
)
from mistralai.models.chat_completion import ( # type: ignore[import]
DeltaMessage as MistralDeltaMessage,
)
logger = logging.getLogger(__name__)
def _create_retry_decorator(
llm: ChatMistralAI,
run_manager: Optional[
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
] = None,
) -> Callable[[Any], Any]:
"""Returns a tenacity retry decorator, preconfigured to handle exceptions"""
errors = [
MistralException,
MistralAPIException,
MistralConnectionException,
]
return create_base_retry_decorator(
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
)
def _convert_mistral_chat_message_to_message(
_message: MistralChatMessage,
) -> BaseMessage:
role = _message.role
if role == "user":
return HumanMessage(content=_message.content)
elif role == "assistant":
return AIMessage(content=_message.content)
elif role == "system":
return SystemMessage(content=_message.content)
else:
return ChatMessage(content=_message.content, role=role)
async def acompletion_with_retry(
llm: ChatMistralAI,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the async completion call."""
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator
async def _completion_with_retry(**kwargs: Any) -> Any:
client = MistralAsyncClient(
api_key=llm.mistral_api_key,
endpoint=llm.endpoint,
max_retries=llm.max_retries,
timeout=llm.timeout,
max_concurrent_requests=llm.max_concurrent_requests,
)
stream = kwargs.pop("stream", False)
if stream:
return client.chat_stream(**kwargs)
else:
return await client.chat(**kwargs)
return await _completion_with_retry(**kwargs)
def _convert_delta_to_message_chunk(
_obj: MistralDeltaMessage, default_class: Type[BaseMessageChunk]
) -> BaseMessageChunk:
role = getattr(_obj, "role")
content = getattr(_obj, "content", "")
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk:
return AIMessageChunk(content=content)
elif role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content)
elif role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role)
else:
return default_class(content=content)
def _convert_message_to_mistral_chat_message(
message: BaseMessage,
) -> MistralChatMessage:
if isinstance(message, ChatMessage):
mistral_message = MistralChatMessage(role=message.role, content=message.content)
elif isinstance(message, HumanMessage):
mistral_message = MistralChatMessage(role="user", content=message.content)
elif isinstance(message, AIMessage):
mistral_message = MistralChatMessage(role="assistant", content=message.content)
elif isinstance(message, SystemMessage):
mistral_message = MistralChatMessage(role="system", content=message.content)
else:
raise ValueError(f"Got unknown type {message}")
return mistral_message
class ChatMistralAI(BaseChatModel):
"""A chat model that uses the MistralAI API."""
client: Any #: :meta private:
mistral_api_key: Optional[str] = None
endpoint: str = DEFAULT_MISTRAL_ENDPOINT
max_retries: int = 5
timeout: int = 120
max_concurrent_requests: int = 64
model: str = "mistral-small"
temperature: float = 0.7
max_tokens: Optional[int] = None
top_p: float = 1
"""Decode using nucleus sampling: consider the smallest set of tokens whose
probability sum is at least top_p. Must be in the closed interval [0.0, 1.0]."""
random_seed: Optional[int] = None
safe_mode: bool = False
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling the API."""
defaults = {
"model": self.model,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"random_seed": self.random_seed,
"safe_mode": self.safe_mode,
}
filtered = {k: v for k, v in defaults.items() if v is not None}
return filtered
@property
def _client_params(self) -> Dict[str, Any]:
"""Get the parameters used for the client."""
return self._default_params
def completion_with_retry(
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(self, run_manager=run_manager)
@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
stream = kwargs.pop("stream", False)
if stream:
return self.client.chat_stream(**kwargs)
else:
return self.client.chat(**kwargs)
return _completion_with_retry(**kwargs)
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate api key, python package exists, temperature, and top_p."""
mistralai_spec = importlib.util.find_spec("mistralai")
if mistralai_spec is None:
raise MistralException(
"Could not find mistralai python package. "
"Please install it with `pip install mistralai`"
)
values["mistral_api_key"] = get_from_dict_or_env(
values, "mistral_api_key", "MISTRAL_API_KEY", default=""
)
values["client"] = MistralClient(
api_key=values["mistral_api_key"],
endpoint=values["endpoint"],
max_retries=values["max_retries"],
timeout=values["timeout"],
)
if values["temperature"] is not None and not 0 <= values["temperature"] <= 1:
raise ValueError("temperature must be in the range [0.0, 1.0]")
if values["top_p"] is not None and not 0 <= values["top_p"] <= 1:
raise ValueError("top_p must be in the range [0.0, 1.0]")
return values
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> ChatResult:
should_stream = stream if stream is not None else False
if should_stream:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
response = self.completion_with_retry(
messages=message_dicts, run_manager=run_manager, **params
)
return self._create_chat_result(response)
def _create_chat_result(
self, response: MistralChatCompletionResponse
) -> ChatResult:
generations = []
for res in response.choices:
finish_reason = getattr(res, "finish_reason")
if finish_reason:
finish_reason = finish_reason.value
gen = ChatGeneration(
message=_convert_mistral_chat_message_to_message(res.message),
generation_info={"finish_reason": finish_reason},
)
generations.append(gen)
token_usage = getattr(response, "usage")
token_usage = vars(token_usage) if token_usage else {}
llm_output = {"token_usage": token_usage, "model": self.model}
return ChatResult(generations=generations, llm_output=llm_output)
def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]]
) -> Tuple[List[MistralChatMessage], Dict[str, Any]]:
params = self._client_params
if stop is not None:
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop
message_dicts = [_convert_message_to_mistral_chat_message(m) for m in messages]
return message_dicts, params
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
default_chunk_class = AIMessageChunk
for chunk in self.completion_with_retry(
messages=message_dicts, run_manager=run_manager, **params
):
if len(chunk.choices) == 0:
continue
delta = chunk.choices[0].delta
if not delta.content:
continue
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
default_chunk_class = chunk.__class__
yield ChatGenerationChunk(message=chunk)
if run_manager:
run_manager.on_llm_new_token(chunk.content)
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
default_chunk_class = AIMessageChunk
async for chunk in await acompletion_with_retry(
self, messages=message_dicts, run_manager=run_manager, **params
):
if len(chunk.choices) == 0:
continue
delta = chunk.choices[0].delta
if not delta.content:
continue
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
default_chunk_class = chunk.__class__
yield ChatGenerationChunk(message=chunk)
if run_manager:
await run_manager.on_llm_new_token(chunk.content)
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> ChatResult:
should_stream = stream if stream is not None else False
if should_stream:
stream_iter = self._astream(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
response = await acompletion_with_retry(
self, messages=message_dicts, run_manager=run_manager, **params
)
return self._create_chat_result(response)
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return self._default_params
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "mistralai-chat"
@property
def lc_secrets(self) -> Dict[str, str]:
return {"mistral_api_key": "MISTRAL_API_KEY"}
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this model can be serialized by Langchain."""
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "chat_models", "mistralai"]

1212
libs/partners/mistralai/poetry.lock generated Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,83 @@
[tool.poetry]
name = "langchain-mistralai"
version = "0.0.1"
description = "An integration package connecting Mistral and LangChain"
authors = []
readme = "README.md"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
langchain-core = "^0.1"
mistralai = "^0.0.8"
[tool.poetry.group.test]
optional = true
[tool.poetry.group.test.dependencies]
pytest = "^7.3.0"
pytest-asyncio = "^0.21.1"
langchain-core = {path = "../../core", develop = true}
[tool.poetry.group.codespell]
optional = true
[tool.poetry.group.codespell.dependencies]
codespell = "^2.2.0"
[tool.poetry.group.test_integration]
optional = true
[tool.poetry.group.test_integration.dependencies]
[tool.poetry.group.lint]
optional = true
[tool.poetry.group.lint.dependencies]
ruff = "^0.1.5"
[tool.poetry.group.typing.dependencies]
mypy = "^0.991"
langchain-core = {path = "../../core", develop = true}
[tool.poetry.group.dev]
optional = true
[tool.poetry.group.dev.dependencies]
langchain-core = {path = "../../core", develop = true}
[tool.ruff]
select = [
"E", # pycodestyle
"F", # pyflakes
"I", # isort
]
[tool.mypy]
disallow_untyped_defs = "True"
[tool.coverage.run]
omit = [
"tests/*",
]
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
[tool.pytest.ini_options]
# --strict-markers will raise errors on unknown marks.
# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
#
# https://docs.pytest.org/en/7.1.x/reference/reference.html
# --strict-config any warnings encountered while parsing the `pytest`
# section of the configuration file raise errors.
#
addopts = "--strict-markers --strict-config --durations=5"
# Registering custom markers.
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
markers = [
"requires: mark tests as requiring a specific library",
"asyncio: mark tests as requiring asyncio",
"compile: mark placeholder test used to compile integration tests without running them",
]
asyncio_mode = "auto"

View File

@ -0,0 +1,17 @@
import sys
import traceback
from importlib.machinery import SourceFileLoader
if __name__ == "__main__":
files = sys.argv[1:]
has_failure = False
for file in files:
try:
SourceFileLoader("x", file).load_module()
except Exception:
has_faillure = True
print(file)
traceback.print_exc()
print()
sys.exit(1 if has_failure else 0)

View File

@ -0,0 +1,27 @@
#!/bin/bash
#
# This script searches for lines starting with "import pydantic" or "from pydantic"
# in tracked files within a Git repository.
#
# Usage: ./scripts/check_pydantic.sh /path/to/repository
# Check if a path argument is provided
if [ $# -ne 1 ]; then
echo "Usage: $0 /path/to/repository"
exit 1
fi
repository_path="$1"
# Search for lines matching the pattern within the specified repository
result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic')
# Check if any matching lines were found
if [ -n "$result" ]; then
echo "ERROR: The following lines need to be updated:"
echo "$result"
echo "Please replace the code with an import from langchain_core.pydantic_v1."
echo "For example, replace 'from pydantic import BaseModel'"
echo "with 'from langchain_core.pydantic_v1 import BaseModel'"
exit 1
fi

View File

@ -0,0 +1,18 @@
#!/bin/bash
set -eu
# Initialize a variable to keep track of errors
errors=0
# make sure not importing from langchain or langchain_experimental
git --no-pager grep '^from langchain\.' . && errors=$((errors+1))
git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1))
git --no-pager grep '^from langchain_community\.' . && errors=$((errors+1))
# Decide on an exit status based on the errors
if [ "$errors" -gt 0 ]; then
exit 1
else
exit 0
fi

View File

@ -0,0 +1,63 @@
"""Test ChatMistral chat model."""
from langchain_mistralai.chat_models import ChatMistralAI
def test_stream() -> None:
"""Test streaming tokens from ChatMistralAI."""
llm = ChatMistralAI()
for token in llm.stream("I'm Pickle Rick"):
assert isinstance(token.content, str)
async def test_astream() -> None:
"""Test streaming tokens from ChatMistralAI."""
llm = ChatMistralAI()
async for token in llm.astream("I'm Pickle Rick"):
assert isinstance(token.content, str)
async def test_abatch() -> None:
"""Test streaming tokens from ChatMistralAI"""
llm = ChatMistralAI()
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
for token in result:
assert isinstance(token.content, str)
async def test_abatch_tags() -> None:
"""Test batch tokens from ChatMistralAI"""
llm = ChatMistralAI()
result = await llm.abatch(
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
)
for token in result:
assert isinstance(token.content, str)
def test_batch() -> None:
"""Test batch tokens from ChatMistralAI"""
llm = ChatMistralAI()
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
for token in result:
assert isinstance(token.content, str)
async def test_ainvoke() -> None:
"""Test invoke tokens from ChatMistralAI"""
llm = ChatMistralAI()
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
assert isinstance(result.content, str)
def test_invoke() -> None:
"""Test invoke tokens from ChatMistralAI"""
llm = ChatMistralAI()
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
assert isinstance(result.content, str)

View File

@ -0,0 +1,7 @@
import pytest
@pytest.mark.compile
def test_placeholder() -> None:
"""Used for compiling integration tests without running any real tests."""
pass

View File

@ -0,0 +1,65 @@
"""Test MistralAI Chat API wrapper."""
import os
import pytest
from langchain_core.messages import (
AIMessage,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
)
# TODO: Remove 'type: ignore' once mistralai has stubs or py.typed marker.
from mistralai.models.chat_completion import ( # type: ignore[import]
ChatMessage as MistralChatMessage,
)
from langchain_mistralai.chat_models import ( # type: ignore[import]
ChatMistralAI,
_convert_message_to_mistral_chat_message,
)
os.environ["MISTRAL_API_KEY"] = "foo"
@pytest.mark.requires("mistralai")
def test_mistralai_model_param() -> None:
llm = ChatMistralAI(model="foo")
assert llm.model == "foo"
@pytest.mark.requires("mistralai")
def test_mistralai_initialization() -> None:
"""Test ChatMistralAI initialization."""
# Verify that ChatMistralAI can be initialized using a secret key provided
# as a parameter rather than an environment variable.
ChatMistralAI(model="test", mistral_api_key="test")
@pytest.mark.parametrize(
("message", "expected"),
[
(
SystemMessage(content="Hello"),
MistralChatMessage(role="system", content="Hello"),
),
(
HumanMessage(content="Hello"),
MistralChatMessage(role="user", content="Hello"),
),
(
AIMessage(content="Hello"),
MistralChatMessage(role="assistant", content="Hello"),
),
(
ChatMessage(role="assistant", content="Hello"),
MistralChatMessage(role="assistant", content="Hello"),
),
],
)
def test_convert_message_to_mistral_chat_message(
message: BaseMessage, expected: MistralChatMessage
) -> None:
result = _convert_message_to_mistral_chat_message(message)
assert result == expected

View File

@ -0,0 +1,7 @@
from langchain_mistralai import __all__
EXPECTED_ALL = ["ChatMistralAI"]
def test_all_imports() -> None:
assert sorted(EXPECTED_ALL) == sorted(__all__)