mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 16:43:35 +00:00
[Enhancement] Add support for directly providing a run_id (#18990)
The root run id (~trace id's) is useful for assigning feedback, but the current recommended approach is to use callbacks to retrieve it, which has some drawbacks: 1. Doesn't work for streaming until after the first event 2. Doesn't let you call other endpoints with the same trace ID in parallel (since you have to wait until the call is completed/started to use This PR lets you provide = "run_id" in the runnable config. Couple considerations: 1. For batch calls, we split the trace up into separate trees (to permit better rendering). We keep the provided run ID for the first one and generate a unique one for other elements of the batch. 2. For nested calls, the provided ID is ONLY used on the top root/trace. ### Example Usage ``` chain.invoke("foo", {"run_id": uuid.uuid4()}) ```
This commit is contained in:
parent
bd329e9aad
commit
780337488e
@ -1183,6 +1183,7 @@ class CallbackManager(BaseCallbackManager):
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[CallbackManagerForLLMRun]:
|
||||
"""Run when LLM starts running.
|
||||
@ -1197,8 +1198,9 @@ class CallbackManager(BaseCallbackManager):
|
||||
prompt as an LLM run.
|
||||
"""
|
||||
managers = []
|
||||
for prompt in prompts:
|
||||
run_id_ = uuid.uuid4()
|
||||
for i, prompt in enumerate(prompts):
|
||||
# Can't have duplicate runs with the same run ID (if provided)
|
||||
run_id_ = run_id if i == 0 and run_id is not None else uuid.uuid4()
|
||||
handle_event(
|
||||
self.handlers,
|
||||
"on_llm_start",
|
||||
@ -1231,6 +1233,7 @@ class CallbackManager(BaseCallbackManager):
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[CallbackManagerForLLMRun]:
|
||||
"""Run when LLM starts running.
|
||||
@ -1247,6 +1250,10 @@ class CallbackManager(BaseCallbackManager):
|
||||
|
||||
managers = []
|
||||
for message_list in messages:
|
||||
if run_id is not None:
|
||||
run_id_ = run_id
|
||||
run_id = None
|
||||
else:
|
||||
run_id_ = uuid.uuid4()
|
||||
handle_event(
|
||||
self.handlers,
|
||||
@ -1520,6 +1527,7 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[AsyncCallbackManagerForLLMRun]:
|
||||
"""Run when LLM starts running.
|
||||
@ -1539,6 +1547,10 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
managers = []
|
||||
|
||||
for prompt in prompts:
|
||||
if run_id is not None:
|
||||
run_id_ = run_id
|
||||
run_id = None
|
||||
else:
|
||||
run_id_ = uuid.uuid4()
|
||||
|
||||
tasks.append(
|
||||
@ -1577,6 +1589,7 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[AsyncCallbackManagerForLLMRun]:
|
||||
"""Run when LLM starts running.
|
||||
@ -1595,6 +1608,10 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
managers = []
|
||||
|
||||
for message_list in messages:
|
||||
if run_id is not None:
|
||||
run_id_ = run_id
|
||||
run_id = None
|
||||
else:
|
||||
run_id_ = uuid.uuid4()
|
||||
|
||||
tasks.append(
|
||||
|
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import uuid
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import (
|
||||
@ -234,6 +235,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
name=config.get("run_name"),
|
||||
run_id=config.pop("run_id", None),
|
||||
batch_size=1,
|
||||
)
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
@ -312,6 +314,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
name=config.get("run_name"),
|
||||
run_id=config.pop("run_id", None),
|
||||
batch_size=1,
|
||||
)
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
@ -371,6 +374,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
run_name: Optional[str] = None,
|
||||
run_id: Optional[uuid.UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Pass a sequence of prompts to the model and return model generations.
|
||||
@ -415,6 +419,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
name=run_name,
|
||||
run_id=run_id,
|
||||
batch_size=len(messages),
|
||||
)
|
||||
results = []
|
||||
@ -456,6 +461,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
run_name: Optional[str] = None,
|
||||
run_id: Optional[uuid.UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Asynchronously pass a sequence of prompts to a model and return generations.
|
||||
@ -502,6 +508,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
options=options,
|
||||
name=run_name,
|
||||
batch_size=len(messages),
|
||||
run_id=run_id,
|
||||
)
|
||||
|
||||
results = await asyncio.gather(
|
||||
|
@ -7,6 +7,7 @@ import functools
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
@ -271,6 +272,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
run_name=config.get("run_name"),
|
||||
run_id=config.pop("run_id", None),
|
||||
**kwargs,
|
||||
)
|
||||
.generations[0][0]
|
||||
@ -293,6 +295,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
run_name=config.get("run_name"),
|
||||
run_id=config.pop("run_id", None),
|
||||
**kwargs,
|
||||
)
|
||||
return llm_result.generations[0][0].text
|
||||
@ -423,6 +426,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
name=config.get("run_name"),
|
||||
run_id=config.pop("run_id", None),
|
||||
batch_size=1,
|
||||
)
|
||||
generation: Optional[GenerationChunk] = None
|
||||
@ -499,6 +503,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
name=config.get("run_name"),
|
||||
run_id=config.pop("run_id", None),
|
||||
batch_size=1,
|
||||
)
|
||||
generation: Optional[GenerationChunk] = None
|
||||
@ -632,6 +637,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
tags: Optional[Union[List[str], List[List[str]]]] = None,
|
||||
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
|
||||
run_name: Optional[Union[str, List[str]]] = None,
|
||||
run_id: Optional[Union[uuid.UUID, List[Optional[uuid.UUID]]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Pass a sequence of prompts to a model and return generations.
|
||||
@ -717,7 +723,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
)
|
||||
] * len(prompts)
|
||||
run_name_list = [cast(Optional[str], run_name)] * len(prompts)
|
||||
|
||||
run_ids_list = self._get_run_ids_list(run_id, prompts)
|
||||
params = self.dict()
|
||||
params["stop"] = stop
|
||||
options = {"stop": stop}
|
||||
@ -744,9 +750,10 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
options=options,
|
||||
name=run_name,
|
||||
batch_size=len(prompts),
|
||||
run_id=run_id_,
|
||||
)[0]
|
||||
for callback_manager, prompt, run_name in zip(
|
||||
callback_managers, prompts, run_name_list
|
||||
for callback_manager, prompt, run_name, run_id_ in zip(
|
||||
callback_managers, prompts, run_name_list, run_ids_list
|
||||
)
|
||||
]
|
||||
output = self._generate_helper(
|
||||
@ -782,6 +789,21 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
generations = [existing_prompts[i] for i in range(len(prompts))]
|
||||
return LLMResult(generations=generations, llm_output=llm_output, run=run_info)
|
||||
|
||||
@staticmethod
|
||||
def _get_run_ids_list(
|
||||
run_id: Optional[Union[uuid.UUID, List[Optional[uuid.UUID]]]], prompts: list
|
||||
) -> list:
|
||||
if run_id is None:
|
||||
return [None] * len(prompts)
|
||||
if isinstance(run_id, list):
|
||||
if len(run_id) != len(prompts):
|
||||
raise ValueError(
|
||||
"Number of manually provided run_id's does not match batch length."
|
||||
f" {len(run_id)} != {len(prompts)}"
|
||||
)
|
||||
return run_id
|
||||
return [run_id] + [None] * (len(prompts) - 1)
|
||||
|
||||
async def _agenerate_helper(
|
||||
self,
|
||||
prompts: List[str],
|
||||
@ -833,6 +855,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
tags: Optional[Union[List[str], List[List[str]]]] = None,
|
||||
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
|
||||
run_name: Optional[Union[str, List[str]]] = None,
|
||||
run_id: Optional[Union[uuid.UUID, List[Optional[uuid.UUID]]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Asynchronously pass a sequence of prompts to a model and return generations.
|
||||
@ -909,7 +932,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
)
|
||||
] * len(prompts)
|
||||
run_name_list = [cast(Optional[str], run_name)] * len(prompts)
|
||||
|
||||
run_ids_list = self._get_run_ids_list(run_id, prompts)
|
||||
params = self.dict()
|
||||
params["stop"] = stop
|
||||
options = {"stop": stop}
|
||||
@ -937,9 +960,10 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
options=options,
|
||||
name=run_name,
|
||||
batch_size=len(prompts),
|
||||
run_id=run_id_,
|
||||
)
|
||||
for callback_manager, prompt, run_name in zip(
|
||||
callback_managers, prompts, run_name_list
|
||||
for callback_manager, prompt, run_name, run_id_ in zip(
|
||||
callback_managers, prompts, run_name_list, run_ids_list
|
||||
)
|
||||
]
|
||||
)
|
||||
|
@ -230,6 +230,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
||||
dumpd(self),
|
||||
query,
|
||||
name=run_name,
|
||||
run_id=kwargs.pop("run_id", None),
|
||||
)
|
||||
try:
|
||||
_kwargs = kwargs if self._expects_other_args else {}
|
||||
@ -286,6 +287,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
||||
dumpd(self),
|
||||
query,
|
||||
name=run_name,
|
||||
run_id=kwargs.pop("run_id", None),
|
||||
)
|
||||
try:
|
||||
_kwargs = kwargs if self._expects_other_args else {}
|
||||
|
@ -1448,6 +1448,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
input,
|
||||
run_type=run_type,
|
||||
name=config.get("run_name") or self.get_name(),
|
||||
run_id=config.pop("run_id", None),
|
||||
)
|
||||
try:
|
||||
child_config = patch_config(config, callbacks=run_manager.get_child())
|
||||
@ -1495,6 +1496,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
input,
|
||||
run_type=run_type,
|
||||
name=config.get("run_name") or self.get_name(),
|
||||
run_id=config.pop("run_id", None),
|
||||
)
|
||||
try:
|
||||
child_config = patch_config(config, callbacks=run_manager.get_child())
|
||||
@ -1547,6 +1549,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
input,
|
||||
run_type=run_type,
|
||||
name=config.get("run_name") or self.get_name(),
|
||||
run_id=config.pop("run_id", None),
|
||||
)
|
||||
for callback_manager, input, config in zip(
|
||||
callback_managers, input, configs
|
||||
@ -1619,6 +1622,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
input,
|
||||
run_type=run_type,
|
||||
name=config.get("run_name") or self.get_name(),
|
||||
run_id=config.pop("run_id", None),
|
||||
)
|
||||
for callback_manager, input, config in zip(
|
||||
callback_managers, input, configs
|
||||
@ -1694,6 +1698,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
{"input": ""},
|
||||
run_type=run_type,
|
||||
name=config.get("run_name") or self.get_name(),
|
||||
run_id=config.pop("run_id", None),
|
||||
)
|
||||
try:
|
||||
child_config = patch_config(config, callbacks=run_manager.get_child())
|
||||
@ -1781,6 +1786,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
{"input": ""},
|
||||
run_type=run_type,
|
||||
name=config.get("run_name") or self.get_name(),
|
||||
run_id=config.pop("run_id", None),
|
||||
)
|
||||
try:
|
||||
child_config = patch_config(config, callbacks=run_manager.get_child())
|
||||
@ -2262,7 +2268,10 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
callback_manager = get_callback_manager_for_config(config)
|
||||
# start the root run
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self), input, name=config.get("run_name") or self.get_name()
|
||||
dumpd(self),
|
||||
input,
|
||||
name=config.get("run_name") or self.get_name(),
|
||||
run_id=config.pop("run_id", None),
|
||||
)
|
||||
|
||||
# invoke all steps in sequence
|
||||
@ -2296,7 +2305,10 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
callback_manager = get_async_callback_manager_for_config(config)
|
||||
# start the root run
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self), input, name=config.get("run_name") or self.get_name()
|
||||
dumpd(self),
|
||||
input,
|
||||
name=config.get("run_name") or self.get_name(),
|
||||
run_id=config.pop("run_id", None),
|
||||
)
|
||||
|
||||
# invoke all steps in sequence
|
||||
@ -2354,6 +2366,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
dumpd(self),
|
||||
input,
|
||||
name=config.get("run_name") or self.get_name(),
|
||||
run_id=config.pop("run_id", None),
|
||||
)
|
||||
for cm, input, config in zip(callback_managers, inputs, configs)
|
||||
]
|
||||
@ -2478,6 +2491,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
dumpd(self),
|
||||
input,
|
||||
name=config.get("run_name") or self.get_name(),
|
||||
run_id=config.pop("run_id", None),
|
||||
)
|
||||
for cm, input, config in zip(callback_managers, inputs, configs)
|
||||
)
|
||||
@ -2885,7 +2899,10 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
||||
)
|
||||
# start the root run
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self), input, name=config.get("run_name") or self.get_name()
|
||||
dumpd(self),
|
||||
input,
|
||||
name=config.get("run_name") or self.get_name(),
|
||||
run_id=config.pop("run_id", None),
|
||||
)
|
||||
|
||||
# gather results from all steps
|
||||
@ -2925,7 +2942,10 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
||||
callback_manager = get_async_callback_manager_for_config(config)
|
||||
# start the root run
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self), input, name=config.get("run_name") or self.get_name()
|
||||
dumpd(self),
|
||||
input,
|
||||
name=config.get("run_name") or self.get_name(),
|
||||
run_id=config.pop("run_id", None),
|
||||
)
|
||||
|
||||
# gather results from all steps
|
||||
|
@ -183,6 +183,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
||||
dumpd(self),
|
||||
input,
|
||||
name=config.get("run_name"),
|
||||
run_id=config.pop("run_id", None),
|
||||
)
|
||||
|
||||
try:
|
||||
@ -231,6 +232,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
||||
dumpd(self),
|
||||
input,
|
||||
name=config.get("run_name"),
|
||||
run_id=config.pop("run_id", None),
|
||||
)
|
||||
try:
|
||||
for idx, branch in enumerate(self.branches):
|
||||
@ -282,6 +284,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
||||
dumpd(self),
|
||||
input,
|
||||
name=config.get("run_name"),
|
||||
run_id=config.pop("run_id", None),
|
||||
)
|
||||
final_output: Optional[Output] = None
|
||||
final_output_supported = True
|
||||
@ -356,6 +359,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
||||
dumpd(self),
|
||||
input,
|
||||
name=config.get("run_name"),
|
||||
run_id=config.pop("run_id", None),
|
||||
)
|
||||
final_output: Optional[Output] = None
|
||||
final_output_supported = True
|
||||
|
@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
import warnings
|
||||
from concurrent.futures import Executor, Future, ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar, copy_context
|
||||
@ -95,6 +97,12 @@ class RunnableConfig(TypedDict, total=False):
|
||||
configurable.
|
||||
"""
|
||||
|
||||
run_id: Optional[uuid.UUID]
|
||||
"""
|
||||
Unique identifier for the tracer run for this call. If not provided, a new UUID
|
||||
will be generated.
|
||||
"""
|
||||
|
||||
|
||||
var_child_runnable_config = ContextVar(
|
||||
"child_runnable_config", default=RunnableConfig()
|
||||
@ -116,6 +124,7 @@ def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
|
||||
metadata={},
|
||||
callbacks=None,
|
||||
recursion_limit=25,
|
||||
run_id=None,
|
||||
)
|
||||
if var_config := var_child_runnable_config.get():
|
||||
empty.update(
|
||||
@ -158,11 +167,21 @@ def get_config_list(
|
||||
f"but got {len(config)} configs for {length} inputs"
|
||||
)
|
||||
|
||||
return (
|
||||
list(map(ensure_config, config))
|
||||
if isinstance(config, list)
|
||||
else [ensure_config(config) for _ in range(length)]
|
||||
if isinstance(config, list):
|
||||
return list(map(ensure_config, config))
|
||||
if length > 1 and isinstance(config, dict) and config.get("run_id") is not None:
|
||||
warnings.warn(
|
||||
"Provided run_id be used only for the first element of the batch.",
|
||||
category=RuntimeWarning,
|
||||
)
|
||||
subsequent = cast(
|
||||
RunnableConfig, {k: v for k, v in config.items() if k != "run_id"}
|
||||
)
|
||||
return [
|
||||
ensure_config(subsequent) if i else ensure_config(config)
|
||||
for i in range(length)
|
||||
]
|
||||
return [ensure_config(config) for i in range(length)]
|
||||
|
||||
|
||||
def patch_config(
|
||||
@ -199,6 +218,8 @@ def patch_config(
|
||||
config["callbacks"] = callbacks
|
||||
if "run_name" in config:
|
||||
del config["run_name"]
|
||||
if "run_id" in config:
|
||||
del config["run_id"]
|
||||
if recursion_limit is not None:
|
||||
config["recursion_limit"] = recursion_limit
|
||||
if max_concurrency is not None:
|
||||
|
@ -156,7 +156,10 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
callback_manager = get_callback_manager_for_config(config)
|
||||
# start the root run
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self), input, name=config.get("run_name")
|
||||
dumpd(self),
|
||||
input,
|
||||
name=config.get("run_name"),
|
||||
run_id=config.pop("run_id", None),
|
||||
)
|
||||
first_error = None
|
||||
last_error = None
|
||||
@ -200,7 +203,10 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
callback_manager = get_async_callback_manager_for_config(config)
|
||||
# start the root run
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self), input, name=config.get("run_name")
|
||||
dumpd(self),
|
||||
input,
|
||||
name=config.get("run_name"),
|
||||
run_id=config.pop("run_id", None),
|
||||
)
|
||||
|
||||
first_error = None
|
||||
@ -270,6 +276,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
dumpd(self),
|
||||
input if isinstance(input, dict) else {"input": input},
|
||||
name=config.get("run_name"),
|
||||
run_id=config.pop("run_id", None),
|
||||
)
|
||||
for cm, input, config in zip(callback_managers, inputs, configs)
|
||||
]
|
||||
@ -362,6 +369,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
dumpd(self),
|
||||
input,
|
||||
name=config.get("run_name"),
|
||||
run_id=config.pop("run_id", None),
|
||||
)
|
||||
for cm, input, config in zip(callback_managers, inputs, configs)
|
||||
)
|
||||
@ -436,7 +444,10 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
callback_manager = get_callback_manager_for_config(config)
|
||||
# start the root run
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self), input, name=config.get("run_name")
|
||||
dumpd(self),
|
||||
input,
|
||||
name=config.get("run_name"),
|
||||
run_id=config.pop("run_id", None),
|
||||
)
|
||||
first_error = None
|
||||
last_error = None
|
||||
@ -493,7 +504,10 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
callback_manager = get_async_callback_manager_for_config(config)
|
||||
# start the root run
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self), input, name=config.get("run_name")
|
||||
dumpd(self),
|
||||
input,
|
||||
name=config.get("run_name"),
|
||||
run_id=config.pop("run_id", None),
|
||||
)
|
||||
first_error = None
|
||||
last_error = None
|
||||
|
15
libs/core/langchain_core/runnables/learnable.py
Normal file
15
libs/core/langchain_core/runnables/learnable.py
Normal file
@ -0,0 +1,15 @@
|
||||
# from langchain_core.runnables.base import RunnableBinding
|
||||
|
||||
|
||||
# class RunnableLearnable(RunnableBinding):
|
||||
# def __init__(self, *args, **kwargs):
|
||||
# super().__init__(*args, **kwargs)
|
||||
# self.parameters = []
|
||||
|
||||
# def backward(self):
|
||||
# for param in self.parameters:
|
||||
# param.backward()
|
||||
|
||||
# def update(self, optimizer):
|
||||
# for param in self.parameters:
|
||||
# optimizer.update(param)
|
@ -20,6 +20,7 @@ tool for the job.
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import uuid
|
||||
import warnings
|
||||
from abc import abstractmethod
|
||||
from inspect import signature
|
||||
@ -243,6 +244,7 @@ class ChildTool(BaseTool):
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
run_name=config.get("run_name"),
|
||||
run_id=config.pop("run_id", None),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -259,6 +261,7 @@ class ChildTool(BaseTool):
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
run_name=config.get("run_name"),
|
||||
run_id=config.pop("run_id", None),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -339,6 +342,7 @@ class ChildTool(BaseTool):
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
run_name: Optional[str] = None,
|
||||
run_id: Optional[uuid.UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run the tool."""
|
||||
@ -362,6 +366,7 @@ class ChildTool(BaseTool):
|
||||
tool_input if isinstance(tool_input, str) else str(tool_input),
|
||||
color=start_color,
|
||||
name=run_name,
|
||||
run_id=run_id,
|
||||
# Inputs by definition should always be dicts.
|
||||
# For now, it's unclear whether this assumption is ever violated,
|
||||
# but if it is we will send a `None` value to the callback instead
|
||||
@ -430,6 +435,7 @@ class ChildTool(BaseTool):
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
run_name: Optional[str] = None,
|
||||
run_id: Optional[uuid.UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run the tool asynchronously."""
|
||||
@ -453,6 +459,7 @@ class ChildTool(BaseTool):
|
||||
color=start_color,
|
||||
name=run_name,
|
||||
inputs=tool_input,
|
||||
run_id=run_id,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
|
@ -1,4 +1,5 @@
|
||||
import sys
|
||||
import uuid
|
||||
from functools import partial
|
||||
from operator import itemgetter
|
||||
from typing import (
|
||||
@ -136,6 +137,22 @@ class FakeTracer(BaseTracer):
|
||||
|
||||
self.runs.append(self._copy_run(run))
|
||||
|
||||
def flattened_runs(self) -> List[Run]:
|
||||
q = [] + self.runs
|
||||
result = []
|
||||
while q:
|
||||
parent = q.pop()
|
||||
result.append(parent)
|
||||
if parent.child_runs:
|
||||
q.extend(parent.child_runs)
|
||||
return result
|
||||
|
||||
@property
|
||||
def run_ids(self) -> List[Optional[uuid.UUID]]:
|
||||
runs = self.flattened_runs()
|
||||
uuids_map = {v: k for k, v in self.uuids_map.items()}
|
||||
return [uuids_map.get(r.id) for r in runs]
|
||||
|
||||
|
||||
class FakeRunnable(Runnable[str, int]):
|
||||
def invoke(
|
||||
@ -1367,6 +1384,7 @@ async def test_with_config_metadata_passthrough(mocker: MockerFixture) -> None:
|
||||
recursion_limit=25,
|
||||
configurable={"hello": "there"},
|
||||
metadata={"hello": "there", "bye": "now"},
|
||||
run_id=None,
|
||||
),
|
||||
)
|
||||
spy.reset_mock()
|
||||
@ -1508,6 +1526,7 @@ async def test_with_config(mocker: MockerFixture) -> None:
|
||||
tags=["c"],
|
||||
callbacks=None,
|
||||
recursion_limit=5,
|
||||
run_id=None,
|
||||
),
|
||||
),
|
||||
mocker.call(
|
||||
@ -1517,6 +1536,7 @@ async def test_with_config(mocker: MockerFixture) -> None:
|
||||
tags=["c"],
|
||||
callbacks=None,
|
||||
recursion_limit=5,
|
||||
run_id=None,
|
||||
),
|
||||
),
|
||||
]
|
||||
@ -1542,6 +1562,7 @@ async def test_with_config(mocker: MockerFixture) -> None:
|
||||
tags=["c"],
|
||||
callbacks=None,
|
||||
recursion_limit=5,
|
||||
run_id=None,
|
||||
),
|
||||
)
|
||||
second_call = next(call for call in spy.call_args_list if call.args[0] == "wooorld")
|
||||
@ -1552,6 +1573,7 @@ async def test_with_config(mocker: MockerFixture) -> None:
|
||||
tags=["c"],
|
||||
callbacks=None,
|
||||
recursion_limit=5,
|
||||
run_id=None,
|
||||
),
|
||||
)
|
||||
|
||||
@ -1620,6 +1642,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
|
||||
tags=[],
|
||||
callbacks=None,
|
||||
recursion_limit=25,
|
||||
run_id=None,
|
||||
),
|
||||
),
|
||||
mocker.call(
|
||||
@ -1629,6 +1652,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
|
||||
tags=[],
|
||||
callbacks=None,
|
||||
recursion_limit=25,
|
||||
run_id=None,
|
||||
),
|
||||
),
|
||||
]
|
||||
@ -4822,27 +4846,45 @@ async def test_runnable_gen_context_config() -> None:
|
||||
}
|
||||
|
||||
tracer = FakeTracer()
|
||||
assert runnable.invoke(None, {"callbacks": [tracer]}) == 6
|
||||
run_id = uuid.uuid4()
|
||||
assert runnable.invoke(None, {"callbacks": [tracer], "run_id": run_id}) == 6
|
||||
assert len(tracer.runs) == 1
|
||||
assert tracer.runs[0].outputs == {"output": 6}
|
||||
assert len(tracer.runs[0].child_runs) == 3
|
||||
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
|
||||
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
|
||||
run_ids = tracer.run_ids
|
||||
assert run_id in run_ids
|
||||
assert len(run_ids) == len(set(run_ids))
|
||||
tracer.runs.clear()
|
||||
|
||||
assert list(runnable.stream(None)) == [1, 2, 3]
|
||||
assert len(tracer.runs) == 0, "callbacks doesn't persist from previous call"
|
||||
|
||||
tracer = FakeTracer()
|
||||
assert list(runnable.stream(None, {"callbacks": [tracer]})) == [1, 2, 3]
|
||||
run_id = uuid.uuid4()
|
||||
assert list(runnable.stream(None, {"callbacks": [tracer], "run_id": run_id})) == [
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
]
|
||||
assert len(tracer.runs) == 1
|
||||
assert tracer.runs[0].outputs == {"output": 6}
|
||||
assert len(tracer.runs[0].child_runs) == 3
|
||||
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
|
||||
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
|
||||
run_ids = tracer.run_ids
|
||||
assert run_id in run_ids
|
||||
assert len(run_ids) == len(set(run_ids))
|
||||
tracer.runs.clear()
|
||||
|
||||
tracer = FakeTracer()
|
||||
assert runnable.batch([None, None], {"callbacks": [tracer]}) == [6, 6]
|
||||
run_id = uuid.uuid4()
|
||||
|
||||
with pytest.warns(RuntimeWarning):
|
||||
assert runnable.batch(
|
||||
[None, None], {"callbacks": [tracer], "run_id": run_id}
|
||||
) == [6, 6]
|
||||
assert len(tracer.runs) == 2
|
||||
assert tracer.runs[0].outputs == {"output": 6}
|
||||
assert tracer.runs[1].outputs == {"output": 6}
|
||||
@ -4865,19 +4907,30 @@ async def test_runnable_gen_context_config() -> None:
|
||||
arunnable = RunnableGenerator(agen)
|
||||
|
||||
tracer = FakeTracer()
|
||||
assert await arunnable.ainvoke(None, {"callbacks": [tracer]}) == 6
|
||||
|
||||
run_id = uuid.uuid4()
|
||||
assert await arunnable.ainvoke(None, {"callbacks": [tracer], "run_id": run_id}) == 6
|
||||
assert len(tracer.runs) == 1
|
||||
assert tracer.runs[0].outputs == {"output": 6}
|
||||
assert len(tracer.runs[0].child_runs) == 3
|
||||
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
|
||||
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
|
||||
run_ids = tracer.run_ids
|
||||
assert run_id in run_ids
|
||||
assert len(run_ids) == len(set(run_ids))
|
||||
tracer.runs.clear()
|
||||
|
||||
assert [p async for p in arunnable.astream(None)] == [1, 2, 3]
|
||||
assert len(tracer.runs) == 0, "callbacks doesn't persist from previous call"
|
||||
|
||||
tracer = FakeTracer()
|
||||
assert [p async for p in arunnable.astream(None, {"callbacks": [tracer]})] == [
|
||||
run_id = uuid.uuid4()
|
||||
assert [
|
||||
p
|
||||
async for p in arunnable.astream(
|
||||
None, {"callbacks": [tracer], "run_id": run_id}
|
||||
)
|
||||
] == [
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
@ -4887,9 +4940,16 @@ async def test_runnable_gen_context_config() -> None:
|
||||
assert len(tracer.runs[0].child_runs) == 3
|
||||
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
|
||||
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
|
||||
run_ids = tracer.run_ids
|
||||
assert run_id in run_ids
|
||||
assert len(run_ids) == len(set(run_ids))
|
||||
|
||||
tracer = FakeTracer()
|
||||
assert await arunnable.abatch([None, None], {"callbacks": [tracer]}) == [6, 6]
|
||||
run_id = uuid.uuid4()
|
||||
with pytest.warns(RuntimeWarning):
|
||||
assert await arunnable.abatch(
|
||||
[None, None], {"callbacks": [tracer], "run_id": run_id}
|
||||
) == [6, 6]
|
||||
assert len(tracer.runs) == 2
|
||||
assert tracer.runs[0].outputs == {"output": 6}
|
||||
assert tracer.runs[1].outputs == {"output": 6}
|
||||
|
Loading…
Reference in New Issue
Block a user