mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 08:33:49 +00:00
Pass through Run ID Explicitly (#21469)
This commit is contained in:
parent
83eecd54fe
commit
b28be5d407
@ -1,4 +1,5 @@
|
||||
"""Chain that takes in an input and produces an action and action input."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
@ -346,11 +347,11 @@ class RunnableAgent(BaseSingleActionAgent):
|
||||
input_keys_arg: List[str] = []
|
||||
return_keys_arg: List[str] = []
|
||||
stream_runnable: bool = True
|
||||
"""Whether to stream from the runnable or not.
|
||||
"""Whether to stream from the runnable or not.
|
||||
|
||||
If True then underlying LLM is invoked in a streaming fashion to make it possible
|
||||
to get access to the individual LLM tokens when using stream_log with the Agent
|
||||
Executor. If False then LLM is invoked in a non-streaming fashion and
|
||||
If True then underlying LLM is invoked in a streaming fashion to make it possible
|
||||
to get access to the individual LLM tokens when using stream_log with the Agent
|
||||
Executor. If False then LLM is invoked in a non-streaming fashion and
|
||||
individual LLM tokens will not be available in stream_log.
|
||||
"""
|
||||
|
||||
@ -455,11 +456,11 @@ class RunnableMultiActionAgent(BaseMultiActionAgent):
|
||||
input_keys_arg: List[str] = []
|
||||
return_keys_arg: List[str] = []
|
||||
stream_runnable: bool = True
|
||||
"""Whether to stream from the runnable or not.
|
||||
|
||||
If True then underlying LLM is invoked in a streaming fashion to make it possible
|
||||
to get access to the individual LLM tokens when using stream_log with the Agent
|
||||
Executor. If False then LLM is invoked in a non-streaming fashion and
|
||||
"""Whether to stream from the runnable or not.
|
||||
|
||||
If True then underlying LLM is invoked in a streaming fashion to make it possible
|
||||
to get access to the individual LLM tokens when using stream_log with the Agent
|
||||
Executor. If False then LLM is invoked in a non-streaming fashion and
|
||||
individual LLM tokens will not be available in stream_log.
|
||||
"""
|
||||
|
||||
@ -926,7 +927,7 @@ class AgentExecutor(Chain):
|
||||
max_iterations: Optional[int] = 15
|
||||
"""The maximum number of steps to take before ending the execution
|
||||
loop.
|
||||
|
||||
|
||||
Setting to 'None' could lead to an infinite loop."""
|
||||
max_execution_time: Optional[float] = None
|
||||
"""The maximum amount of wall clock time to spend in the execution
|
||||
@ -938,7 +939,7 @@ class AgentExecutor(Chain):
|
||||
|
||||
`"force"` returns a string saying that it stopped because it met a
|
||||
time or iteration limit.
|
||||
|
||||
|
||||
`"generate"` calls the agent's LLM Chain one final time to generate
|
||||
a final answer based on the previous steps.
|
||||
"""
|
||||
@ -1565,6 +1566,7 @@ class AgentExecutor(Chain):
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
run_name=config.get("run_name"),
|
||||
run_id=config.get("run_id"),
|
||||
yield_actions=True,
|
||||
**kwargs,
|
||||
)
|
||||
@ -1586,6 +1588,7 @@ class AgentExecutor(Chain):
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
run_name=config.get("run_name"),
|
||||
run_id=config.get("run_id"),
|
||||
yield_actions=True,
|
||||
**kwargs,
|
||||
)
|
||||
|
@ -14,6 +14,7 @@ from typing import (
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.agents import (
|
||||
AgentAction,
|
||||
@ -54,6 +55,7 @@ class AgentExecutorIterator:
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
run_name: Optional[str] = None,
|
||||
run_id: Optional[UUID] = None,
|
||||
include_run_info: bool = False,
|
||||
yield_actions: bool = False,
|
||||
):
|
||||
@ -67,6 +69,7 @@ class AgentExecutorIterator:
|
||||
self.tags = tags
|
||||
self.metadata = metadata
|
||||
self.run_name = run_name
|
||||
self.run_id = run_id
|
||||
self.include_run_info = include_run_info
|
||||
self.yield_actions = yield_actions
|
||||
self.reset()
|
||||
@ -76,6 +79,7 @@ class AgentExecutorIterator:
|
||||
tags: Optional[list[str]]
|
||||
metadata: Optional[Dict[str, Any]]
|
||||
run_name: Optional[str]
|
||||
run_id: Optional[UUID]
|
||||
include_run_info: bool
|
||||
yield_actions: bool
|
||||
|
||||
@ -162,6 +166,7 @@ class AgentExecutorIterator:
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self.agent_executor),
|
||||
self.inputs,
|
||||
self.run_id,
|
||||
name=self.run_name,
|
||||
)
|
||||
try:
|
||||
@ -227,6 +232,7 @@ class AgentExecutorIterator:
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self.agent_executor),
|
||||
self.inputs,
|
||||
self.run_id,
|
||||
name=self.run_name,
|
||||
)
|
||||
try:
|
||||
|
@ -1,4 +1,5 @@
|
||||
"""Base interface that all chains should implement."""
|
||||
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
@ -127,6 +128,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
tags = config.get("tags")
|
||||
metadata = config.get("metadata")
|
||||
run_name = config.get("run_name") or self.get_name()
|
||||
run_id = config.get("run_id")
|
||||
include_run_info = kwargs.get("include_run_info", False)
|
||||
return_only_outputs = kwargs.get("return_only_outputs", False)
|
||||
|
||||
@ -145,6 +147,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
inputs,
|
||||
run_id,
|
||||
name=run_name,
|
||||
)
|
||||
try:
|
||||
@ -178,6 +181,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
tags = config.get("tags")
|
||||
metadata = config.get("metadata")
|
||||
run_name = config.get("run_name") or self.get_name()
|
||||
run_id = config.get("run_id")
|
||||
include_run_info = kwargs.get("include_run_info", False)
|
||||
return_only_outputs = kwargs.get("return_only_outputs", False)
|
||||
|
||||
@ -195,6 +199,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
inputs,
|
||||
run_id,
|
||||
name=run_name,
|
||||
)
|
||||
try:
|
||||
|
@ -3,6 +3,7 @@ from uuid import UUID
|
||||
import pytest
|
||||
from langchain_core.language_models import FakeListLLM
|
||||
from langchain_core.tools import Tool
|
||||
from langchain_core.tracers.context import collect_runs
|
||||
|
||||
from langchain.agents import (
|
||||
AgentExecutor,
|
||||
@ -251,6 +252,28 @@ def test_agent_iterator_properties_and_setters() -> None:
|
||||
assert isinstance(agent_iter.agent_executor, AgentExecutor)
|
||||
|
||||
|
||||
def test_agent_iterator_manual_run_id() -> None:
|
||||
"""Test react chain iterator with manually specified run_id."""
|
||||
agent = _get_agent()
|
||||
run_id = UUID("f47ac10b-58cc-4372-a567-0e02b2c3d479")
|
||||
with collect_runs() as cb:
|
||||
agent_iter = agent.stream("when was langchain made", {"run_id": run_id})
|
||||
list(agent_iter)
|
||||
run = cb.traced_runs[0]
|
||||
assert run.id == run_id
|
||||
|
||||
|
||||
async def test_manually_specify_rid_async() -> None:
|
||||
agent = _get_agent()
|
||||
run_id = UUID("f47ac10b-58cc-4372-a567-0e02b2c3d479")
|
||||
with collect_runs() as cb:
|
||||
res = agent.astream("bar", {"run_id": run_id})
|
||||
async for _ in res:
|
||||
pass
|
||||
run = cb.traced_runs[0]
|
||||
assert run.id == run_id
|
||||
|
||||
|
||||
def test_agent_iterator_reset() -> None:
|
||||
"""Test reset functionality of AgentExecutorIterator."""
|
||||
agent = _get_agent()
|
||||
|
@ -1,9 +1,12 @@
|
||||
"""Test logic on base chain class."""
|
||||
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import pytest
|
||||
from langchain_core.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain_core.memory import BaseMemory
|
||||
from langchain_core.tracers.context import collect_runs
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.schema import RUN_KEY
|
||||
@ -180,6 +183,37 @@ def test_run_with_callback_and_input_error() -> None:
|
||||
assert handler.errors == 1
|
||||
|
||||
|
||||
def test_manually_specify_rid() -> None:
|
||||
chain = FakeChain()
|
||||
run_id = uuid.uuid4()
|
||||
with collect_runs() as cb:
|
||||
chain.invoke({"foo": "bar"}, {"run_id": run_id})
|
||||
run = cb.traced_runs[0]
|
||||
assert run.id == run_id
|
||||
|
||||
run_id2 = uuid.uuid4()
|
||||
with collect_runs() as cb:
|
||||
list(chain.stream({"foo": "bar"}, {"run_id": run_id2}))
|
||||
run = cb.traced_runs[0]
|
||||
assert run.id == run_id2
|
||||
|
||||
|
||||
async def test_manually_specify_rid_async() -> None:
|
||||
chain = FakeChain()
|
||||
run_id = uuid.uuid4()
|
||||
with collect_runs() as cb:
|
||||
await chain.ainvoke({"foo": "bar"}, {"run_id": run_id})
|
||||
run = cb.traced_runs[0]
|
||||
assert run.id == run_id
|
||||
run_id2 = uuid.uuid4()
|
||||
with collect_runs() as cb:
|
||||
res = chain.astream({"foo": "bar"}, {"run_id": run_id2})
|
||||
async for _ in res:
|
||||
pass
|
||||
run = cb.traced_runs[0]
|
||||
assert run.id == run_id2
|
||||
|
||||
|
||||
def test_run_with_callback_and_output_error() -> None:
|
||||
"""Test callback manager catches run validation output error."""
|
||||
handler = FakeCallbackHandler()
|
||||
|
Loading…
Reference in New Issue
Block a user