From 42341bc78739d8fb9172aa0bef6222eb62d4946e Mon Sep 17 00:00:00 2001 From: William De Vena <60664495+williamdevena@users.noreply.github.com> Date: Fri, 1 Mar 2024 20:46:18 +0100 Subject: [PATCH] infra: fake model invoke callback prior to yielding token (#18286) ## PR title core[patch]: Invoke callback prior to yielding ## PR message Description: Invoke on_llm_new_token callback prior to yielding token in _stream and _astream methods. Issue: https://github.com/langchain-ai/langchain/issues/16913 Dependencies: None Twitter handle: None --- libs/core/tests/unit_tests/fake/chat_model.py | 8 +++--- .../runnables/test_runnable_events.py | 26 +++++++++---------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/libs/core/tests/unit_tests/fake/chat_model.py b/libs/core/tests/unit_tests/fake/chat_model.py index d0135d7f116..ec03c431e32 100644 --- a/libs/core/tests/unit_tests/fake/chat_model.py +++ b/libs/core/tests/unit_tests/fake/chat_model.py @@ -225,9 +225,9 @@ class GenericFakeChatModel(BaseChatModel): for token in content_chunks: chunk = ChatGenerationChunk(message=AIMessageChunk(content=token)) - yield chunk if run_manager: run_manager.on_llm_new_token(token, chunk=chunk) + yield chunk if message.additional_kwargs: for key, value in message.additional_kwargs.items(): @@ -247,12 +247,12 @@ class GenericFakeChatModel(BaseChatModel): }, ) ) - yield chunk if run_manager: run_manager.on_llm_new_token( "", chunk=chunk, # No token for function call ) + yield chunk else: chunk = ChatGenerationChunk( message=AIMessageChunk( @@ -260,24 +260,24 @@ class GenericFakeChatModel(BaseChatModel): additional_kwargs={"function_call": {fkey: fvalue}}, ) ) - yield chunk if run_manager: run_manager.on_llm_new_token( "", chunk=chunk, # No token for function call ) + yield chunk else: chunk = ChatGenerationChunk( message=AIMessageChunk( content="", additional_kwargs={key: value} ) ) - yield chunk if run_manager: run_manager.on_llm_new_token( "", chunk=chunk, # No token for function call ) + yield chunk async def _astream( self, diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events.py b/libs/core/tests/unit_tests/runnables/test_runnable_events.py index a31e15e7ba4..54ccbbad058 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable_events.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events.py @@ -398,22 +398,14 @@ async def test_event_stream_with_simple_chain() -> None: }, { "data": {"chunk": AIMessageChunk(content="hello")}, - "event": "on_chain_stream", - "metadata": {"foo": "bar"}, - "name": "my_chain", + "event": "on_chat_model_stream", + "metadata": {"a": "b", "foo": "bar"}, + "name": "my_model", "run_id": "", - "tags": ["my_chain"], + "tags": ["my_chain", "my_model", "seq:step:2"], }, { "data": {"chunk": AIMessageChunk(content="hello")}, - "event": "on_chat_model_stream", - "metadata": {"a": "b", "foo": "bar"}, - "name": "my_model", - "run_id": "", - "tags": ["my_chain", "my_model", "seq:step:2"], - }, - { - "data": {"chunk": AIMessageChunk(content=" ")}, "event": "on_chain_stream", "metadata": {"foo": "bar"}, "name": "my_chain", @@ -429,7 +421,7 @@ async def test_event_stream_with_simple_chain() -> None: "tags": ["my_chain", "my_model", "seq:step:2"], }, { - "data": {"chunk": AIMessageChunk(content="world!")}, + "data": {"chunk": AIMessageChunk(content=" ")}, "event": "on_chain_stream", "metadata": {"foo": "bar"}, "name": "my_chain", @@ -444,6 +436,14 @@ async def test_event_stream_with_simple_chain() -> None: "run_id": "", "tags": ["my_chain", "my_model", "seq:step:2"], }, + { + "data": {"chunk": AIMessageChunk(content="world!")}, + "event": "on_chain_stream", + "metadata": {"foo": "bar"}, + "name": "my_chain", + "run_id": "", + "tags": ["my_chain"], + }, { "data": { "input": {