mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 07:35:18 +00:00
mistralai: Add langchain-mistralai partner package (#14783)
Co-authored-by: Chad Phillips <chad@apartmentlines.com>
This commit is contained in:
parent
44cb899a93
commit
a5be9f9475
1
.github/workflows/_integration_test.yml
vendored
1
.github/workflows/_integration_test.yml
vendored
@ -41,6 +41,7 @@ jobs:
|
||||
shell: bash
|
||||
env:
|
||||
GOOGLE_API_KEY: ${{ secrets.GOOGLE_API_KEY }}
|
||||
MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }}
|
||||
run: |
|
||||
make integration_tests
|
||||
|
||||
|
1
.github/workflows/_release.yml
vendored
1
.github/workflows/_release.yml
vendored
@ -153,6 +153,7 @@ jobs:
|
||||
if: ${{ startsWith(inputs.working-directory, 'libs/partners/') }}
|
||||
env:
|
||||
GOOGLE_API_KEY: ${{ secrets.GOOGLE_API_KEY }}
|
||||
MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }}
|
||||
run: make integration_tests
|
||||
working-directory: ${{ inputs.working-directory }}
|
||||
|
||||
|
152
docs/docs/integrations/chat/mistralai.ipynb
Normal file
152
docs/docs/integrations/chat/mistralai.ipynb
Normal file
@ -0,0 +1,152 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "raw",
|
||||
"id": "53fbf15f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"---\n",
|
||||
"sidebar_label: MistralAI\n",
|
||||
"---"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "bf733a38-db84-4363-89e2-de6735c37230",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# ChatMistralAI\n",
|
||||
"\n",
|
||||
"This notebook covers how to get started with MistralAI chat models, via their [API](https://docs.mistral.ai/api/).\n",
|
||||
"\n",
|
||||
"A valid [API key](https://console.mistral.ai/users/api-keys/) is needed to communicate with the API."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "d4a7c55d-b235-4ca4-a579-c90cc9570da9",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_core.messages import HumanMessage\n",
|
||||
"from langchain_mistralai.chat_models import ChatMistralAI"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "70cf04e8-423a-4ff6-8b09-f11fb711c817",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"mistral_api_key = os.environ.get(\"MISTRAL_API_KEY\")\n",
|
||||
"# If mistral_api_key is not passed, default behavior is to use the `MISTRAL_API_KEY` environment variable.\n",
|
||||
"chat = ChatMistralAI(mistral_api_key=mistral_api_key)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "8199ef8f-eb8b-4253-9ea0-6c24a013ca4c",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content=\"Hello! I'm here to assist you. How can I help you today? If you have any questions or need information on a particular topic, feel free to ask. I'm ready to provide accurate and helpful answers to the best of my ability.\")"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"messages = [HumanMessage(content=\"say a brief hello\")]\n",
|
||||
"chat.invoke(messages)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c361ab1e-8c0c-4206-9e3c-9d1424a12b9c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## `ChatMistralAI` also supports async and streaming functionality:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "c5fac0e9-05a4-4fc1-a3b3-e5bbb24b971b",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content=\"Hello! I'm glad you're here. If you have any questions or need assistance with something related to programming or software development, feel free to ask. I'll do my best to help you out. Have a great day!\")"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"await chat.ainvoke(messages)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "025be980-e50d-4a68-93dc-c9c7b500ce34",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Hello! I'm happy to assist you. Is there a specific question or topic you would like to discuss? I can provide information and answer questions on a wide variety of subjects."
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for chunk in chat.stream(messages):\n",
|
||||
" print(chunk.content, end=\"\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
1
libs/partners/mistralai/.gitignore
vendored
Normal file
1
libs/partners/mistralai/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
__pycache__
|
21
libs/partners/mistralai/LICENSE
Normal file
21
libs/partners/mistralai/LICENSE
Normal file
@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 LangChain, Inc.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
62
libs/partners/mistralai/Makefile
Normal file
62
libs/partners/mistralai/Makefile
Normal file
@ -0,0 +1,62 @@
|
||||
.PHONY: all format lint test tests integration_tests docker_tests help extended_tests
|
||||
|
||||
# Default target executed when no arguments are given to make.
|
||||
all: help
|
||||
|
||||
# Define a variable for the test file path.
|
||||
TEST_FILE ?= tests/unit_tests/
|
||||
INTEGRATION_TEST_FILE ?= tests/integration_tests/
|
||||
|
||||
integration_test integration_tests: TEST_FILE=$(INTEGRATION_TEST_FILE)
|
||||
|
||||
test tests:
|
||||
poetry run pytest $(TEST_FILE)
|
||||
|
||||
integration_test integration_tests:
|
||||
poetry run pytest $(TEST_FILE)
|
||||
|
||||
|
||||
######################
|
||||
# LINTING AND FORMATTING
|
||||
######################
|
||||
|
||||
# Define a variable for Python and notebook files.
|
||||
PYTHON_FILES=.
|
||||
MYPY_CACHE=.mypy_cache
|
||||
lint format: PYTHON_FILES=.
|
||||
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/partners/mistral --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
|
||||
lint_package: PYTHON_FILES=langchain_mistralai
|
||||
lint_tests: PYTHON_FILES=tests
|
||||
lint_tests: MYPY_CACHE=.mypy_cache_test
|
||||
|
||||
lint lint_diff lint_package lint_tests:
|
||||
poetry run ruff .
|
||||
poetry run ruff format $(PYTHON_FILES) --diff
|
||||
poetry run ruff --select I $(PYTHON_FILES)
|
||||
mkdir $(MYPY_CACHE); poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
|
||||
|
||||
format format_diff:
|
||||
poetry run ruff format $(PYTHON_FILES)
|
||||
poetry run ruff --select I --fix $(PYTHON_FILES)
|
||||
|
||||
spell_check:
|
||||
poetry run codespell --toml pyproject.toml
|
||||
|
||||
spell_fix:
|
||||
poetry run codespell --toml pyproject.toml -w
|
||||
|
||||
check_imports: $(shell find langchain_mistralai -name '*.py')
|
||||
poetry run python ./scripts/check_imports.py $^
|
||||
|
||||
######################
|
||||
# HELP
|
||||
######################
|
||||
|
||||
help:
|
||||
@echo '----'
|
||||
@echo 'check_imports - check imports'
|
||||
@echo 'format - run code formatters'
|
||||
@echo 'lint - run linters'
|
||||
@echo 'test - run unit tests'
|
||||
@echo 'tests - run unit tests'
|
||||
@echo 'test TEST_FILE=<test_file> - run all tests in file'
|
1
libs/partners/mistralai/README.md
Normal file
1
libs/partners/mistralai/README.md
Normal file
@ -0,0 +1 @@
|
||||
# langchain-mistralai
|
3
libs/partners/mistralai/langchain_mistralai/__init__.py
Normal file
3
libs/partners/mistralai/langchain_mistralai/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from langchain_mistralai.chat_models import ChatMistralAI
|
||||
|
||||
__all__ = ["ChatMistralAI"]
|
390
libs/partners/mistralai/langchain_mistralai/chat_models.py
Normal file
390
libs/partners/mistralai/langchain_mistralai/chat_models.py
Normal file
@ -0,0 +1,390 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import logging
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import (
|
||||
BaseChatModel,
|
||||
agenerate_from_stream,
|
||||
generate_from_stream,
|
||||
)
|
||||
from langchain_core.language_models.llms import create_base_retry_decorator
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
ChatMessage,
|
||||
ChatMessageChunk,
|
||||
HumanMessage,
|
||||
HumanMessageChunk,
|
||||
SystemMessage,
|
||||
SystemMessageChunk,
|
||||
)
|
||||
from langchain_core.outputs import (
|
||||
ChatGeneration,
|
||||
ChatGenerationChunk,
|
||||
ChatResult,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
|
||||
# TODO: Remove 'type: ignore' once mistralai has stubs or py.typed marker.
|
||||
from mistralai.async_client import MistralAsyncClient # type: ignore[import]
|
||||
from mistralai.client import MistralClient # type: ignore[import]
|
||||
from mistralai.constants import ( # type: ignore[import]
|
||||
ENDPOINT as DEFAULT_MISTRAL_ENDPOINT,
|
||||
)
|
||||
from mistralai.exceptions import ( # type: ignore[import]
|
||||
MistralAPIException,
|
||||
MistralConnectionException,
|
||||
MistralException,
|
||||
)
|
||||
from mistralai.models.chat_completion import ( # type: ignore[import]
|
||||
ChatCompletionResponse as MistralChatCompletionResponse,
|
||||
)
|
||||
from mistralai.models.chat_completion import ( # type: ignore[import]
|
||||
ChatMessage as MistralChatMessage,
|
||||
)
|
||||
from mistralai.models.chat_completion import ( # type: ignore[import]
|
||||
DeltaMessage as MistralDeltaMessage,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _create_retry_decorator(
|
||||
llm: ChatMistralAI,
|
||||
run_manager: Optional[
|
||||
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
|
||||
] = None,
|
||||
) -> Callable[[Any], Any]:
|
||||
"""Returns a tenacity retry decorator, preconfigured to handle exceptions"""
|
||||
|
||||
errors = [
|
||||
MistralException,
|
||||
MistralAPIException,
|
||||
MistralConnectionException,
|
||||
]
|
||||
return create_base_retry_decorator(
|
||||
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
|
||||
)
|
||||
|
||||
|
||||
def _convert_mistral_chat_message_to_message(
|
||||
_message: MistralChatMessage,
|
||||
) -> BaseMessage:
|
||||
role = _message.role
|
||||
if role == "user":
|
||||
return HumanMessage(content=_message.content)
|
||||
elif role == "assistant":
|
||||
return AIMessage(content=_message.content)
|
||||
elif role == "system":
|
||||
return SystemMessage(content=_message.content)
|
||||
else:
|
||||
return ChatMessage(content=_message.content, role=role)
|
||||
|
||||
|
||||
async def acompletion_with_retry(
|
||||
llm: ChatMistralAI,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Use tenacity to retry the async completion call."""
|
||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
async def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
client = MistralAsyncClient(
|
||||
api_key=llm.mistral_api_key,
|
||||
endpoint=llm.endpoint,
|
||||
max_retries=llm.max_retries,
|
||||
timeout=llm.timeout,
|
||||
max_concurrent_requests=llm.max_concurrent_requests,
|
||||
)
|
||||
stream = kwargs.pop("stream", False)
|
||||
if stream:
|
||||
return client.chat_stream(**kwargs)
|
||||
else:
|
||||
return await client.chat(**kwargs)
|
||||
|
||||
return await _completion_with_retry(**kwargs)
|
||||
|
||||
|
||||
def _convert_delta_to_message_chunk(
|
||||
_obj: MistralDeltaMessage, default_class: Type[BaseMessageChunk]
|
||||
) -> BaseMessageChunk:
|
||||
role = getattr(_obj, "role")
|
||||
content = getattr(_obj, "content", "")
|
||||
if role == "user" or default_class == HumanMessageChunk:
|
||||
return HumanMessageChunk(content=content)
|
||||
elif role == "assistant" or default_class == AIMessageChunk:
|
||||
return AIMessageChunk(content=content)
|
||||
elif role == "system" or default_class == SystemMessageChunk:
|
||||
return SystemMessageChunk(content=content)
|
||||
elif role or default_class == ChatMessageChunk:
|
||||
return ChatMessageChunk(content=content, role=role)
|
||||
else:
|
||||
return default_class(content=content)
|
||||
|
||||
|
||||
def _convert_message_to_mistral_chat_message(
|
||||
message: BaseMessage,
|
||||
) -> MistralChatMessage:
|
||||
if isinstance(message, ChatMessage):
|
||||
mistral_message = MistralChatMessage(role=message.role, content=message.content)
|
||||
elif isinstance(message, HumanMessage):
|
||||
mistral_message = MistralChatMessage(role="user", content=message.content)
|
||||
elif isinstance(message, AIMessage):
|
||||
mistral_message = MistralChatMessage(role="assistant", content=message.content)
|
||||
elif isinstance(message, SystemMessage):
|
||||
mistral_message = MistralChatMessage(role="system", content=message.content)
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
return mistral_message
|
||||
|
||||
|
||||
class ChatMistralAI(BaseChatModel):
|
||||
"""A chat model that uses the MistralAI API."""
|
||||
|
||||
client: Any #: :meta private:
|
||||
mistral_api_key: Optional[str] = None
|
||||
endpoint: str = DEFAULT_MISTRAL_ENDPOINT
|
||||
max_retries: int = 5
|
||||
timeout: int = 120
|
||||
max_concurrent_requests: int = 64
|
||||
|
||||
model: str = "mistral-small"
|
||||
temperature: float = 0.7
|
||||
max_tokens: Optional[int] = None
|
||||
top_p: float = 1
|
||||
"""Decode using nucleus sampling: consider the smallest set of tokens whose
|
||||
probability sum is at least top_p. Must be in the closed interval [0.0, 1.0]."""
|
||||
random_seed: Optional[int] = None
|
||||
safe_mode: bool = False
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling the API."""
|
||||
defaults = {
|
||||
"model": self.model,
|
||||
"temperature": self.temperature,
|
||||
"max_tokens": self.max_tokens,
|
||||
"top_p": self.top_p,
|
||||
"random_seed": self.random_seed,
|
||||
"safe_mode": self.safe_mode,
|
||||
}
|
||||
filtered = {k: v for k, v in defaults.items() if v is not None}
|
||||
return filtered
|
||||
|
||||
@property
|
||||
def _client_params(self) -> Dict[str, Any]:
|
||||
"""Get the parameters used for the client."""
|
||||
return self._default_params
|
||||
|
||||
def completion_with_retry(
|
||||
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = _create_retry_decorator(self, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
stream = kwargs.pop("stream", False)
|
||||
if stream:
|
||||
return self.client.chat_stream(**kwargs)
|
||||
else:
|
||||
return self.client.chat(**kwargs)
|
||||
|
||||
return _completion_with_retry(**kwargs)
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate api key, python package exists, temperature, and top_p."""
|
||||
mistralai_spec = importlib.util.find_spec("mistralai")
|
||||
if mistralai_spec is None:
|
||||
raise MistralException(
|
||||
"Could not find mistralai python package. "
|
||||
"Please install it with `pip install mistralai`"
|
||||
)
|
||||
|
||||
values["mistral_api_key"] = get_from_dict_or_env(
|
||||
values, "mistral_api_key", "MISTRAL_API_KEY", default=""
|
||||
)
|
||||
values["client"] = MistralClient(
|
||||
api_key=values["mistral_api_key"],
|
||||
endpoint=values["endpoint"],
|
||||
max_retries=values["max_retries"],
|
||||
timeout=values["timeout"],
|
||||
)
|
||||
|
||||
if values["temperature"] is not None and not 0 <= values["temperature"] <= 1:
|
||||
raise ValueError("temperature must be in the range [0.0, 1.0]")
|
||||
|
||||
if values["top_p"] is not None and not 0 <= values["top_p"] <= 1:
|
||||
raise ValueError("top_p must be in the range [0.0, 1.0]")
|
||||
|
||||
return values
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
stream: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
should_stream = stream if stream is not None else False
|
||||
if should_stream:
|
||||
stream_iter = self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return generate_from_stream(stream_iter)
|
||||
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs}
|
||||
response = self.completion_with_retry(
|
||||
messages=message_dicts, run_manager=run_manager, **params
|
||||
)
|
||||
return self._create_chat_result(response)
|
||||
|
||||
def _create_chat_result(
|
||||
self, response: MistralChatCompletionResponse
|
||||
) -> ChatResult:
|
||||
generations = []
|
||||
for res in response.choices:
|
||||
finish_reason = getattr(res, "finish_reason")
|
||||
if finish_reason:
|
||||
finish_reason = finish_reason.value
|
||||
gen = ChatGeneration(
|
||||
message=_convert_mistral_chat_message_to_message(res.message),
|
||||
generation_info={"finish_reason": finish_reason},
|
||||
)
|
||||
generations.append(gen)
|
||||
token_usage = getattr(response, "usage")
|
||||
token_usage = vars(token_usage) if token_usage else {}
|
||||
llm_output = {"token_usage": token_usage, "model": self.model}
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
def _create_message_dicts(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
||||
) -> Tuple[List[MistralChatMessage], Dict[str, Any]]:
|
||||
params = self._client_params
|
||||
if stop is not None:
|
||||
if "stop" in params:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
params["stop"] = stop
|
||||
message_dicts = [_convert_message_to_mistral_chat_message(m) for m in messages]
|
||||
return message_dicts, params
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
|
||||
default_chunk_class = AIMessageChunk
|
||||
for chunk in self.completion_with_retry(
|
||||
messages=message_dicts, run_manager=run_manager, **params
|
||||
):
|
||||
if len(chunk.choices) == 0:
|
||||
continue
|
||||
delta = chunk.choices[0].delta
|
||||
if not delta.content:
|
||||
continue
|
||||
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
|
||||
default_chunk_class = chunk.__class__
|
||||
yield ChatGenerationChunk(message=chunk)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(chunk.content)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
|
||||
default_chunk_class = AIMessageChunk
|
||||
async for chunk in await acompletion_with_retry(
|
||||
self, messages=message_dicts, run_manager=run_manager, **params
|
||||
):
|
||||
if len(chunk.choices) == 0:
|
||||
continue
|
||||
delta = chunk.choices[0].delta
|
||||
if not delta.content:
|
||||
continue
|
||||
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
|
||||
default_chunk_class = chunk.__class__
|
||||
yield ChatGenerationChunk(message=chunk)
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(chunk.content)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
stream: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
should_stream = stream if stream is not None else False
|
||||
if should_stream:
|
||||
stream_iter = self._astream(
|
||||
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return await agenerate_from_stream(stream_iter)
|
||||
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs}
|
||||
response = await acompletion_with_retry(
|
||||
self, messages=message_dicts, run_manager=run_manager, **params
|
||||
)
|
||||
return self._create_chat_result(response)
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return self._default_params
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
return "mistralai-chat"
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"mistral_api_key": "MISTRAL_API_KEY"}
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""Return whether this model can be serialized by Langchain."""
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "chat_models", "mistralai"]
|
1212
libs/partners/mistralai/poetry.lock
generated
Normal file
1212
libs/partners/mistralai/poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
83
libs/partners/mistralai/pyproject.toml
Normal file
83
libs/partners/mistralai/pyproject.toml
Normal file
@ -0,0 +1,83 @@
|
||||
[tool.poetry]
|
||||
name = "langchain-mistralai"
|
||||
version = "0.0.1"
|
||||
description = "An integration package connecting Mistral and LangChain"
|
||||
authors = []
|
||||
readme = "README.md"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.8.1,<4.0"
|
||||
langchain-core = "^0.1"
|
||||
mistralai = "^0.0.8"
|
||||
|
||||
[tool.poetry.group.test]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.test.dependencies]
|
||||
pytest = "^7.3.0"
|
||||
pytest-asyncio = "^0.21.1"
|
||||
langchain-core = {path = "../../core", develop = true}
|
||||
|
||||
[tool.poetry.group.codespell]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.codespell.dependencies]
|
||||
codespell = "^2.2.0"
|
||||
|
||||
[tool.poetry.group.test_integration]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.test_integration.dependencies]
|
||||
|
||||
[tool.poetry.group.lint]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.lint.dependencies]
|
||||
ruff = "^0.1.5"
|
||||
|
||||
[tool.poetry.group.typing.dependencies]
|
||||
mypy = "^0.991"
|
||||
langchain-core = {path = "../../core", develop = true}
|
||||
|
||||
[tool.poetry.group.dev]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
langchain-core = {path = "../../core", develop = true}
|
||||
|
||||
[tool.ruff]
|
||||
select = [
|
||||
"E", # pycodestyle
|
||||
"F", # pyflakes
|
||||
"I", # isort
|
||||
]
|
||||
|
||||
[tool.mypy]
|
||||
disallow_untyped_defs = "True"
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = [
|
||||
"tests/*",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
# --strict-markers will raise errors on unknown marks.
|
||||
# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
|
||||
#
|
||||
# https://docs.pytest.org/en/7.1.x/reference/reference.html
|
||||
# --strict-config any warnings encountered while parsing the `pytest`
|
||||
# section of the configuration file raise errors.
|
||||
#
|
||||
addopts = "--strict-markers --strict-config --durations=5"
|
||||
# Registering custom markers.
|
||||
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
|
||||
markers = [
|
||||
"requires: mark tests as requiring a specific library",
|
||||
"asyncio: mark tests as requiring asyncio",
|
||||
"compile: mark placeholder test used to compile integration tests without running them",
|
||||
]
|
||||
asyncio_mode = "auto"
|
17
libs/partners/mistralai/scripts/check_imports.py
Normal file
17
libs/partners/mistralai/scripts/check_imports.py
Normal file
@ -0,0 +1,17 @@
|
||||
import sys
|
||||
import traceback
|
||||
from importlib.machinery import SourceFileLoader
|
||||
|
||||
if __name__ == "__main__":
|
||||
files = sys.argv[1:]
|
||||
has_failure = False
|
||||
for file in files:
|
||||
try:
|
||||
SourceFileLoader("x", file).load_module()
|
||||
except Exception:
|
||||
has_faillure = True
|
||||
print(file)
|
||||
traceback.print_exc()
|
||||
print()
|
||||
|
||||
sys.exit(1 if has_failure else 0)
|
27
libs/partners/mistralai/scripts/check_pydantic.sh
Executable file
27
libs/partners/mistralai/scripts/check_pydantic.sh
Executable file
@ -0,0 +1,27 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# This script searches for lines starting with "import pydantic" or "from pydantic"
|
||||
# in tracked files within a Git repository.
|
||||
#
|
||||
# Usage: ./scripts/check_pydantic.sh /path/to/repository
|
||||
|
||||
# Check if a path argument is provided
|
||||
if [ $# -ne 1 ]; then
|
||||
echo "Usage: $0 /path/to/repository"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
repository_path="$1"
|
||||
|
||||
# Search for lines matching the pattern within the specified repository
|
||||
result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic')
|
||||
|
||||
# Check if any matching lines were found
|
||||
if [ -n "$result" ]; then
|
||||
echo "ERROR: The following lines need to be updated:"
|
||||
echo "$result"
|
||||
echo "Please replace the code with an import from langchain_core.pydantic_v1."
|
||||
echo "For example, replace 'from pydantic import BaseModel'"
|
||||
echo "with 'from langchain_core.pydantic_v1 import BaseModel'"
|
||||
exit 1
|
||||
fi
|
18
libs/partners/mistralai/scripts/lint_imports.sh
Executable file
18
libs/partners/mistralai/scripts/lint_imports.sh
Executable file
@ -0,0 +1,18 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -eu
|
||||
|
||||
# Initialize a variable to keep track of errors
|
||||
errors=0
|
||||
|
||||
# make sure not importing from langchain or langchain_experimental
|
||||
git --no-pager grep '^from langchain\.' . && errors=$((errors+1))
|
||||
git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1))
|
||||
git --no-pager grep '^from langchain_community\.' . && errors=$((errors+1))
|
||||
|
||||
# Decide on an exit status based on the errors
|
||||
if [ "$errors" -gt 0 ]; then
|
||||
exit 1
|
||||
else
|
||||
exit 0
|
||||
fi
|
0
libs/partners/mistralai/tests/__init__.py
Normal file
0
libs/partners/mistralai/tests/__init__.py
Normal file
@ -0,0 +1,63 @@
|
||||
"""Test ChatMistral chat model."""
|
||||
from langchain_mistralai.chat_models import ChatMistralAI
|
||||
|
||||
|
||||
def test_stream() -> None:
|
||||
"""Test streaming tokens from ChatMistralAI."""
|
||||
llm = ChatMistralAI()
|
||||
|
||||
for token in llm.stream("I'm Pickle Rick"):
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
async def test_astream() -> None:
|
||||
"""Test streaming tokens from ChatMistralAI."""
|
||||
llm = ChatMistralAI()
|
||||
|
||||
async for token in llm.astream("I'm Pickle Rick"):
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
async def test_abatch() -> None:
|
||||
"""Test streaming tokens from ChatMistralAI"""
|
||||
llm = ChatMistralAI()
|
||||
|
||||
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||
for token in result:
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
async def test_abatch_tags() -> None:
|
||||
"""Test batch tokens from ChatMistralAI"""
|
||||
llm = ChatMistralAI()
|
||||
|
||||
result = await llm.abatch(
|
||||
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
|
||||
)
|
||||
for token in result:
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
def test_batch() -> None:
|
||||
"""Test batch tokens from ChatMistralAI"""
|
||||
llm = ChatMistralAI()
|
||||
|
||||
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||
for token in result:
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
async def test_ainvoke() -> None:
|
||||
"""Test invoke tokens from ChatMistralAI"""
|
||||
llm = ChatMistralAI()
|
||||
|
||||
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
|
||||
assert isinstance(result.content, str)
|
||||
|
||||
|
||||
def test_invoke() -> None:
|
||||
"""Test invoke tokens from ChatMistralAI"""
|
||||
llm = ChatMistralAI()
|
||||
|
||||
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
|
||||
assert isinstance(result.content, str)
|
@ -0,0 +1,7 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.compile
|
||||
def test_placeholder() -> None:
|
||||
"""Used for compiling integration tests without running any real tests."""
|
||||
pass
|
65
libs/partners/mistralai/tests/unit_tests/test_chat_models.py
Normal file
65
libs/partners/mistralai/tests/unit_tests/test_chat_models.py
Normal file
@ -0,0 +1,65 @@
|
||||
"""Test MistralAI Chat API wrapper."""
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
|
||||
# TODO: Remove 'type: ignore' once mistralai has stubs or py.typed marker.
|
||||
from mistralai.models.chat_completion import ( # type: ignore[import]
|
||||
ChatMessage as MistralChatMessage,
|
||||
)
|
||||
|
||||
from langchain_mistralai.chat_models import ( # type: ignore[import]
|
||||
ChatMistralAI,
|
||||
_convert_message_to_mistral_chat_message,
|
||||
)
|
||||
|
||||
os.environ["MISTRAL_API_KEY"] = "foo"
|
||||
|
||||
|
||||
@pytest.mark.requires("mistralai")
|
||||
def test_mistralai_model_param() -> None:
|
||||
llm = ChatMistralAI(model="foo")
|
||||
assert llm.model == "foo"
|
||||
|
||||
|
||||
@pytest.mark.requires("mistralai")
|
||||
def test_mistralai_initialization() -> None:
|
||||
"""Test ChatMistralAI initialization."""
|
||||
# Verify that ChatMistralAI can be initialized using a secret key provided
|
||||
# as a parameter rather than an environment variable.
|
||||
ChatMistralAI(model="test", mistral_api_key="test")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("message", "expected"),
|
||||
[
|
||||
(
|
||||
SystemMessage(content="Hello"),
|
||||
MistralChatMessage(role="system", content="Hello"),
|
||||
),
|
||||
(
|
||||
HumanMessage(content="Hello"),
|
||||
MistralChatMessage(role="user", content="Hello"),
|
||||
),
|
||||
(
|
||||
AIMessage(content="Hello"),
|
||||
MistralChatMessage(role="assistant", content="Hello"),
|
||||
),
|
||||
(
|
||||
ChatMessage(role="assistant", content="Hello"),
|
||||
MistralChatMessage(role="assistant", content="Hello"),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_convert_message_to_mistral_chat_message(
|
||||
message: BaseMessage, expected: MistralChatMessage
|
||||
) -> None:
|
||||
result = _convert_message_to_mistral_chat_message(message)
|
||||
assert result == expected
|
7
libs/partners/mistralai/tests/unit_tests/test_imports.py
Normal file
7
libs/partners/mistralai/tests/unit_tests/test_imports.py
Normal file
@ -0,0 +1,7 @@
|
||||
from langchain_mistralai import __all__
|
||||
|
||||
EXPECTED_ALL = ["ChatMistralAI"]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert sorted(EXPECTED_ALL) == sorted(__all__)
|
Loading…
Reference in New Issue
Block a user