From b28be5d4071af0aa9ece7df0a811008a44ed64ca Mon Sep 17 00:00:00 2001 From: William FH <13333726+hinthornw@users.noreply.github.com> Date: Wed, 8 May 2024 22:20:51 -0700 Subject: [PATCH] Pass through Run ID Explicitly (#21469) --- libs/langchain/langchain/agents/agent.py | 25 ++++++++------ .../langchain/agents/agent_iterator.py | 6 ++++ libs/langchain/langchain/chains/base.py | 5 +++ .../unit_tests/agents/test_agent_iterator.py | 23 +++++++++++++ .../tests/unit_tests/chains/test_base.py | 34 +++++++++++++++++++ 5 files changed, 82 insertions(+), 11 deletions(-) diff --git a/libs/langchain/langchain/agents/agent.py b/libs/langchain/langchain/agents/agent.py index 54e5ce73215..85348db3059 100644 --- a/libs/langchain/langchain/agents/agent.py +++ b/libs/langchain/langchain/agents/agent.py @@ -1,4 +1,5 @@ """Chain that takes in an input and produces an action and action input.""" + from __future__ import annotations import asyncio @@ -346,11 +347,11 @@ class RunnableAgent(BaseSingleActionAgent): input_keys_arg: List[str] = [] return_keys_arg: List[str] = [] stream_runnable: bool = True - """Whether to stream from the runnable or not. + """Whether to stream from the runnable or not. - If True then underlying LLM is invoked in a streaming fashion to make it possible - to get access to the individual LLM tokens when using stream_log with the Agent - Executor. If False then LLM is invoked in a non-streaming fashion and + If True then underlying LLM is invoked in a streaming fashion to make it possible + to get access to the individual LLM tokens when using stream_log with the Agent + Executor. If False then LLM is invoked in a non-streaming fashion and individual LLM tokens will not be available in stream_log. """ @@ -455,11 +456,11 @@ class RunnableMultiActionAgent(BaseMultiActionAgent): input_keys_arg: List[str] = [] return_keys_arg: List[str] = [] stream_runnable: bool = True - """Whether to stream from the runnable or not. - - If True then underlying LLM is invoked in a streaming fashion to make it possible - to get access to the individual LLM tokens when using stream_log with the Agent - Executor. If False then LLM is invoked in a non-streaming fashion and + """Whether to stream from the runnable or not. + + If True then underlying LLM is invoked in a streaming fashion to make it possible + to get access to the individual LLM tokens when using stream_log with the Agent + Executor. If False then LLM is invoked in a non-streaming fashion and individual LLM tokens will not be available in stream_log. """ @@ -926,7 +927,7 @@ class AgentExecutor(Chain): max_iterations: Optional[int] = 15 """The maximum number of steps to take before ending the execution loop. - + Setting to 'None' could lead to an infinite loop.""" max_execution_time: Optional[float] = None """The maximum amount of wall clock time to spend in the execution @@ -938,7 +939,7 @@ class AgentExecutor(Chain): `"force"` returns a string saying that it stopped because it met a time or iteration limit. - + `"generate"` calls the agent's LLM Chain one final time to generate a final answer based on the previous steps. """ @@ -1565,6 +1566,7 @@ class AgentExecutor(Chain): tags=config.get("tags"), metadata=config.get("metadata"), run_name=config.get("run_name"), + run_id=config.get("run_id"), yield_actions=True, **kwargs, ) @@ -1586,6 +1588,7 @@ class AgentExecutor(Chain): tags=config.get("tags"), metadata=config.get("metadata"), run_name=config.get("run_name"), + run_id=config.get("run_id"), yield_actions=True, **kwargs, ) diff --git a/libs/langchain/langchain/agents/agent_iterator.py b/libs/langchain/langchain/agents/agent_iterator.py index 12c995f2e97..67a4def60a2 100644 --- a/libs/langchain/langchain/agents/agent_iterator.py +++ b/libs/langchain/langchain/agents/agent_iterator.py @@ -14,6 +14,7 @@ from typing import ( Tuple, Union, ) +from uuid import UUID from langchain_core.agents import ( AgentAction, @@ -54,6 +55,7 @@ class AgentExecutorIterator: tags: Optional[list[str]] = None, metadata: Optional[Dict[str, Any]] = None, run_name: Optional[str] = None, + run_id: Optional[UUID] = None, include_run_info: bool = False, yield_actions: bool = False, ): @@ -67,6 +69,7 @@ class AgentExecutorIterator: self.tags = tags self.metadata = metadata self.run_name = run_name + self.run_id = run_id self.include_run_info = include_run_info self.yield_actions = yield_actions self.reset() @@ -76,6 +79,7 @@ class AgentExecutorIterator: tags: Optional[list[str]] metadata: Optional[Dict[str, Any]] run_name: Optional[str] + run_id: Optional[UUID] include_run_info: bool yield_actions: bool @@ -162,6 +166,7 @@ class AgentExecutorIterator: run_manager = callback_manager.on_chain_start( dumpd(self.agent_executor), self.inputs, + self.run_id, name=self.run_name, ) try: @@ -227,6 +232,7 @@ class AgentExecutorIterator: run_manager = await callback_manager.on_chain_start( dumpd(self.agent_executor), self.inputs, + self.run_id, name=self.run_name, ) try: diff --git a/libs/langchain/langchain/chains/base.py b/libs/langchain/langchain/chains/base.py index 04b73d2744c..28b39a2293a 100644 --- a/libs/langchain/langchain/chains/base.py +++ b/libs/langchain/langchain/chains/base.py @@ -1,4 +1,5 @@ """Base interface that all chains should implement.""" + import inspect import json import logging @@ -127,6 +128,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC): tags = config.get("tags") metadata = config.get("metadata") run_name = config.get("run_name") or self.get_name() + run_id = config.get("run_id") include_run_info = kwargs.get("include_run_info", False) return_only_outputs = kwargs.get("return_only_outputs", False) @@ -145,6 +147,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC): run_manager = callback_manager.on_chain_start( dumpd(self), inputs, + run_id, name=run_name, ) try: @@ -178,6 +181,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC): tags = config.get("tags") metadata = config.get("metadata") run_name = config.get("run_name") or self.get_name() + run_id = config.get("run_id") include_run_info = kwargs.get("include_run_info", False) return_only_outputs = kwargs.get("return_only_outputs", False) @@ -195,6 +199,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC): run_manager = await callback_manager.on_chain_start( dumpd(self), inputs, + run_id, name=run_name, ) try: diff --git a/libs/langchain/tests/unit_tests/agents/test_agent_iterator.py b/libs/langchain/tests/unit_tests/agents/test_agent_iterator.py index 4dd1317a47b..dd6881e8acd 100644 --- a/libs/langchain/tests/unit_tests/agents/test_agent_iterator.py +++ b/libs/langchain/tests/unit_tests/agents/test_agent_iterator.py @@ -3,6 +3,7 @@ from uuid import UUID import pytest from langchain_core.language_models import FakeListLLM from langchain_core.tools import Tool +from langchain_core.tracers.context import collect_runs from langchain.agents import ( AgentExecutor, @@ -251,6 +252,28 @@ def test_agent_iterator_properties_and_setters() -> None: assert isinstance(agent_iter.agent_executor, AgentExecutor) +def test_agent_iterator_manual_run_id() -> None: + """Test react chain iterator with manually specified run_id.""" + agent = _get_agent() + run_id = UUID("f47ac10b-58cc-4372-a567-0e02b2c3d479") + with collect_runs() as cb: + agent_iter = agent.stream("when was langchain made", {"run_id": run_id}) + list(agent_iter) + run = cb.traced_runs[0] + assert run.id == run_id + + +async def test_manually_specify_rid_async() -> None: + agent = _get_agent() + run_id = UUID("f47ac10b-58cc-4372-a567-0e02b2c3d479") + with collect_runs() as cb: + res = agent.astream("bar", {"run_id": run_id}) + async for _ in res: + pass + run = cb.traced_runs[0] + assert run.id == run_id + + def test_agent_iterator_reset() -> None: """Test reset functionality of AgentExecutorIterator.""" agent = _get_agent() diff --git a/libs/langchain/tests/unit_tests/chains/test_base.py b/libs/langchain/tests/unit_tests/chains/test_base.py index 2070180b63b..26dabe3a997 100644 --- a/libs/langchain/tests/unit_tests/chains/test_base.py +++ b/libs/langchain/tests/unit_tests/chains/test_base.py @@ -1,9 +1,12 @@ """Test logic on base chain class.""" + +import uuid from typing import Any, Dict, List, Optional import pytest from langchain_core.callbacks.manager import CallbackManagerForChainRun from langchain_core.memory import BaseMemory +from langchain_core.tracers.context import collect_runs from langchain.chains.base import Chain from langchain.schema import RUN_KEY @@ -180,6 +183,37 @@ def test_run_with_callback_and_input_error() -> None: assert handler.errors == 1 +def test_manually_specify_rid() -> None: + chain = FakeChain() + run_id = uuid.uuid4() + with collect_runs() as cb: + chain.invoke({"foo": "bar"}, {"run_id": run_id}) + run = cb.traced_runs[0] + assert run.id == run_id + + run_id2 = uuid.uuid4() + with collect_runs() as cb: + list(chain.stream({"foo": "bar"}, {"run_id": run_id2})) + run = cb.traced_runs[0] + assert run.id == run_id2 + + +async def test_manually_specify_rid_async() -> None: + chain = FakeChain() + run_id = uuid.uuid4() + with collect_runs() as cb: + await chain.ainvoke({"foo": "bar"}, {"run_id": run_id}) + run = cb.traced_runs[0] + assert run.id == run_id + run_id2 = uuid.uuid4() + with collect_runs() as cb: + res = chain.astream({"foo": "bar"}, {"run_id": run_id2}) + async for _ in res: + pass + run = cb.traced_runs[0] + assert run.id == run_id2 + + def test_run_with_callback_and_output_error() -> None: """Test callback manager catches run validation output error.""" handler = FakeCallbackHandler()