Pass through Run ID Explicitly (#21469)

This commit is contained in:
William FH 2024-05-08 22:20:51 -07:00 committed by GitHub
parent 83eecd54fe
commit b28be5d407
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 82 additions and 11 deletions

View File

@ -1,4 +1,5 @@
"""Chain that takes in an input and produces an action and action input.""" """Chain that takes in an input and produces an action and action input."""
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
@ -1565,6 +1566,7 @@ class AgentExecutor(Chain):
tags=config.get("tags"), tags=config.get("tags"),
metadata=config.get("metadata"), metadata=config.get("metadata"),
run_name=config.get("run_name"), run_name=config.get("run_name"),
run_id=config.get("run_id"),
yield_actions=True, yield_actions=True,
**kwargs, **kwargs,
) )
@ -1586,6 +1588,7 @@ class AgentExecutor(Chain):
tags=config.get("tags"), tags=config.get("tags"),
metadata=config.get("metadata"), metadata=config.get("metadata"),
run_name=config.get("run_name"), run_name=config.get("run_name"),
run_id=config.get("run_id"),
yield_actions=True, yield_actions=True,
**kwargs, **kwargs,
) )

View File

@ -14,6 +14,7 @@ from typing import (
Tuple, Tuple,
Union, Union,
) )
from uuid import UUID
from langchain_core.agents import ( from langchain_core.agents import (
AgentAction, AgentAction,
@ -54,6 +55,7 @@ class AgentExecutorIterator:
tags: Optional[list[str]] = None, tags: Optional[list[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None, run_name: Optional[str] = None,
run_id: Optional[UUID] = None,
include_run_info: bool = False, include_run_info: bool = False,
yield_actions: bool = False, yield_actions: bool = False,
): ):
@ -67,6 +69,7 @@ class AgentExecutorIterator:
self.tags = tags self.tags = tags
self.metadata = metadata self.metadata = metadata
self.run_name = run_name self.run_name = run_name
self.run_id = run_id
self.include_run_info = include_run_info self.include_run_info = include_run_info
self.yield_actions = yield_actions self.yield_actions = yield_actions
self.reset() self.reset()
@ -76,6 +79,7 @@ class AgentExecutorIterator:
tags: Optional[list[str]] tags: Optional[list[str]]
metadata: Optional[Dict[str, Any]] metadata: Optional[Dict[str, Any]]
run_name: Optional[str] run_name: Optional[str]
run_id: Optional[UUID]
include_run_info: bool include_run_info: bool
yield_actions: bool yield_actions: bool
@ -162,6 +166,7 @@ class AgentExecutorIterator:
run_manager = callback_manager.on_chain_start( run_manager = callback_manager.on_chain_start(
dumpd(self.agent_executor), dumpd(self.agent_executor),
self.inputs, self.inputs,
self.run_id,
name=self.run_name, name=self.run_name,
) )
try: try:
@ -227,6 +232,7 @@ class AgentExecutorIterator:
run_manager = await callback_manager.on_chain_start( run_manager = await callback_manager.on_chain_start(
dumpd(self.agent_executor), dumpd(self.agent_executor),
self.inputs, self.inputs,
self.run_id,
name=self.run_name, name=self.run_name,
) )
try: try:

View File

@ -1,4 +1,5 @@
"""Base interface that all chains should implement.""" """Base interface that all chains should implement."""
import inspect import inspect
import json import json
import logging import logging
@ -127,6 +128,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
tags = config.get("tags") tags = config.get("tags")
metadata = config.get("metadata") metadata = config.get("metadata")
run_name = config.get("run_name") or self.get_name() run_name = config.get("run_name") or self.get_name()
run_id = config.get("run_id")
include_run_info = kwargs.get("include_run_info", False) include_run_info = kwargs.get("include_run_info", False)
return_only_outputs = kwargs.get("return_only_outputs", False) return_only_outputs = kwargs.get("return_only_outputs", False)
@ -145,6 +147,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
run_manager = callback_manager.on_chain_start( run_manager = callback_manager.on_chain_start(
dumpd(self), dumpd(self),
inputs, inputs,
run_id,
name=run_name, name=run_name,
) )
try: try:
@ -178,6 +181,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
tags = config.get("tags") tags = config.get("tags")
metadata = config.get("metadata") metadata = config.get("metadata")
run_name = config.get("run_name") or self.get_name() run_name = config.get("run_name") or self.get_name()
run_id = config.get("run_id")
include_run_info = kwargs.get("include_run_info", False) include_run_info = kwargs.get("include_run_info", False)
return_only_outputs = kwargs.get("return_only_outputs", False) return_only_outputs = kwargs.get("return_only_outputs", False)
@ -195,6 +199,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
run_manager = await callback_manager.on_chain_start( run_manager = await callback_manager.on_chain_start(
dumpd(self), dumpd(self),
inputs, inputs,
run_id,
name=run_name, name=run_name,
) )
try: try:

View File

@ -3,6 +3,7 @@ from uuid import UUID
import pytest import pytest
from langchain_core.language_models import FakeListLLM from langchain_core.language_models import FakeListLLM
from langchain_core.tools import Tool from langchain_core.tools import Tool
from langchain_core.tracers.context import collect_runs
from langchain.agents import ( from langchain.agents import (
AgentExecutor, AgentExecutor,
@ -251,6 +252,28 @@ def test_agent_iterator_properties_and_setters() -> None:
assert isinstance(agent_iter.agent_executor, AgentExecutor) assert isinstance(agent_iter.agent_executor, AgentExecutor)
def test_agent_iterator_manual_run_id() -> None:
"""Test react chain iterator with manually specified run_id."""
agent = _get_agent()
run_id = UUID("f47ac10b-58cc-4372-a567-0e02b2c3d479")
with collect_runs() as cb:
agent_iter = agent.stream("when was langchain made", {"run_id": run_id})
list(agent_iter)
run = cb.traced_runs[0]
assert run.id == run_id
async def test_manually_specify_rid_async() -> None:
agent = _get_agent()
run_id = UUID("f47ac10b-58cc-4372-a567-0e02b2c3d479")
with collect_runs() as cb:
res = agent.astream("bar", {"run_id": run_id})
async for _ in res:
pass
run = cb.traced_runs[0]
assert run.id == run_id
def test_agent_iterator_reset() -> None: def test_agent_iterator_reset() -> None:
"""Test reset functionality of AgentExecutorIterator.""" """Test reset functionality of AgentExecutorIterator."""
agent = _get_agent() agent = _get_agent()

View File

@ -1,9 +1,12 @@
"""Test logic on base chain class.""" """Test logic on base chain class."""
import uuid
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import pytest import pytest
from langchain_core.callbacks.manager import CallbackManagerForChainRun from langchain_core.callbacks.manager import CallbackManagerForChainRun
from langchain_core.memory import BaseMemory from langchain_core.memory import BaseMemory
from langchain_core.tracers.context import collect_runs
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.schema import RUN_KEY from langchain.schema import RUN_KEY
@ -180,6 +183,37 @@ def test_run_with_callback_and_input_error() -> None:
assert handler.errors == 1 assert handler.errors == 1
def test_manually_specify_rid() -> None:
chain = FakeChain()
run_id = uuid.uuid4()
with collect_runs() as cb:
chain.invoke({"foo": "bar"}, {"run_id": run_id})
run = cb.traced_runs[0]
assert run.id == run_id
run_id2 = uuid.uuid4()
with collect_runs() as cb:
list(chain.stream({"foo": "bar"}, {"run_id": run_id2}))
run = cb.traced_runs[0]
assert run.id == run_id2
async def test_manually_specify_rid_async() -> None:
chain = FakeChain()
run_id = uuid.uuid4()
with collect_runs() as cb:
await chain.ainvoke({"foo": "bar"}, {"run_id": run_id})
run = cb.traced_runs[0]
assert run.id == run_id
run_id2 = uuid.uuid4()
with collect_runs() as cb:
res = chain.astream({"foo": "bar"}, {"run_id": run_id2})
async for _ in res:
pass
run = cb.traced_runs[0]
assert run.id == run_id2
def test_run_with_callback_and_output_error() -> None: def test_run_with_callback_and_output_error() -> None:
"""Test callback manager catches run validation output error.""" """Test callback manager catches run validation output error."""
handler = FakeCallbackHandler() handler = FakeCallbackHandler()