1
0
mirror of https://github.com/hwchase17/langchain.git synced 2025-05-04 22:58:42 +00:00

Merge remote-tracking branch 'upstream/master' into integration-update

This commit is contained in:
pulvedu 2025-03-19 14:47:27 -04:00
commit e5f6239aed
5 changed files with 227 additions and 211 deletions
libs/core/langchain_core

View File

@ -17,7 +17,6 @@ from collections.abc import (
Sequence,
)
from concurrent.futures import FIRST_COMPLETED, wait
from contextvars import copy_context
from functools import wraps
from itertools import groupby, tee
from operator import itemgetter
@ -47,7 +46,6 @@ from langchain_core.load.serializable import (
)
from langchain_core.runnables.config import (
RunnableConfig,
_set_config_context,
acall_func_with_variable_args,
call_func_with_variable_args,
ensure_config,
@ -58,6 +56,7 @@ from langchain_core.runnables.config import (
merge_configs,
patch_config,
run_in_executor,
set_config_context,
)
from langchain_core.runnables.graph import Graph
from langchain_core.runnables.utils import (
@ -1920,19 +1919,18 @@ class Runnable(Generic[Input, Output], ABC):
)
try:
child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context()
context.run(_set_config_context, child_config)
output = cast(
Output,
context.run(
call_func_with_variable_args, # type: ignore[arg-type]
func, # type: ignore[arg-type]
input, # type: ignore[arg-type]
config,
run_manager,
**kwargs,
),
)
with set_config_context(child_config) as context:
output = cast(
Output,
context.run(
call_func_with_variable_args, # type: ignore[arg-type]
func, # type: ignore[arg-type]
input, # type: ignore[arg-type]
config,
run_manager,
**kwargs,
),
)
except BaseException as e:
run_manager.on_chain_error(e)
raise
@ -1970,15 +1968,14 @@ class Runnable(Generic[Input, Output], ABC):
)
try:
child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context()
context.run(_set_config_context, child_config)
coro = acall_func_with_variable_args(
func, input, config, run_manager, **kwargs
)
if asyncio_accepts_context():
output: Output = await asyncio.create_task(coro, context=context) # type: ignore
else:
output = await coro
with set_config_context(child_config) as context:
coro = acall_func_with_variable_args(
func, input, config, run_manager, **kwargs
)
if asyncio_accepts_context():
output: Output = await asyncio.create_task(coro, context=context) # type: ignore
else:
output = await coro
except BaseException as e:
await run_manager.on_chain_error(e)
raise
@ -2182,49 +2179,50 @@ class Runnable(Generic[Input, Output], ABC):
kwargs["config"] = child_config
if accepts_run_manager(transformer):
kwargs["run_manager"] = run_manager
context = copy_context()
context.run(_set_config_context, child_config)
iterator = context.run(transformer, input_for_transform, **kwargs) # type: ignore[arg-type]
if stream_handler := next(
(
cast(_StreamingCallbackHandler, h)
for h in run_manager.handlers
# instance check OK here, it's a mixin
if isinstance(h, _StreamingCallbackHandler) # type: ignore[misc]
),
None,
):
# populates streamed_output in astream_log() output if needed
iterator = stream_handler.tap_output_iter(run_manager.run_id, iterator)
try:
while True:
chunk: Output = context.run(next, iterator) # type: ignore
yield chunk
if final_output_supported:
if final_output is None:
with set_config_context(child_config) as context:
iterator = context.run(transformer, input_for_transform, **kwargs) # type: ignore[arg-type]
if stream_handler := next(
(
cast(_StreamingCallbackHandler, h)
for h in run_manager.handlers
# instance check OK here, it's a mixin
if isinstance(h, _StreamingCallbackHandler) # type: ignore[misc]
),
None,
):
# populates streamed_output in astream_log() output if needed
iterator = stream_handler.tap_output_iter(
run_manager.run_id, iterator
)
try:
while True:
chunk: Output = context.run(next, iterator) # type: ignore
yield chunk
if final_output_supported:
if final_output is None:
final_output = chunk
else:
try:
final_output = final_output + chunk # type: ignore
except TypeError:
final_output = chunk
final_output_supported = False
else:
final_output = chunk
except (StopIteration, GeneratorExit):
pass
for ichunk in input_for_tracing:
if final_input_supported:
if final_input is None:
final_input = ichunk
else:
try:
final_output = final_output + chunk # type: ignore
final_input = final_input + ichunk # type: ignore
except TypeError:
final_output = chunk
final_output_supported = False
final_input = ichunk
final_input_supported = False
else:
final_output = chunk
except (StopIteration, GeneratorExit):
pass
for ichunk in input_for_tracing:
if final_input_supported:
if final_input is None:
final_input = ichunk
else:
try:
final_input = final_input + ichunk # type: ignore
except TypeError:
final_input = ichunk
final_input_supported = False
else:
final_input = ichunk
except BaseException as e:
run_manager.on_chain_error(e, inputs=final_input)
raise
@ -2283,60 +2281,59 @@ class Runnable(Generic[Input, Output], ABC):
kwargs["config"] = child_config
if accepts_run_manager(transformer):
kwargs["run_manager"] = run_manager
context = copy_context()
context.run(_set_config_context, child_config)
iterator_ = context.run(transformer, input_for_transform, **kwargs) # type: ignore[arg-type]
with set_config_context(child_config) as context:
iterator_ = context.run(transformer, input_for_transform, **kwargs) # type: ignore[arg-type]
if stream_handler := next(
(
cast(_StreamingCallbackHandler, h)
for h in run_manager.handlers
# instance check OK here, it's a mixin
if isinstance(h, _StreamingCallbackHandler) # type: ignore[misc]
),
None,
):
# populates streamed_output in astream_log() output if needed
iterator = stream_handler.tap_output_aiter(
run_manager.run_id, iterator_
)
else:
iterator = iterator_
try:
while True:
if asyncio_accepts_context():
chunk: Output = await asyncio.create_task( # type: ignore[call-arg]
py_anext(iterator), # type: ignore[arg-type]
context=context,
)
else:
chunk = cast(Output, await py_anext(iterator))
yield chunk
if final_output_supported:
if final_output is None:
if stream_handler := next(
(
cast(_StreamingCallbackHandler, h)
for h in run_manager.handlers
# instance check OK here, it's a mixin
if isinstance(h, _StreamingCallbackHandler) # type: ignore[misc]
),
None,
):
# populates streamed_output in astream_log() output if needed
iterator = stream_handler.tap_output_aiter(
run_manager.run_id, iterator_
)
else:
iterator = iterator_
try:
while True:
if asyncio_accepts_context():
chunk: Output = await asyncio.create_task( # type: ignore[call-arg]
py_anext(iterator), # type: ignore[arg-type]
context=context,
)
else:
chunk = cast(Output, await py_anext(iterator))
yield chunk
if final_output_supported:
if final_output is None:
final_output = chunk
else:
try:
final_output = final_output + chunk # type: ignore
except TypeError:
final_output = chunk
final_output_supported = False
else:
final_output = chunk
except StopAsyncIteration:
pass
async for ichunk in input_for_tracing:
if final_input_supported:
if final_input is None:
final_input = ichunk
else:
try:
final_output = final_output + chunk # type: ignore
final_input = final_input + ichunk # type: ignore[operator]
except TypeError:
final_output = chunk
final_output_supported = False
final_input = ichunk
final_input_supported = False
else:
final_output = chunk
except StopAsyncIteration:
pass
async for ichunk in input_for_tracing:
if final_input_supported:
if final_input is None:
final_input = ichunk
else:
try:
final_input = final_input + ichunk # type: ignore[operator]
except TypeError:
final_input = ichunk
final_input_supported = False
else:
final_input = ichunk
except BaseException as e:
await run_manager.on_chain_error(e, inputs=final_input)
raise
@ -3021,12 +3018,11 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
config = patch_config(
config, callbacks=run_manager.get_child(f"seq:step:{i + 1}")
)
context = copy_context()
context.run(_set_config_context, config)
if i == 0:
input = context.run(step.invoke, input, config, **kwargs)
else:
input = context.run(step.invoke, input, config)
with set_config_context(config) as context:
if i == 0:
input = context.run(step.invoke, input, config, **kwargs)
else:
input = context.run(step.invoke, input, config)
# finish the root run
except BaseException as e:
run_manager.on_chain_error(e)
@ -3061,17 +3057,16 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
config = patch_config(
config, callbacks=run_manager.get_child(f"seq:step:{i + 1}")
)
context = copy_context()
context.run(_set_config_context, config)
if i == 0:
part = functools.partial(step.ainvoke, input, config, **kwargs)
else:
part = functools.partial(step.ainvoke, input, config)
if asyncio_accepts_context():
input = await asyncio.create_task(part(), context=context) # type: ignore
else:
input = await asyncio.create_task(part())
# finish the root run
with set_config_context(config) as context:
if i == 0:
part = functools.partial(step.ainvoke, input, config, **kwargs)
else:
part = functools.partial(step.ainvoke, input, config)
if asyncio_accepts_context():
input = await asyncio.create_task(part(), context=context) # type: ignore
else:
input = await asyncio.create_task(part())
# finish the root run
except BaseException as e:
await run_manager.on_chain_error(e)
raise
@ -3713,13 +3708,12 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
# mark each step as a child run
callbacks=run_manager.get_child(f"map:key:{key}"),
)
context = copy_context()
context.run(_set_config_context, child_config)
return context.run(
step.invoke,
input,
child_config,
)
with set_config_context(child_config) as context:
return context.run(
step.invoke,
input,
child_config,
)
# gather results from all steps
try:
@ -3764,14 +3758,13 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
config,
callbacks=run_manager.get_child(f"map:key:{key}"),
)
context = copy_context()
context.run(_set_config_context, child_config)
if asyncio_accepts_context():
return await asyncio.create_task( # type: ignore
step.ainvoke(input, child_config), context=context
)
else:
return await asyncio.create_task(step.ainvoke(input, child_config))
with set_config_context(child_config) as context:
if asyncio_accepts_context():
return await asyncio.create_task( # type: ignore
step.ainvoke(input, child_config), context=context
)
else:
return await asyncio.create_task(step.ainvoke(input, child_config))
# gather results from all steps
try:

