From 5b1de2ae93d49258762c7bc87a996d30cbe56d5e Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Mon, 1 Jul 2024 22:53:09 +0200 Subject: [PATCH] mistralai: Fixed streaming in MistralAI with ainvoke and callbacks (#22000) # Fix streaming in mistral with ainvoke - [x] **PR title** - [x] **PR message** - [x] **Add tests and docs**: 1. [x] Added a test for the fixed integration. 2. [x] An example notebook showing its use. It lives in `docs/docs/integrations` directory. - [x] **Lint and test**: Ran `make format`, `make lint` and `make test` from the root of the package(s) I've modified. Hello * I Identified an issue in the mistral package where the callback streaming (see on_llm_new_token) was not functioning correctly when the streaming parameter was set to True and call with `ainvoke`. * The root cause of the problem was the streaming not taking into account. ( I think it's an oversight ) * To resolve the issue, I added the `streaming` attribut. * Now, the callback with streaming works as expected when the streaming parameter is set to True. ## How to reproduce ``` from langchain_mistralai.chat_models import ChatMistralAI chain = ChatMistralAI(streaming=True) # Add a callback chain.ainvoke(..) # Oberve on_llm_new_token # Now, the callback is given as streaming tokens, before it was in grouped format. ``` Co-authored-by: Erick Friis --- libs/partners/mistralai/langchain_mistralai/chat_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/partners/mistralai/langchain_mistralai/chat_models.py b/libs/partners/mistralai/langchain_mistralai/chat_models.py index 5892aeeeb08..4952f088d50 100644 --- a/libs/partners/mistralai/langchain_mistralai/chat_models.py +++ b/libs/partners/mistralai/langchain_mistralai/chat_models.py @@ -586,7 +586,7 @@ class ChatMistralAI(BaseChatModel): stream: Optional[bool] = None, **kwargs: Any, ) -> ChatResult: - should_stream = stream if stream is not None else False + should_stream = stream if stream is not None else self.streaming if should_stream: stream_iter = self._astream( messages=messages, stop=stop, run_manager=run_manager, **kwargs