mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-14 08:56:27 +00:00
langchain_mistralai[patch]: Invoke callback prior to yielding token (#16986)
- **Description:** Invoke callback prior to yielding token in stream and astream methods for ChatMistralAI. - **Issue:** https://github.com/langchain-ai/langchain/issues/16913
This commit is contained in:
parent
267e71606e
commit
0826d87ecd
@ -317,9 +317,9 @@ class ChatMistralAI(BaseChatModel):
|
|||||||
continue
|
continue
|
||||||
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
|
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
|
||||||
default_chunk_class = chunk.__class__
|
default_chunk_class = chunk.__class__
|
||||||
yield ChatGenerationChunk(message=chunk)
|
|
||||||
if run_manager:
|
if run_manager:
|
||||||
run_manager.on_llm_new_token(token=chunk.content, chunk=chunk)
|
run_manager.on_llm_new_token(token=chunk.content, chunk=chunk)
|
||||||
|
yield ChatGenerationChunk(message=chunk)
|
||||||
|
|
||||||
async def _astream(
|
async def _astream(
|
||||||
self,
|
self,
|
||||||
@ -342,9 +342,9 @@ class ChatMistralAI(BaseChatModel):
|
|||||||
continue
|
continue
|
||||||
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
|
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
|
||||||
default_chunk_class = chunk.__class__
|
default_chunk_class = chunk.__class__
|
||||||
yield ChatGenerationChunk(message=chunk)
|
|
||||||
if run_manager:
|
if run_manager:
|
||||||
await run_manager.on_llm_new_token(token=chunk.content, chunk=chunk)
|
await run_manager.on_llm_new_token(token=chunk.content, chunk=chunk)
|
||||||
|
yield ChatGenerationChunk(message=chunk)
|
||||||
|
|
||||||
async def _agenerate(
|
async def _agenerate(
|
||||||
self,
|
self,
|
||||||
|
@ -1,7 +1,10 @@
|
|||||||
"""Test MistralAI Chat API wrapper."""
|
"""Test MistralAI Chat API wrapper."""
|
||||||
import os
|
import os
|
||||||
|
from typing import Any, AsyncGenerator, Generator
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from langchain_core.callbacks.base import BaseCallbackHandler
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
@ -12,6 +15,11 @@ from langchain_core.messages import (
|
|||||||
|
|
||||||
# TODO: Remove 'type: ignore' once mistralai has stubs or py.typed marker.
|
# TODO: Remove 'type: ignore' once mistralai has stubs or py.typed marker.
|
||||||
from mistralai.models.chat_completion import ( # type: ignore[import]
|
from mistralai.models.chat_completion import ( # type: ignore[import]
|
||||||
|
ChatCompletionResponseStreamChoice,
|
||||||
|
ChatCompletionStreamResponse,
|
||||||
|
DeltaMessage,
|
||||||
|
)
|
||||||
|
from mistralai.models.chat_completion import (
|
||||||
ChatMessage as MistralChatMessage,
|
ChatMessage as MistralChatMessage,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -63,3 +71,50 @@ def test_convert_message_to_mistral_chat_message(
|
|||||||
) -> None:
|
) -> None:
|
||||||
result = _convert_message_to_mistral_chat_message(message)
|
result = _convert_message_to_mistral_chat_message(message)
|
||||||
assert result == expected
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
def _make_completion_response_from_token(token: str) -> ChatCompletionStreamResponse:
|
||||||
|
return ChatCompletionStreamResponse(
|
||||||
|
id="abc123",
|
||||||
|
model="fake_model",
|
||||||
|
choices=[
|
||||||
|
ChatCompletionResponseStreamChoice(
|
||||||
|
index=0,
|
||||||
|
delta=DeltaMessage(content=token),
|
||||||
|
finish_reason=None,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def mock_chat_stream(*args: Any, **kwargs: Any) -> Generator:
|
||||||
|
for token in ["Hello", " how", " can", " I", " help", "?"]:
|
||||||
|
yield _make_completion_response_from_token(token)
|
||||||
|
|
||||||
|
|
||||||
|
async def mock_chat_astream(*args: Any, **kwargs: Any) -> AsyncGenerator:
|
||||||
|
for token in ["Hello", " how", " can", " I", " help", "?"]:
|
||||||
|
yield _make_completion_response_from_token(token)
|
||||||
|
|
||||||
|
|
||||||
|
class MyCustomHandler(BaseCallbackHandler):
|
||||||
|
last_token: str = ""
|
||||||
|
|
||||||
|
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||||
|
self.last_token = token
|
||||||
|
|
||||||
|
|
||||||
|
@patch("mistralai.client.MistralClient.chat_stream", new=mock_chat_stream)
|
||||||
|
def test_stream_with_callback() -> None:
|
||||||
|
callback = MyCustomHandler()
|
||||||
|
chat = ChatMistralAI(callbacks=[callback])
|
||||||
|
for token in chat.stream("Hello"):
|
||||||
|
assert callback.last_token == token.content
|
||||||
|
|
||||||
|
|
||||||
|
@patch("mistralai.async_client.MistralAsyncClient.chat_stream", new=mock_chat_astream)
|
||||||
|
async def test_astream_with_callback() -> None:
|
||||||
|
callback = MyCustomHandler()
|
||||||
|
chat = ChatMistralAI(callbacks=[callback])
|
||||||
|
async for token in chat.astream("Hello"):
|
||||||
|
assert callback.last_token == token.content
|
||||||
|
Loading…
Reference in New Issue
Block a user