View File

@ -6,7 +6,7 @@ import warnings
from collections.abc import Awaitable, Generator, Iterable, Iterator, Sequence
from concurrent.futures import Executor, Future, ThreadPoolExecutor
from contextlib import contextmanager
from contextvars import ContextVar, copy_context
from contextvars import Context, ContextVar, Token, copy_context
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union, cast
@ -115,7 +115,10 @@ var_child_runnable_config: ContextVar[RunnableConfig | None] = ContextVar(
)
def _set_config_context(config: RunnableConfig) -> None:
# This is imported and used in langgraph, so don't break.
def _set_config_context(
config: RunnableConfig,
) -> tuple[Token[Optional[RunnableConfig]], Optional[dict[str, Any]]]:
"""Set the child Runnable config + tracing context.
Args:
@ -123,7 +126,8 @@ def _set_config_context(config: RunnableConfig) -> None:
"""
from langchain_core.tracers.langchain import LangChainTracer
var_child_runnable_config.set(config)
config_token = var_child_runnable_config.set(config)
current_context = None
if (
(callbacks := config.get("callbacks"))
and (
@ -141,9 +145,32 @@ def _set_config_context(config: RunnableConfig) -> None:
)
and (run := tracer.run_map.get(str(parent_run_id)))
):
from langsmith.run_helpers import _set_tracing_context
from langsmith.run_helpers import _set_tracing_context, get_tracing_context
current_context = get_tracing_context()
_set_tracing_context({"parent": run})
return config_token, current_context
@contextmanager
def set_config_context(
config: RunnableConfig, ctx: Optional[Context] = None
) -> Generator[Context, None, None]:
"""Set the child Runnable config + tracing context.
Args:
config (RunnableConfig): The config to set.
"""
from langsmith.run_helpers import _set_tracing_context
ctx = ctx if ctx is not None else copy_context()
config_token, current_context = ctx.run(_set_config_context, config)
try:
yield ctx
finally:
ctx.run(var_child_runnable_config.reset, config_token)
if current_context is not None:
ctx.run(_set_tracing_context, current_context)
def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:

