mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
mistralai[minor]: 0.1.0rc0, remove mistral sdk (#19420)
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
"""Test MistralAI Chat API wrapper."""
|
||||
|
||||
import os
|
||||
from typing import Any, AsyncGenerator, Generator
|
||||
from typing import Any, AsyncGenerator, Dict, Generator
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@@ -13,16 +14,6 @@ from langchain_core.messages import (
|
||||
SystemMessage,
|
||||
)
|
||||
|
||||
# TODO: Remove 'type: ignore' once mistralai has stubs or py.typed marker.
|
||||
from mistralai.models.chat_completion import ( # type: ignore[import]
|
||||
ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse,
|
||||
DeltaMessage,
|
||||
)
|
||||
from mistralai.models.chat_completion import (
|
||||
ChatMessage as MistralChatMessage,
|
||||
)
|
||||
|
||||
from langchain_mistralai.chat_models import ( # type: ignore[import]
|
||||
ChatMistralAI,
|
||||
_convert_message_to_mistral_chat_message,
|
||||
@@ -31,13 +22,11 @@ from langchain_mistralai.chat_models import ( # type: ignore[import]
|
||||
os.environ["MISTRAL_API_KEY"] = "foo"
|
||||
|
||||
|
||||
@pytest.mark.requires("mistralai")
|
||||
def test_mistralai_model_param() -> None:
|
||||
llm = ChatMistralAI(model="foo")
|
||||
assert llm.model == "foo"
|
||||
|
||||
|
||||
@pytest.mark.requires("mistralai")
|
||||
def test_mistralai_initialization() -> None:
|
||||
"""Test ChatMistralAI initialization."""
|
||||
# Verify that ChatMistralAI can be initialized using a secret key provided
|
||||
@@ -50,37 +39,37 @@ def test_mistralai_initialization() -> None:
|
||||
[
|
||||
(
|
||||
SystemMessage(content="Hello"),
|
||||
MistralChatMessage(role="system", content="Hello"),
|
||||
dict(role="system", content="Hello"),
|
||||
),
|
||||
(
|
||||
HumanMessage(content="Hello"),
|
||||
MistralChatMessage(role="user", content="Hello"),
|
||||
dict(role="user", content="Hello"),
|
||||
),
|
||||
(
|
||||
AIMessage(content="Hello"),
|
||||
MistralChatMessage(role="assistant", content="Hello"),
|
||||
dict(role="assistant", content="Hello", tool_calls=None),
|
||||
),
|
||||
(
|
||||
ChatMessage(role="assistant", content="Hello"),
|
||||
MistralChatMessage(role="assistant", content="Hello"),
|
||||
dict(role="assistant", content="Hello"),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_convert_message_to_mistral_chat_message(
|
||||
message: BaseMessage, expected: MistralChatMessage
|
||||
message: BaseMessage, expected: Dict
|
||||
) -> None:
|
||||
result = _convert_message_to_mistral_chat_message(message)
|
||||
assert result == expected
|
||||
|
||||
|
||||
def _make_completion_response_from_token(token: str) -> ChatCompletionStreamResponse:
|
||||
return ChatCompletionStreamResponse(
|
||||
def _make_completion_response_from_token(token: str) -> Dict:
|
||||
return dict(
|
||||
id="abc123",
|
||||
model="fake_model",
|
||||
choices=[
|
||||
ChatCompletionResponseStreamChoice(
|
||||
dict(
|
||||
index=0,
|
||||
delta=DeltaMessage(content=token),
|
||||
delta=dict(content=token),
|
||||
finish_reason=None,
|
||||
)
|
||||
],
|
||||
@@ -88,13 +77,19 @@ def _make_completion_response_from_token(token: str) -> ChatCompletionStreamResp
|
||||
|
||||
|
||||
def mock_chat_stream(*args: Any, **kwargs: Any) -> Generator:
|
||||
for token in ["Hello", " how", " can", " I", " help", "?"]:
|
||||
yield _make_completion_response_from_token(token)
|
||||
def it() -> Generator:
|
||||
for token in ["Hello", " how", " can", " I", " help", "?"]:
|
||||
yield _make_completion_response_from_token(token)
|
||||
|
||||
return it()
|
||||
|
||||
|
||||
async def mock_chat_astream(*args: Any, **kwargs: Any) -> AsyncGenerator:
|
||||
for token in ["Hello", " how", " can", " I", " help", "?"]:
|
||||
yield _make_completion_response_from_token(token)
|
||||
async def it() -> AsyncGenerator:
|
||||
for token in ["Hello", " how", " can", " I", " help", "?"]:
|
||||
yield _make_completion_response_from_token(token)
|
||||
|
||||
return it()
|
||||
|
||||
|
||||
class MyCustomHandler(BaseCallbackHandler):
|
||||
@@ -104,7 +99,10 @@ class MyCustomHandler(BaseCallbackHandler):
|
||||
self.last_token = token
|
||||
|
||||
|
||||
@patch("mistralai.client.MistralClient.chat_stream", new=mock_chat_stream)
|
||||
@patch(
|
||||
"langchain_mistralai.chat_models.ChatMistralAI.completion_with_retry",
|
||||
new=mock_chat_stream,
|
||||
)
|
||||
def test_stream_with_callback() -> None:
|
||||
callback = MyCustomHandler()
|
||||
chat = ChatMistralAI(callbacks=[callback])
|
||||
@@ -112,7 +110,7 @@ def test_stream_with_callback() -> None:
|
||||
assert callback.last_token == token.content
|
||||
|
||||
|
||||
@patch("mistralai.async_client.MistralAsyncClient.chat_stream", new=mock_chat_astream)
|
||||
@patch("langchain_mistralai.chat_models.acompletion_with_retry", new=mock_chat_astream)
|
||||
async def test_astream_with_callback() -> None:
|
||||
callback = MyCustomHandler()
|
||||
chat = ChatMistralAI(callbacks=[callback])
|
||||
|
||||
Reference in New Issue
Block a user