mistralai[minor]: 0.1.0rc0, remove mistral sdk (#19420)

This commit is contained in:
Erick Friis
2024-03-21 18:24:58 -07:00
committed by GitHub
parent e980c14d6a
commit 53ac1ebbbc
5 changed files with 237 additions and 448 deletions

View File

@@ -1,6 +1,7 @@
"""Test MistralAI Chat API wrapper."""
import os
from typing import Any, AsyncGenerator, Generator
from typing import Any, AsyncGenerator, Dict, Generator
from unittest.mock import patch
import pytest
@@ -13,16 +14,6 @@ from langchain_core.messages import (
SystemMessage,
)
# TODO: Remove 'type: ignore' once mistralai has stubs or py.typed marker.
from mistralai.models.chat_completion import ( # type: ignore[import]
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
DeltaMessage,
)
from mistralai.models.chat_completion import (
ChatMessage as MistralChatMessage,
)
from langchain_mistralai.chat_models import ( # type: ignore[import]
ChatMistralAI,
_convert_message_to_mistral_chat_message,
@@ -31,13 +22,11 @@ from langchain_mistralai.chat_models import ( # type: ignore[import]
os.environ["MISTRAL_API_KEY"] = "foo"
@pytest.mark.requires("mistralai")
def test_mistralai_model_param() -> None:
llm = ChatMistralAI(model="foo")
assert llm.model == "foo"
@pytest.mark.requires("mistralai")
def test_mistralai_initialization() -> None:
"""Test ChatMistralAI initialization."""
# Verify that ChatMistralAI can be initialized using a secret key provided
@@ -50,37 +39,37 @@ def test_mistralai_initialization() -> None:
[
(
SystemMessage(content="Hello"),
MistralChatMessage(role="system", content="Hello"),
dict(role="system", content="Hello"),
),
(
HumanMessage(content="Hello"),
MistralChatMessage(role="user", content="Hello"),
dict(role="user", content="Hello"),
),
(
AIMessage(content="Hello"),
MistralChatMessage(role="assistant", content="Hello"),
dict(role="assistant", content="Hello", tool_calls=None),
),
(
ChatMessage(role="assistant", content="Hello"),
MistralChatMessage(role="assistant", content="Hello"),
dict(role="assistant", content="Hello"),
),
],
)
def test_convert_message_to_mistral_chat_message(
message: BaseMessage, expected: MistralChatMessage
message: BaseMessage, expected: Dict
) -> None:
result = _convert_message_to_mistral_chat_message(message)
assert result == expected
def _make_completion_response_from_token(token: str) -> ChatCompletionStreamResponse:
return ChatCompletionStreamResponse(
def _make_completion_response_from_token(token: str) -> Dict:
return dict(
id="abc123",
model="fake_model",
choices=[
ChatCompletionResponseStreamChoice(
dict(
index=0,
delta=DeltaMessage(content=token),
delta=dict(content=token),
finish_reason=None,
)
],
@@ -88,13 +77,19 @@ def _make_completion_response_from_token(token: str) -> ChatCompletionStreamResp
def mock_chat_stream(*args: Any, **kwargs: Any) -> Generator:
for token in ["Hello", " how", " can", " I", " help", "?"]:
yield _make_completion_response_from_token(token)
def it() -> Generator:
for token in ["Hello", " how", " can", " I", " help", "?"]:
yield _make_completion_response_from_token(token)
return it()
async def mock_chat_astream(*args: Any, **kwargs: Any) -> AsyncGenerator:
for token in ["Hello", " how", " can", " I", " help", "?"]:
yield _make_completion_response_from_token(token)
async def it() -> AsyncGenerator:
for token in ["Hello", " how", " can", " I", " help", "?"]:
yield _make_completion_response_from_token(token)
return it()
class MyCustomHandler(BaseCallbackHandler):
@@ -104,7 +99,10 @@ class MyCustomHandler(BaseCallbackHandler):
self.last_token = token
@patch("mistralai.client.MistralClient.chat_stream", new=mock_chat_stream)
@patch(
"langchain_mistralai.chat_models.ChatMistralAI.completion_with_retry",
new=mock_chat_stream,
)
def test_stream_with_callback() -> None:
callback = MyCustomHandler()
chat = ChatMistralAI(callbacks=[callback])
@@ -112,7 +110,7 @@ def test_stream_with_callback() -> None:
assert callback.last_token == token.content
@patch("mistralai.async_client.MistralAsyncClient.chat_stream", new=mock_chat_astream)
@patch("langchain_mistralai.chat_models.acompletion_with_retry", new=mock_chat_astream)
async def test_astream_with_callback() -> None:
callback = MyCustomHandler()
chat = ChatMistralAI(callbacks=[callback])