mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 15:43:54 +00:00
Track RunnableAssign as a separate run trace (#13972)
Addressing incorrect order being sent to callbacks / tracers, due to the nature of threading --------- Co-authored-by: Nuno Campos <nuno@boringbits.io>
This commit is contained in:
parent
0f255bb6c4
commit
eb67f07e32
@ -5,6 +5,7 @@ import asyncio
|
|||||||
import inspect
|
import inspect
|
||||||
import threading
|
import threading
|
||||||
from typing import (
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
AsyncIterator,
|
AsyncIterator,
|
||||||
Awaitable,
|
Awaitable,
|
||||||
@ -31,11 +32,18 @@ from langchain_core.runnables.config import (
|
|||||||
acall_func_with_variable_args,
|
acall_func_with_variable_args,
|
||||||
call_func_with_variable_args,
|
call_func_with_variable_args,
|
||||||
get_executor_for_config,
|
get_executor_for_config,
|
||||||
|
patch_config,
|
||||||
)
|
)
|
||||||
from langchain_core.runnables.utils import AddableDict, ConfigurableFieldSpec
|
from langchain_core.runnables.utils import AddableDict, ConfigurableFieldSpec
|
||||||
from langchain_core.utils.aiter import atee, py_anext
|
from langchain_core.utils.aiter import atee, py_anext
|
||||||
from langchain_core.utils.iter import safetee
|
from langchain_core.utils.iter import safetee
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from langchain_core.callbacks.manager import (
|
||||||
|
AsyncCallbackManagerForChainRun,
|
||||||
|
CallbackManagerForChainRun,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def identity(x: Other) -> Other:
|
def identity(x: Other) -> Other:
|
||||||
"""An identity function"""
|
"""An identity function"""
|
||||||
@ -345,18 +353,52 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
|||||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||||
return self.mapper.config_specs
|
return self.mapper.config_specs
|
||||||
|
|
||||||
|
def _invoke(
|
||||||
|
self,
|
||||||
|
input: Dict[str, Any],
|
||||||
|
run_manager: CallbackManagerForChainRun,
|
||||||
|
config: RunnableConfig,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
assert isinstance(
|
||||||
|
input, dict
|
||||||
|
), "The input to RunnablePassthrough.assign() must be a dict."
|
||||||
|
|
||||||
|
return {
|
||||||
|
**input,
|
||||||
|
**self.mapper.invoke(
|
||||||
|
input,
|
||||||
|
patch_config(config, callbacks=run_manager.get_child()),
|
||||||
|
**kwargs,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
def invoke(
|
def invoke(
|
||||||
self,
|
self,
|
||||||
input: Dict[str, Any],
|
input: Dict[str, Any],
|
||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
return self._call_with_config(self._invoke, input, config, **kwargs)
|
||||||
|
|
||||||
|
async def _ainvoke(
|
||||||
|
self,
|
||||||
|
input: Dict[str, Any],
|
||||||
|
run_manager: AsyncCallbackManagerForChainRun,
|
||||||
|
config: RunnableConfig,
|
||||||
|
**kwargs: Any,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
input, dict
|
input, dict
|
||||||
), "The input to RunnablePassthrough.assign() must be a dict."
|
), "The input to RunnablePassthrough.assign() must be a dict."
|
||||||
|
|
||||||
return {
|
return {
|
||||||
**input,
|
**input,
|
||||||
**self.mapper.invoke(input, config, **kwargs),
|
**await self.mapper.ainvoke(
|
||||||
|
input,
|
||||||
|
patch_config(config, callbacks=run_manager.get_child()),
|
||||||
|
**kwargs,
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
async def ainvoke(
|
async def ainvoke(
|
||||||
@ -365,26 +407,30 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
|||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
assert isinstance(
|
return await self._acall_with_config(self._ainvoke, input, config, **kwargs)
|
||||||
input, dict
|
|
||||||
), "The input to RunnablePassthrough.assign() must be a dict."
|
|
||||||
return {
|
|
||||||
**input,
|
|
||||||
**await self.mapper.ainvoke(input, config, **kwargs),
|
|
||||||
}
|
|
||||||
|
|
||||||
def transform(
|
def _transform(
|
||||||
self,
|
self,
|
||||||
input: Iterator[Dict[str, Any]],
|
input: Iterator[Dict[str, Any]],
|
||||||
config: Optional[RunnableConfig] = None,
|
run_manager: CallbackManagerForChainRun,
|
||||||
|
config: RunnableConfig,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Iterator[Dict[str, Any]]:
|
) -> Iterator[Dict[str, Any]]:
|
||||||
# collect mapper keys
|
# collect mapper keys
|
||||||
mapper_keys = set(self.mapper.steps.keys())
|
mapper_keys = set(self.mapper.steps.keys())
|
||||||
# create two streams, one for the map and one for the passthrough
|
# create two streams, one for the map and one for the passthrough
|
||||||
for_passthrough, for_map = safetee(input, 2, lock=threading.Lock())
|
for_passthrough, for_map = safetee(input, 2, lock=threading.Lock())
|
||||||
|
|
||||||
# create map output stream
|
# create map output stream
|
||||||
map_output = self.mapper.transform(for_map, config, **kwargs)
|
map_output = self.mapper.transform(
|
||||||
|
for_map,
|
||||||
|
patch_config(
|
||||||
|
config,
|
||||||
|
callbacks=run_manager.get_child(),
|
||||||
|
),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
# get executor to start map output stream in background
|
# get executor to start map output stream in background
|
||||||
with get_executor_for_config(config or {}) as executor:
|
with get_executor_for_config(config or {}) as executor:
|
||||||
# start map output stream
|
# start map output stream
|
||||||
@ -409,10 +455,21 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
|||||||
for chunk in map_output:
|
for chunk in map_output:
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def atransform(
|
def transform(
|
||||||
|
self,
|
||||||
|
input: Iterator[Dict[str, Any]],
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
**kwargs: Any | None,
|
||||||
|
) -> Iterator[Dict[str, Any]]:
|
||||||
|
yield from self._transform_stream_with_config(
|
||||||
|
input, self._transform, config, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _atransform(
|
||||||
self,
|
self,
|
||||||
input: AsyncIterator[Dict[str, Any]],
|
input: AsyncIterator[Dict[str, Any]],
|
||||||
config: Optional[RunnableConfig] = None,
|
run_manager: AsyncCallbackManagerForChainRun,
|
||||||
|
config: RunnableConfig,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AsyncIterator[Dict[str, Any]]:
|
) -> AsyncIterator[Dict[str, Any]]:
|
||||||
# collect mapper keys
|
# collect mapper keys
|
||||||
@ -420,7 +477,14 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
|||||||
# create two streams, one for the map and one for the passthrough
|
# create two streams, one for the map and one for the passthrough
|
||||||
for_passthrough, for_map = atee(input, 2, lock=asyncio.Lock())
|
for_passthrough, for_map = atee(input, 2, lock=asyncio.Lock())
|
||||||
# create map output stream
|
# create map output stream
|
||||||
map_output = self.mapper.atransform(for_map, config, **kwargs)
|
map_output = self.mapper.atransform(
|
||||||
|
for_map,
|
||||||
|
patch_config(
|
||||||
|
config,
|
||||||
|
callbacks=run_manager.get_child(),
|
||||||
|
),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
# start map output stream
|
# start map output stream
|
||||||
first_map_chunk_task: asyncio.Task = asyncio.create_task(
|
first_map_chunk_task: asyncio.Task = asyncio.create_task(
|
||||||
py_anext(map_output, None), # type: ignore[arg-type]
|
py_anext(map_output, None), # type: ignore[arg-type]
|
||||||
@ -441,6 +505,17 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
|||||||
async for chunk in map_output:
|
async for chunk in map_output:
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
async def atransform(
|
||||||
|
self,
|
||||||
|
input: AsyncIterator[Dict[str, Any]],
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AsyncIterator[Dict[str, Any]]:
|
||||||
|
async for chunk in self._atransform_stream_with_config(
|
||||||
|
input, self._atransform, config, **kwargs
|
||||||
|
):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
def stream(
|
def stream(
|
||||||
self,
|
self,
|
||||||
input: Dict[str, Any],
|
input: Dict[str, Any],
|
||||||
|
@ -4146,3 +4146,44 @@ async def test_ainvoke_on_returned_runnable() -> None:
|
|||||||
return idchain
|
return idchain
|
||||||
|
|
||||||
assert await RunnableLambda(func).ainvoke({})
|
assert await RunnableLambda(func).ainvoke({})
|
||||||
|
|
||||||
|
|
||||||
|
def test_invoke_stream_passthrough_assign_trace() -> None:
|
||||||
|
def idchain_sync(__input: dict) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
chain = RunnablePassthrough.assign(urls=idchain_sync)
|
||||||
|
|
||||||
|
tracer = FakeTracer()
|
||||||
|
chain.invoke({"example": [1, 2, 3]}, dict(callbacks=[tracer]))
|
||||||
|
|
||||||
|
assert tracer.runs[0].name == "RunnableAssign"
|
||||||
|
assert tracer.runs[0].child_runs[0].name == "RunnableParallel"
|
||||||
|
|
||||||
|
tracer = FakeTracer()
|
||||||
|
for item in chain.stream({"example": [1, 2, 3]}, dict(callbacks=[tracer])):
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert tracer.runs[0].name == "RunnableAssign"
|
||||||
|
assert tracer.runs[0].child_runs[0].name == "RunnableParallel"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ainvoke_astream_passthrough_assign_trace() -> None:
|
||||||
|
def idchain_sync(__input: dict) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
chain = RunnablePassthrough.assign(urls=idchain_sync)
|
||||||
|
|
||||||
|
tracer = FakeTracer()
|
||||||
|
await chain.ainvoke({"example": [1, 2, 3]}, dict(callbacks=[tracer]))
|
||||||
|
|
||||||
|
assert tracer.runs[0].name == "RunnableAssign"
|
||||||
|
assert tracer.runs[0].child_runs[0].name == "RunnableParallel"
|
||||||
|
|
||||||
|
tracer = FakeTracer()
|
||||||
|
async for item in chain.astream({"example": [1, 2, 3]}, dict(callbacks=[tracer])):
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert tracer.runs[0].name == "RunnableAssign"
|
||||||
|
assert tracer.runs[0].child_runs[0].name == "RunnableParallel"
|
||||||
|
Loading…
Reference in New Issue
Block a user