diff --git a/libs/langchain/langchain/agents/agent.py b/libs/langchain/langchain/agents/agent.py index 459953679bb..8fbdf976a11 100644 --- a/libs/langchain/langchain/agents/agent.py +++ b/libs/langchain/langchain/agents/agent.py @@ -371,7 +371,7 @@ class RunnableAgent(BaseSingleActionAgent): callbacks: Callbacks = None, **kwargs: Any, ) -> Union[AgentAction, AgentFinish]: - """Given input, decided what to do. + """Based on past history and current inputs, decide what to do. Args: intermediate_steps: Steps the LLM has taken to date, @@ -383,8 +383,19 @@ class RunnableAgent(BaseSingleActionAgent): Action specifying what tool to use. """ inputs = {**kwargs, **{"intermediate_steps": intermediate_steps}} - output = self.runnable.invoke(inputs, config={"callbacks": callbacks}) - return output + # Use streaming to make sure that the underlying LLM is invoked in a streaming + # fashion to make it possible to get access to the individual LLM tokens + # when using stream_log with the Agent Executor. + # Because the response from the plan is not a generator, we need to + # accumulate the output into final output and return that. + final_output: Any = None + for chunk in self.runnable.stream(inputs, config={"callbacks": callbacks}): + if final_output is None: + final_output = chunk + else: + final_output += chunk + + return final_output async def aplan( self, @@ -395,20 +406,32 @@ class RunnableAgent(BaseSingleActionAgent): AgentAction, AgentFinish, ]: - """Given input, decided what to do. + """Based on past history and current inputs, decide what to do. Args: intermediate_steps: Steps the LLM has taken to date, along with observations callbacks: Callbacks to run. - **kwargs: User inputs. + **kwargs: User inputs Returns: Action specifying what tool to use. """ inputs = {**kwargs, **{"intermediate_steps": intermediate_steps}} - output = await self.runnable.ainvoke(inputs, config={"callbacks": callbacks}) - return output + final_output: Any = None + # Use streaming to make sure that the underlying LLM is invoked in a streaming + # fashion to make it possible to get access to the individual LLM tokens + # when using stream_log with the Agent Executor. + # Because the response from the plan is not a generator, we need to + # accumulate the output into final output and return that. + async for chunk in self.runnable.astream( + inputs, config={"callbacks": callbacks} + ): + if final_output is None: + final_output = chunk + else: + final_output += chunk + return final_output class RunnableMultiActionAgent(BaseMultiActionAgent): @@ -447,7 +470,7 @@ class RunnableMultiActionAgent(BaseMultiActionAgent): List[AgentAction], AgentFinish, ]: - """Given input, decided what to do. + """Based on past history and current inputs, decide what to do. Args: intermediate_steps: Steps the LLM has taken to date, @@ -459,8 +482,19 @@ class RunnableMultiActionAgent(BaseMultiActionAgent): Action specifying what tool to use. """ inputs = {**kwargs, **{"intermediate_steps": intermediate_steps}} - output = self.runnable.invoke(inputs, config={"callbacks": callbacks}) - return output + # Use streaming to make sure that the underlying LLM is invoked in a streaming + # fashion to make it possible to get access to the individual LLM tokens + # when using stream_log with the Agent Executor. + # Because the response from the plan is not a generator, we need to + # accumulate the output into final output and return that. + final_output: Any = None + for chunk in self.runnable.stream(inputs, config={"callbacks": callbacks}): + if final_output is None: + final_output = chunk + else: + final_output += chunk + + return final_output async def aplan( self, @@ -471,7 +505,7 @@ class RunnableMultiActionAgent(BaseMultiActionAgent): List[AgentAction], AgentFinish, ]: - """Given input, decided what to do. + """Based on past history and current inputs, decide what to do. Args: intermediate_steps: Steps the LLM has taken to date, @@ -483,8 +517,21 @@ class RunnableMultiActionAgent(BaseMultiActionAgent): Action specifying what tool to use. """ inputs = {**kwargs, **{"intermediate_steps": intermediate_steps}} - output = await self.runnable.ainvoke(inputs, config={"callbacks": callbacks}) - return output + # Use streaming to make sure that the underlying LLM is invoked in a streaming + # fashion to make it possible to get access to the individual LLM tokens + # when using stream_log with the Agent Executor. + # Because the response from the plan is not a generator, we need to + # accumulate the output into final output and return that. + final_output: Any = None + async for chunk in self.runnable.astream( + inputs, config={"callbacks": callbacks} + ): + if final_output is None: + final_output = chunk + else: + final_output += chunk + + return final_output @deprecated( diff --git a/libs/langchain/tests/unit_tests/agents/test_agent.py b/libs/langchain/tests/unit_tests/agents/test_agent.py index ca7fa9b1b55..7baee92e072 100644 --- a/libs/langchain/tests/unit_tests/agents/test_agent.py +++ b/libs/langchain/tests/unit_tests/agents/test_agent.py @@ -1,16 +1,37 @@ """Unit tests for agents.""" +import json +from itertools import cycle +from typing import Any, Dict, List, Optional, Union, cast -from typing import Any, Dict, List, Optional - -from langchain_core.agents import AgentAction, AgentStep +from langchain_core.agents import ( + AgentAction, + AgentActionMessageLog, + AgentFinish, + AgentStep, +) from langchain_core.language_models.llms import LLM -from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.messages import ( + AIMessage, + AIMessageChunk, + FunctionMessage, + HumanMessage, +) +from langchain_core.prompts import MessagesPlaceholder from langchain_core.runnables.utils import add from langchain_core.tools import Tool +from langchain_core.tracers import RunLog, RunLogPatch -from langchain.agents import AgentExecutor, AgentType, initialize_agent +from langchain.agents import ( + AgentExecutor, + AgentType, + create_openai_functions_agent, + initialize_agent, +) from langchain.callbacks.manager import CallbackManagerForLLMRun +from langchain.prompts import ChatPromptTemplate +from langchain.tools import tool from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler +from tests.unit_tests.llms.fake_chat_model import GenericFakeChatModel class FakeListLLM(LLM): @@ -414,3 +435,356 @@ def test_agent_invalid_tool() -> None: resp = agent("when was langchain made") resp["intermediate_steps"][0][1] == "Foo is not a valid tool, try one of [Search]." + + +async def test_runnable_agent() -> None: + """Simple test to verify that an agent built with LCEL works.""" + + # Will alternate between responding with hello and goodbye + infinite_cycle = cycle([AIMessage(content="hello world!")]) + # When streaming GenericFakeChatModel breaks AIMessage into chunks based on spaces + model = GenericFakeChatModel(messages=infinite_cycle) + + template = ChatPromptTemplate.from_messages( + [("system", "You are Cat Agent 007"), ("human", "{question}")] + ) + + def fake_parse(inputs: dict) -> Union[AgentFinish, AgentAction]: + """A parser.""" + return AgentFinish(return_values={"foo": "meow"}, log="hard-coded-message") + + agent = template | model | fake_parse + executor = AgentExecutor(agent=agent, tools=[]) + + # Invoke + result = executor.invoke({"question": "hello"}) + assert result == {"foo": "meow", "question": "hello"} + + # ainvoke + result = await executor.ainvoke({"question": "hello"}) + assert result == {"foo": "meow", "question": "hello"} + + # Batch + result = executor.batch( # type: ignore[assignment] + [{"question": "hello"}, {"question": "hello"}] + ) + assert result == [ + {"foo": "meow", "question": "hello"}, + {"foo": "meow", "question": "hello"}, + ] + + # abatch + result = await executor.abatch( # type: ignore[assignment] + [{"question": "hello"}, {"question": "hello"}] + ) + assert result == [ + {"foo": "meow", "question": "hello"}, + {"foo": "meow", "question": "hello"}, + ] + + # Stream + results = list(executor.stream({"question": "hello"})) + assert results == [ + {"foo": "meow", "messages": [AIMessage(content="hard-coded-message")]} + ] + + # astream + results = [r async for r in executor.astream({"question": "hello"})] + assert results == [ + { + "foo": "meow", + "messages": [ + AIMessage(content="hard-coded-message"), + ], + } + ] + + # stream log + results: List[RunLogPatch] = [ # type: ignore[no-redef] + r async for r in executor.astream_log({"question": "hello"}) + ] + # # Let's stream just the llm tokens. + messages = [] + for log_record in results: + for op in log_record.ops: # type: ignore[attr-defined] + if op["op"] == "add" and isinstance(op["value"], AIMessageChunk): + messages.append(op["value"]) + + assert messages != [] + + # Aggregate state + run_log = None + + for result in results: + if run_log is None: + run_log = result + else: + # `+` is defined for RunLogPatch + run_log = run_log + result # type: ignore[union-attr] + + assert isinstance(run_log, RunLog) + + assert run_log.state["final_output"] == { + "foo": "meow", + "messages": [AIMessage(content="hard-coded-message")], + } + + +async def test_runnable_agent_with_function_calls() -> None: + """Test agent with intermediate agent actions.""" + # Will alternate between responding with hello and goodbye + infinite_cycle = cycle( + [AIMessage(content="looking for pet..."), AIMessage(content="Found Pet")] + ) + model = GenericFakeChatModel(messages=infinite_cycle) + + template = ChatPromptTemplate.from_messages( + [("system", "You are Cat Agent 007"), ("human", "{question}")] + ) + + parser_responses = cycle( + [ + AgentAction( + tool="find_pet", + tool_input={ + "pet": "cat", + }, + log="find_pet()", + ), + AgentFinish( + return_values={"foo": "meow"}, + log="hard-coded-message", + ), + ], + ) + + def fake_parse(inputs: dict) -> Union[AgentFinish, AgentAction]: + """A parser.""" + return cast(Union[AgentFinish, AgentAction], next(parser_responses)) + + @tool + def find_pet(pet: str) -> str: + """Find the given pet.""" + if pet != "cat": + raise ValueError("Only cats allowed") + return "Spying from under the bed." + + agent = template | model | fake_parse + executor = AgentExecutor(agent=agent, tools=[find_pet]) + + # Invoke + result = executor.invoke({"question": "hello"}) + assert result == {"foo": "meow", "question": "hello"} + + # ainvoke + result = await executor.ainvoke({"question": "hello"}) + assert result == {"foo": "meow", "question": "hello"} + + # astream + results = [r async for r in executor.astream({"question": "hello"})] + assert results == [ + { + "actions": [ + AgentAction( + tool="find_pet", tool_input={"pet": "cat"}, log="find_pet()" + ) + ], + "messages": [AIMessage(content="find_pet()")], + }, + { + "messages": [HumanMessage(content="Spying from under the bed.")], + "steps": [ + AgentStep( + action=AgentAction( + tool="find_pet", tool_input={"pet": "cat"}, log="find_pet()" + ), + observation="Spying from under the bed.", + ) + ], + }, + {"foo": "meow", "messages": [AIMessage(content="hard-coded-message")]}, + ] + + # astream log + + messages = [] + async for patch in executor.astream_log({"question": "hello"}): + for op in patch.ops: + if op["op"] != "add": + continue + + value = op["value"] + + if not isinstance(value, AIMessageChunk): + continue + + if value.content == "": # Then it's a function invocation message + continue + + messages.append(value.content) + + assert messages == ["looking", " ", "for", " ", "pet...", "Found", " ", "Pet"] + + +def _make_func_invocation(name: str, **kwargs: Any) -> AIMessage: + """Create an AIMessage that represents a function invocation. + + Args: + name: Name of the function to invoke. + kwargs: Keyword arguments to pass to the function. + + Returns: + AIMessage that represents a request to invoke a function. + """ + return AIMessage( + content="", + additional_kwargs={ + "function_call": { + "name": name, + "arguments": json.dumps(kwargs), + } + }, + ) + + +async def test_openai_agent_with_streaming() -> None: + """Test openai agent with streaming.""" + infinite_cycle = cycle( + [ + _make_func_invocation("find_pet", pet="cat"), + AIMessage(content="The cat is spying from under the bed."), + ] + ) + + model = GenericFakeChatModel(messages=infinite_cycle) + + @tool + def find_pet(pet: str) -> str: + """Find the given pet.""" + if pet != "cat": + raise ValueError("Only cats allowed") + return "Spying from under the bed." + + template = ChatPromptTemplate.from_messages( + [ + ("system", "You are a helpful AI bot. Your name is kitty power meow."), + ("human", "{question}"), + MessagesPlaceholder( + variable_name="agent_scratchpad", + ), + ] + ) + + # type error due to base tool type below -- would need to be adjusted on tool + # decorator. + agent = create_openai_functions_agent( + model, + [find_pet], # type: ignore[list-item] + template, + ) + executor = AgentExecutor(agent=agent, tools=[find_pet]) + + # Invoke + result = executor.invoke({"question": "hello"}) + assert result == { + "output": "The cat is spying from under the bed.", + "question": "hello", + } + + # astream + chunks = [chunk async for chunk in executor.astream({"question": "hello"})] + assert chunks == [ + { + "actions": [ + AgentActionMessageLog( + tool="find_pet", + tool_input={"pet": "cat"}, + log="\nInvoking: `find_pet` with `{'pet': 'cat'}`\n\n\n", + message_log=[ + AIMessageChunk( + content="", + additional_kwargs={ + "function_call": { + "name": "find_pet", + "arguments": '{"pet": "cat"}', + } + }, + ) + ], + ) + ], + "messages": [ + AIMessageChunk( + content="", + additional_kwargs={ + "function_call": { + "name": "find_pet", + "arguments": '{"pet": "cat"}', + } + }, + ) + ], + }, + { + "messages": [ + FunctionMessage(content="Spying from under the bed.", name="find_pet") + ], + "steps": [ + AgentStep( + action=AgentActionMessageLog( + tool="find_pet", + tool_input={"pet": "cat"}, + log="\nInvoking: `find_pet` with `{'pet': 'cat'}`\n\n\n", + message_log=[ + AIMessageChunk( + content="", + additional_kwargs={ + "function_call": { + "name": "find_pet", + "arguments": '{"pet": "cat"}', + } + }, + ) + ], + ), + observation="Spying from under the bed.", + ) + ], + }, + { + "messages": [AIMessage(content="The cat is spying from under the bed.")], + "output": "The cat is spying from under the bed.", + }, + ] + # + # # astream_log + log_patches = [ + log_patch async for log_patch in executor.astream_log({"question": "hello"}) + ] + + messages = [] + + for log_patch in log_patches: + for op in log_patch.ops: + if op["op"] == "add" and isinstance(op["value"], AIMessageChunk): + value = op["value"] + if value.content: # Filter out function call messages + messages.append(value.content) + + assert messages == [ + "The", + " ", + "cat", + " ", + "is", + " ", + "spying", + " ", + "from", + " ", + "under", + " ", + "the", + " ", + "bed.", + ] diff --git a/libs/langchain/tests/unit_tests/llms/fake_chat_model.py b/libs/langchain/tests/unit_tests/llms/fake_chat_model.py index c709473f4df..3a746720d82 100644 --- a/libs/langchain/tests/unit_tests/llms/fake_chat_model.py +++ b/libs/langchain/tests/unit_tests/llms/fake_chat_model.py @@ -1,9 +1,15 @@ """Fake Chat Model wrapper for testing purposes.""" -from typing import Any, Dict, List, Optional +import re +from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, cast -from langchain_core.language_models.chat_models import SimpleChatModel -from langchain_core.messages import AIMessage, BaseMessage -from langchain_core.outputs import ChatGeneration, ChatResult +from langchain_core.language_models.chat_models import BaseChatModel, SimpleChatModel +from langchain_core.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, +) +from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult +from langchain_core.runnables import run_in_executor from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, @@ -42,3 +48,151 @@ class FakeChatModel(SimpleChatModel): @property def _identifying_params(self) -> Dict[str, Any]: return {"key": "fake"} + + +class GenericFakeChatModel(BaseChatModel): + """A generic fake chat model that can be used to test the chat model interface. + + * Chat model should be usable in both sync and async tests + * Invokes on_llm_new_token to allow for testing of callback related code for new + tokens. + * Includes logic to break messages into message chunk to facilitate testing of + streaming. + """ + + messages: Iterator[AIMessage] + """Get an iterator over messages. + + This can be expanded to accept other types like Callables / dicts / strings + to make the interface more generic if needed. + + Note: if you want to pass a list, you can use `iter` to convert it to an iterator. + + Please note that streaming is not implemented yet. We should try to implement it + in the future by delegating to invoke and then breaking the resulting output + into message chunks. + """ + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + """Top Level call""" + message = next(self.messages) + generation = ChatGeneration(message=message) + return ChatResult(generations=[generation]) + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + """Stream the output of the model.""" + chat_result = self._generate( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + if not isinstance(chat_result, ChatResult): + raise ValueError( + f"Expected generate to return a ChatResult, " + f"but got {type(chat_result)} instead." + ) + + message = chat_result.generations[0].message + + if not isinstance(message, AIMessage): + raise ValueError( + f"Expected invoke to return an AIMessage, " + f"but got {type(message)} instead." + ) + + content = message.content + + if content: + # Use a regular expression to split on whitespace with a capture group + # so that we can preserve the whitespace in the output. + assert isinstance(content, str) + content_chunks = cast(List[str], re.split(r"(\s)", content)) + + for token in content_chunks: + chunk = ChatGenerationChunk(message=AIMessageChunk(content=token)) + yield chunk + if run_manager: + run_manager.on_llm_new_token(token, chunk=chunk) + + if message.additional_kwargs: + for key, value in message.additional_kwargs.items(): + # We should further break down the additional kwargs into chunks + # Special case for function call + if key == "function_call": + for fkey, fvalue in value.items(): + if isinstance(fvalue, str): + # Break function call by `,` + fvalue_chunks = cast(List[str], re.split(r"(,)", fvalue)) + for fvalue_chunk in fvalue_chunks: + chunk = ChatGenerationChunk( + message=AIMessageChunk( + content="", + additional_kwargs={ + "function_call": {fkey: fvalue_chunk} + }, + ) + ) + yield chunk + if run_manager: + run_manager.on_llm_new_token( + "", + chunk=chunk, # No token for function call + ) + else: + chunk = ChatGenerationChunk( + message=AIMessageChunk( + content="", + additional_kwargs={"function_call": {fkey: fvalue}}, + ) + ) + yield chunk + if run_manager: + run_manager.on_llm_new_token( + "", + chunk=chunk, # No token for function call + ) + else: + chunk = ChatGenerationChunk( + message=AIMessageChunk( + content="", additional_kwargs={key: value} + ) + ) + yield chunk + if run_manager: + run_manager.on_llm_new_token( + "", + chunk=chunk, # No token for function call + ) + + async def _astream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[ChatGenerationChunk]: + """Stream the output of the model.""" + result = await run_in_executor( + None, + self._stream, + messages, + stop=stop, + run_manager=run_manager.get_sync() if run_manager else None, + **kwargs, + ) + for chunk in result: + yield chunk + + @property + def _llm_type(self) -> str: + return "generic-fake-chat-model" diff --git a/libs/langchain/tests/unit_tests/llms/test_fake_chat_model.py b/libs/langchain/tests/unit_tests/llms/test_fake_chat_model.py new file mode 100644 index 00000000000..db658736e89 --- /dev/null +++ b/libs/langchain/tests/unit_tests/llms/test_fake_chat_model.py @@ -0,0 +1,185 @@ +"""Tests for verifying that testing utility code works as expected.""" +from itertools import cycle +from typing import Any, Dict, List, Optional, Union +from uuid import UUID + +from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage +from langchain_core.outputs import ChatGenerationChunk, GenerationChunk + +from langchain.callbacks.base import AsyncCallbackHandler +from tests.unit_tests.llms.fake_chat_model import GenericFakeChatModel + + +def test_generic_fake_chat_model_invoke() -> None: + # Will alternate between responding with hello and goodbye + infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")]) + model = GenericFakeChatModel(messages=infinite_cycle) + response = model.invoke("meow") + assert response == AIMessage(content="hello") + response = model.invoke("kitty") + assert response == AIMessage(content="goodbye") + response = model.invoke("meow") + assert response == AIMessage(content="hello") + + +async def test_generic_fake_chat_model_ainvoke() -> None: + # Will alternate between responding with hello and goodbye + infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")]) + model = GenericFakeChatModel(messages=infinite_cycle) + response = await model.ainvoke("meow") + assert response == AIMessage(content="hello") + response = await model.ainvoke("kitty") + assert response == AIMessage(content="goodbye") + response = await model.ainvoke("meow") + assert response == AIMessage(content="hello") + + +async def test_generic_fake_chat_model_stream() -> None: + """Test streaming.""" + infinite_cycle = cycle( + [ + AIMessage(content="hello goodbye"), + ] + ) + model = GenericFakeChatModel(messages=infinite_cycle) + chunks = [chunk async for chunk in model.astream("meow")] + assert chunks == [ + AIMessageChunk(content="hello"), + AIMessageChunk(content=" "), + AIMessageChunk(content="goodbye"), + ] + + chunks = [chunk for chunk in model.stream("meow")] + assert chunks == [ + AIMessageChunk(content="hello"), + AIMessageChunk(content=" "), + AIMessageChunk(content="goodbye"), + ] + + # Test streaming of additional kwargs. + # Relying on insertion order of the additional kwargs dict + message = AIMessage(content="", additional_kwargs={"foo": 42, "bar": 24}) + model = GenericFakeChatModel(messages=cycle([message])) + chunks = [chunk async for chunk in model.astream("meow")] + assert chunks == [ + AIMessageChunk(content="", additional_kwargs={"foo": 42}), + AIMessageChunk(content="", additional_kwargs={"bar": 24}), + ] + + message = AIMessage( + content="", + additional_kwargs={ + "function_call": { + "name": "move_file", + "arguments": '{\n "source_path": "foo",\n "' + 'destination_path": "bar"\n}', + } + }, + ) + model = GenericFakeChatModel(messages=cycle([message])) + chunks = [chunk async for chunk in model.astream("meow")] + + assert chunks == [ + AIMessageChunk( + content="", additional_kwargs={"function_call": {"name": "move_file"}} + ), + AIMessageChunk( + content="", + additional_kwargs={ + "function_call": {"arguments": '{\n "source_path": "foo"'} + }, + ), + AIMessageChunk( + content="", additional_kwargs={"function_call": {"arguments": ","}} + ), + AIMessageChunk( + content="", + additional_kwargs={ + "function_call": {"arguments": '\n "destination_path": "bar"\n}'} + }, + ), + ] + + accumulate_chunks = None + for chunk in chunks: + if accumulate_chunks is None: + accumulate_chunks = chunk + else: + accumulate_chunks += chunk + + assert accumulate_chunks == AIMessageChunk( + content="", + additional_kwargs={ + "function_call": { + "name": "move_file", + "arguments": '{\n "source_path": "foo",\n "' + 'destination_path": "bar"\n}', + } + }, + ) + + +async def test_generic_fake_chat_model_astream_log() -> None: + """Test streaming.""" + infinite_cycle = cycle([AIMessage(content="hello goodbye")]) + model = GenericFakeChatModel(messages=infinite_cycle) + log_patches = [ + log_patch async for log_patch in model.astream_log("meow", diff=False) + ] + final = log_patches[-1] + assert final.state["streamed_output"] == [ + AIMessageChunk(content="hello"), + AIMessageChunk(content=" "), + AIMessageChunk(content="goodbye"), + ] + + +async def test_callback_handlers() -> None: + """Verify that model is implemented correctly with handlers working.""" + + class MyCustomAsyncHandler(AsyncCallbackHandler): + def __init__(self, store: List[str]) -> None: + self.store = store + + async def on_chat_model_start( + self, + serialized: Dict[str, Any], + messages: List[List[BaseMessage]], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> Any: + # Do nothing + # Required to implement since this is an abstract method + pass + + async def on_llm_new_token( + self, + token: str, + *, + chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + self.store.append(token) + + infinite_cycle = cycle( + [ + AIMessage(content="hello goodbye"), + ] + ) + model = GenericFakeChatModel(messages=infinite_cycle) + tokens: List[str] = [] + # New model + results = list(model.stream("meow", {"callbacks": [MyCustomAsyncHandler(tokens)]})) + assert results == [ + AIMessageChunk(content="hello"), + AIMessageChunk(content=" "), + AIMessageChunk(content="goodbye"), + ] + assert tokens == ["hello", " ", "goodbye"]