diff --git a/libs/langchain/langchain/callbacks/tracers/root_listeners.py b/libs/langchain/langchain/callbacks/tracers/root_listeners.py index cb065bf8866..af489bfe9ab 100644 --- a/libs/langchain/langchain/callbacks/tracers/root_listeners.py +++ b/libs/langchain/langchain/callbacks/tracers/root_listeners.py @@ -1,20 +1,28 @@ -from typing import Callable, Optional +from typing import Callable, Optional, Union from uuid import UUID from langchain.callbacks.tracers.base import BaseTracer from langchain.callbacks.tracers.schemas import Run +from langchain.schema.runnable.config import ( + RunnableConfig, + call_func_with_variable_args, +) + +Listener = Union[Callable[[Run], None], Callable[[Run, RunnableConfig], None]] class RootListenersTracer(BaseTracer): def __init__( self, *, - on_start: Optional[Callable[[Run], None]], - on_end: Optional[Callable[[Run], None]], - on_error: Optional[Callable[[Run], None]], + config: RunnableConfig, + on_start: Optional[Listener], + on_end: Optional[Listener], + on_error: Optional[Listener], ) -> None: super().__init__() + self.config = config self._arg_on_start = on_start self._arg_on_end = on_end self._arg_on_error = on_error @@ -32,7 +40,7 @@ class RootListenersTracer(BaseTracer): self.root_id = run.id if self._arg_on_start is not None: - self._arg_on_start(run) + call_func_with_variable_args(self._arg_on_start, run, self.config) def _on_run_update(self, run: Run) -> None: if run.id != self.root_id: @@ -40,7 +48,7 @@ class RootListenersTracer(BaseTracer): if run.error is None: if self._arg_on_end is not None: - self._arg_on_end(run) + call_func_with_variable_args(self._arg_on_end, run, self.config) else: if self._arg_on_error is not None: - self._arg_on_error(run) + call_func_with_variable_args(self._arg_on_error, run, self.config) diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 39c99d14637..643bdb1b45d 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -37,7 +37,7 @@ if TYPE_CHECKING: CallbackManagerForChainRun, ) from langchain.callbacks.tracers.log_stream import RunLog, RunLogPatch - from langchain.callbacks.tracers.schemas import Run + from langchain.callbacks.tracers.root_listeners import Listener from langchain.schema.runnable.fallbacks import ( RunnableWithFallbacks as RunnableWithFallbacksT, ) @@ -591,9 +591,9 @@ class Runnable(Generic[Input, Output], ABC): def with_listeners( self, *, - on_start: Optional[Callable[[Run], None]] = None, - on_end: Optional[Callable[[Run], None]] = None, - on_error: Optional[Callable[[Run], None]] = None, + on_start: Optional[Listener] = None, + on_end: Optional[Listener] = None, + on_error: Optional[Listener] = None, ) -> Runnable[Input, Output]: """ Bind lifecycle listeners to a Runnable, returning a new Runnable. @@ -611,10 +611,13 @@ class Runnable(Generic[Input, Output], ABC): return RunnableBinding( bound=self, config_factories=[ - lambda: { + lambda config: { "callbacks": [ RootListenersTracer( - on_start=on_start, on_end=on_end, on_error=on_error + config=config, + on_start=on_start, + on_end=on_end, + on_error=on_error, ) ], } @@ -2391,9 +2394,9 @@ class RunnableEach(RunnableSerializable[List[Input], List[Output]]): def with_listeners( self, *, - on_start: Optional[Callable[[Run], None]] = None, - on_end: Optional[Callable[[Run], None]] = None, - on_error: Optional[Callable[[Run], None]] = None, + on_start: Optional[Listener] = None, + on_end: Optional[Listener] = None, + on_error: Optional[Listener] = None, ) -> RunnableEach[Input, Output]: """ Bind lifecycle listeners to a Runnable, returning a new Runnable. @@ -2456,7 +2459,9 @@ class RunnableBinding(RunnableSerializable[Input, Output]): config: RunnableConfig = Field(default_factory=dict) - config_factories: List[Callable[[], RunnableConfig]] = Field(default_factory=list) + config_factories: List[Callable[[RunnableConfig], RunnableConfig]] = Field( + default_factory=list + ) # Union[Type[Input], BaseModel] + things like List[str] custom_input_type: Optional[Any] = None @@ -2472,7 +2477,9 @@ class RunnableBinding(RunnableSerializable[Input, Output]): bound: Runnable[Input, Output], kwargs: Optional[Mapping[str, Any]] = None, config: Optional[RunnableConfig] = None, - config_factories: Optional[List[Callable[[], RunnableConfig]]] = None, + config_factories: Optional[ + List[Callable[[RunnableConfig], RunnableConfig]] + ] = None, custom_input_type: Optional[Union[Type[Input], BaseModel]] = None, custom_output_type: Optional[Union[Type[Output], BaseModel]] = None, **other_kwargs: Any, @@ -2570,9 +2577,9 @@ class RunnableBinding(RunnableSerializable[Input, Output]): def with_listeners( self, *, - on_start: Optional[Callable[[Run], None]] = None, - on_end: Optional[Callable[[Run], None]] = None, - on_error: Optional[Callable[[Run], None]] = None, + on_start: Optional[Listener] = None, + on_end: Optional[Listener] = None, + on_error: Optional[Listener] = None, ) -> Runnable[Input, Output]: """ Bind lifecycle listeners to a Runnable, returning a new Runnable. @@ -2592,10 +2599,13 @@ class RunnableBinding(RunnableSerializable[Input, Output]): kwargs=self.kwargs, config=self.config, config_factories=[ - lambda: { + lambda config: { "callbacks": [ RootListenersTracer( - on_start=on_start, on_end=on_end, on_error=on_error + config=config, + on_start=on_start, + on_end=on_end, + on_error=on_error, ) ], } @@ -2629,9 +2639,8 @@ class RunnableBinding(RunnableSerializable[Input, Output]): ) def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig: - return merge_configs( - self.config, *(f() for f in self.config_factories), *configs - ) + config = merge_configs(self.config, *configs) + return merge_configs(config, *(f(config) for f in self.config_factories)) def invoke( self, diff --git a/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr b/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr index 7c23fabf107..85dacbb7a03 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr +++ b/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr @@ -1185,6 +1185,119 @@ Run(id=UUID('00000000-0000-4000-8000-000000000000'), name='RunnableSequence', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='chain', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'runnable', 'RunnableSequence'], 'kwargs': {'first': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, 'middle': [{'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'chat_models', 'fake', 'FakeListChatModel'], 'repr': "FakeListChatModel(responses=['foo, bar'])"}], 'last': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'output_parsers', 'list', 'CommaSeparatedListOutputParser'], 'kwargs': {}}}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'output': ['foo', 'bar']}, reference_example_id=None, parent_run_id=None, tags=[], execution_order=None, child_execution_order=None, child_runs=[Run(id=UUID('00000000-0000-4000-8000-000000000001'), name='ChatPromptTemplate', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='prompt', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptValue'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'SystemMessage'], 'kwargs': {'content': 'You are a nice assistant.', 'additional_kwargs': {}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'HumanMessage'], 'kwargs': {'content': 'What is your name?', 'additional_kwargs': {}}}]}}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=['seq:step:1'], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000002'), name='FakeListChatModel', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='llm', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={'invocation_params': {'responses': ['foo, bar'], '_type': 'fake-list-chat-model', 'stop': None}, 'options': {'stop': None}}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'chat_models', 'fake', 'FakeListChatModel'], 'repr': "FakeListChatModel(responses=['foo, bar'])"}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'prompts': ['System: You are a nice assistant.\nHuman: What is your name?']}, outputs={'generations': [[{'text': 'foo, bar', 'generation_info': None, 'type': 'ChatGeneration', 'message': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'AIMessage'], 'kwargs': {'content': 'foo, bar'}}}]], 'llm_output': None, 'run': None}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=['seq:step:2'], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000003'), name='CommaSeparatedListOutputParser', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='parser', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'output_parsers', 'list', 'CommaSeparatedListOutputParser'], 'kwargs': {}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'input': AIMessage(content='foo, bar')}, outputs={'output': ['foo', 'bar']}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=['seq:step:3'], execution_order=None, child_execution_order=None, child_runs=[])]), ]) # --- +# name: test_prompt_with_chat_model_async + ''' + ChatPromptTemplate(input_variables=['question'], messages=[SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=[], template='You are a nice assistant.')), HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['question'], template='{question}'))]) + | FakeListChatModel(responses=['foo']) + ''' +# --- +# name: test_prompt_with_chat_model_async.1 + ''' + { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "schema", + "runnable", + "RunnableSequence" + ], + "kwargs": { + "first": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "chat", + "ChatPromptTemplate" + ], + "kwargs": { + "messages": [ + { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "chat", + "SystemMessagePromptTemplate" + ], + "kwargs": { + "prompt": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "prompt", + "PromptTemplate" + ], + "kwargs": { + "input_variables": [], + "template": "You are a nice assistant.", + "template_format": "f-string", + "partial_variables": {} + } + } + } + }, + { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "chat", + "HumanMessagePromptTemplate" + ], + "kwargs": { + "prompt": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "prompt", + "PromptTemplate" + ], + "kwargs": { + "input_variables": [ + "question" + ], + "template": "{question}", + "template_format": "f-string", + "partial_variables": {} + } + } + } + } + ], + "input_variables": [ + "question" + ] + } + }, + "last": { + "lc": 1, + "type": "not_implemented", + "id": [ + "langchain", + "chat_models", + "fake", + "FakeListChatModel" + ], + "repr": "FakeListChatModel(responses=['foo'])" + } + } + } + ''' +# --- +# name: test_prompt_with_chat_model_async.2 + list([ + Run(id=UUID('00000000-0000-4000-8000-000000000000'), name='RunnableSequence', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='chain', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'runnable', 'RunnableSequence'], 'kwargs': {'first': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, 'last': {'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'chat_models', 'fake', 'FakeListChatModel'], 'repr': "FakeListChatModel(responses=['foo'])"}}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'output': AIMessage(content='foo')}, reference_example_id=None, parent_run_id=None, tags=[], execution_order=None, child_execution_order=None, child_runs=[Run(id=UUID('00000000-0000-4000-8000-000000000001'), name='ChatPromptTemplate', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='prompt', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptValue'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'SystemMessage'], 'kwargs': {'content': 'You are a nice assistant.', 'additional_kwargs': {}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'HumanMessage'], 'kwargs': {'content': 'What is your name?', 'additional_kwargs': {}}}]}}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=['seq:step:1'], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000002'), name='FakeListChatModel', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='llm', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={'invocation_params': {'responses': ['foo'], '_type': 'fake-list-chat-model', 'stop': None}, 'options': {'stop': None}}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'chat_models', 'fake', 'FakeListChatModel'], 'repr': "FakeListChatModel(responses=['foo'])"}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'prompts': ['System: You are a nice assistant.\nHuman: What is your name?']}, outputs={'generations': [[{'text': 'foo', 'generation_info': None, 'type': 'ChatGeneration', 'message': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'AIMessage'], 'kwargs': {'content': 'foo'}}}]], 'llm_output': None, 'run': None}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=['seq:step:2'], execution_order=None, child_execution_order=None, child_runs=[])]), + ]) +# --- # name: test_prompt_with_llm ''' { diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index 1d2890e4fe8..56af1734cd3 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -1712,8 +1712,41 @@ def test_with_listeners(mocker: MockerFixture) -> None: @pytest.mark.asyncio +async def test_with_listeners_async(mocker: MockerFixture) -> None: + prompt = ( + SystemMessagePromptTemplate.from_template("You are a nice assistant.") + + "{question}" + ) + chat = FakeListChatModel(responses=["foo"]) + + chain = prompt | chat + + mock_start = mocker.Mock() + mock_end = mocker.Mock() + + await chain.with_listeners(on_start=mock_start, on_end=mock_end).ainvoke( + {"question": "Who are you?"} + ) + + assert mock_start.call_count == 1 + assert mock_start.call_args[0][0].name == "RunnableSequence" + assert mock_end.call_count == 1 + + mock_start.reset_mock() + mock_end.reset_mock() + + async with atrace_as_chain_group("hello") as manager: + await chain.with_listeners(on_start=mock_start, on_end=mock_end).ainvoke( + {"question": "Who are you?"}, {"callbacks": manager} + ) + + assert mock_start.call_count == 1 + assert mock_start.call_args[0][0].name == "RunnableSequence" + assert mock_end.call_count == 1 + + @freeze_time("2023-01-01") -async def test_prompt_with_chat_model( +def test_prompt_with_chat_model( mocker: MockerFixture, snapshot: SnapshotAssertion ) -> None: prompt = ( @@ -1816,6 +1849,114 @@ async def test_prompt_with_chat_model( ) +@pytest.mark.asyncio +@freeze_time("2023-01-01") +async def test_prompt_with_chat_model_async( + mocker: MockerFixture, snapshot: SnapshotAssertion +) -> None: + prompt = ( + SystemMessagePromptTemplate.from_template("You are a nice assistant.") + + "{question}" + ) + chat = FakeListChatModel(responses=["foo"]) + + chain = prompt | chat + + assert repr(chain) == snapshot + assert isinstance(chain, RunnableSequence) + assert chain.first == prompt + assert chain.middle == [] + assert chain.last == chat + assert dumps(chain, pretty=True) == snapshot + + # Test invoke + prompt_spy = mocker.spy(prompt.__class__, "ainvoke") + chat_spy = mocker.spy(chat.__class__, "ainvoke") + tracer = FakeTracer() + assert await chain.ainvoke( + {"question": "What is your name?"}, dict(callbacks=[tracer]) + ) == AIMessage(content="foo") + assert prompt_spy.call_args.args[1] == {"question": "What is your name?"} + assert chat_spy.call_args.args[1] == ChatPromptValue( + messages=[ + SystemMessage(content="You are a nice assistant."), + HumanMessage(content="What is your name?"), + ] + ) + + assert tracer.runs == snapshot + + mocker.stop(prompt_spy) + mocker.stop(chat_spy) + + # Test batch + prompt_spy = mocker.spy(prompt.__class__, "abatch") + chat_spy = mocker.spy(chat.__class__, "abatch") + tracer = FakeTracer() + assert await chain.abatch( + [ + {"question": "What is your name?"}, + {"question": "What is your favorite color?"}, + ], + dict(callbacks=[tracer]), + ) == [ + AIMessage(content="foo"), + AIMessage(content="foo"), + ] + assert prompt_spy.call_args.args[1] == [ + {"question": "What is your name?"}, + {"question": "What is your favorite color?"}, + ] + assert chat_spy.call_args.args[1] == [ + ChatPromptValue( + messages=[ + SystemMessage(content="You are a nice assistant."), + HumanMessage(content="What is your name?"), + ] + ), + ChatPromptValue( + messages=[ + SystemMessage(content="You are a nice assistant."), + HumanMessage(content="What is your favorite color?"), + ] + ), + ] + assert ( + len( + [ + r + for r in tracer.runs + if r.parent_run_id is None and len(r.child_runs) == 2 + ] + ) + == 2 + ), "Each of 2 outer runs contains exactly two inner runs (1 prompt, 1 chat)" + mocker.stop(prompt_spy) + mocker.stop(chat_spy) + + # Test stream + prompt_spy = mocker.spy(prompt.__class__, "ainvoke") + chat_spy = mocker.spy(chat.__class__, "astream") + tracer = FakeTracer() + assert [ + a + async for a in chain.astream( + {"question": "What is your name?"}, dict(callbacks=[tracer]) + ) + ] == [ + AIMessageChunk(content="f"), + AIMessageChunk(content="o"), + AIMessageChunk(content="o"), + ] + assert prompt_spy.call_args.args[1] == {"question": "What is your name?"} + assert chat_spy.call_args.args[1] == ChatPromptValue( + messages=[ + SystemMessage(content="You are a nice assistant."), + HumanMessage(content="What is your name?"), + ] + ) + + @pytest.mark.asyncio @freeze_time("2023-01-01") async def test_prompt_with_llm(