Pass through Run ID Explicitly (#21469)

This commit is contained in:
William FH 2024-05-08 22:20:51 -07:00 committed by GitHub
parent 83eecd54fe
commit b28be5d407
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 82 additions and 11 deletions

View File

@ -1,4 +1,5 @@
"""Chain that takes in an input and produces an action and action input.""" """Chain that takes in an input and produces an action and action input."""
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
@ -346,11 +347,11 @@ class RunnableAgent(BaseSingleActionAgent):
input_keys_arg: List[str] = [] input_keys_arg: List[str] = []
return_keys_arg: List[str] = [] return_keys_arg: List[str] = []
stream_runnable: bool = True stream_runnable: bool = True
"""Whether to stream from the runnable or not. """Whether to stream from the runnable or not.
If True then underlying LLM is invoked in a streaming fashion to make it possible If True then underlying LLM is invoked in a streaming fashion to make it possible
to get access to the individual LLM tokens when using stream_log with the Agent to get access to the individual LLM tokens when using stream_log with the Agent
Executor. If False then LLM is invoked in a non-streaming fashion and Executor. If False then LLM is invoked in a non-streaming fashion and
individual LLM tokens will not be available in stream_log. individual LLM tokens will not be available in stream_log.
""" """
@ -455,11 +456,11 @@ class RunnableMultiActionAgent(BaseMultiActionAgent):
input_keys_arg: List[str] = [] input_keys_arg: List[str] = []
return_keys_arg: List[str] = [] return_keys_arg: List[str] = []
stream_runnable: bool = True stream_runnable: bool = True
"""Whether to stream from the runnable or not. """Whether to stream from the runnable or not.
If True then underlying LLM is invoked in a streaming fashion to make it possible If True then underlying LLM is invoked in a streaming fashion to make it possible
to get access to the individual LLM tokens when using stream_log with the Agent to get access to the individual LLM tokens when using stream_log with the Agent
Executor. If False then LLM is invoked in a non-streaming fashion and Executor. If False then LLM is invoked in a non-streaming fashion and
individual LLM tokens will not be available in stream_log. individual LLM tokens will not be available in stream_log.
""" """
@ -926,7 +927,7 @@ class AgentExecutor(Chain):
max_iterations: Optional[int] = 15 max_iterations: Optional[int] = 15
"""The maximum number of steps to take before ending the execution """The maximum number of steps to take before ending the execution
loop. loop.
Setting to 'None' could lead to an infinite loop.""" Setting to 'None' could lead to an infinite loop."""
max_execution_time: Optional[float] = None max_execution_time: Optional[float] = None
"""The maximum amount of wall clock time to spend in the execution """The maximum amount of wall clock time to spend in the execution
@ -938,7 +939,7 @@ class AgentExecutor(Chain):
`"force"` returns a string saying that it stopped because it met a `"force"` returns a string saying that it stopped because it met a
time or iteration limit. time or iteration limit.
`"generate"` calls the agent's LLM Chain one final time to generate `"generate"` calls the agent's LLM Chain one final time to generate
a final answer based on the previous steps. a final answer based on the previous steps.
""" """
@ -1565,6 +1566,7 @@ class AgentExecutor(Chain):
tags=config.get("tags"), tags=config.get("tags"),
metadata=config.get("metadata"), metadata=config.get("metadata"),
run_name=config.get("run_name"), run_name=config.get("run_name"),
run_id=config.get("run_id"),
yield_actions=True, yield_actions=True,
**kwargs, **kwargs,
) )
@ -1586,6 +1588,7 @@ class AgentExecutor(Chain):
tags=config.get("tags"), tags=config.get("tags"),
metadata=config.get("metadata"), metadata=config.get("metadata"),
run_name=config.get("run_name"), run_name=config.get("run_name"),
run_id=config.get("run_id"),
yield_actions=True, yield_actions=True,
**kwargs, **kwargs,
) )

View File