View File

@ -2,7 +2,6 @@ import asyncio
import inspect
import typing
from collections.abc import AsyncIterator, Iterator, Sequence
from contextvars import copy_context
from functools import wraps
from typing import (
TYPE_CHECKING,
@ -18,12 +17,12 @@ from typing_extensions import override
from langchain_core.runnables.base import Runnable, RunnableSerializable
from langchain_core.runnables.config import (
RunnableConfig,
_set_config_context,
ensure_config,
get_async_callback_manager_for_config,
get_callback_manager_for_config,
get_config_list,
patch_config,
set_config_context,
)
from langchain_core.runnables.utils import (
ConfigurableFieldSpec,
@ -174,14 +173,13 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
if self.exception_key and last_error is not None:
input[self.exception_key] = last_error
child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context()
context.run(_set_config_context, child_config)
output = context.run(
runnable.invoke,
input,
config,
**kwargs,
)
with set_config_context(child_config) as context:
output = context.run(
runnable.invoke,
input,
config,
**kwargs,
)
except self.exceptions_to_handle as e:
if first_error is None:
first_error = e
@ -228,13 +226,12 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
if self.exception_key and last_error is not None:
input[self.exception_key] = last_error
child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context()
context.run(_set_config_context, child_config)
coro = runnable.ainvoke(input, child_config, **kwargs)
if asyncio_accepts_context():
output = await asyncio.create_task(coro, context=context) # type: ignore
else:
output = await coro
with set_config_context(child_config) as context:
coro = context.run(runnable.ainvoke, input, config, **kwargs)
if asyncio_accepts_context():
output = await asyncio.create_task(coro, context=context) # type: ignore
else:
output = await coro
except self.exceptions_to_handle as e:
if first_error is None:
first_error = e
@ -475,14 +472,13 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
if self.exception_key and last_error is not None:
input[self.exception_key] = last_error
child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context()
context.run(_set_config_context, child_config)
stream = context.run(
runnable.stream,
input,
**kwargs,
)
chunk: Output = context.run(next, stream) # type: ignore
with set_config_context(child_config) as context:
stream = context.run(
runnable.stream,
input,
**kwargs,
)
chunk: Output = context.run(next, stream) # type: ignore
except self.exceptions_to_handle as e:
first_error = e if first_error is None else first_error
last_error = e
@ -539,20 +535,19 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
if self.exception_key and last_error is not None:
input[self.exception_key] = last_error
child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context()
context.run(_set_config_context, child_config)
stream = runnable.astream(
input,
child_config,
**kwargs,
)
if asyncio_accepts_context():
chunk: Output = await asyncio.create_task( # type: ignore[call-arg]
py_anext(stream), # type: ignore[arg-type]
context=context,
with set_config_context(child_config) as context:
stream = runnable.astream(
input,
child_config,
**kwargs,
)
else:
chunk = cast(Output, await py_anext(stream))
if asyncio_accepts_context():
chunk: Output = await asyncio.create_task( # type: ignore[call-arg]
py_anext(stream), # type: ignore[arg-type]
context=context,
)
else:
chunk = cast(Output, await py_anext(stream))
except self.exceptions_to_handle as e:
first_error = e if first_error is None else first_error
last_error = e

View File

@ -6,7 +6,6 @@ import inspect
import json
import warnings
from abc import ABC, abstractmethod
from contextvars import copy_context
from inspect import signature
from typing import (
TYPE_CHECKING,
@ -52,7 +51,7 @@ from langchain_core.runnables import (
patch_config,
run_in_executor,
)
from langchain_core.runnables.config import _set_config_context
from langchain_core.runnables.config import set_config_context
from langchain_core.runnables.utils import asyncio_accepts_context
from langchain_core.utils.function_calling import (
_parse_google_docstring,
@ -722,14 +721,15 @@ class ChildTool(BaseTool):
error_to_raise: Union[Exception, KeyboardInterrupt, None] = None
try:
child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context()
context.run(_set_config_context, child_config)
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input, tool_call_id)
if signature(self._run).parameters.get("run_manager"):
tool_kwargs = tool_kwargs | {"run_manager": run_manager}
if config_param := _get_runnable_config_param(self._run):
tool_kwargs = tool_kwargs | {config_param: config}
response = context.run(self._run, *tool_args, **tool_kwargs)
with set_config_context(child_config) as context:
tool_args, tool_kwargs = self._to_args_and_kwargs(
tool_input, tool_call_id
)
if signature(self._run).parameters.get("run_manager"):
tool_kwargs = tool_kwargs | {"run_manager": run_manager}
if config_param := _get_runnable_config_param(self._run):
tool_kwargs = tool_kwargs | {config_param: config}
response = context.run(self._run, *tool_args, **tool_kwargs)
if self.response_format == "content_and_artifact":
if not isinstance(response, tuple) or len(response) != 2:
msg = (
@ -832,21 +832,20 @@ class ChildTool(BaseTool):
try:
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input, tool_call_id)
child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context()
context.run(_set_config_context, child_config)
func_to_check = (
self._run if self.__class__._arun is BaseTool._arun else self._arun
)
if signature(func_to_check).parameters.get("run_manager"):
tool_kwargs["run_manager"] = run_manager
if config_param := _get_runnable_config_param(func_to_check):
tool_kwargs[config_param] = config
with set_config_context(child_config) as context:
func_to_check = (
self._run if self.__class__._arun is BaseTool._arun else self._arun
)
if signature(func_to_check).parameters.get("run_manager"):
tool_kwargs["run_manager"] = run_manager
if config_param := _get_runnable_config_param(func_to_check):
tool_kwargs[config_param] = config
coro = context.run(self._arun, *tool_args, **tool_kwargs)
if asyncio_accepts_context():
response = await asyncio.create_task(coro, context=context) # type: ignore
else:
response = await coro
coro = self._arun(*tool_args, **tool_kwargs)
if asyncio_accepts_context():
response = await asyncio.create_task(coro, context=context) # type: ignore
else:
response = await coro
if self.response_format == "content_and_artifact":
if not isinstance(response, tuple) or len(response) != 2:
msg = (

View File

@ -89,11 +89,11 @@ def tracing_v2_enabled(
tags=tags,
client=client,
)
token = tracing_v2_callback_var.set(cb)
try:
tracing_v2_callback_var.set(cb)
yield cb
finally:
tracing_v2_callback_var.set(None)
tracing_v2_callback_var.reset(token)
@contextmanager
@ -109,9 +109,11 @@ def collect_runs() -> Generator[RunCollectorCallbackHandler, None, None]:
run_id = runs_cb.traced_runs[0].id
"""
cb = RunCollectorCallbackHandler()
run_collector_var.set(cb)
yield cb
run_collector_var.set(None)
token = run_collector_var.set(cb)
try:
yield cb
finally:
run_collector_var.reset(token)
def _get_trace_callbacks(