Compare commits

...

1 Commits

Author SHA1 Message Date
William Fu-Hinthorn
293bae133f RFC [Core] Tracing - unify tracing 2024-06-10 08:26:00 -07:00
7 changed files with 33 additions and 45 deletions

View File

@@ -45,7 +45,6 @@ from langchain_core.load.serializable import (
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables.config import (
RunnableConfig,
_set_config_context,
acall_func_with_variable_args,
call_func_with_variable_args,
ensure_config,
@@ -56,6 +55,7 @@ from langchain_core.runnables.config import (
merge_configs,
patch_config,
run_in_executor,
var_child_runnable_config,
)
from langchain_core.runnables.graph import Graph
from langchain_core.runnables.schema import StreamEvent
@@ -1590,7 +1590,7 @@ class Runnable(Generic[Input, Output], ABC):
try:
child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context()
context.run(_set_config_context, child_config)
context.run(var_child_runnable_config.set, child_config)
output = cast(
Output,
context.run(
@@ -1638,7 +1638,7 @@ class Runnable(Generic[Input, Output], ABC):
try:
child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context()
context.run(_set_config_context, child_config)
context.run(var_child_runnable_config.set, child_config)
coro = acall_func_with_variable_args(
func, input, config, run_manager, **kwargs
)
@@ -1847,7 +1847,7 @@ class Runnable(Generic[Input, Output], ABC):
if accepts_run_manager(transformer):
kwargs["run_manager"] = run_manager
context = copy_context()
context.run(_set_config_context, child_config)
context.run(var_child_runnable_config.set, child_config)
iterator = context.run(transformer, input_for_transform, **kwargs) # type: ignore[arg-type]
if stream_handler := next(
(
@@ -1947,7 +1947,7 @@ class Runnable(Generic[Input, Output], ABC):
if accepts_run_manager(transformer):
kwargs["run_manager"] = run_manager
context = copy_context()
context.run(_set_config_context, child_config)
context.run(var_child_runnable_config.set, child_config)
iterator = context.run(transformer, input_for_transform, **kwargs) # type: ignore[arg-type]
if stream_handler := next(

View File

@@ -110,25 +110,6 @@ var_child_runnable_config = ContextVar(
)
def _set_config_context(config: RunnableConfig) -> None:
"""Set the child runnable config + tracing context
Args:
config (RunnableConfig): The config to set.
"""
from langsmith import (
RunTree, # type: ignore
run_helpers, # type: ignore
)
var_child_runnable_config.set(config)
if hasattr(RunTree, "from_runnable_config"):
# import _set_tracing_context, get_tracing_context
rt = RunTree.from_runnable_config(dict(config))
tc = run_helpers.get_tracing_context()
run_helpers._set_tracing_context({**tc, "parent": rt})
def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
"""Ensure that a config is a dict with all keys present.

View File

@@ -65,9 +65,9 @@ from langchain_core.runnables import (
ensure_config,
)
from langchain_core.runnables.config import (
_set_config_context,
patch_config,
run_in_executor,
var_child_runnable_config,
)
from langchain_core.runnables.utils import accepts_context
@@ -402,7 +402,7 @@ class ChildTool(BaseTool):
callbacks=run_manager.get_child(),
)
context = copy_context()
context.run(_set_config_context, child_config)
context.run(var_child_runnable_config.set, child_config)
parsed_input = self._parse_input(tool_input)
tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input)
observation = (
@@ -502,7 +502,7 @@ class ChildTool(BaseTool):
callbacks=run_manager.get_child(),
)
context = copy_context()
context.run(_set_config_context, child_config)
context.run(var_child_runnable_config.set, child_config)
coro = (
context.run(
self._arun, *tool_args, run_manager=run_manager, **tool_kwargs

View File

@@ -10,6 +10,7 @@ from uuid import UUID
from langsmith import Client
from langsmith import utils as ls_utils
from langsmith.run_helpers import _set_tracing_context
from tenacity import (
Retrying,
retry_if_exception_type,
@@ -65,7 +66,7 @@ def _get_executor() -> ThreadPoolExecutor:
def _run_to_dict(run: Run) -> dict:
return {
**run.dict(exclude={"child_runs", "inputs", "outputs"}),
**run.dict(exclude={"child_runs", "inputs", "outputs", "client", "parent_run"}),
"inputs": run.inputs.copy() if run.inputs is not None else None,
"outputs": run.outputs.copy() if run.outputs is not None else None,
}
@@ -125,6 +126,7 @@ class LangChainTracer(BaseTracer):
return chat_model_run
def _persist_run(self, run: Run) -> None:
breakpoint()
run_ = run.copy()
run_.reference_example_id = self.example_id
self.latest_run = run_
@@ -155,6 +157,11 @@ class LangChainTracer(BaseTracer):
def _persist_run_single(self, run: Run) -> None:
"""Persist a run."""
run.client = self.client
if run.parent_run_id:
if str(run.parent_run_id) in self.run_map:
run.parent_run = self.run_map[str(run.parent_run_id)]
_set_tracing_context({"parent": run})
run_dict = _run_to_dict(run)
run_dict["tags"] = self._get_tags(run)
extra = run_dict.get("extra", {})
@@ -169,6 +176,7 @@ class LangChainTracer(BaseTracer):
def _update_run_single(self, run: Run) -> None:
"""Update a run."""
_set_tracing_context({"parent": run.parent_run})
try:
run_dict = _run_to_dict(run)
run_dict["tags"] = self._get_tags(run)

View File

@@ -1,4 +1,5 @@
"""Schemas for tracers."""
from __future__ import annotations
import datetime
@@ -6,8 +7,8 @@ import warnings
from typing import Any, Dict, List, Optional, Type
from uuid import UUID
from langsmith.schemas import RunBase as BaseRunV2
from langsmith.schemas import RunTypeEnum as RunTypeEnumDep
from langsmith import RunTree
from langsmith.schemas import RunTypeEnum as RunTypeEnumDep, RunBase as BaseRunDep # type: ignore
from langchain_core._api import deprecated
from langchain_core.outputs import LLMResult
@@ -110,14 +111,13 @@ class ToolRun(BaseRun):
# Begin V2 API Schemas
class Run(BaseRunV2):
# TODO: rm client API key validation
class Run(RunTree):
"""Run schema for the V2 API in the Tracer."""
child_runs: List[Run] = Field(default_factory=list)
tags: Optional[List[str]] = Field(default_factory=list)
events: List[Dict[str, Any]] = Field(default_factory=list)
trace_id: Optional[UUID] = None
dotted_order: Optional[str] = None
@root_validator(pre=True)
def assign_name(cls, values: dict) -> dict:

File diff suppressed because one or more lines are too long

View File

@@ -2931,7 +2931,6 @@ def test_prompt_with_chat_model_and_parser(
]
)
assert parser_spy.call_args.args[1] == AIMessage(content="foo, bar", id=AnyStr())
assert tracer.runs == snapshot