@ -14,6 +14,7 @@ from typing import (
Tuple, Tuple,
Union, Union,
) )
from uuid import UUID
from langchain_core.agents import ( from langchain_core.agents import (
AgentAction, AgentAction,
@ -54,6 +55,7 @@ class AgentExecutorIterator:
tags: Optional[list[str]] = None, tags: Optional[list[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None, run_name: Optional[str] = None,
run_id: Optional[UUID] = None,
include_run_info: bool = False, include_run_info: bool = False,
yield_actions: bool = False, yield_actions: bool = False,
): ):
@ -67,6 +69,7 @@ class AgentExecutorIterator:
self.tags = tags self.tags = tags
self.metadata = metadata self.metadata = metadata
self.run_name = run_name self.run_name = run_name
self.run_id = run_id
self.include_run_info = include_run_info self.include_run_info = include_run_info
self.yield_actions = yield_actions self.yield_actions = yield_actions
self.reset() self.reset()
@ -76,6 +79,7 @@ class AgentExecutorIterator:
tags: Optional[list[str]] tags: Optional[list[str]]
metadata: Optional[Dict[str, Any]] metadata: Optional[Dict[str, Any]]
run_name: Optional[str] run_name: Optional[str]
run_id: Optional[UUID]
include_run_info: bool include_run_info: bool
yield_actions: bool yield_actions: bool
@ -162,6 +166,7 @@ class AgentExecutorIterator:
run_manager = callback_manager.on_chain_start( run_manager = callback_manager.on_chain_start(
dumpd(self.agent_executor), dumpd(self.agent_executor),
self.inputs, self.inputs,
self.run_id,
name=self.run_name, name=self.run_name,
) )
try: try:
@ -227,6 +232,7 @@ class AgentExecutorIterator:
run_manager = await callback_manager.on_chain_start( run_manager = await callback_manager.on_chain_start(
dumpd(self.agent_executor), dumpd(self.agent_executor),
self.inputs, self.inputs,
self.run_id,
name=self.run_name, name=self.run_name,
) )
try: try:

View File

@ -1,4 +1,5 @@
"""Base interface that all chains should implement.""" """Base interface that all chains should implement."""
import inspect import inspect
import json import json
import logging import logging
@ -127,6 +128,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
tags = config.get("tags") tags = config.get("tags")
metadata = config.get("metadata") metadata = config.get("metadata")
run_name = config.get("run_name") or self.get_name() run_name = config.get("run_name") or self.get_name()
run_id = config.get("run_id")
include_run_info = kwargs.get("include_run_info", False) include_run_info = kwargs.get("include_run_info", False)
return_only_outputs = kwargs.get("return_only_outputs", False) return_only_outputs = kwargs.get("return_only_outputs", False)
@ -145,6 +147,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
run_manager = callback_manager.on_chain_start( run_manager = callback_manager.on_chain_start(
dumpd(self), dumpd(self),
inputs, inputs,
run_id,
name=run_name, name=run_name,
) )
try: try:
@ -178,6 +181,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
tags = config.get("tags") tags = config.get("tags")
metadata = config.get("metadata") metadata = config.get("metadata")
run_name = config.get("run_name") or self.get_name() run_name = config.get("run_name") or self.get_name()
run_id = config.get("run_id")
include_run_info = kwargs.get("include_run_info", False) include_run_info = kwargs.get("include_run_info", False)
return_only_outputs = kwargs.get("return_only_outputs", False) return_only_outputs = kwargs.get("return_only_outputs", False)
@ -195,6 +199,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
run_manager = await callback_manager.on_chain_start( run_manager = await callback_manager.on_chain_start(
dumpd(self), dumpd(self),
inputs, inputs,
run_id,
name=run_name, name=run_name,
) )
try: try:

View File

@ -3,6 +3,7 @@ from uuid import UUID
import pytest import pytest
from langchain_core.language_models import FakeListLLM from langchain_core.language_models import FakeListLLM
from langchain_core.tools import Tool from langchain_core.tools import Tool
from langchain_core.tracers.context import collect_runs
from langchain.agents import ( from langchain.agents import (
AgentExecutor, AgentExecutor,
@ -251,6 +252,28 @@ def test_agent_iterator_properties_and_setters() -> None:
assert isinstance(agent_iter.agent_executor, AgentExecutor) assert isinstance(agent_iter.agent_executor, AgentExecutor)
def test_agent_iterator_manual_run_id() -> None:
"""Test react chain iterator with manually specified run_id."""
agent = _get_agent()
run_id = UUID("f47ac10b-58cc-4372-a567-0e02b2c3d479")
with collect_runs() as cb:
agent_iter = agent.stream("when was langchain made", {"run_id": run_id})
list(agent_iter)
run = cb.traced_runs[0]
assert run.id == run_id
async def test_manually_specify_rid_async() -> None:
agent = _get_agent()
run_id = UUID("f47ac10b-58cc-4372-a567-0e02b2c3d479")
with collect_runs() as cb:
res = agent.astream("bar", {"run_id": run_id})
async for _ in res:
pass
run = cb.traced_runs[0]
assert run.id == run_id
def test_agent_iterator_reset() -> None: def test_agent_iterator_reset() -> None:
"""Test reset functionality of AgentExecutorIterator.""" """Test reset functionality of AgentExecutorIterator."""
agent = _get_agent() agent = _get_agent()

View File

@ -1,9 +1,12 @@
"""Test logic on base chain class.""" """Test logic on base chain class."""
import uuid
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import pytest import pytest
from langchain_core.callbacks.manager import CallbackManagerForChainRun from langchain_core.callbacks.manager import CallbackManagerForChainRun
from langchain_core.memory import BaseMemory from langchain_core.memory import BaseMemory
from langchain_core.tracers.context import collect_runs
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.schema import RUN_KEY from langchain.schema import RUN_KEY
@ -180,6 +183,37 @@ def test_run_with_callback_and_input_error() -> None:
assert handler.errors == 1 assert handler.errors == 1
def test_manually_specify_rid() -> None:
chain = FakeChain()
run_id = uuid.uuid4()
with collect_runs() as cb:
chain.invoke({"foo": "bar"}, {"run_id": run_id})
run = cb.traced_runs[0]
assert run.id == run_id
run_id2 = uuid.uuid4()
with collect_runs() as cb:
list(chain.stream({"foo": "bar"}, {"run_id": run_id2}))
run = cb.traced_runs[0]
assert run.id == run_id2
async def test_manually_specify_rid_async() -> None:
chain = FakeChain()
run_id = uuid.uuid4()
with collect_runs() as cb:
await chain.ainvoke({"foo": "bar"}, {"run_id": run_id})
run = cb.traced_runs[0]
assert run.id == run_id
run_id2 = uuid.uuid4()
with collect_runs() as cb:
res = chain.astream({"foo": "bar"}, {"run_id": run_id2})
async for _ in res:
pass
run = cb.traced_runs[0]
assert run.id == run_id2
def test_run_with_callback_and_output_error() -> None: def test_run_with_callback_and_output_error() -> None:
"""Test callback manager catches run validation output error.""" """Test callback manager catches run validation output error."""
handler = FakeCallbackHandler() handler = FakeCallbackHandler()