mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 05:13:46 +00:00
Unset context after step (#30378)
While we are already careful to copy before setting the config, if other objects hold a reference to the config or context, it wouldn't be cleared.
This commit is contained in:
parent
37190881d3
commit
4130e6476b
@ -17,7 +17,6 @@ from collections.abc import (
|
|||||||
Sequence,
|
Sequence,
|
||||||
)
|
)
|
||||||
from concurrent.futures import FIRST_COMPLETED, wait
|
from concurrent.futures import FIRST_COMPLETED, wait
|
||||||
from contextvars import copy_context
|
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from itertools import groupby, tee
|
from itertools import groupby, tee
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
@ -47,7 +46,6 @@ from langchain_core.load.serializable import (
|
|||||||
)
|
)
|
||||||
from langchain_core.runnables.config import (
|
from langchain_core.runnables.config import (
|
||||||
RunnableConfig,
|
RunnableConfig,
|
||||||
_set_config_context,
|
|
||||||
acall_func_with_variable_args,
|
acall_func_with_variable_args,
|
||||||
call_func_with_variable_args,
|
call_func_with_variable_args,
|
||||||
ensure_config,
|
ensure_config,
|
||||||
@ -58,6 +56,7 @@ from langchain_core.runnables.config import (
|
|||||||
merge_configs,
|
merge_configs,
|
||||||
patch_config,
|
patch_config,
|
||||||
run_in_executor,
|
run_in_executor,
|
||||||
|
set_config_context,
|
||||||
)
|
)
|
||||||
from langchain_core.runnables.graph import Graph
|
from langchain_core.runnables.graph import Graph
|
||||||
from langchain_core.runnables.utils import (
|
from langchain_core.runnables.utils import (
|
||||||
@ -1920,19 +1919,18 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
child_config = patch_config(config, callbacks=run_manager.get_child())
|
child_config = patch_config(config, callbacks=run_manager.get_child())
|
||||||
context = copy_context()
|
with set_config_context(child_config) as context:
|
||||||
context.run(_set_config_context, child_config)
|
output = cast(
|
||||||
output = cast(
|
Output,
|
||||||
Output,
|
context.run(
|
||||||
context.run(
|
call_func_with_variable_args, # type: ignore[arg-type]
|
||||||
call_func_with_variable_args, # type: ignore[arg-type]
|
func, # type: ignore[arg-type]
|
||||||
func, # type: ignore[arg-type]
|
input, # type: ignore[arg-type]
|
||||||
input, # type: ignore[arg-type]
|
config,
|
||||||
config,
|
run_manager,
|
||||||
run_manager,
|
**kwargs,
|
||||||
**kwargs,
|
),
|
||||||
),
|
)
|
||||||
)
|
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
run_manager.on_chain_error(e)
|
run_manager.on_chain_error(e)
|
||||||
raise
|
raise
|
||||||
@ -1970,15 +1968,14 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
child_config = patch_config(config, callbacks=run_manager.get_child())
|
child_config = patch_config(config, callbacks=run_manager.get_child())
|
||||||
context = copy_context()
|
with set_config_context(child_config) as context:
|
||||||
context.run(_set_config_context, child_config)
|
coro = acall_func_with_variable_args(
|
||||||
coro = acall_func_with_variable_args(
|
func, input, config, run_manager, **kwargs
|
||||||
func, input, config, run_manager, **kwargs
|
)
|
||||||
)
|
if asyncio_accepts_context():
|
||||||
if asyncio_accepts_context():
|
output: Output = await asyncio.create_task(coro, context=context) # type: ignore
|
||||||
output: Output = await asyncio.create_task(coro, context=context) # type: ignore
|
else:
|
||||||
else:
|
output = await coro
|
||||||
output = await coro
|
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
await run_manager.on_chain_error(e)
|
await run_manager.on_chain_error(e)
|
||||||
raise
|
raise
|
||||||
@ -2182,49 +2179,50 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
kwargs["config"] = child_config
|
kwargs["config"] = child_config
|
||||||
if accepts_run_manager(transformer):
|
if accepts_run_manager(transformer):
|
||||||
kwargs["run_manager"] = run_manager
|
kwargs["run_manager"] = run_manager
|
||||||
context = copy_context()
|
with set_config_context(child_config) as context:
|
||||||
context.run(_set_config_context, child_config)
|
iterator = context.run(transformer, input_for_transform, **kwargs) # type: ignore[arg-type]
|
||||||
iterator = context.run(transformer, input_for_transform, **kwargs) # type: ignore[arg-type]
|
if stream_handler := next(
|
||||||
if stream_handler := next(
|
(
|
||||||
(
|
cast(_StreamingCallbackHandler, h)
|
||||||
cast(_StreamingCallbackHandler, h)
|
for h in run_manager.handlers
|
||||||
for h in run_manager.handlers
|
# instance check OK here, it's a mixin
|
||||||
# instance check OK here, it's a mixin
|
if isinstance(h, _StreamingCallbackHandler) # type: ignore[misc]
|
||||||
if isinstance(h, _StreamingCallbackHandler) # type: ignore[misc]
|
),
|
||||||
),
|
None,
|
||||||
None,
|
):
|
||||||
):
|
# populates streamed_output in astream_log() output if needed
|
||||||
# populates streamed_output in astream_log() output if needed
|
iterator = stream_handler.tap_output_iter(
|
||||||
iterator = stream_handler.tap_output_iter(run_manager.run_id, iterator)
|
run_manager.run_id, iterator
|
||||||
try:
|
)
|
||||||
while True:
|
try:
|
||||||
chunk: Output = context.run(next, iterator) # type: ignore
|
while True:
|
||||||
yield chunk
|
chunk: Output = context.run(next, iterator) # type: ignore
|
||||||
if final_output_supported:
|
yield chunk
|
||||||
if final_output is None:
|
if final_output_supported:
|
||||||
|
if final_output is None:
|
||||||
|
final_output = chunk
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
final_output = final_output + chunk # type: ignore
|
||||||
|
except TypeError:
|
||||||
|
final_output = chunk
|
||||||
|
final_output_supported = False
|
||||||
|
else:
|
||||||
final_output = chunk
|
final_output = chunk
|
||||||
|
except (StopIteration, GeneratorExit):
|
||||||
|
pass
|
||||||
|
for ichunk in input_for_tracing:
|
||||||
|
if final_input_supported:
|
||||||
|
if final_input is None:
|
||||||
|
final_input = ichunk
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
final_output = final_output + chunk # type: ignore
|
final_input = final_input + ichunk # type: ignore
|
||||||
except TypeError:
|
except TypeError:
|
||||||
final_output = chunk
|
final_input = ichunk
|
||||||
final_output_supported = False
|
final_input_supported = False
|
||||||
else:
|
else:
|
||||||
final_output = chunk
|
|
||||||
except (StopIteration, GeneratorExit):
|
|
||||||
pass
|
|
||||||
for ichunk in input_for_tracing:
|
|
||||||
if final_input_supported:
|
|
||||||
if final_input is None:
|
|
||||||
final_input = ichunk
|
final_input = ichunk
|
||||||
else:
|
|
||||||
try:
|
|
||||||
final_input = final_input + ichunk # type: ignore
|
|
||||||
except TypeError:
|
|
||||||
final_input = ichunk
|
|
||||||
final_input_supported = False
|
|
||||||
else:
|
|
||||||
final_input = ichunk
|
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
run_manager.on_chain_error(e, inputs=final_input)
|
run_manager.on_chain_error(e, inputs=final_input)
|
||||||
raise
|
raise
|
||||||
@ -2283,60 +2281,59 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
kwargs["config"] = child_config
|
kwargs["config"] = child_config
|
||||||
if accepts_run_manager(transformer):
|
if accepts_run_manager(transformer):
|
||||||
kwargs["run_manager"] = run_manager
|
kwargs["run_manager"] = run_manager
|
||||||
context = copy_context()
|
with set_config_context(child_config) as context:
|
||||||
context.run(_set_config_context, child_config)
|
iterator_ = context.run(transformer, input_for_transform, **kwargs) # type: ignore[arg-type]
|
||||||
iterator_ = context.run(transformer, input_for_transform, **kwargs) # type: ignore[arg-type]
|
|
||||||
|
|
||||||
if stream_handler := next(
|
if stream_handler := next(
|
||||||
(
|
(
|
||||||
cast(_StreamingCallbackHandler, h)
|
cast(_StreamingCallbackHandler, h)
|
||||||
for h in run_manager.handlers
|
for h in run_manager.handlers
|
||||||
# instance check OK here, it's a mixin
|
# instance check OK here, it's a mixin
|
||||||
if isinstance(h, _StreamingCallbackHandler) # type: ignore[misc]
|
if isinstance(h, _StreamingCallbackHandler) # type: ignore[misc]
|
||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
):
|
):
|
||||||
# populates streamed_output in astream_log() output if needed
|
# populates streamed_output in astream_log() output if needed
|
||||||
iterator = stream_handler.tap_output_aiter(
|
iterator = stream_handler.tap_output_aiter(
|
||||||
run_manager.run_id, iterator_
|
run_manager.run_id, iterator_
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
iterator = iterator_
|
iterator = iterator_
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
if asyncio_accepts_context():
|
if asyncio_accepts_context():
|
||||||
chunk: Output = await asyncio.create_task( # type: ignore[call-arg]
|
chunk: Output = await asyncio.create_task( # type: ignore[call-arg]
|
||||||
py_anext(iterator), # type: ignore[arg-type]
|
py_anext(iterator), # type: ignore[arg-type]
|
||||||
context=context,
|
context=context,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
chunk = cast(Output, await py_anext(iterator))
|
chunk = cast(Output, await py_anext(iterator))
|
||||||
yield chunk
|
yield chunk
|
||||||
if final_output_supported:
|
if final_output_supported:
|
||||||
if final_output is None:
|
if final_output is None:
|
||||||
|
final_output = chunk
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
final_output = final_output + chunk # type: ignore
|
||||||
|
except TypeError:
|
||||||
|
final_output = chunk
|
||||||
|
final_output_supported = False
|
||||||
|
else:
|
||||||
final_output = chunk
|
final_output = chunk
|
||||||
|
except StopAsyncIteration:
|
||||||
|
pass
|
||||||
|
async for ichunk in input_for_tracing:
|
||||||
|
if final_input_supported:
|
||||||
|
if final_input is None:
|
||||||
|
final_input = ichunk
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
final_output = final_output + chunk # type: ignore
|
final_input = final_input + ichunk # type: ignore[operator]
|
||||||
except TypeError:
|
except TypeError:
|
||||||
final_output = chunk
|
final_input = ichunk
|
||||||
final_output_supported = False
|
final_input_supported = False
|
||||||
else:
|
else:
|
||||||
final_output = chunk
|
|
||||||
except StopAsyncIteration:
|
|
||||||
pass
|
|
||||||
async for ichunk in input_for_tracing:
|
|
||||||
if final_input_supported:
|
|
||||||
if final_input is None:
|
|
||||||
final_input = ichunk
|
final_input = ichunk
|
||||||
else:
|
|
||||||
try:
|
|
||||||
final_input = final_input + ichunk # type: ignore[operator]
|
|
||||||
except TypeError:
|
|
||||||
final_input = ichunk
|
|
||||||
final_input_supported = False
|
|
||||||
else:
|
|
||||||
final_input = ichunk
|
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
await run_manager.on_chain_error(e, inputs=final_input)
|
await run_manager.on_chain_error(e, inputs=final_input)
|
||||||
raise
|
raise
|
||||||
@ -3021,12 +3018,11 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
config = patch_config(
|
config = patch_config(
|
||||||
config, callbacks=run_manager.get_child(f"seq:step:{i + 1}")
|
config, callbacks=run_manager.get_child(f"seq:step:{i + 1}")
|
||||||
)
|
)
|
||||||
context = copy_context()
|
with set_config_context(config) as context:
|
||||||
context.run(_set_config_context, config)
|
if i == 0:
|
||||||
if i == 0:
|
input = context.run(step.invoke, input, config, **kwargs)
|
||||||
input = context.run(step.invoke, input, config, **kwargs)
|
else:
|
||||||
else:
|
input = context.run(step.invoke, input, config)
|
||||||
input = context.run(step.invoke, input, config)
|
|
||||||
# finish the root run
|
# finish the root run
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
run_manager.on_chain_error(e)
|
run_manager.on_chain_error(e)
|
||||||
@ -3061,17 +3057,16 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
config = patch_config(
|
config = patch_config(
|
||||||
config, callbacks=run_manager.get_child(f"seq:step:{i + 1}")
|
config, callbacks=run_manager.get_child(f"seq:step:{i + 1}")
|
||||||
)
|
)
|
||||||
context = copy_context()
|
with set_config_context(config) as context:
|
||||||
context.run(_set_config_context, config)
|
if i == 0:
|
||||||
if i == 0:
|
part = functools.partial(step.ainvoke, input, config, **kwargs)
|
||||||
part = functools.partial(step.ainvoke, input, config, **kwargs)
|
else:
|
||||||
else:
|
part = functools.partial(step.ainvoke, input, config)
|
||||||
part = functools.partial(step.ainvoke, input, config)
|
if asyncio_accepts_context():
|
||||||
if asyncio_accepts_context():
|
input = await asyncio.create_task(part(), context=context) # type: ignore
|
||||||
input = await asyncio.create_task(part(), context=context) # type: ignore
|
else:
|
||||||
else:
|
input = await asyncio.create_task(part())
|
||||||
input = await asyncio.create_task(part())
|
# finish the root run
|
||||||
# finish the root run
|
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
await run_manager.on_chain_error(e)
|
await run_manager.on_chain_error(e)
|
||||||
raise
|
raise
|
||||||
@ -3713,13 +3708,12 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
|
|||||||
# mark each step as a child run
|
# mark each step as a child run
|
||||||
callbacks=run_manager.get_child(f"map:key:{key}"),
|
callbacks=run_manager.get_child(f"map:key:{key}"),
|
||||||
)
|
)
|
||||||
context = copy_context()
|
with set_config_context(child_config) as context:
|
||||||
context.run(_set_config_context, child_config)
|
return context.run(
|
||||||
return context.run(
|
step.invoke,
|
||||||
step.invoke,
|
input,
|
||||||
input,
|
child_config,
|
||||||
child_config,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# gather results from all steps
|
# gather results from all steps
|
||||||
try:
|
try:
|
||||||
@ -3764,14 +3758,13 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
|
|||||||
config,
|
config,
|
||||||
callbacks=run_manager.get_child(f"map:key:{key}"),
|
callbacks=run_manager.get_child(f"map:key:{key}"),
|
||||||
)
|
)
|
||||||
context = copy_context()
|
with set_config_context(child_config) as context:
|
||||||
context.run(_set_config_context, child_config)
|
if asyncio_accepts_context():
|
||||||
if asyncio_accepts_context():
|
return await asyncio.create_task( # type: ignore
|
||||||
return await asyncio.create_task( # type: ignore
|
step.ainvoke(input, child_config), context=context
|
||||||
step.ainvoke(input, child_config), context=context
|
)
|
||||||
)
|
else:
|
||||||
else:
|
return await asyncio.create_task(step.ainvoke(input, child_config))
|
||||||
return await asyncio.create_task(step.ainvoke(input, child_config))
|
|
||||||
|
|
||||||
# gather results from all steps
|
# gather results from all steps
|
||||||
try:
|
try:
|
||||||
|
@ -6,7 +6,7 @@ import warnings
|
|||||||
from collections.abc import Awaitable, Generator, Iterable, Iterator, Sequence
|
from collections.abc import Awaitable, Generator, Iterable, Iterator, Sequence
|
||||||
from concurrent.futures import Executor, Future, ThreadPoolExecutor
|
from concurrent.futures import Executor, Future, ThreadPoolExecutor
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from contextvars import ContextVar, copy_context
|
from contextvars import Context, ContextVar, Token, copy_context
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union, cast
|
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union, cast
|
||||||
|
|
||||||
@ -115,7 +115,10 @@ var_child_runnable_config: ContextVar[RunnableConfig | None] = ContextVar(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _set_config_context(config: RunnableConfig) -> None:
|
# This is imported and used in langgraph, so don't break.
|
||||||
|
def _set_config_context(
|
||||||
|
config: RunnableConfig,
|
||||||
|
) -> tuple[Token[Optional[RunnableConfig]], Optional[dict[str, Any]]]:
|
||||||
"""Set the child Runnable config + tracing context.
|
"""Set the child Runnable config + tracing context.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -123,7 +126,8 @@ def _set_config_context(config: RunnableConfig) -> None:
|
|||||||
"""
|
"""
|
||||||
from langchain_core.tracers.langchain import LangChainTracer
|
from langchain_core.tracers.langchain import LangChainTracer
|
||||||
|
|
||||||
var_child_runnable_config.set(config)
|
config_token = var_child_runnable_config.set(config)
|
||||||
|
current_context = None
|
||||||
if (
|
if (
|
||||||
(callbacks := config.get("callbacks"))
|
(callbacks := config.get("callbacks"))
|
||||||
and (
|
and (
|
||||||
@ -141,9 +145,32 @@ def _set_config_context(config: RunnableConfig) -> None:
|
|||||||
)
|
)
|
||||||
and (run := tracer.run_map.get(str(parent_run_id)))
|
and (run := tracer.run_map.get(str(parent_run_id)))
|
||||||
):
|
):
|
||||||
from langsmith.run_helpers import _set_tracing_context
|
from langsmith.run_helpers import _set_tracing_context, get_tracing_context
|
||||||
|
|
||||||
|
current_context = get_tracing_context()
|
||||||
_set_tracing_context({"parent": run})
|
_set_tracing_context({"parent": run})
|
||||||
|
return config_token, current_context
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def set_config_context(
|
||||||
|
config: RunnableConfig, ctx: Optional[Context] = None
|
||||||
|
) -> Generator[Context, None, None]:
|
||||||
|
"""Set the child Runnable config + tracing context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (RunnableConfig): The config to set.
|
||||||
|
"""
|
||||||
|
from langsmith.run_helpers import _set_tracing_context
|
||||||
|
|
||||||
|
ctx = ctx if ctx is not None else copy_context()
|
||||||
|
config_token, current_context = ctx.run(_set_config_context, config)
|
||||||
|
try:
|
||||||
|
yield ctx
|
||||||
|
finally:
|
||||||
|
ctx.run(var_child_runnable_config.reset, config_token)
|
||||||
|
if current_context is not None:
|
||||||
|
ctx.run(_set_tracing_context, current_context)
|
||||||
|
|
||||||
|
|
||||||
def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
|
def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
|
||||||
|
@ -2,7 +2,6 @@ import asyncio
|
|||||||
import inspect
|
import inspect
|
||||||
import typing
|
import typing
|
||||||
from collections.abc import AsyncIterator, Iterator, Sequence
|
from collections.abc import AsyncIterator, Iterator, Sequence
|
||||||
from contextvars import copy_context
|
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
@ -18,12 +17,12 @@ from typing_extensions import override
|
|||||||
from langchain_core.runnables.base import Runnable, RunnableSerializable
|
from langchain_core.runnables.base import Runnable, RunnableSerializable
|
||||||
from langchain_core.runnables.config import (
|
from langchain_core.runnables.config import (
|
||||||
RunnableConfig,
|
RunnableConfig,
|
||||||
_set_config_context,
|
|
||||||
ensure_config,
|
ensure_config,
|
||||||
get_async_callback_manager_for_config,
|
get_async_callback_manager_for_config,
|
||||||
get_callback_manager_for_config,
|
get_callback_manager_for_config,
|
||||||
get_config_list,
|
get_config_list,
|
||||||
patch_config,
|
patch_config,
|
||||||
|
set_config_context,
|
||||||
)
|
)
|
||||||
from langchain_core.runnables.utils import (
|
from langchain_core.runnables.utils import (
|
||||||
ConfigurableFieldSpec,
|
ConfigurableFieldSpec,
|
||||||
@ -174,14 +173,13 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
|||||||
if self.exception_key and last_error is not None:
|
if self.exception_key and last_error is not None:
|
||||||
input[self.exception_key] = last_error
|
input[self.exception_key] = last_error
|
||||||
child_config = patch_config(config, callbacks=run_manager.get_child())
|
child_config = patch_config(config, callbacks=run_manager.get_child())
|
||||||
context = copy_context()
|
with set_config_context(child_config) as context:
|
||||||
context.run(_set_config_context, child_config)
|
output = context.run(
|
||||||
output = context.run(
|
runnable.invoke,
|
||||||
runnable.invoke,
|
input,
|
||||||
input,
|
config,
|
||||||
config,
|
**kwargs,
|
||||||
**kwargs,
|
)
|
||||||
)
|
|
||||||
except self.exceptions_to_handle as e:
|
except self.exceptions_to_handle as e:
|
||||||
if first_error is None:
|
if first_error is None:
|
||||||
first_error = e
|
first_error = e
|
||||||
@ -228,13 +226,12 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
|||||||
if self.exception_key and last_error is not None:
|
if self.exception_key and last_error is not None:
|
||||||
input[self.exception_key] = last_error
|
input[self.exception_key] = last_error
|
||||||
child_config = patch_config(config, callbacks=run_manager.get_child())
|
child_config = patch_config(config, callbacks=run_manager.get_child())
|
||||||
context = copy_context()
|
with set_config_context(child_config) as context:
|
||||||
context.run(_set_config_context, child_config)
|
coro = context.run(runnable.ainvoke, input, config, **kwargs)
|
||||||
coro = runnable.ainvoke(input, child_config, **kwargs)
|
if asyncio_accepts_context():
|
||||||
if asyncio_accepts_context():
|
output = await asyncio.create_task(coro, context=context) # type: ignore
|
||||||
output = await asyncio.create_task(coro, context=context) # type: ignore
|
else:
|
||||||
else:
|
output = await coro
|
||||||
output = await coro
|
|
||||||
except self.exceptions_to_handle as e:
|
except self.exceptions_to_handle as e:
|
||||||
if first_error is None:
|
if first_error is None:
|
||||||
first_error = e
|
first_error = e
|
||||||
@ -475,14 +472,13 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
|||||||
if self.exception_key and last_error is not None:
|
if self.exception_key and last_error is not None:
|
||||||
input[self.exception_key] = last_error
|
input[self.exception_key] = last_error
|
||||||
child_config = patch_config(config, callbacks=run_manager.get_child())
|
child_config = patch_config(config, callbacks=run_manager.get_child())
|
||||||
context = copy_context()
|
with set_config_context(child_config) as context:
|
||||||
context.run(_set_config_context, child_config)
|
stream = context.run(
|
||||||
stream = context.run(
|
runnable.stream,
|
||||||
runnable.stream,
|
input,
|
||||||
input,
|
**kwargs,
|
||||||
**kwargs,
|
)
|
||||||
)
|
chunk: Output = context.run(next, stream) # type: ignore
|
||||||
chunk: Output = context.run(next, stream) # type: ignore
|
|
||||||
except self.exceptions_to_handle as e:
|
except self.exceptions_to_handle as e:
|
||||||
first_error = e if first_error is None else first_error
|
first_error = e if first_error is None else first_error
|
||||||
last_error = e
|
last_error = e
|
||||||
@ -539,20 +535,19 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
|||||||
if self.exception_key and last_error is not None:
|
if self.exception_key and last_error is not None:
|
||||||
input[self.exception_key] = last_error
|
input[self.exception_key] = last_error
|
||||||
child_config = patch_config(config, callbacks=run_manager.get_child())
|
child_config = patch_config(config, callbacks=run_manager.get_child())
|
||||||
context = copy_context()
|
with set_config_context(child_config) as context:
|
||||||
context.run(_set_config_context, child_config)
|
stream = runnable.astream(
|
||||||
stream = runnable.astream(
|
input,
|
||||||
input,
|
child_config,
|
||||||
child_config,
|
**kwargs,
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
if asyncio_accepts_context():
|
|
||||||
chunk: Output = await asyncio.create_task( # type: ignore[call-arg]
|
|
||||||
py_anext(stream), # type: ignore[arg-type]
|
|
||||||
context=context,
|
|
||||||
)
|
)
|
||||||
else:
|
if asyncio_accepts_context():
|
||||||
chunk = cast(Output, await py_anext(stream))
|
chunk: Output = await asyncio.create_task( # type: ignore[call-arg]
|
||||||
|
py_anext(stream), # type: ignore[arg-type]
|
||||||
|
context=context,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
chunk = cast(Output, await py_anext(stream))
|
||||||
except self.exceptions_to_handle as e:
|
except self.exceptions_to_handle as e:
|
||||||
first_error = e if first_error is None else first_error
|
first_error = e if first_error is None else first_error
|
||||||
last_error = e
|
last_error = e
|
||||||
|
@ -6,7 +6,6 @@ import inspect
|
|||||||
import json
|
import json
|
||||||
import warnings
|
import warnings
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from contextvars import copy_context
|
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
@ -52,7 +51,7 @@ from langchain_core.runnables import (
|
|||||||
patch_config,
|
patch_config,
|
||||||
run_in_executor,
|
run_in_executor,
|
||||||
)
|
)
|
||||||
from langchain_core.runnables.config import _set_config_context
|
from langchain_core.runnables.config import set_config_context
|
||||||
from langchain_core.runnables.utils import asyncio_accepts_context
|
from langchain_core.runnables.utils import asyncio_accepts_context
|
||||||
from langchain_core.utils.function_calling import (
|
from langchain_core.utils.function_calling import (
|
||||||
_parse_google_docstring,
|
_parse_google_docstring,
|
||||||
@ -722,14 +721,15 @@ class ChildTool(BaseTool):
|
|||||||
error_to_raise: Union[Exception, KeyboardInterrupt, None] = None
|
error_to_raise: Union[Exception, KeyboardInterrupt, None] = None
|
||||||
try:
|
try:
|
||||||
child_config = patch_config(config, callbacks=run_manager.get_child())
|
child_config = patch_config(config, callbacks=run_manager.get_child())
|
||||||
context = copy_context()
|
with set_config_context(child_config) as context:
|
||||||
context.run(_set_config_context, child_config)
|
tool_args, tool_kwargs = self._to_args_and_kwargs(
|
||||||
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input, tool_call_id)
|
tool_input, tool_call_id
|
||||||
if signature(self._run).parameters.get("run_manager"):
|
)
|
||||||
tool_kwargs = tool_kwargs | {"run_manager": run_manager}
|
if signature(self._run).parameters.get("run_manager"):
|
||||||
if config_param := _get_runnable_config_param(self._run):
|
tool_kwargs = tool_kwargs | {"run_manager": run_manager}
|
||||||
tool_kwargs = tool_kwargs | {config_param: config}
|
if config_param := _get_runnable_config_param(self._run):
|
||||||
response = context.run(self._run, *tool_args, **tool_kwargs)
|
tool_kwargs = tool_kwargs | {config_param: config}
|
||||||
|
response = context.run(self._run, *tool_args, **tool_kwargs)
|
||||||
if self.response_format == "content_and_artifact":
|
if self.response_format == "content_and_artifact":
|
||||||
if not isinstance(response, tuple) or len(response) != 2:
|
if not isinstance(response, tuple) or len(response) != 2:
|
||||||
msg = (
|
msg = (
|
||||||
@ -832,21 +832,20 @@ class ChildTool(BaseTool):
|
|||||||
try:
|
try:
|
||||||
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input, tool_call_id)
|
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input, tool_call_id)
|
||||||
child_config = patch_config(config, callbacks=run_manager.get_child())
|
child_config = patch_config(config, callbacks=run_manager.get_child())
|
||||||
context = copy_context()
|
with set_config_context(child_config) as context:
|
||||||
context.run(_set_config_context, child_config)
|
func_to_check = (
|
||||||
func_to_check = (
|
self._run if self.__class__._arun is BaseTool._arun else self._arun
|
||||||
self._run if self.__class__._arun is BaseTool._arun else self._arun
|
)
|
||||||
)
|
if signature(func_to_check).parameters.get("run_manager"):
|
||||||
if signature(func_to_check).parameters.get("run_manager"):
|
tool_kwargs["run_manager"] = run_manager
|
||||||
tool_kwargs["run_manager"] = run_manager
|
if config_param := _get_runnable_config_param(func_to_check):
|
||||||
if config_param := _get_runnable_config_param(func_to_check):
|
tool_kwargs[config_param] = config
|
||||||
tool_kwargs[config_param] = config
|
|
||||||
|
|
||||||
coro = context.run(self._arun, *tool_args, **tool_kwargs)
|
coro = self._arun(*tool_args, **tool_kwargs)
|
||||||
if asyncio_accepts_context():
|
if asyncio_accepts_context():
|
||||||
response = await asyncio.create_task(coro, context=context) # type: ignore
|
response = await asyncio.create_task(coro, context=context) # type: ignore
|
||||||
else:
|
else:
|
||||||
response = await coro
|
response = await coro
|
||||||
if self.response_format == "content_and_artifact":
|
if self.response_format == "content_and_artifact":
|
||||||
if not isinstance(response, tuple) or len(response) != 2:
|
if not isinstance(response, tuple) or len(response) != 2:
|
||||||
msg = (
|
msg = (
|
||||||
|
@ -89,11 +89,11 @@ def tracing_v2_enabled(
|
|||||||
tags=tags,
|
tags=tags,
|
||||||
client=client,
|
client=client,
|
||||||
)
|
)
|
||||||
|
token = tracing_v2_callback_var.set(cb)
|
||||||
try:
|
try:
|
||||||
tracing_v2_callback_var.set(cb)
|
|
||||||
yield cb
|
yield cb
|
||||||
finally:
|
finally:
|
||||||
tracing_v2_callback_var.set(None)
|
tracing_v2_callback_var.reset(token)
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
@ -109,9 +109,11 @@ def collect_runs() -> Generator[RunCollectorCallbackHandler, None, None]:
|
|||||||
run_id = runs_cb.traced_runs[0].id
|
run_id = runs_cb.traced_runs[0].id
|
||||||
"""
|
"""
|
||||||
cb = RunCollectorCallbackHandler()
|
cb = RunCollectorCallbackHandler()
|
||||||
run_collector_var.set(cb)
|
token = run_collector_var.set(cb)
|
||||||
yield cb
|
try:
|
||||||
run_collector_var.set(None)
|
yield cb
|
||||||
|
finally:
|
||||||
|
run_collector_var.reset(token)
|
||||||
|
|
||||||
|
|
||||||
def _get_trace_callbacks(
|
def _get_trace_callbacks(
|
||||||
|
Loading…
Reference in New Issue
Block a user