mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 16:43:35 +00:00
Pass through Run ID Explicitly (#21469)
This commit is contained in:
parent
83eecd54fe
commit
b28be5d407
@ -1,4 +1,5 @@
|
|||||||
"""Chain that takes in an input and produces an action and action input."""
|
"""Chain that takes in an input and produces an action and action input."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
@ -346,11 +347,11 @@ class RunnableAgent(BaseSingleActionAgent):
|
|||||||
input_keys_arg: List[str] = []
|
input_keys_arg: List[str] = []
|
||||||
return_keys_arg: List[str] = []
|
return_keys_arg: List[str] = []
|
||||||
stream_runnable: bool = True
|
stream_runnable: bool = True
|
||||||
"""Whether to stream from the runnable or not.
|
"""Whether to stream from the runnable or not.
|
||||||
|
|
||||||
If True then underlying LLM is invoked in a streaming fashion to make it possible
|
If True then underlying LLM is invoked in a streaming fashion to make it possible
|
||||||
to get access to the individual LLM tokens when using stream_log with the Agent
|
to get access to the individual LLM tokens when using stream_log with the Agent
|
||||||
Executor. If False then LLM is invoked in a non-streaming fashion and
|
Executor. If False then LLM is invoked in a non-streaming fashion and
|
||||||
individual LLM tokens will not be available in stream_log.
|
individual LLM tokens will not be available in stream_log.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -455,11 +456,11 @@ class RunnableMultiActionAgent(BaseMultiActionAgent):
|
|||||||
input_keys_arg: List[str] = []
|
input_keys_arg: List[str] = []
|
||||||
return_keys_arg: List[str] = []
|
return_keys_arg: List[str] = []
|
||||||
stream_runnable: bool = True
|
stream_runnable: bool = True
|
||||||
"""Whether to stream from the runnable or not.
|
"""Whether to stream from the runnable or not.
|
||||||
|
|
||||||
If True then underlying LLM is invoked in a streaming fashion to make it possible
|
If True then underlying LLM is invoked in a streaming fashion to make it possible
|
||||||
to get access to the individual LLM tokens when using stream_log with the Agent
|
to get access to the individual LLM tokens when using stream_log with the Agent
|
||||||
Executor. If False then LLM is invoked in a non-streaming fashion and
|
Executor. If False then LLM is invoked in a non-streaming fashion and
|
||||||
individual LLM tokens will not be available in stream_log.
|
individual LLM tokens will not be available in stream_log.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -926,7 +927,7 @@ class AgentExecutor(Chain):
|
|||||||
max_iterations: Optional[int] = 15
|
max_iterations: Optional[int] = 15
|
||||||
"""The maximum number of steps to take before ending the execution
|
"""The maximum number of steps to take before ending the execution
|
||||||
loop.
|
loop.
|
||||||
|
|
||||||
Setting to 'None' could lead to an infinite loop."""
|
Setting to 'None' could lead to an infinite loop."""
|
||||||
max_execution_time: Optional[float] = None
|
max_execution_time: Optional[float] = None
|
||||||
"""The maximum amount of wall clock time to spend in the execution
|
"""The maximum amount of wall clock time to spend in the execution
|
||||||
@ -938,7 +939,7 @@ class AgentExecutor(Chain):
|
|||||||
|
|
||||||
`"force"` returns a string saying that it stopped because it met a
|
`"force"` returns a string saying that it stopped because it met a
|
||||||
time or iteration limit.
|
time or iteration limit.
|
||||||
|
|
||||||
`"generate"` calls the agent's LLM Chain one final time to generate
|
`"generate"` calls the agent's LLM Chain one final time to generate
|
||||||
a final answer based on the previous steps.
|
a final answer based on the previous steps.
|
||||||
"""
|
"""
|
||||||
@ -1565,6 +1566,7 @@ class AgentExecutor(Chain):
|
|||||||
tags=config.get("tags"),
|
tags=config.get("tags"),
|
||||||
metadata=config.get("metadata"),
|
metadata=config.get("metadata"),
|
||||||
run_name=config.get("run_name"),
|
run_name=config.get("run_name"),
|
||||||
|
run_id=config.get("run_id"),
|
||||||
yield_actions=True,
|
yield_actions=True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
@ -1586,6 +1588,7 @@ class AgentExecutor(Chain):
|
|||||||
tags=config.get("tags"),
|
tags=config.get("tags"),
|
||||||
metadata=config.get("metadata"),
|
metadata=config.get("metadata"),
|
||||||
run_name=config.get("run_name"),
|
run_name=config.get("run_name"),
|
||||||
|
run_id=config.get("run_id"),
|
||||||
yield_actions=True,
|
yield_actions=True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
@ -14,6 +14,7 @@ from typing import (
|
|||||||
Tuple,
|
Tuple,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
from langchain_core.agents import (
|
from langchain_core.agents import (
|
||||||
AgentAction,
|
AgentAction,
|
||||||
@ -54,6 +55,7 @@ class AgentExecutorIterator:
|
|||||||
tags: Optional[list[str]] = None,
|
tags: Optional[list[str]] = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
run_name: Optional[str] = None,
|
run_name: Optional[str] = None,
|
||||||
|
run_id: Optional[UUID] = None,
|
||||||
include_run_info: bool = False,
|
include_run_info: bool = False,
|
||||||
yield_actions: bool = False,
|
yield_actions: bool = False,
|
||||||
):
|
):
|
||||||
@ -67,6 +69,7 @@ class AgentExecutorIterator:
|
|||||||
self.tags = tags
|
self.tags = tags
|
||||||
self.metadata = metadata
|
self.metadata = metadata
|
||||||
self.run_name = run_name
|
self.run_name = run_name
|
||||||
|
self.run_id = run_id
|
||||||
self.include_run_info = include_run_info
|
self.include_run_info = include_run_info
|
||||||
self.yield_actions = yield_actions
|
self.yield_actions = yield_actions
|
||||||
self.reset()
|
self.reset()
|
||||||
@ -76,6 +79,7 @@ class AgentExecutorIterator:
|
|||||||
tags: Optional[list[str]]
|
tags: Optional[list[str]]
|
||||||
metadata: Optional[Dict[str, Any]]
|
metadata: Optional[Dict[str, Any]]
|
||||||
run_name: Optional[str]
|
run_name: Optional[str]
|
||||||
|
run_id: Optional[UUID]
|
||||||
include_run_info: bool
|
include_run_info: bool
|
||||||
yield_actions: bool
|
yield_actions: bool
|
||||||
|
|
||||||
@ -162,6 +166,7 @@ class AgentExecutorIterator:
|
|||||||
run_manager = callback_manager.on_chain_start(
|
run_manager = callback_manager.on_chain_start(
|
||||||
dumpd(self.agent_executor),
|
dumpd(self.agent_executor),
|
||||||
self.inputs,
|
self.inputs,
|
||||||
|
self.run_id,
|
||||||
name=self.run_name,
|
name=self.run_name,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
@ -227,6 +232,7 @@ class AgentExecutorIterator:
|
|||||||
run_manager = await callback_manager.on_chain_start(
|
run_manager = await callback_manager.on_chain_start(
|
||||||
dumpd(self.agent_executor),
|
dumpd(self.agent_executor),
|
||||||
self.inputs,
|
self.inputs,
|
||||||
|
self.run_id,
|
||||||
name=self.run_name,
|
name=self.run_name,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
"""Base interface that all chains should implement."""
|
"""Base interface that all chains should implement."""
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
@ -127,6 +128,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
tags = config.get("tags")
|
tags = config.get("tags")
|
||||||
metadata = config.get("metadata")
|
metadata = config.get("metadata")
|
||||||
run_name = config.get("run_name") or self.get_name()
|
run_name = config.get("run_name") or self.get_name()
|
||||||
|
run_id = config.get("run_id")
|
||||||
include_run_info = kwargs.get("include_run_info", False)
|
include_run_info = kwargs.get("include_run_info", False)
|
||||||
return_only_outputs = kwargs.get("return_only_outputs", False)
|
return_only_outputs = kwargs.get("return_only_outputs", False)
|
||||||
|
|
||||||
@ -145,6 +147,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
run_manager = callback_manager.on_chain_start(
|
run_manager = callback_manager.on_chain_start(
|
||||||
dumpd(self),
|
dumpd(self),
|
||||||
inputs,
|
inputs,
|
||||||
|
run_id,
|
||||||
name=run_name,
|
name=run_name,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
@ -178,6 +181,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
tags = config.get("tags")
|
tags = config.get("tags")
|
||||||
metadata = config.get("metadata")
|
metadata = config.get("metadata")
|
||||||
run_name = config.get("run_name") or self.get_name()
|
run_name = config.get("run_name") or self.get_name()
|
||||||
|
run_id = config.get("run_id")
|
||||||
include_run_info = kwargs.get("include_run_info", False)
|
include_run_info = kwargs.get("include_run_info", False)
|
||||||
return_only_outputs = kwargs.get("return_only_outputs", False)
|
return_only_outputs = kwargs.get("return_only_outputs", False)
|
||||||
|
|
||||||
@ -195,6 +199,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
run_manager = await callback_manager.on_chain_start(
|
run_manager = await callback_manager.on_chain_start(
|
||||||
dumpd(self),
|
dumpd(self),
|
||||||
inputs,
|
inputs,
|
||||||
|
run_id,
|
||||||
name=run_name,
|
name=run_name,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
|
@ -3,6 +3,7 @@ from uuid import UUID
|
|||||||
import pytest
|
import pytest
|
||||||
from langchain_core.language_models import FakeListLLM
|
from langchain_core.language_models import FakeListLLM
|
||||||
from langchain_core.tools import Tool
|
from langchain_core.tools import Tool
|
||||||
|
from langchain_core.tracers.context import collect_runs
|
||||||
|
|
||||||
from langchain.agents import (
|
from langchain.agents import (
|
||||||
AgentExecutor,
|
AgentExecutor,
|
||||||
@ -251,6 +252,28 @@ def test_agent_iterator_properties_and_setters() -> None:
|
|||||||
assert isinstance(agent_iter.agent_executor, AgentExecutor)
|
assert isinstance(agent_iter.agent_executor, AgentExecutor)
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_iterator_manual_run_id() -> None:
|
||||||
|
"""Test react chain iterator with manually specified run_id."""
|
||||||
|
agent = _get_agent()
|
||||||
|
run_id = UUID("f47ac10b-58cc-4372-a567-0e02b2c3d479")
|
||||||
|
with collect_runs() as cb:
|
||||||
|
agent_iter = agent.stream("when was langchain made", {"run_id": run_id})
|
||||||
|
list(agent_iter)
|
||||||
|
run = cb.traced_runs[0]
|
||||||
|
assert run.id == run_id
|
||||||
|
|
||||||
|
|
||||||
|
async def test_manually_specify_rid_async() -> None:
|
||||||
|
agent = _get_agent()
|
||||||
|
run_id = UUID("f47ac10b-58cc-4372-a567-0e02b2c3d479")
|
||||||
|
with collect_runs() as cb:
|
||||||
|
res = agent.astream("bar", {"run_id": run_id})
|
||||||
|
async for _ in res:
|
||||||
|
pass
|
||||||
|
run = cb.traced_runs[0]
|
||||||
|
assert run.id == run_id
|
||||||
|
|
||||||
|
|
||||||
def test_agent_iterator_reset() -> None:
|
def test_agent_iterator_reset() -> None:
|
||||||
"""Test reset functionality of AgentExecutorIterator."""
|
"""Test reset functionality of AgentExecutorIterator."""
|
||||||
agent = _get_agent()
|
agent = _get_agent()
|
||||||
|
@ -1,9 +1,12 @@
|
|||||||
"""Test logic on base chain class."""
|
"""Test logic on base chain class."""
|
||||||
|
|
||||||
|
import uuid
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from langchain_core.callbacks.manager import CallbackManagerForChainRun
|
from langchain_core.callbacks.manager import CallbackManagerForChainRun
|
||||||
from langchain_core.memory import BaseMemory
|
from langchain_core.memory import BaseMemory
|
||||||
|
from langchain_core.tracers.context import collect_runs
|
||||||
|
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.schema import RUN_KEY
|
from langchain.schema import RUN_KEY
|
||||||
@ -180,6 +183,37 @@ def test_run_with_callback_and_input_error() -> None:
|
|||||||
assert handler.errors == 1
|
assert handler.errors == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_manually_specify_rid() -> None:
|
||||||
|
chain = FakeChain()
|
||||||
|
run_id = uuid.uuid4()
|
||||||
|
with collect_runs() as cb:
|
||||||
|
chain.invoke({"foo": "bar"}, {"run_id": run_id})
|
||||||
|
run = cb.traced_runs[0]
|
||||||
|
assert run.id == run_id
|
||||||
|
|
||||||
|
run_id2 = uuid.uuid4()
|
||||||
|
with collect_runs() as cb:
|
||||||
|
list(chain.stream({"foo": "bar"}, {"run_id": run_id2}))
|
||||||
|
run = cb.traced_runs[0]
|
||||||
|
assert run.id == run_id2
|
||||||
|
|
||||||
|
|
||||||
|
async def test_manually_specify_rid_async() -> None:
|
||||||
|
chain = FakeChain()
|
||||||
|
run_id = uuid.uuid4()
|
||||||
|
with collect_runs() as cb:
|
||||||
|
await chain.ainvoke({"foo": "bar"}, {"run_id": run_id})
|
||||||
|
run = cb.traced_runs[0]
|
||||||
|
assert run.id == run_id
|
||||||
|
run_id2 = uuid.uuid4()
|
||||||
|
with collect_runs() as cb:
|
||||||
|
res = chain.astream({"foo": "bar"}, {"run_id": run_id2})
|
||||||
|
async for _ in res:
|
||||||
|
pass
|
||||||
|
run = cb.traced_runs[0]
|
||||||
|
assert run.id == run_id2
|
||||||
|
|
||||||
|
|
||||||
def test_run_with_callback_and_output_error() -> None:
|
def test_run_with_callback_and_output_error() -> None:
|
||||||
"""Test callback manager catches run validation output error."""
|
"""Test callback manager catches run validation output error."""
|
||||||
handler = FakeCallbackHandler()
|
handler = FakeCallbackHandler()
|
||||||
|
Loading…
Reference in New Issue
Block a user