mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-04 04:07:54 +00:00
[Breaking] Refactor Base Tracer(#4549)
### Refactor the BaseTracer - Remove the 'session' abstraction from the BaseTracer - Rename 'RunV2' object(s) to be called 'Run' objects (Rename previous Run objects to be RunV1 objects) - Ditto for sessions: TracerSession*V2 -> TracerSession* - Remove now deprecated conversion from v1 run objects to v2 run objects in LangChainTracerV2 - Add conversion from v2 run objects to v1 run objects in V1 tracer
This commit is contained in:
parent
1e322ffc1c
commit
928cdd57a4
@ -20,9 +20,9 @@ from langchain.callbacks.base import (
|
|||||||
)
|
)
|
||||||
from langchain.callbacks.openai_info import OpenAICallbackHandler
|
from langchain.callbacks.openai_info import OpenAICallbackHandler
|
||||||
from langchain.callbacks.stdout import StdOutCallbackHandler
|
from langchain.callbacks.stdout import StdOutCallbackHandler
|
||||||
from langchain.callbacks.tracers.base import TracerSession
|
from langchain.callbacks.tracers.langchain import LangChainTracer
|
||||||
from langchain.callbacks.tracers.langchain import LangChainTracer, LangChainTracerV2
|
from langchain.callbacks.tracers.langchain_v1 import LangChainTracerV1, TracerSessionV1
|
||||||
from langchain.callbacks.tracers.schemas import TracerSessionV2
|
from langchain.callbacks.tracers.schemas import TracerSession
|
||||||
from langchain.schema import (
|
from langchain.schema import (
|
||||||
AgentAction,
|
AgentAction,
|
||||||
AgentFinish,
|
AgentFinish,
|
||||||
@ -37,11 +37,13 @@ Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]]
|
|||||||
openai_callback_var: ContextVar[Optional[OpenAICallbackHandler]] = ContextVar(
|
openai_callback_var: ContextVar[Optional[OpenAICallbackHandler]] = ContextVar(
|
||||||
"openai_callback", default=None
|
"openai_callback", default=None
|
||||||
)
|
)
|
||||||
tracing_callback_var: ContextVar[Optional[LangChainTracer]] = ContextVar( # noqa: E501
|
tracing_callback_var: ContextVar[
|
||||||
|
Optional[LangChainTracerV1]
|
||||||
|
] = ContextVar( # noqa: E501
|
||||||
"tracing_callback", default=None
|
"tracing_callback", default=None
|
||||||
)
|
)
|
||||||
tracing_v2_callback_var: ContextVar[
|
tracing_v2_callback_var: ContextVar[
|
||||||
Optional[LangChainTracerV2]
|
Optional[LangChainTracer]
|
||||||
] = ContextVar( # noqa: E501
|
] = ContextVar( # noqa: E501
|
||||||
"tracing_callback_v2", default=None
|
"tracing_callback_v2", default=None
|
||||||
)
|
)
|
||||||
@ -59,10 +61,10 @@ def get_openai_callback() -> Generator[OpenAICallbackHandler, None, None]:
|
|||||||
@contextmanager
|
@contextmanager
|
||||||
def tracing_enabled(
|
def tracing_enabled(
|
||||||
session_name: str = "default",
|
session_name: str = "default",
|
||||||
) -> Generator[TracerSession, None, None]:
|
) -> Generator[TracerSessionV1, None, None]:
|
||||||
"""Get Tracer in a context manager."""
|
"""Get Tracer in a context manager."""
|
||||||
cb = LangChainTracer()
|
cb = LangChainTracerV1()
|
||||||
session = cast(TracerSession, cb.load_session(session_name))
|
session = cast(TracerSessionV1, cb.load_session(session_name))
|
||||||
tracing_callback_var.set(cb)
|
tracing_callback_var.set(cb)
|
||||||
yield session
|
yield session
|
||||||
tracing_callback_var.set(None)
|
tracing_callback_var.set(None)
|
||||||
@ -70,9 +72,12 @@ def tracing_enabled(
|
|||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def tracing_v2_enabled(
|
def tracing_v2_enabled(
|
||||||
session_name: str = "default",
|
session_name: Optional[str] = None,
|
||||||
|
*,
|
||||||
example_id: Optional[Union[str, UUID]] = None,
|
example_id: Optional[Union[str, UUID]] = None,
|
||||||
) -> Generator[TracerSessionV2, None, None]:
|
tenant_id: Optional[str] = None,
|
||||||
|
session_extra: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> Generator[TracerSession, None, None]:
|
||||||
"""Get the experimental tracer handler in a context manager."""
|
"""Get the experimental tracer handler in a context manager."""
|
||||||
# Issue a warning that this is experimental
|
# Issue a warning that this is experimental
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
@ -81,11 +86,16 @@ def tracing_v2_enabled(
|
|||||||
)
|
)
|
||||||
if isinstance(example_id, str):
|
if isinstance(example_id, str):
|
||||||
example_id = UUID(example_id)
|
example_id = UUID(example_id)
|
||||||
cb = LangChainTracerV2(example_id=example_id)
|
cb = LangChainTracer(
|
||||||
session = cast(TracerSessionV2, cb.new_session(session_name))
|
tenant_id=tenant_id,
|
||||||
tracing_callback_var.set(cb)
|
session_name=session_name,
|
||||||
|
example_id=example_id,
|
||||||
|
session_extra=session_extra,
|
||||||
|
)
|
||||||
|
session = cb.ensure_session()
|
||||||
|
tracing_v2_callback_var.set(cb)
|
||||||
yield session
|
yield session
|
||||||
tracing_callback_var.set(None)
|
tracing_v2_callback_var.set(None)
|
||||||
|
|
||||||
|
|
||||||
def _handle_event(
|
def _handle_event(
|
||||||
@ -829,32 +839,35 @@ def _configure(
|
|||||||
tracer_session = os.environ.get("LANGCHAIN_SESSION")
|
tracer_session = os.environ.get("LANGCHAIN_SESSION")
|
||||||
if tracer_session is None:
|
if tracer_session is None:
|
||||||
tracer_session = "default"
|
tracer_session = "default"
|
||||||
if verbose or tracing_enabled_ or open_ai is not None:
|
if verbose or tracing_enabled_ or tracing_v2_enabled_ or open_ai is not None:
|
||||||
if verbose and not any(
|
if verbose and not any(
|
||||||
isinstance(handler, StdOutCallbackHandler)
|
isinstance(handler, StdOutCallbackHandler)
|
||||||
for handler in callback_manager.handlers
|
for handler in callback_manager.handlers
|
||||||
):
|
):
|
||||||
callback_manager.add_handler(StdOutCallbackHandler(), False)
|
callback_manager.add_handler(StdOutCallbackHandler(), False)
|
||||||
if tracing_enabled_ and not any(
|
if tracing_enabled_ and not any(
|
||||||
isinstance(handler, LangChainTracer)
|
isinstance(handler, LangChainTracerV1)
|
||||||
for handler in callback_manager.handlers
|
for handler in callback_manager.handlers
|
||||||
):
|
):
|
||||||
if tracer:
|
if tracer:
|
||||||
callback_manager.add_handler(tracer, True)
|
callback_manager.add_handler(tracer, True)
|
||||||
else:
|
else:
|
||||||
handler = LangChainTracer()
|
handler = LangChainTracerV1()
|
||||||
handler.load_session(tracer_session)
|
handler.load_session(tracer_session)
|
||||||
callback_manager.add_handler(handler, True)
|
callback_manager.add_handler(handler, True)
|
||||||
if tracing_v2_enabled_ and not any(
|
if tracing_v2_enabled_ and not any(
|
||||||
isinstance(handler, LangChainTracerV2)
|
isinstance(handler, LangChainTracer)
|
||||||
for handler in callback_manager.handlers
|
for handler in callback_manager.handlers
|
||||||
):
|
):
|
||||||
if tracer_v2:
|
if tracer_v2:
|
||||||
callback_manager.add_handler(tracer_v2, True)
|
callback_manager.add_handler(tracer_v2, True)
|
||||||
else:
|
else:
|
||||||
handler = LangChainTracerV2()
|
try:
|
||||||
handler.load_session(tracer_session)
|
handler = LangChainTracer(session_name=tracer_session)
|
||||||
|
handler.ensure_session()
|
||||||
callback_manager.add_handler(handler, True)
|
callback_manager.add_handler(handler, True)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug("Unable to load requested LangChainTracer", e)
|
||||||
if open_ai is not None and not any(
|
if open_ai is not None and not any(
|
||||||
isinstance(handler, OpenAICallbackHandler)
|
isinstance(handler, OpenAICallbackHandler)
|
||||||
for handler in callback_manager.handlers
|
for handler in callback_manager.handlers
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
"""Tracers that record execution of LangChain runs."""
|
"""Tracers that record execution of LangChain runs."""
|
||||||
|
|
||||||
from langchain.callbacks.tracers.langchain import LangChainTracer
|
from langchain.callbacks.tracers.langchain import LangChainTracer
|
||||||
|
from langchain.callbacks.tracers.langchain_v1 import LangChainTracerV1
|
||||||
|
|
||||||
__all__ = ["LangChainTracer"]
|
__all__ = ["LangChainTracer", "LangChainTracerV1"]
|
||||||
|
@ -7,15 +7,7 @@ from typing import Any, Dict, List, Optional, Union
|
|||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from langchain.callbacks.base import BaseCallbackHandler
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
from langchain.callbacks.tracers.schemas import (
|
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum
|
||||||
ChainRun,
|
|
||||||
LLMRun,
|
|
||||||
ToolRun,
|
|
||||||
TracerSession,
|
|
||||||
TracerSessionBase,
|
|
||||||
TracerSessionCreate,
|
|
||||||
TracerSessionV2,
|
|
||||||
)
|
|
||||||
from langchain.schema import LLMResult
|
from langchain.schema import LLMResult
|
||||||
|
|
||||||
|
|
||||||
@ -28,89 +20,45 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
|
|
||||||
def __init__(self, **kwargs: Any) -> None:
|
def __init__(self, **kwargs: Any) -> None:
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.run_map: Dict[str, Union[LLMRun, ChainRun, ToolRun]] = {}
|
self.run_map: Dict[str, Run] = {}
|
||||||
self.session: Optional[Union[TracerSession, TracerSessionV2]] = None
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _add_child_run(
|
def _add_child_run(
|
||||||
parent_run: Union[ChainRun, ToolRun],
|
parent_run: Run,
|
||||||
child_run: Union[LLMRun, ChainRun, ToolRun],
|
child_run: Run,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Add child run to a chain run or tool run."""
|
"""Add child run to a chain run or tool run."""
|
||||||
if isinstance(child_run, LLMRun):
|
parent_run.child_runs.append(child_run)
|
||||||
parent_run.child_llm_runs.append(child_run)
|
|
||||||
elif isinstance(child_run, ChainRun):
|
|
||||||
parent_run.child_chain_runs.append(child_run)
|
|
||||||
elif isinstance(child_run, ToolRun):
|
|
||||||
parent_run.child_tool_runs.append(child_run)
|
|
||||||
else:
|
|
||||||
raise TracerException(f"Invalid run type: {type(child_run)}")
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
|
def _persist_run(self, run: Run) -> None:
|
||||||
"""Persist a run."""
|
"""Persist a run."""
|
||||||
|
|
||||||
@abstractmethod
|
def _start_trace(self, run: Run) -> None:
|
||||||
def _persist_session(
|
|
||||||
self, session: TracerSessionBase
|
|
||||||
) -> Union[TracerSession, TracerSessionV2]:
|
|
||||||
"""Persist a tracing session."""
|
|
||||||
|
|
||||||
def _get_session_create(
|
|
||||||
self, name: Optional[str] = None, **kwargs: Any
|
|
||||||
) -> TracerSessionBase:
|
|
||||||
return TracerSessionCreate(name=name, extra=kwargs)
|
|
||||||
|
|
||||||
def new_session(
|
|
||||||
self, name: Optional[str] = None, **kwargs: Any
|
|
||||||
) -> Union[TracerSession, TracerSessionV2]:
|
|
||||||
"""NOT thread safe, do not call this method from multiple threads."""
|
|
||||||
session_create = self._get_session_create(name=name, **kwargs)
|
|
||||||
session = self._persist_session(session_create)
|
|
||||||
self.session = session
|
|
||||||
return session
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def load_session(self, session_name: str) -> Union[TracerSession, TracerSessionV2]:
|
|
||||||
"""Load a tracing session and set it as the Tracer's session."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def load_default_session(self) -> Union[TracerSession, TracerSessionV2]:
|
|
||||||
"""Load the default tracing session and set it as the Tracer's session."""
|
|
||||||
|
|
||||||
def _start_trace(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
|
|
||||||
"""Start a trace for a run."""
|
"""Start a trace for a run."""
|
||||||
if run.parent_uuid:
|
if run.parent_run_id:
|
||||||
parent_run = self.run_map[run.parent_uuid]
|
parent_run = self.run_map[str(run.parent_run_id)]
|
||||||
if parent_run:
|
if parent_run:
|
||||||
if isinstance(parent_run, LLMRun):
|
|
||||||
raise TracerException(
|
|
||||||
"Cannot add child run to an LLM run. "
|
|
||||||
"LLM runs are not allowed to have children."
|
|
||||||
)
|
|
||||||
self._add_child_run(parent_run, run)
|
self._add_child_run(parent_run, run)
|
||||||
else:
|
else:
|
||||||
raise TracerException(
|
raise TracerException(
|
||||||
f"Parent run with UUID {run.parent_uuid} not found."
|
f"Parent run with UUID {run.parent_run_id} not found."
|
||||||
)
|
)
|
||||||
|
self.run_map[str(run.id)] = run
|
||||||
|
|
||||||
self.run_map[run.uuid] = run
|
def _end_trace(self, run: Run) -> None:
|
||||||
|
|
||||||
def _end_trace(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
|
|
||||||
"""End a trace for a run."""
|
"""End a trace for a run."""
|
||||||
if not run.parent_uuid:
|
if not run.parent_run_id:
|
||||||
self._persist_run(run)
|
self._persist_run(run)
|
||||||
else:
|
else:
|
||||||
parent_run = self.run_map.get(run.parent_uuid)
|
parent_run = self.run_map.get(str(run.parent_run_id))
|
||||||
if parent_run is None:
|
if parent_run is None:
|
||||||
raise TracerException(
|
raise TracerException(
|
||||||
f"Parent run with UUID {run.parent_uuid} not found."
|
f"Parent run with UUID {run.parent_run_id} not found."
|
||||||
)
|
)
|
||||||
if isinstance(parent_run, LLMRun):
|
|
||||||
raise TracerException("LLM Runs are not allowed to have children. ")
|
|
||||||
if run.child_execution_order > parent_run.child_execution_order:
|
if run.child_execution_order > parent_run.child_execution_order:
|
||||||
parent_run.child_execution_order = run.child_execution_order
|
parent_run.child_execution_order = run.child_execution_order
|
||||||
self.run_map.pop(run.uuid)
|
self.run_map.pop(str(run.id))
|
||||||
|
|
||||||
def _get_execution_order(self, parent_run_id: Optional[str] = None) -> int:
|
def _get_execution_order(self, parent_run_id: Optional[str] = None) -> int:
|
||||||
"""Get the execution order for a run."""
|
"""Get the execution order for a run."""
|
||||||
@ -121,9 +69,6 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
if parent_run is None:
|
if parent_run is None:
|
||||||
raise TracerException(f"Parent run with UUID {parent_run_id} not found.")
|
raise TracerException(f"Parent run with UUID {parent_run_id} not found.")
|
||||||
|
|
||||||
if isinstance(parent_run, LLMRun):
|
|
||||||
raise TracerException("LLM Runs are not allowed to have children. ")
|
|
||||||
|
|
||||||
return parent_run.child_execution_order + 1
|
return parent_run.child_execution_order + 1
|
||||||
|
|
||||||
def on_llm_start(
|
def on_llm_start(
|
||||||
@ -136,25 +81,22 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Start a trace for an LLM run."""
|
"""Start a trace for an LLM run."""
|
||||||
if self.session is None:
|
|
||||||
self.session = self.load_default_session()
|
|
||||||
|
|
||||||
run_id_ = str(run_id)
|
|
||||||
parent_run_id_ = str(parent_run_id) if parent_run_id else None
|
parent_run_id_ = str(parent_run_id) if parent_run_id else None
|
||||||
|
|
||||||
execution_order = self._get_execution_order(parent_run_id_)
|
execution_order = self._get_execution_order(parent_run_id_)
|
||||||
llm_run = LLMRun(
|
llm_run = Run(
|
||||||
uuid=run_id_,
|
id=run_id,
|
||||||
parent_uuid=parent_run_id_,
|
name=serialized.get("name"),
|
||||||
|
parent_run_id=parent_run_id,
|
||||||
serialized=serialized,
|
serialized=serialized,
|
||||||
prompts=prompts,
|
inputs={"prompts": prompts},
|
||||||
extra=kwargs,
|
extra=kwargs,
|
||||||
start_time=datetime.utcnow(),
|
start_time=datetime.utcnow(),
|
||||||
execution_order=execution_order,
|
execution_order=execution_order,
|
||||||
child_execution_order=execution_order,
|
child_execution_order=execution_order,
|
||||||
session_id=self.session.id,
|
run_type=RunTypeEnum.llm,
|
||||||
)
|
)
|
||||||
self._start_trace(llm_run)
|
self._start_trace(llm_run)
|
||||||
|
self._on_llm_start(llm_run)
|
||||||
|
|
||||||
def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> None:
|
def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> None:
|
||||||
"""End a trace for an LLM run."""
|
"""End a trace for an LLM run."""
|
||||||
@ -163,11 +105,12 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
|
|
||||||
run_id_ = str(run_id)
|
run_id_ = str(run_id)
|
||||||
llm_run = self.run_map.get(run_id_)
|
llm_run = self.run_map.get(run_id_)
|
||||||
if llm_run is None or not isinstance(llm_run, LLMRun):
|
if llm_run is None or llm_run.run_type != RunTypeEnum.llm:
|
||||||
raise TracerException("No LLMRun found to be traced")
|
raise TracerException("No LLM Run found to be traced")
|
||||||
llm_run.response = response
|
llm_run.outputs = response.dict()
|
||||||
llm_run.end_time = datetime.utcnow()
|
llm_run.end_time = datetime.utcnow()
|
||||||
self._end_trace(llm_run)
|
self._end_trace(llm_run)
|
||||||
|
self._on_llm_end(llm_run)
|
||||||
|
|
||||||
def on_llm_error(
|
def on_llm_error(
|
||||||
self,
|
self,
|
||||||
@ -182,12 +125,12 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
|
|
||||||
run_id_ = str(run_id)
|
run_id_ = str(run_id)
|
||||||
llm_run = self.run_map.get(run_id_)
|
llm_run = self.run_map.get(run_id_)
|
||||||
if llm_run is None or not isinstance(llm_run, LLMRun):
|
if llm_run is None or llm_run.run_type != RunTypeEnum.llm:
|
||||||
raise TracerException("No LLMRun found to be traced")
|
raise TracerException("No LLM Run found to be traced")
|
||||||
|
|
||||||
llm_run.error = repr(error)
|
llm_run.error = repr(error)
|
||||||
llm_run.end_time = datetime.utcnow()
|
llm_run.end_time = datetime.utcnow()
|
||||||
self._end_trace(llm_run)
|
self._end_trace(llm_run)
|
||||||
|
self._on_chain_error(llm_run)
|
||||||
|
|
||||||
def on_chain_start(
|
def on_chain_start(
|
||||||
self,
|
self,
|
||||||
@ -199,16 +142,12 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Start a trace for a chain run."""
|
"""Start a trace for a chain run."""
|
||||||
if self.session is None:
|
|
||||||
self.session = self.load_default_session()
|
|
||||||
|
|
||||||
run_id_ = str(run_id)
|
|
||||||
parent_run_id_ = str(parent_run_id) if parent_run_id else None
|
parent_run_id_ = str(parent_run_id) if parent_run_id else None
|
||||||
|
|
||||||
execution_order = self._get_execution_order(parent_run_id_)
|
execution_order = self._get_execution_order(parent_run_id_)
|
||||||
chain_run = ChainRun(
|
chain_run = Run(
|
||||||
uuid=run_id_,
|
id=run_id,
|
||||||
parent_uuid=parent_run_id_,
|
name=serialized.get("name"),
|
||||||
|
parent_run_id=parent_run_id,
|
||||||
serialized=serialized,
|
serialized=serialized,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
extra=kwargs,
|
extra=kwargs,
|
||||||
@ -216,23 +155,25 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
execution_order=execution_order,
|
execution_order=execution_order,
|
||||||
child_execution_order=execution_order,
|
child_execution_order=execution_order,
|
||||||
child_runs=[],
|
child_runs=[],
|
||||||
session_id=self.session.id,
|
run_type=RunTypeEnum.chain,
|
||||||
)
|
)
|
||||||
self._start_trace(chain_run)
|
self._start_trace(chain_run)
|
||||||
|
self._on_chain_start(chain_run)
|
||||||
|
|
||||||
def on_chain_end(
|
def on_chain_end(
|
||||||
self, outputs: Dict[str, Any], *, run_id: UUID, **kwargs: Any
|
self, outputs: Dict[str, Any], *, run_id: UUID, **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""End a trace for a chain run."""
|
"""End a trace for a chain run."""
|
||||||
run_id_ = str(run_id)
|
if not run_id:
|
||||||
|
raise TracerException("No run_id provided for on_chain_end callback.")
|
||||||
chain_run = self.run_map.get(run_id_)
|
chain_run = self.run_map.get(str(run_id))
|
||||||
if chain_run is None or not isinstance(chain_run, ChainRun):
|
if chain_run is None or chain_run.run_type != RunTypeEnum.chain:
|
||||||
raise TracerException("No ChainRun found to be traced")
|
raise TracerException("No chain Run found to be traced")
|
||||||
|
|
||||||
chain_run.outputs = outputs
|
chain_run.outputs = outputs
|
||||||
chain_run.end_time = datetime.utcnow()
|
chain_run.end_time = datetime.utcnow()
|
||||||
self._end_trace(chain_run)
|
self._end_trace(chain_run)
|
||||||
|
self._on_chain_end(chain_run)
|
||||||
|
|
||||||
def on_chain_error(
|
def on_chain_error(
|
||||||
self,
|
self,
|
||||||
@ -242,15 +183,16 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Handle an error for a chain run."""
|
"""Handle an error for a chain run."""
|
||||||
run_id_ = str(run_id)
|
if not run_id:
|
||||||
|
raise TracerException("No run_id provided for on_chain_error callback.")
|
||||||
chain_run = self.run_map.get(run_id_)
|
chain_run = self.run_map.get(str(run_id))
|
||||||
if chain_run is None or not isinstance(chain_run, ChainRun):
|
if chain_run is None or chain_run.run_type != RunTypeEnum.chain:
|
||||||
raise TracerException("No ChainRun found to be traced")
|
raise TracerException("No chain Run found to be traced")
|
||||||
|
|
||||||
chain_run.error = repr(error)
|
chain_run.error = repr(error)
|
||||||
chain_run.end_time = datetime.utcnow()
|
chain_run.end_time = datetime.utcnow()
|
||||||
self._end_trace(chain_run)
|
self._end_trace(chain_run)
|
||||||
|
self._on_chain_error(chain_run)
|
||||||
|
|
||||||
def on_tool_start(
|
def on_tool_start(
|
||||||
self,
|
self,
|
||||||
@ -262,40 +204,36 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Start a trace for a tool run."""
|
"""Start a trace for a tool run."""
|
||||||
if self.session is None:
|
|
||||||
self.session = self.load_default_session()
|
|
||||||
|
|
||||||
run_id_ = str(run_id)
|
|
||||||
parent_run_id_ = str(parent_run_id) if parent_run_id else None
|
parent_run_id_ = str(parent_run_id) if parent_run_id else None
|
||||||
|
|
||||||
execution_order = self._get_execution_order(parent_run_id_)
|
execution_order = self._get_execution_order(parent_run_id_)
|
||||||
tool_run = ToolRun(
|
tool_run = Run(
|
||||||
uuid=run_id_,
|
id=run_id,
|
||||||
parent_uuid=parent_run_id_,
|
name=serialized.get("name"),
|
||||||
|
parent_run_id=parent_run_id,
|
||||||
serialized=serialized,
|
serialized=serialized,
|
||||||
# TODO: this is duplicate info as above, not needed.
|
inputs={"input": input_str},
|
||||||
action=str(serialized),
|
|
||||||
tool_input=input_str,
|
|
||||||
extra=kwargs,
|
extra=kwargs,
|
||||||
start_time=datetime.utcnow(),
|
start_time=datetime.utcnow(),
|
||||||
execution_order=execution_order,
|
execution_order=execution_order,
|
||||||
child_execution_order=execution_order,
|
child_execution_order=execution_order,
|
||||||
child_runs=[],
|
child_runs=[],
|
||||||
session_id=self.session.id,
|
run_type=RunTypeEnum.tool,
|
||||||
)
|
)
|
||||||
self._start_trace(tool_run)
|
self._start_trace(tool_run)
|
||||||
|
self._on_tool_start(tool_run)
|
||||||
|
|
||||||
def on_tool_end(self, output: str, *, run_id: UUID, **kwargs: Any) -> None:
|
def on_tool_end(self, output: str, *, run_id: UUID, **kwargs: Any) -> None:
|
||||||
"""End a trace for a tool run."""
|
"""End a trace for a tool run."""
|
||||||
run_id_ = str(run_id)
|
if not run_id:
|
||||||
|
raise TracerException("No run_id provided for on_tool_end callback.")
|
||||||
|
tool_run = self.run_map.get(str(run_id))
|
||||||
|
if tool_run is None or tool_run.run_type != RunTypeEnum.tool:
|
||||||
|
raise TracerException("No tool Run found to be traced")
|
||||||
|
|
||||||
tool_run = self.run_map.get(run_id_)
|
tool_run.outputs = {"output": output}
|
||||||
if tool_run is None or not isinstance(tool_run, ToolRun):
|
|
||||||
raise TracerException("No ToolRun found to be traced")
|
|
||||||
|
|
||||||
tool_run.output = output
|
|
||||||
tool_run.end_time = datetime.utcnow()
|
tool_run.end_time = datetime.utcnow()
|
||||||
self._end_trace(tool_run)
|
self._end_trace(tool_run)
|
||||||
|
self._on_tool_end(tool_run)
|
||||||
|
|
||||||
def on_tool_error(
|
def on_tool_error(
|
||||||
self,
|
self,
|
||||||
@ -305,15 +243,16 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Handle an error for a tool run."""
|
"""Handle an error for a tool run."""
|
||||||
run_id_ = str(run_id)
|
if not run_id:
|
||||||
|
raise TracerException("No run_id provided for on_tool_error callback.")
|
||||||
tool_run = self.run_map.get(run_id_)
|
tool_run = self.run_map.get(str(run_id))
|
||||||
if tool_run is None or not isinstance(tool_run, ToolRun):
|
if tool_run is None or tool_run.run_type != RunTypeEnum.tool:
|
||||||
raise TracerException("No ToolRun found to be traced")
|
raise TracerException("No tool Run found to be traced")
|
||||||
|
|
||||||
tool_run.error = repr(error)
|
tool_run.error = repr(error)
|
||||||
tool_run.end_time = datetime.utcnow()
|
tool_run.end_time = datetime.utcnow()
|
||||||
self._end_trace(tool_run)
|
self._end_trace(tool_run)
|
||||||
|
self._on_tool_error(tool_run)
|
||||||
|
|
||||||
def __deepcopy__(self, memo: dict) -> BaseTracer:
|
def __deepcopy__(self, memo: dict) -> BaseTracer:
|
||||||
"""Deepcopy the tracer."""
|
"""Deepcopy the tracer."""
|
||||||
@ -322,3 +261,33 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
def __copy__(self) -> BaseTracer:
|
def __copy__(self) -> BaseTracer:
|
||||||
"""Copy the tracer."""
|
"""Copy the tracer."""
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def _on_llm_start(self, run: Run) -> None:
|
||||||
|
"""Process the LLM Run upon start."""
|
||||||
|
|
||||||
|
def _on_llm_end(self, run: Run) -> None:
|
||||||
|
"""Process the LLM Run."""
|
||||||
|
|
||||||
|
def _on_llm_error(self, run: Run) -> None:
|
||||||
|
"""Process the LLM Run upon error."""
|
||||||
|
|
||||||
|
def _on_chain_start(self, run: Run) -> None:
|
||||||
|
"""Process the Chain Run upon start."""
|
||||||
|
|
||||||
|
def _on_chain_end(self, run: Run) -> None:
|
||||||
|
"""Process the Chain Run."""
|
||||||
|
|
||||||
|
def _on_chain_error(self, run: Run) -> None:
|
||||||
|
"""Process the Chain Run upon error."""
|
||||||
|
|
||||||
|
def _on_tool_start(self, run: Run) -> None:
|
||||||
|
"""Process the Tool Run upon start."""
|
||||||
|
|
||||||
|
def _on_tool_end(self, run: Run) -> None:
|
||||||
|
"""Process the Tool Run."""
|
||||||
|
|
||||||
|
def _on_tool_error(self, run: Run) -> None:
|
||||||
|
"""Process the Tool Run upon error."""
|
||||||
|
|
||||||
|
def _on_chat_model_start(self, run: Run) -> None:
|
||||||
|
"""Process the Chat Model Run upon start."""
|
||||||
|
@ -4,27 +4,24 @@ from __future__ import annotations
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional
|
||||||
from uuid import UUID, uuid4
|
from uuid import UUID
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from langchain.callbacks.tracers.base import BaseTracer
|
from langchain.callbacks.tracers.base import BaseTracer
|
||||||
from langchain.callbacks.tracers.schemas import (
|
from langchain.callbacks.tracers.schemas import (
|
||||||
ChainRun,
|
Run,
|
||||||
LLMRun,
|
|
||||||
RunCreate,
|
RunCreate,
|
||||||
ToolRun,
|
RunTypeEnum,
|
||||||
TracerSession,
|
TracerSession,
|
||||||
TracerSessionBase,
|
TracerSessionCreate,
|
||||||
TracerSessionV2,
|
|
||||||
TracerSessionV2Create,
|
|
||||||
)
|
)
|
||||||
from langchain.schema import BaseMessage, messages_to_dict
|
from langchain.schema import BaseMessage, messages_to_dict
|
||||||
from langchain.utils import raise_for_status_with_text
|
from langchain.utils import raise_for_status_with_text
|
||||||
|
|
||||||
|
|
||||||
def _get_headers() -> Dict[str, Any]:
|
def get_headers() -> Dict[str, Any]:
|
||||||
"""Get the headers for the LangChain API."""
|
"""Get the headers for the LangChain API."""
|
||||||
headers: Dict[str, Any] = {"Content-Type": "application/json"}
|
headers: Dict[str, Any] = {"Content-Type": "application/json"}
|
||||||
if os.getenv("LANGCHAIN_API_KEY"):
|
if os.getenv("LANGCHAIN_API_KEY"):
|
||||||
@ -32,168 +29,47 @@ def _get_headers() -> Dict[str, Any]:
|
|||||||
return headers
|
return headers
|
||||||
|
|
||||||
|
|
||||||
def _get_endpoint() -> str:
|
def get_endpoint() -> str:
|
||||||
return os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:8000")
|
return os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:8000")
|
||||||
|
|
||||||
|
|
||||||
|
def _get_tenant_id(
|
||||||
|
tenant_id: Optional[str], endpoint: Optional[str], headers: Optional[dict]
|
||||||
|
) -> str:
|
||||||
|
"""Get the tenant ID for the LangChain API."""
|
||||||
|
tenant_id_: Optional[str] = tenant_id or os.getenv("LANGCHAIN_TENANT_ID")
|
||||||
|
if tenant_id_:
|
||||||
|
return tenant_id_
|
||||||
|
endpoint_ = endpoint or get_endpoint()
|
||||||
|
headers_ = headers or get_headers()
|
||||||
|
response = requests.get(endpoint_ + "/tenants", headers=headers_)
|
||||||
|
raise_for_status_with_text(response)
|
||||||
|
tenants: List[Dict[str, Any]] = response.json()
|
||||||
|
if not tenants:
|
||||||
|
raise ValueError(f"No tenants found for URL {endpoint_}")
|
||||||
|
return tenants[0]["id"]
|
||||||
|
|
||||||
|
|
||||||
class LangChainTracer(BaseTracer):
|
class LangChainTracer(BaseTracer):
|
||||||
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
|
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
|
||||||
|
|
||||||
def __init__(self, **kwargs: Any) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
tenant_id: Optional[str] = None,
|
||||||
|
example_id: Optional[UUID] = None,
|
||||||
|
session_name: Optional[str] = None,
|
||||||
|
session_extra: Optional[Dict[str, Any]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
"""Initialize the LangChain tracer."""
|
"""Initialize the LangChain tracer."""
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self._endpoint = _get_endpoint()
|
self.session: Optional[TracerSession] = None
|
||||||
self._headers = _get_headers()
|
self._endpoint = get_endpoint()
|
||||||
|
self._headers = get_headers()
|
||||||
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
|
self.tenant_id = tenant_id
|
||||||
"""Persist a run."""
|
|
||||||
if isinstance(run, LLMRun):
|
|
||||||
endpoint = f"{self._endpoint}/llm-runs"
|
|
||||||
elif isinstance(run, ChainRun):
|
|
||||||
endpoint = f"{self._endpoint}/chain-runs"
|
|
||||||
else:
|
|
||||||
endpoint = f"{self._endpoint}/tool-runs"
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = requests.post(
|
|
||||||
endpoint,
|
|
||||||
data=run.json(),
|
|
||||||
headers=self._headers,
|
|
||||||
)
|
|
||||||
raise_for_status_with_text(response)
|
|
||||||
except Exception as e:
|
|
||||||
logging.warning(f"Failed to persist run: {e}")
|
|
||||||
|
|
||||||
def _persist_session(
|
|
||||||
self, session_create: TracerSessionBase
|
|
||||||
) -> Union[TracerSession, TracerSessionV2]:
|
|
||||||
"""Persist a session."""
|
|
||||||
try:
|
|
||||||
r = requests.post(
|
|
||||||
f"{self._endpoint}/sessions",
|
|
||||||
data=session_create.json(),
|
|
||||||
headers=self._headers,
|
|
||||||
)
|
|
||||||
session = TracerSession(id=r.json()["id"], **session_create.dict())
|
|
||||||
except Exception as e:
|
|
||||||
logging.warning(f"Failed to create session, using default session: {e}")
|
|
||||||
session = TracerSession(id=1, **session_create.dict())
|
|
||||||
return session
|
|
||||||
|
|
||||||
def _load_session(self, session_name: Optional[str] = None) -> TracerSession:
|
|
||||||
"""Load a session from the tracer."""
|
|
||||||
try:
|
|
||||||
url = f"{self._endpoint}/sessions"
|
|
||||||
if session_name:
|
|
||||||
url += f"?name={session_name}"
|
|
||||||
r = requests.get(url, headers=self._headers)
|
|
||||||
|
|
||||||
tracer_session = TracerSession(**r.json()[0])
|
|
||||||
except Exception as e:
|
|
||||||
session_type = "default" if not session_name else session_name
|
|
||||||
logging.warning(
|
|
||||||
f"Failed to load {session_type} session, using empty session: {e}"
|
|
||||||
)
|
|
||||||
tracer_session = TracerSession(id=1)
|
|
||||||
|
|
||||||
self.session = tracer_session
|
|
||||||
return tracer_session
|
|
||||||
|
|
||||||
def load_session(self, session_name: str) -> Union[TracerSession, TracerSessionV2]:
|
|
||||||
"""Load a session with the given name from the tracer."""
|
|
||||||
return self._load_session(session_name)
|
|
||||||
|
|
||||||
def load_default_session(self) -> Union[TracerSession, TracerSessionV2]:
|
|
||||||
"""Load the default tracing session and set it as the Tracer's session."""
|
|
||||||
return self._load_session("default")
|
|
||||||
|
|
||||||
|
|
||||||
def _get_tenant_id() -> Optional[str]:
|
|
||||||
"""Get the tenant ID for the LangChain API."""
|
|
||||||
tenant_id: Optional[str] = os.getenv("LANGCHAIN_TENANT_ID")
|
|
||||||
if tenant_id:
|
|
||||||
return tenant_id
|
|
||||||
endpoint = _get_endpoint()
|
|
||||||
headers = _get_headers()
|
|
||||||
response = requests.get(endpoint + "/tenants", headers=headers)
|
|
||||||
raise_for_status_with_text(response)
|
|
||||||
tenants: List[Dict[str, Any]] = response.json()
|
|
||||||
if not tenants:
|
|
||||||
raise ValueError(f"No tenants found for URL {endpoint}")
|
|
||||||
return tenants[0]["id"]
|
|
||||||
|
|
||||||
|
|
||||||
class LangChainTracerV2(LangChainTracer):
|
|
||||||
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
|
|
||||||
|
|
||||||
def __init__(self, example_id: Optional[UUID] = None, **kwargs: Any) -> None:
|
|
||||||
"""Initialize the LangChain tracer."""
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self._endpoint = _get_endpoint()
|
|
||||||
self._headers = _get_headers()
|
|
||||||
self.tenant_id = _get_tenant_id()
|
|
||||||
self.example_id = example_id
|
self.example_id = example_id
|
||||||
|
self.session_name = session_name or os.getenv("LANGCHAIN_SESSION", "default")
|
||||||
def _get_session_create(
|
self.session_extra = session_extra
|
||||||
self, name: Optional[str] = None, **kwargs: Any
|
|
||||||
) -> TracerSessionBase:
|
|
||||||
return TracerSessionV2Create(name=name, extra=kwargs, tenant_id=self.tenant_id)
|
|
||||||
|
|
||||||
def _persist_session(self, session_create: TracerSessionBase) -> TracerSessionV2:
|
|
||||||
"""Persist a session."""
|
|
||||||
session: Optional[TracerSessionV2] = None
|
|
||||||
try:
|
|
||||||
r = requests.post(
|
|
||||||
f"{self._endpoint}/sessions",
|
|
||||||
data=session_create.json(),
|
|
||||||
headers=self._headers,
|
|
||||||
)
|
|
||||||
raise_for_status_with_text(r)
|
|
||||||
creation_args = session_create.dict()
|
|
||||||
if "id" in creation_args:
|
|
||||||
del creation_args["id"]
|
|
||||||
return TracerSessionV2(id=r.json()["id"], **creation_args)
|
|
||||||
except Exception as e:
|
|
||||||
if session_create.name is not None:
|
|
||||||
try:
|
|
||||||
return self.load_session(session_create.name)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
logging.warning(
|
|
||||||
f"Failed to create session {session_create.name},"
|
|
||||||
f" using empty session: {e}"
|
|
||||||
)
|
|
||||||
session = TracerSessionV2(id=uuid4(), **session_create.dict())
|
|
||||||
|
|
||||||
return session
|
|
||||||
|
|
||||||
def _get_default_query_params(self) -> Dict[str, Any]:
|
|
||||||
"""Get the query params for the LangChain API."""
|
|
||||||
return {"tenant_id": self.tenant_id}
|
|
||||||
|
|
||||||
def load_session(self, session_name: str) -> TracerSessionV2:
|
|
||||||
"""Load a session with the given name from the tracer."""
|
|
||||||
try:
|
|
||||||
url = f"{self._endpoint}/sessions"
|
|
||||||
params = {"tenant_id": self.tenant_id}
|
|
||||||
if session_name:
|
|
||||||
params["name"] = session_name
|
|
||||||
r = requests.get(url, headers=self._headers, params=params)
|
|
||||||
raise_for_status_with_text(r)
|
|
||||||
tracer_session = TracerSessionV2(**r.json()[0])
|
|
||||||
except Exception as e:
|
|
||||||
session_type = "default" if not session_name else session_name
|
|
||||||
logging.warning(
|
|
||||||
f"Failed to load {session_type} session, using empty session: {e}"
|
|
||||||
)
|
|
||||||
tracer_session = TracerSessionV2(id=uuid4(), tenant_id=self.tenant_id)
|
|
||||||
|
|
||||||
self.session = tracer_session
|
|
||||||
return tracer_session
|
|
||||||
|
|
||||||
def load_default_session(self) -> TracerSessionV2:
|
|
||||||
"""Load the default tracing session and set it as the Tracer's session."""
|
|
||||||
return self.load_session("default")
|
|
||||||
|
|
||||||
def on_chat_model_start(
|
def on_chat_model_start(
|
||||||
self,
|
self,
|
||||||
@ -205,81 +81,56 @@ class LangChainTracerV2(LangChainTracer):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Start a trace for an LLM run."""
|
"""Start a trace for an LLM run."""
|
||||||
if self.session is None:
|
|
||||||
self.session = self.load_default_session()
|
|
||||||
|
|
||||||
run_id_ = str(run_id)
|
|
||||||
parent_run_id_ = str(parent_run_id) if parent_run_id else None
|
parent_run_id_ = str(parent_run_id) if parent_run_id else None
|
||||||
|
|
||||||
execution_order = self._get_execution_order(parent_run_id_)
|
execution_order = self._get_execution_order(parent_run_id_)
|
||||||
llm_run = LLMRun(
|
chat_model_run = Run(
|
||||||
uuid=run_id_,
|
id=run_id,
|
||||||
parent_uuid=parent_run_id_,
|
name=serialized.get("name"),
|
||||||
|
parent_run_id=parent_run_id,
|
||||||
serialized=serialized,
|
serialized=serialized,
|
||||||
prompts=[],
|
inputs={"messages": messages_to_dict(batch) for batch in messages},
|
||||||
extra={**kwargs, "messages": messages},
|
extra=kwargs,
|
||||||
start_time=datetime.utcnow(),
|
start_time=datetime.utcnow(),
|
||||||
execution_order=execution_order,
|
execution_order=execution_order,
|
||||||
child_execution_order=execution_order,
|
child_execution_order=execution_order,
|
||||||
session_id=self.session.id,
|
run_type=RunTypeEnum.llm,
|
||||||
)
|
)
|
||||||
self._start_trace(llm_run)
|
self._start_trace(chat_model_run)
|
||||||
|
self._on_chat_model_start(chat_model_run)
|
||||||
|
|
||||||
def _convert_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> RunCreate:
|
def ensure_tenant_id(self) -> str:
|
||||||
"""Convert a run to a Run."""
|
"""Load or use the tenant ID."""
|
||||||
session = self.session or self.load_default_session()
|
tenant_id = self.tenant_id or _get_tenant_id(
|
||||||
inputs: Dict[str, Any] = {}
|
self.tenant_id, self._endpoint, self._headers
|
||||||
outputs: Optional[Dict[str, Any]] = None
|
|
||||||
child_runs: List[Union[LLMRun, ChainRun, ToolRun]] = []
|
|
||||||
if isinstance(run, LLMRun):
|
|
||||||
run_type = "llm"
|
|
||||||
if run.extra is not None and "messages" in run.extra:
|
|
||||||
messages: List[List[BaseMessage]] = run.extra.pop("messages")
|
|
||||||
converted_messages = [messages_to_dict(batch) for batch in messages]
|
|
||||||
inputs = {"messages": converted_messages}
|
|
||||||
else:
|
|
||||||
inputs = {"prompts": run.prompts}
|
|
||||||
outputs = run.response.dict() if run.response else {}
|
|
||||||
child_runs = []
|
|
||||||
elif isinstance(run, ChainRun):
|
|
||||||
run_type = "chain"
|
|
||||||
inputs = run.inputs
|
|
||||||
outputs = run.outputs
|
|
||||||
child_runs = [
|
|
||||||
*run.child_llm_runs,
|
|
||||||
*run.child_chain_runs,
|
|
||||||
*run.child_tool_runs,
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
run_type = "tool"
|
|
||||||
inputs = {"input": run.tool_input}
|
|
||||||
outputs = {"output": run.output} if run.output else {}
|
|
||||||
child_runs = [
|
|
||||||
*run.child_llm_runs,
|
|
||||||
*run.child_chain_runs,
|
|
||||||
*run.child_tool_runs,
|
|
||||||
]
|
|
||||||
|
|
||||||
return RunCreate(
|
|
||||||
id=run.uuid,
|
|
||||||
name=run.serialized.get("name"),
|
|
||||||
start_time=run.start_time,
|
|
||||||
end_time=run.end_time,
|
|
||||||
extra=run.extra or {},
|
|
||||||
error=run.error,
|
|
||||||
execution_order=run.execution_order,
|
|
||||||
serialized=run.serialized,
|
|
||||||
inputs=inputs,
|
|
||||||
outputs=outputs,
|
|
||||||
session_id=session.id,
|
|
||||||
run_type=run_type,
|
|
||||||
child_runs=[self._convert_run(child) for child in child_runs],
|
|
||||||
)
|
)
|
||||||
|
self.tenant_id = tenant_id
|
||||||
|
return tenant_id
|
||||||
|
|
||||||
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
|
def ensure_session(self) -> TracerSession:
|
||||||
|
"""Upsert a session."""
|
||||||
|
if self.session is not None:
|
||||||
|
return self.session
|
||||||
|
tenant_id = self.ensure_tenant_id()
|
||||||
|
url = f"{self._endpoint}/sessions?upsert=true"
|
||||||
|
session_create = TracerSessionCreate(
|
||||||
|
name=self.session_name, extra=self.session_extra, tenant_id=tenant_id
|
||||||
|
)
|
||||||
|
r = requests.post(
|
||||||
|
url,
|
||||||
|
data=session_create.json(),
|
||||||
|
headers=self._headers,
|
||||||
|
)
|
||||||
|
raise_for_status_with_text(r)
|
||||||
|
self.session = TracerSession(**r.json())
|
||||||
|
return self.session
|
||||||
|
|
||||||
|
def _persist_run_nested(self, run: Run) -> None:
|
||||||
"""Persist a run."""
|
"""Persist a run."""
|
||||||
run_create = self._convert_run(run)
|
session = self.ensure_session()
|
||||||
run_create.reference_example_id = self.example_id
|
child_runs = run.child_runs
|
||||||
|
run_dict = run.dict()
|
||||||
|
del run_dict["child_runs"]
|
||||||
|
run_create = RunCreate(**run_dict, session_id=session.id)
|
||||||
try:
|
try:
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
f"{self._endpoint}/runs",
|
f"{self._endpoint}/runs",
|
||||||
@ -289,3 +140,12 @@ class LangChainTracerV2(LangChainTracer):
|
|||||||
raise_for_status_with_text(response)
|
raise_for_status_with_text(response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning(f"Failed to persist run: {e}")
|
logging.warning(f"Failed to persist run: {e}")
|
||||||
|
for child_run in child_runs:
|
||||||
|
child_run.parent_run_id = run.id
|
||||||
|
self._persist_run_nested(child_run)
|
||||||
|
|
||||||
|
def _persist_run(self, run: Run) -> None:
|
||||||
|
"""Persist a run."""
|
||||||
|
run.reference_example_id = self.example_id
|
||||||
|
# TODO: Post first then patch
|
||||||
|
self._persist_run_nested(run)
|
||||||
|
171
langchain/callbacks/tracers/langchain_v1.py
Normal file
171
langchain/callbacks/tracers/langchain_v1.py
Normal file
@ -0,0 +1,171 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from langchain.callbacks.tracers.base import BaseTracer
|
||||||
|
from langchain.callbacks.tracers.langchain import get_endpoint, get_headers
|
||||||
|
from langchain.callbacks.tracers.schemas import (
|
||||||
|
ChainRun,
|
||||||
|
LLMRun,
|
||||||
|
Run,
|
||||||
|
ToolRun,
|
||||||
|
TracerSession,
|
||||||
|
TracerSessionV1,
|
||||||
|
TracerSessionV1Base,
|
||||||
|
)
|
||||||
|
from langchain.schema import get_buffer_string
|
||||||
|
from langchain.utils import raise_for_status_with_text
|
||||||
|
|
||||||
|
|
||||||
|
class LangChainTracerV1(BaseTracer):
|
||||||
|
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs: Any) -> None:
|
||||||
|
"""Initialize the LangChain tracer."""
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.session: Optional[TracerSessionV1] = None
|
||||||
|
self._endpoint = get_endpoint()
|
||||||
|
self._headers = get_headers()
|
||||||
|
|
||||||
|
def _convert_to_v1_run(self, run: Run) -> Union[LLMRun, ChainRun, ToolRun]:
|
||||||
|
session = self.session or self.load_default_session()
|
||||||
|
if not isinstance(session, TracerSessionV1):
|
||||||
|
raise ValueError(
|
||||||
|
"LangChainTracerV1 is not compatible with"
|
||||||
|
f" session of type {type(session)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if run.run_type == "llm":
|
||||||
|
if "prompts" in run.inputs:
|
||||||
|
prompts = run.inputs["prompts"]
|
||||||
|
elif "messages" in run.inputs:
|
||||||
|
prompts = [get_buffer_string(batch) for batch in run.inputs["messages"]]
|
||||||
|
else:
|
||||||
|
raise ValueError("No prompts found in LLM run inputs")
|
||||||
|
return LLMRun(
|
||||||
|
uuid=str(run.id) if run.id else None,
|
||||||
|
parent_uuid=str(run.parent_run_id) if run.parent_run_id else None,
|
||||||
|
start_time=run.start_time,
|
||||||
|
end_time=run.end_time,
|
||||||
|
extra=run.extra,
|
||||||
|
execution_order=run.execution_order,
|
||||||
|
child_execution_order=run.child_execution_order,
|
||||||
|
serialized=run.serialized,
|
||||||
|
session_id=session.id,
|
||||||
|
error=run.error,
|
||||||
|
prompts=prompts,
|
||||||
|
response=run.outputs if run.outputs else None,
|
||||||
|
)
|
||||||
|
if run.run_type == "chain":
|
||||||
|
child_runs = [self._convert_to_v1_run(run) for run in run.child_runs]
|
||||||
|
return ChainRun(
|
||||||
|
uuid=str(run.id) if run.id else None,
|
||||||
|
parent_uuid=str(run.parent_run_id) if run.parent_run_id else None,
|
||||||
|
start_time=run.start_time,
|
||||||
|
end_time=run.end_time,
|
||||||
|
execution_order=run.execution_order,
|
||||||
|
child_execution_order=run.child_execution_order,
|
||||||
|
serialized=run.serialized,
|
||||||
|
session_id=session.id,
|
||||||
|
inputs=run.inputs,
|
||||||
|
outputs=run.outputs,
|
||||||
|
error=run.error,
|
||||||
|
extra=run.extra,
|
||||||
|
child_llm_runs=[run for run in child_runs if isinstance(run, LLMRun)],
|
||||||
|
child_chain_runs=[
|
||||||
|
run for run in child_runs if isinstance(run, ChainRun)
|
||||||
|
],
|
||||||
|
child_tool_runs=[run for run in child_runs if isinstance(run, ToolRun)],
|
||||||
|
)
|
||||||
|
if run.run_type == "tool":
|
||||||
|
child_runs = [self._convert_to_v1_run(run) for run in run.child_runs]
|
||||||
|
return ToolRun(
|
||||||
|
uuid=str(run.id) if run.id else None,
|
||||||
|
parent_uuid=str(run.parent_run_id) if run.parent_run_id else None,
|
||||||
|
start_time=run.start_time,
|
||||||
|
end_time=run.end_time,
|
||||||
|
execution_order=run.execution_order,
|
||||||
|
child_execution_order=run.child_execution_order,
|
||||||
|
serialized=run.serialized,
|
||||||
|
session_id=session.id,
|
||||||
|
action=str(run.serialized),
|
||||||
|
tool_input=run.inputs.get("input", ""),
|
||||||
|
output=None if run.outputs is None else run.outputs.get("output"),
|
||||||
|
error=run.error,
|
||||||
|
extra=run.extra,
|
||||||
|
child_chain_runs=[
|
||||||
|
run for run in child_runs if isinstance(run, ChainRun)
|
||||||
|
],
|
||||||
|
child_tool_runs=[run for run in child_runs if isinstance(run, ToolRun)],
|
||||||
|
child_llm_runs=[run for run in child_runs if isinstance(run, LLMRun)],
|
||||||
|
)
|
||||||
|
raise ValueError(f"Unknown run type: {run.run_type}")
|
||||||
|
|
||||||
|
def _persist_run(self, run: Union[Run, LLMRun, ChainRun, ToolRun]) -> None:
|
||||||
|
"""Persist a run."""
|
||||||
|
if isinstance(run, Run):
|
||||||
|
v1_run = self._convert_to_v1_run(run)
|
||||||
|
else:
|
||||||
|
v1_run = run
|
||||||
|
if isinstance(v1_run, LLMRun):
|
||||||
|
endpoint = f"{self._endpoint}/llm-runs"
|
||||||
|
elif isinstance(v1_run, ChainRun):
|
||||||
|
endpoint = f"{self._endpoint}/chain-runs"
|
||||||
|
else:
|
||||||
|
endpoint = f"{self._endpoint}/tool-runs"
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
endpoint,
|
||||||
|
data=v1_run.json(),
|
||||||
|
headers=self._headers,
|
||||||
|
)
|
||||||
|
raise_for_status_with_text(response)
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"Failed to persist run: {e}")
|
||||||
|
|
||||||
|
def _persist_session(
|
||||||
|
self, session_create: TracerSessionV1Base
|
||||||
|
) -> Union[TracerSessionV1, TracerSession]:
|
||||||
|
"""Persist a session."""
|
||||||
|
try:
|
||||||
|
r = requests.post(
|
||||||
|
f"{self._endpoint}/sessions",
|
||||||
|
data=session_create.json(),
|
||||||
|
headers=self._headers,
|
||||||
|
)
|
||||||
|
session = TracerSessionV1(id=r.json()["id"], **session_create.dict())
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"Failed to create session, using default session: {e}")
|
||||||
|
session = TracerSessionV1(id=1, **session_create.dict())
|
||||||
|
return session
|
||||||
|
|
||||||
|
def _load_session(self, session_name: Optional[str] = None) -> TracerSessionV1:
|
||||||
|
"""Load a session from the tracer."""
|
||||||
|
try:
|
||||||
|
url = f"{self._endpoint}/sessions"
|
||||||
|
if session_name:
|
||||||
|
url += f"?name={session_name}"
|
||||||
|
r = requests.get(url, headers=self._headers)
|
||||||
|
|
||||||
|
tracer_session = TracerSessionV1(**r.json()[0])
|
||||||
|
except Exception as e:
|
||||||
|
session_type = "default" if not session_name else session_name
|
||||||
|
logging.warning(
|
||||||
|
f"Failed to load {session_type} session, using empty session: {e}"
|
||||||
|
)
|
||||||
|
tracer_session = TracerSessionV1(id=1)
|
||||||
|
|
||||||
|
self.session = tracer_session
|
||||||
|
return tracer_session
|
||||||
|
|
||||||
|
def load_session(self, session_name: str) -> Union[TracerSessionV1, TracerSession]:
|
||||||
|
"""Load a session with the given name from the tracer."""
|
||||||
|
return self._load_session(session_name)
|
||||||
|
|
||||||
|
def load_default_session(self) -> Union[TracerSessionV1, TracerSession]:
|
||||||
|
"""Load the default tracing session and set it as the Tracer's session."""
|
||||||
|
return self._load_session("default")
|
@ -6,47 +6,45 @@ from enum import Enum
|
|||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, root_validator
|
||||||
|
|
||||||
from langchain.schema import LLMResult
|
from langchain.schema import LLMResult
|
||||||
|
|
||||||
|
|
||||||
class TracerSessionBase(BaseModel):
|
class TracerSessionV1Base(BaseModel):
|
||||||
"""Base class for TracerSession."""
|
"""Base class for TracerSessionV1."""
|
||||||
|
|
||||||
start_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
|
start_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
|
||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
extra: Optional[Dict[str, Any]] = None
|
extra: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
class TracerSessionCreate(TracerSessionBase):
|
class TracerSessionV1Create(TracerSessionV1Base):
|
||||||
"""Create class for TracerSession."""
|
"""Create class for TracerSessionV1."""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TracerSession(TracerSessionBase):
|
class TracerSessionV1(TracerSessionV1Base):
|
||||||
"""TracerSession schema."""
|
"""TracerSessionV1 schema."""
|
||||||
|
|
||||||
id: int
|
id: int
|
||||||
|
|
||||||
|
|
||||||
class TracerSessionV2Base(TracerSessionBase):
|
class TracerSessionBase(TracerSessionV1Base):
|
||||||
"""A creation class for TracerSessionV2."""
|
"""A creation class for TracerSession."""
|
||||||
|
|
||||||
tenant_id: UUID
|
tenant_id: UUID
|
||||||
|
|
||||||
|
|
||||||
class TracerSessionV2Create(TracerSessionV2Base):
|
class TracerSessionCreate(TracerSessionBase):
|
||||||
"""A creation class for TracerSessionV2."""
|
"""A creation class for TracerSession."""
|
||||||
|
|
||||||
id: Optional[UUID]
|
id: Optional[UUID]
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
class TracerSession(TracerSessionBase):
|
||||||
class TracerSessionV2(TracerSessionV2Base):
|
"""TracerSessionV1 schema for the V2 API."""
|
||||||
"""TracerSession schema for the V2 API."""
|
|
||||||
|
|
||||||
id: UUID
|
id: UUID
|
||||||
|
|
||||||
@ -111,26 +109,32 @@ class RunBase(BaseModel):
|
|||||||
extra: dict
|
extra: dict
|
||||||
error: Optional[str]
|
error: Optional[str]
|
||||||
execution_order: int
|
execution_order: int
|
||||||
|
child_execution_order: int
|
||||||
serialized: dict
|
serialized: dict
|
||||||
inputs: dict
|
inputs: dict
|
||||||
outputs: Optional[dict]
|
outputs: Optional[dict]
|
||||||
session_id: UUID
|
|
||||||
reference_example_id: Optional[UUID]
|
reference_example_id: Optional[UUID]
|
||||||
run_type: RunTypeEnum
|
run_type: RunTypeEnum
|
||||||
parent_run_id: Optional[UUID]
|
parent_run_id: Optional[UUID]
|
||||||
|
|
||||||
|
|
||||||
class RunCreate(RunBase):
|
|
||||||
"""Schema to create a run in the DB."""
|
|
||||||
|
|
||||||
name: Optional[str]
|
|
||||||
child_runs: List[RunCreate] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
class Run(RunBase):
|
class Run(RunBase):
|
||||||
"""Run schema when loading from the DB."""
|
"""Run schema when loading from the DB."""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
|
child_runs: List[Run] = Field(default_factory=list)
|
||||||
|
|
||||||
|
@root_validator(pre=True)
|
||||||
|
def assign_name(cls, values: dict) -> dict:
|
||||||
|
"""Assign name to the run."""
|
||||||
|
if "name" not in values:
|
||||||
|
values["name"] = values["serialized"]["name"]
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
class RunCreate(RunBase):
|
||||||
|
name: str
|
||||||
|
session_id: UUID
|
||||||
|
|
||||||
|
|
||||||
ChainRun.update_forward_refs()
|
ChainRun.update_forward_refs()
|
||||||
|
@ -27,7 +27,7 @@ from requests import Response
|
|||||||
|
|
||||||
from langchain.base_language import BaseLanguageModel
|
from langchain.base_language import BaseLanguageModel
|
||||||
from langchain.callbacks.manager import tracing_v2_enabled
|
from langchain.callbacks.manager import tracing_v2_enabled
|
||||||
from langchain.callbacks.tracers.langchain import LangChainTracerV2
|
from langchain.callbacks.tracers.langchain import LangChainTracer
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.chat_models.base import BaseChatModel
|
from langchain.chat_models.base import BaseChatModel
|
||||||
from langchain.client.models import Dataset, DatasetCreate, Example, ExampleCreate
|
from langchain.client.models import Dataset, DatasetCreate, Example, ExampleCreate
|
||||||
@ -308,7 +308,7 @@ class LangChainPlusClient(BaseSettings):
|
|||||||
async def _arun_llm(
|
async def _arun_llm(
|
||||||
llm: BaseLanguageModel,
|
llm: BaseLanguageModel,
|
||||||
inputs: Dict[str, Any],
|
inputs: Dict[str, Any],
|
||||||
langchain_tracer: LangChainTracerV2,
|
langchain_tracer: LangChainTracer,
|
||||||
) -> Union[LLMResult, ChatResult]:
|
) -> Union[LLMResult, ChatResult]:
|
||||||
if isinstance(llm, BaseLLM):
|
if isinstance(llm, BaseLLM):
|
||||||
if "prompt" not in inputs:
|
if "prompt" not in inputs:
|
||||||
@ -328,7 +328,7 @@ class LangChainPlusClient(BaseSettings):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
async def _arun_llm_or_chain(
|
async def _arun_llm_or_chain(
|
||||||
example: Example,
|
example: Example,
|
||||||
langchain_tracer: LangChainTracerV2,
|
langchain_tracer: LangChainTracer,
|
||||||
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
||||||
n_repetitions: int,
|
n_repetitions: int,
|
||||||
) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]:
|
) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]:
|
||||||
@ -358,8 +358,8 @@ class LangChainPlusClient(BaseSettings):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
async def _gather_with_concurrency(
|
async def _gather_with_concurrency(
|
||||||
n: int,
|
n: int,
|
||||||
initializer: Callable[[], Coroutine[Any, Any, Tuple[LangChainTracerV2, Dict]]],
|
initializer: Callable[[], Coroutine[Any, Any, Tuple[LangChainTracer, Dict]]],
|
||||||
*async_funcs: Callable[[LangChainTracerV2, Dict], Coroutine[Any, Any, Any]],
|
*async_funcs: Callable[[LangChainTracer, Dict], Coroutine[Any, Any, Any]],
|
||||||
) -> List[Any]:
|
) -> List[Any]:
|
||||||
"""
|
"""
|
||||||
Run coroutines with a concurrency limit.
|
Run coroutines with a concurrency limit.
|
||||||
@ -376,7 +376,7 @@ class LangChainPlusClient(BaseSettings):
|
|||||||
tracer, job_state = await initializer()
|
tracer, job_state = await initializer()
|
||||||
|
|
||||||
async def run_coroutine_with_semaphore(
|
async def run_coroutine_with_semaphore(
|
||||||
async_func: Callable[[LangChainTracerV2, Dict], Coroutine[Any, Any, Any]]
|
async_func: Callable[[LangChainTracer, Dict], Coroutine[Any, Any, Any]]
|
||||||
) -> Any:
|
) -> Any:
|
||||||
async with semaphore:
|
async with semaphore:
|
||||||
return await async_func(tracer, job_state)
|
return await async_func(tracer, job_state)
|
||||||
@ -387,7 +387,7 @@ class LangChainPlusClient(BaseSettings):
|
|||||||
|
|
||||||
async def _tracer_initializer(
|
async def _tracer_initializer(
|
||||||
self, session_name: str
|
self, session_name: str
|
||||||
) -> Tuple[LangChainTracerV2, dict]:
|
) -> Tuple[LangChainTracer, dict]:
|
||||||
"""
|
"""
|
||||||
Initialize a tracer to share across tasks.
|
Initialize a tracer to share across tasks.
|
||||||
|
|
||||||
@ -395,11 +395,11 @@ class LangChainPlusClient(BaseSettings):
|
|||||||
session_name: The session name for the tracer.
|
session_name: The session name for the tracer.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A LangChainTracerV2 instance with an active session.
|
A LangChainTracer instance with an active session.
|
||||||
"""
|
"""
|
||||||
job_state = {"num_processed": 0}
|
job_state = {"num_processed": 0}
|
||||||
with tracing_v2_enabled(session_name=session_name) as session:
|
with tracing_v2_enabled(session_name=session_name) as session:
|
||||||
tracer = LangChainTracerV2()
|
tracer = LangChainTracer()
|
||||||
tracer.session = session
|
tracer.session = session
|
||||||
return tracer, job_state
|
return tracer, job_state
|
||||||
|
|
||||||
@ -440,7 +440,7 @@ class LangChainPlusClient(BaseSettings):
|
|||||||
results: Dict[str, List[Any]] = {}
|
results: Dict[str, List[Any]] = {}
|
||||||
|
|
||||||
async def process_example(
|
async def process_example(
|
||||||
example: Example, tracer: LangChainTracerV2, job_state: dict
|
example: Example, tracer: LangChainTracer, job_state: dict
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Process a single example."""
|
"""Process a single example."""
|
||||||
result = await LangChainPlusClient._arun_llm_or_chain(
|
result = await LangChainPlusClient._arun_llm_or_chain(
|
||||||
@ -469,7 +469,7 @@ class LangChainPlusClient(BaseSettings):
|
|||||||
def run_llm(
|
def run_llm(
|
||||||
llm: BaseLanguageModel,
|
llm: BaseLanguageModel,
|
||||||
inputs: Dict[str, Any],
|
inputs: Dict[str, Any],
|
||||||
langchain_tracer: LangChainTracerV2,
|
langchain_tracer: LangChainTracer,
|
||||||
) -> Union[LLMResult, ChatResult]:
|
) -> Union[LLMResult, ChatResult]:
|
||||||
"""Run the language model on the example."""
|
"""Run the language model on the example."""
|
||||||
if isinstance(llm, BaseLLM):
|
if isinstance(llm, BaseLLM):
|
||||||
@ -492,7 +492,7 @@ class LangChainPlusClient(BaseSettings):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def run_llm_or_chain(
|
def run_llm_or_chain(
|
||||||
example: Example,
|
example: Example,
|
||||||
langchain_tracer: LangChainTracerV2,
|
langchain_tracer: LangChainTracer,
|
||||||
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
||||||
n_repetitions: int,
|
n_repetitions: int,
|
||||||
) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]:
|
) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]:
|
||||||
@ -551,7 +551,7 @@ class LangChainPlusClient(BaseSettings):
|
|||||||
examples = list(self.list_examples(dataset_id=str(dataset.id)))
|
examples = list(self.list_examples(dataset_id=str(dataset.id)))
|
||||||
results: Dict[str, Any] = {}
|
results: Dict[str, Any] = {}
|
||||||
with tracing_v2_enabled(session_name=session_name) as session:
|
with tracing_v2_enabled(session_name=session_name) as session:
|
||||||
tracer = LangChainTracerV2()
|
tracer = LangChainTracer()
|
||||||
tracer.session = session
|
tracer.session = session
|
||||||
|
|
||||||
for i, example in enumerate(examples):
|
for i, example in enumerate(examples):
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
{
|
{
|
||||||
"cells": [
|
"cells": [{
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "1a4596ea-a631-416d-a2a4-3577c140493d",
|
"id": "1a4596ea-a631-416d-a2a4-3577c140493d",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
@ -20,13 +19,12 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 1,
|
"execution_count": 18,
|
||||||
"id": "904db9a5-f387-4a57-914c-c8af8d39e249",
|
"id": "904db9a5-f387-4a57-914c-c8af8d39e249",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": []
|
"tags": []
|
||||||
},
|
},
|
||||||
"outputs": [
|
"outputs": [{
|
||||||
{
|
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
@ -42,7 +40,7 @@
|
|||||||
"LangChainPlusClient (API URL: http://localhost:8000)"
|
"LangChainPlusClient (API URL: http://localhost:8000)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 1,
|
"execution_count": 18,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@ -84,7 +82,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 2,
|
"execution_count": 19,
|
||||||
"id": "4417e0b8-a26f-4a11-b7eb-ba7a18e73885",
|
"id": "4417e0b8-a26f-4a11-b7eb-ba7a18e73885",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": []
|
"tags": []
|
||||||
@ -96,7 +94,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": 20,
|
||||||
"id": "7c801853-8e96-404d-984c-51ace59cbbef",
|
"id": "7c801853-8e96-404d-984c-51ace59cbbef",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": []
|
"tags": []
|
||||||
@ -119,8 +117,7 @@
|
|||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": []
|
"tags": []
|
||||||
},
|
},
|
||||||
"outputs": [
|
"outputs": [{
|
||||||
{
|
|
||||||
"name": "stderr",
|
"name": "stderr",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
@ -207,8 +204,7 @@
|
|||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": []
|
"tags": []
|
||||||
},
|
},
|
||||||
"outputs": [
|
"outputs": [{
|
||||||
{
|
|
||||||
"data": {
|
"data": {
|
||||||
"text/html": [
|
"text/html": [
|
||||||
"<a href=\"http://localhost\", target=\"_blank\" rel=\"noopener\">LangChain+ Client</a>"
|
"<a href=\"http://localhost\", target=\"_blank\" rel=\"noopener\">LangChain+ Client</a>"
|
||||||
@ -220,8 +216,7 @@
|
|||||||
"execution_count": 6,
|
"execution_count": 6,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}],
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"client"
|
"client"
|
||||||
]
|
]
|
||||||
@ -257,8 +252,7 @@
|
|||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": []
|
"tags": []
|
||||||
},
|
},
|
||||||
"outputs": [
|
"outputs": [{
|
||||||
{
|
|
||||||
"name": "stderr",
|
"name": "stderr",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
@ -436,8 +430,7 @@
|
|||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": []
|
"tags": []
|
||||||
},
|
},
|
||||||
"outputs": [
|
"outputs": [{
|
||||||
{
|
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"\u001b[0;31mSignature:\u001b[0m\n",
|
"\u001b[0;31mSignature:\u001b[0m\n",
|
||||||
@ -474,8 +467,7 @@
|
|||||||
},
|
},
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "display_data"
|
"output_type": "display_data"
|
||||||
}
|
}],
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"?client.arun_on_dataset"
|
"?client.arun_on_dataset"
|
||||||
]
|
]
|
||||||
@ -506,8 +498,7 @@
|
|||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": []
|
"tags": []
|
||||||
},
|
},
|
||||||
"outputs": [
|
"outputs": [{
|
||||||
{
|
|
||||||
"name": "stderr",
|
"name": "stderr",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
@ -585,8 +576,7 @@
|
|||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": []
|
"tags": []
|
||||||
},
|
},
|
||||||
"outputs": [
|
"outputs": [{
|
||||||
{
|
|
||||||
"data": {
|
"data": {
|
||||||
"text/html": [
|
"text/html": [
|
||||||
"<a href=\"http://localhost\", target=\"_blank\" rel=\"noopener\">LangChain+ Client</a>"
|
"<a href=\"http://localhost\", target=\"_blank\" rel=\"noopener\">LangChain+ Client</a>"
|
||||||
@ -598,8 +588,7 @@
|
|||||||
"execution_count": 13,
|
"execution_count": 13,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}],
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"# You can navigate to the UI by clicking on the link below\n",
|
"# You can navigate to the UI by clicking on the link below\n",
|
||||||
"client"
|
"client"
|
||||||
@ -625,13 +614,12 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 14,
|
"execution_count": 24,
|
||||||
"id": "64490d7c-9a18-49ed-a3ac-36049c522cb4",
|
"id": "64490d7c-9a18-49ed-a3ac-36049c522cb4",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": []
|
"tags": []
|
||||||
},
|
},
|
||||||
"outputs": [
|
"outputs": [{
|
||||||
{
|
|
||||||
"name": "stderr",
|
"name": "stderr",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
@ -641,7 +629,7 @@
|
|||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
"model_id": "047a8094367f43938f74e863b3e01711",
|
"model_id": "0adb751cec11417b88072963325b481d",
|
||||||
"version_major": 2,
|
"version_major": 2,
|
||||||
"version_minor": 0
|
"version_minor": 0
|
||||||
},
|
},
|
||||||
@ -723,7 +711,7 @@
|
|||||||
"4 [{'data': {'content': 'Here is the topic for a... "
|
"4 [{'data': {'content': 'Here is the topic for a... "
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 14,
|
"execution_count": 24,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@ -739,7 +727,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 15,
|
"execution_count": 25,
|
||||||
"id": "348acd86-a927-4d60-8d52-02e64585e4fc",
|
"id": "348acd86-a927-4d60-8d52-02e64585e4fc",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": []
|
"tags": []
|
||||||
@ -769,7 +757,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 16,
|
"execution_count": 26,
|
||||||
"id": "a69dd183-ad5e-473d-b631-db90706e837f",
|
"id": "a69dd183-ad5e-473d-b631-db90706e837f",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": []
|
"tags": []
|
||||||
@ -783,13 +771,12 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 18,
|
"execution_count": 27,
|
||||||
"id": "063da2a9-3692-4b7b-8edb-e474824fe416",
|
"id": "063da2a9-3692-4b7b-8edb-e474824fe416",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": []
|
"tags": []
|
||||||
},
|
},
|
||||||
"outputs": [
|
"outputs": [{
|
||||||
{
|
|
||||||
"name": "stderr",
|
"name": "stderr",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
@ -838,8 +825,7 @@
|
|||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": []
|
"tags": []
|
||||||
},
|
},
|
||||||
"outputs": [
|
"outputs": [{
|
||||||
{
|
|
||||||
"data": {
|
"data": {
|
||||||
"text/html": [
|
"text/html": [
|
||||||
"<a href=\"http://localhost\", target=\"_blank\" rel=\"noopener\">LangChain+ Client</a>"
|
"<a href=\"http://localhost\", target=\"_blank\" rel=\"noopener\">LangChain+ Client</a>"
|
||||||
@ -851,8 +837,7 @@
|
|||||||
"execution_count": 19,
|
"execution_count": 19,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}],
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"client"
|
"client"
|
||||||
]
|
]
|
||||||
@ -888,8 +873,7 @@
|
|||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": []
|
"tags": []
|
||||||
},
|
},
|
||||||
"outputs": [
|
"outputs": [{
|
||||||
{
|
|
||||||
"name": "stderr",
|
"name": "stderr",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
@ -1032,8 +1016,7 @@
|
|||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": []
|
"tags": []
|
||||||
},
|
},
|
||||||
"outputs": [
|
"outputs": [{
|
||||||
{
|
|
||||||
"name": "stderr",
|
"name": "stderr",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
@ -1076,8 +1059,7 @@
|
|||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": []
|
"tags": []
|
||||||
},
|
},
|
||||||
"outputs": [
|
"outputs": [{
|
||||||
{
|
|
||||||
"data": {
|
"data": {
|
||||||
"text/html": [
|
"text/html": [
|
||||||
"<a href=\"http://localhost\", target=\"_blank\" rel=\"noopener\">LangChain+ Client</a>"
|
"<a href=\"http://localhost\", target=\"_blank\" rel=\"noopener\">LangChain+ Client</a>"
|
||||||
@ -1089,8 +1071,7 @@
|
|||||||
"execution_count": 24,
|
"execution_count": 24,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}],
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"client"
|
"client"
|
||||||
]
|
]
|
||||||
|
464
tests/unit_tests/callbacks/tracers/test_base_tracer.py
Normal file
464
tests/unit_tests/callbacks/tracers/test_base_tracer.py
Normal file
@ -0,0 +1,464 @@
|
|||||||
|
"""Test Tracer classes."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from freezegun import freeze_time
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import CallbackManager
|
||||||
|
from langchain.callbacks.tracers.base import BaseTracer, TracerException
|
||||||
|
from langchain.callbacks.tracers.schemas import Run
|
||||||
|
from langchain.schema import LLMResult
|
||||||
|
|
||||||
|
|
||||||
|
class FakeTracer(BaseTracer):
|
||||||
|
"""Fake tracer that records LangChain execution."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""Initialize the tracer."""
|
||||||
|
super().__init__()
|
||||||
|
self.runs: List[Run] = []
|
||||||
|
|
||||||
|
def _persist_run(self, run: Run) -> None:
|
||||||
|
"""Persist a run."""
|
||||||
|
self.runs.append(run)
|
||||||
|
|
||||||
|
|
||||||
|
@freeze_time("2023-01-01")
|
||||||
|
def test_tracer_llm_run() -> None:
|
||||||
|
"""Test tracer on an LLM run."""
|
||||||
|
uuid = uuid4()
|
||||||
|
compare_run = Run(
|
||||||
|
id=uuid,
|
||||||
|
parent_run_id=None,
|
||||||
|
start_time=datetime.utcnow(),
|
||||||
|
end_time=datetime.utcnow(),
|
||||||
|
extra={},
|
||||||
|
execution_order=1,
|
||||||
|
child_execution_order=1,
|
||||||
|
serialized={"name": "llm"},
|
||||||
|
inputs={"prompts": []},
|
||||||
|
outputs=LLMResult(generations=[[]]),
|
||||||
|
error=None,
|
||||||
|
run_type="llm",
|
||||||
|
)
|
||||||
|
tracer = FakeTracer()
|
||||||
|
|
||||||
|
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid)
|
||||||
|
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
|
||||||
|
assert tracer.runs == [compare_run]
|
||||||
|
|
||||||
|
|
||||||
|
@freeze_time("2023-01-01")
|
||||||
|
def test_tracer_chat_model_run() -> None:
|
||||||
|
"""Test tracer on a Chat Model run."""
|
||||||
|
uuid = uuid4()
|
||||||
|
compare_run = Run(
|
||||||
|
id=str(uuid),
|
||||||
|
name="chat_model",
|
||||||
|
start_time=datetime.utcnow(),
|
||||||
|
end_time=datetime.utcnow(),
|
||||||
|
extra={},
|
||||||
|
execution_order=1,
|
||||||
|
child_execution_order=1,
|
||||||
|
serialized={"name": "chat_model"},
|
||||||
|
inputs=dict(prompts=[""]),
|
||||||
|
outputs=LLMResult(generations=[[]]),
|
||||||
|
error=None,
|
||||||
|
run_type="llm",
|
||||||
|
)
|
||||||
|
tracer = FakeTracer()
|
||||||
|
manager = CallbackManager(handlers=[tracer])
|
||||||
|
run_manager = manager.on_chat_model_start(
|
||||||
|
serialized={"name": "chat_model"}, messages=[[]], run_id=uuid
|
||||||
|
)
|
||||||
|
run_manager.on_llm_end(response=LLMResult(generations=[[]]))
|
||||||
|
assert tracer.runs == [compare_run]
|
||||||
|
|
||||||
|
|
||||||
|
@freeze_time("2023-01-01")
|
||||||
|
def test_tracer_llm_run_errors_no_start() -> None:
|
||||||
|
"""Test tracer on an LLM run without a start."""
|
||||||
|
tracer = FakeTracer()
|
||||||
|
|
||||||
|
with pytest.raises(TracerException):
|
||||||
|
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid4())
|
||||||
|
|
||||||
|
|
||||||
|
@freeze_time("2023-01-01")
|
||||||
|
def test_tracer_multiple_llm_runs() -> None:
|
||||||
|
"""Test the tracer with multiple runs."""
|
||||||
|
uuid = uuid4()
|
||||||
|
compare_run = Run(
|
||||||
|
id=uuid,
|
||||||
|
name="llm",
|
||||||
|
start_time=datetime.utcnow(),
|
||||||
|
end_time=datetime.utcnow(),
|
||||||
|
extra={},
|
||||||
|
execution_order=1,
|
||||||
|
child_execution_order=1,
|
||||||
|
serialized={"name": "llm"},
|
||||||
|
inputs=dict(prompts=[]),
|
||||||
|
outputs=LLMResult(generations=[[]]),
|
||||||
|
error=None,
|
||||||
|
run_type="llm",
|
||||||
|
)
|
||||||
|
tracer = FakeTracer()
|
||||||
|
|
||||||
|
num_runs = 10
|
||||||
|
for _ in range(num_runs):
|
||||||
|
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid)
|
||||||
|
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
|
||||||
|
|
||||||
|
assert tracer.runs == [compare_run] * num_runs
|
||||||
|
|
||||||
|
|
||||||
|
@freeze_time("2023-01-01")
|
||||||
|
def test_tracer_chain_run() -> None:
|
||||||
|
"""Test tracer on a Chain run."""
|
||||||
|
uuid = uuid4()
|
||||||
|
compare_run = Run(
|
||||||
|
id=str(uuid),
|
||||||
|
start_time=datetime.utcnow(),
|
||||||
|
end_time=datetime.utcnow(),
|
||||||
|
extra={},
|
||||||
|
execution_order=1,
|
||||||
|
child_execution_order=1,
|
||||||
|
serialized={"name": "chain"},
|
||||||
|
inputs={},
|
||||||
|
outputs={},
|
||||||
|
error=None,
|
||||||
|
run_type="chain",
|
||||||
|
)
|
||||||
|
tracer = FakeTracer()
|
||||||
|
|
||||||
|
tracer.on_chain_start(serialized={"name": "chain"}, inputs={}, run_id=uuid)
|
||||||
|
tracer.on_chain_end(outputs={}, run_id=uuid)
|
||||||
|
assert tracer.runs == [compare_run]
|
||||||
|
|
||||||
|
|
||||||
|
@freeze_time("2023-01-01")
|
||||||
|
def test_tracer_tool_run() -> None:
|
||||||
|
"""Test tracer on a Tool run."""
|
||||||
|
uuid = uuid4()
|
||||||
|
compare_run = Run(
|
||||||
|
id=str(uuid),
|
||||||
|
start_time=datetime.utcnow(),
|
||||||
|
end_time=datetime.utcnow(),
|
||||||
|
extra={},
|
||||||
|
execution_order=1,
|
||||||
|
child_execution_order=1,
|
||||||
|
serialized={"name": "tool"},
|
||||||
|
inputs={"input": "test"},
|
||||||
|
outputs={"output": "test"},
|
||||||
|
error=None,
|
||||||
|
run_type="tool",
|
||||||
|
)
|
||||||
|
tracer = FakeTracer()
|
||||||
|
tracer.on_tool_start(serialized={"name": "tool"}, input_str="test", run_id=uuid)
|
||||||
|
tracer.on_tool_end("test", run_id=uuid)
|
||||||
|
assert tracer.runs == [compare_run]
|
||||||
|
|
||||||
|
|
||||||
|
@freeze_time("2023-01-01")
|
||||||
|
def test_tracer_nested_run() -> None:
|
||||||
|
"""Test tracer on a nested run."""
|
||||||
|
tracer = FakeTracer()
|
||||||
|
|
||||||
|
chain_uuid = uuid4()
|
||||||
|
tool_uuid = uuid4()
|
||||||
|
llm_uuid1 = uuid4()
|
||||||
|
llm_uuid2 = uuid4()
|
||||||
|
for _ in range(10):
|
||||||
|
tracer.on_chain_start(
|
||||||
|
serialized={"name": "chain"}, inputs={}, run_id=chain_uuid
|
||||||
|
)
|
||||||
|
tracer.on_tool_start(
|
||||||
|
serialized={"name": "tool"},
|
||||||
|
input_str="test",
|
||||||
|
run_id=tool_uuid,
|
||||||
|
parent_run_id=chain_uuid,
|
||||||
|
)
|
||||||
|
tracer.on_llm_start(
|
||||||
|
serialized={"name": "llm"},
|
||||||
|
prompts=[],
|
||||||
|
run_id=llm_uuid1,
|
||||||
|
parent_run_id=tool_uuid,
|
||||||
|
)
|
||||||
|
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
|
||||||
|
tracer.on_tool_end("test", run_id=tool_uuid)
|
||||||
|
tracer.on_llm_start(
|
||||||
|
serialized={"name": "llm"},
|
||||||
|
prompts=[],
|
||||||
|
run_id=llm_uuid2,
|
||||||
|
parent_run_id=chain_uuid,
|
||||||
|
)
|
||||||
|
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid2)
|
||||||
|
tracer.on_chain_end(outputs={}, run_id=chain_uuid)
|
||||||
|
|
||||||
|
compare_run = Run(
|
||||||
|
id=str(chain_uuid),
|
||||||
|
error=None,
|
||||||
|
start_time=datetime.utcnow(),
|
||||||
|
end_time=datetime.utcnow(),
|
||||||
|
extra={},
|
||||||
|
execution_order=1,
|
||||||
|
child_execution_order=4,
|
||||||
|
serialized={"name": "chain"},
|
||||||
|
inputs={},
|
||||||
|
outputs={},
|
||||||
|
run_type="chain",
|
||||||
|
child_runs=[
|
||||||
|
Run(
|
||||||
|
id=tool_uuid,
|
||||||
|
parent_run_id=chain_uuid,
|
||||||
|
start_time=datetime.utcnow(),
|
||||||
|
end_time=datetime.utcnow(),
|
||||||
|
extra={},
|
||||||
|
execution_order=2,
|
||||||
|
child_execution_order=3,
|
||||||
|
serialized={"name": "tool"},
|
||||||
|
inputs=dict(input="test"),
|
||||||
|
outputs=dict(output="test"),
|
||||||
|
error=None,
|
||||||
|
run_type="tool",
|
||||||
|
child_runs=[
|
||||||
|
Run(
|
||||||
|
id=str(llm_uuid1),
|
||||||
|
parent_run_id=str(tool_uuid),
|
||||||
|
error=None,
|
||||||
|
start_time=datetime.utcnow(),
|
||||||
|
end_time=datetime.utcnow(),
|
||||||
|
extra={},
|
||||||
|
execution_order=3,
|
||||||
|
child_execution_order=3,
|
||||||
|
serialized={"name": "llm"},
|
||||||
|
inputs=dict(prompts=[]),
|
||||||
|
outputs=LLMResult(generations=[[]]),
|
||||||
|
run_type="llm",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
Run(
|
||||||
|
id=str(llm_uuid2),
|
||||||
|
parent_run_id=str(chain_uuid),
|
||||||
|
error=None,
|
||||||
|
start_time=datetime.utcnow(),
|
||||||
|
end_time=datetime.utcnow(),
|
||||||
|
extra={},
|
||||||
|
execution_order=4,
|
||||||
|
child_execution_order=4,
|
||||||
|
serialized={"name": "llm"},
|
||||||
|
inputs=dict(prompts=[]),
|
||||||
|
outputs=LLMResult(generations=[[]]),
|
||||||
|
run_type="llm",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert tracer.runs[0] == compare_run
|
||||||
|
assert tracer.runs == [compare_run] * 10
|
||||||
|
|
||||||
|
|
||||||
|
@freeze_time("2023-01-01")
|
||||||
|
def test_tracer_llm_run_on_error() -> None:
|
||||||
|
"""Test tracer on an LLM run with an error."""
|
||||||
|
exception = Exception("test")
|
||||||
|
uuid = uuid4()
|
||||||
|
|
||||||
|
compare_run = Run(
|
||||||
|
id=str(uuid),
|
||||||
|
start_time=datetime.utcnow(),
|
||||||
|
end_time=datetime.utcnow(),
|
||||||
|
extra={},
|
||||||
|
execution_order=1,
|
||||||
|
child_execution_order=1,
|
||||||
|
serialized={"name": "llm"},
|
||||||
|
inputs=dict(prompts=[]),
|
||||||
|
outputs=None,
|
||||||
|
error=repr(exception),
|
||||||
|
run_type="llm",
|
||||||
|
)
|
||||||
|
tracer = FakeTracer()
|
||||||
|
|
||||||
|
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid)
|
||||||
|
tracer.on_llm_error(exception, run_id=uuid)
|
||||||
|
assert tracer.runs == [compare_run]
|
||||||
|
|
||||||
|
|
||||||
|
@freeze_time("2023-01-01")
|
||||||
|
def test_tracer_chain_run_on_error() -> None:
|
||||||
|
"""Test tracer on a Chain run with an error."""
|
||||||
|
exception = Exception("test")
|
||||||
|
uuid = uuid4()
|
||||||
|
|
||||||
|
compare_run = Run(
|
||||||
|
id=str(uuid),
|
||||||
|
start_time=datetime.utcnow(),
|
||||||
|
end_time=datetime.utcnow(),
|
||||||
|
extra={},
|
||||||
|
execution_order=1,
|
||||||
|
child_execution_order=1,
|
||||||
|
serialized={"name": "chain"},
|
||||||
|
inputs={},
|
||||||
|
outputs=None,
|
||||||
|
error=repr(exception),
|
||||||
|
run_type="chain",
|
||||||
|
)
|
||||||
|
tracer = FakeTracer()
|
||||||
|
|
||||||
|
tracer.on_chain_start(serialized={"name": "chain"}, inputs={}, run_id=uuid)
|
||||||
|
tracer.on_chain_error(exception, run_id=uuid)
|
||||||
|
assert tracer.runs == [compare_run]
|
||||||
|
|
||||||
|
|
||||||
|
@freeze_time("2023-01-01")
|
||||||
|
def test_tracer_tool_run_on_error() -> None:
|
||||||
|
"""Test tracer on a Tool run with an error."""
|
||||||
|
exception = Exception("test")
|
||||||
|
uuid = uuid4()
|
||||||
|
|
||||||
|
compare_run = Run(
|
||||||
|
id=str(uuid),
|
||||||
|
start_time=datetime.utcnow(),
|
||||||
|
end_time=datetime.utcnow(),
|
||||||
|
extra={},
|
||||||
|
execution_order=1,
|
||||||
|
child_execution_order=1,
|
||||||
|
serialized={"name": "tool"},
|
||||||
|
inputs=dict(input="test"),
|
||||||
|
outputs=None,
|
||||||
|
action="{'name': 'tool'}",
|
||||||
|
error=repr(exception),
|
||||||
|
run_type="tool",
|
||||||
|
)
|
||||||
|
tracer = FakeTracer()
|
||||||
|
|
||||||
|
tracer.on_tool_start(serialized={"name": "tool"}, input_str="test", run_id=uuid)
|
||||||
|
tracer.on_tool_error(exception, run_id=uuid)
|
||||||
|
assert tracer.runs == [compare_run]
|
||||||
|
|
||||||
|
|
||||||
|
@freeze_time("2023-01-01")
|
||||||
|
def test_tracer_nested_runs_on_error() -> None:
|
||||||
|
"""Test tracer on a nested run with an error."""
|
||||||
|
exception = Exception("test")
|
||||||
|
|
||||||
|
tracer = FakeTracer()
|
||||||
|
chain_uuid = uuid4()
|
||||||
|
tool_uuid = uuid4()
|
||||||
|
llm_uuid1 = uuid4()
|
||||||
|
llm_uuid2 = uuid4()
|
||||||
|
llm_uuid3 = uuid4()
|
||||||
|
|
||||||
|
for _ in range(3):
|
||||||
|
tracer.on_chain_start(
|
||||||
|
serialized={"name": "chain"}, inputs={}, run_id=chain_uuid
|
||||||
|
)
|
||||||
|
tracer.on_llm_start(
|
||||||
|
serialized={"name": "llm"},
|
||||||
|
prompts=[],
|
||||||
|
run_id=llm_uuid1,
|
||||||
|
parent_run_id=chain_uuid,
|
||||||
|
)
|
||||||
|
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
|
||||||
|
tracer.on_llm_start(
|
||||||
|
serialized={"name": "llm"},
|
||||||
|
prompts=[],
|
||||||
|
run_id=llm_uuid2,
|
||||||
|
parent_run_id=chain_uuid,
|
||||||
|
)
|
||||||
|
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid2)
|
||||||
|
tracer.on_tool_start(
|
||||||
|
serialized={"name": "tool"},
|
||||||
|
input_str="test",
|
||||||
|
run_id=tool_uuid,
|
||||||
|
parent_run_id=chain_uuid,
|
||||||
|
)
|
||||||
|
tracer.on_llm_start(
|
||||||
|
serialized={"name": "llm"},
|
||||||
|
prompts=[],
|
||||||
|
run_id=llm_uuid3,
|
||||||
|
parent_run_id=tool_uuid,
|
||||||
|
)
|
||||||
|
tracer.on_llm_error(exception, run_id=llm_uuid3)
|
||||||
|
tracer.on_tool_error(exception, run_id=tool_uuid)
|
||||||
|
tracer.on_chain_error(exception, run_id=chain_uuid)
|
||||||
|
|
||||||
|
compare_run = Run(
|
||||||
|
id=str(chain_uuid),
|
||||||
|
start_time=datetime.utcnow(),
|
||||||
|
end_time=datetime.utcnow(),
|
||||||
|
extra={},
|
||||||
|
execution_order=1,
|
||||||
|
child_execution_order=5,
|
||||||
|
serialized={"name": "chain"},
|
||||||
|
error=repr(exception),
|
||||||
|
inputs={},
|
||||||
|
outputs=None,
|
||||||
|
run_type="chain",
|
||||||
|
child_runs=[
|
||||||
|
Run(
|
||||||
|
id=str(llm_uuid1),
|
||||||
|
parent_run_id=str(chain_uuid),
|
||||||
|
start_time=datetime.utcnow(),
|
||||||
|
end_time=datetime.utcnow(),
|
||||||
|
extra={},
|
||||||
|
execution_order=2,
|
||||||
|
child_execution_order=2,
|
||||||
|
serialized={"name": "llm"},
|
||||||
|
error=None,
|
||||||
|
inputs=dict(prompts=[]),
|
||||||
|
outputs=LLMResult(generations=[[]], llm_output=None),
|
||||||
|
run_type="llm",
|
||||||
|
),
|
||||||
|
Run(
|
||||||
|
id=str(llm_uuid2),
|
||||||
|
parent_run_id=str(chain_uuid),
|
||||||
|
start_time=datetime.utcnow(),
|
||||||
|
end_time=datetime.utcnow(),
|
||||||
|
extra={},
|
||||||
|
execution_order=3,
|
||||||
|
child_execution_order=3,
|
||||||
|
serialized={"name": "llm"},
|
||||||
|
error=None,
|
||||||
|
inputs=dict(prompts=[]),
|
||||||
|
outputs=LLMResult(generations=[[]], llm_output=None),
|
||||||
|
run_type="llm",
|
||||||
|
),
|
||||||
|
Run(
|
||||||
|
id=str(tool_uuid),
|
||||||
|
parent_run_id=str(chain_uuid),
|
||||||
|
start_time=datetime.utcnow(),
|
||||||
|
end_time=datetime.utcnow(),
|
||||||
|
extra={},
|
||||||
|
execution_order=4,
|
||||||
|
child_execution_order=5,
|
||||||
|
serialized={"name": "tool"},
|
||||||
|
error=repr(exception),
|
||||||
|
inputs=dict(input="test"),
|
||||||
|
outputs=None,
|
||||||
|
action="{'name': 'tool'}",
|
||||||
|
child_runs=[
|
||||||
|
Run(
|
||||||
|
id=str(llm_uuid3),
|
||||||
|
parent_run_id=str(tool_uuid),
|
||||||
|
start_time=datetime.utcnow(),
|
||||||
|
end_time=datetime.utcnow(),
|
||||||
|
extra={},
|
||||||
|
execution_order=5,
|
||||||
|
child_execution_order=5,
|
||||||
|
serialized={"name": "llm"},
|
||||||
|
error=repr(exception),
|
||||||
|
inputs=dict(prompts=[]),
|
||||||
|
outputs=None,
|
||||||
|
run_type="llm",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
run_type="tool",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert tracer.runs == [compare_run] * 3
|
676
tests/unit_tests/callbacks/tracers/test_langchain_v1.py
Normal file
676
tests/unit_tests/callbacks/tracers/test_langchain_v1.py
Normal file
@ -0,0 +1,676 @@
|
|||||||
|
"""Test Tracer classes."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from freezegun import freeze_time
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import CallbackManager
|
||||||
|
from langchain.callbacks.tracers.base import BaseTracer, TracerException
|
||||||
|
from langchain.callbacks.tracers.langchain_v1 import (
|
||||||
|
ChainRun,
|
||||||
|
LangChainTracerV1,
|
||||||
|
LLMRun,
|
||||||
|
ToolRun,
|
||||||
|
TracerSessionV1,
|
||||||
|
)
|
||||||
|
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum, TracerSessionV1Base
|
||||||
|
from langchain.schema import LLMResult
|
||||||
|
|
||||||
|
TEST_SESSION_ID = 2023
|
||||||
|
|
||||||
|
|
||||||
|
def load_session(session_name: str) -> TracerSessionV1:
|
||||||
|
"""Load a tracing session."""
|
||||||
|
return TracerSessionV1(
|
||||||
|
id=TEST_SESSION_ID, name=session_name, start_time=datetime.utcnow()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def new_session(name: Optional[str] = None) -> TracerSessionV1:
|
||||||
|
"""Create a new tracing session."""
|
||||||
|
return TracerSessionV1(
|
||||||
|
id=TEST_SESSION_ID, name=name or "default", start_time=datetime.utcnow()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _persist_session(session: TracerSessionV1Base) -> TracerSessionV1:
|
||||||
|
"""Persist a tracing session."""
|
||||||
|
return TracerSessionV1(**{**session.dict(), "id": TEST_SESSION_ID})
|
||||||
|
|
||||||
|
|
||||||
|
def load_default_session() -> TracerSessionV1:
|
||||||
|
"""Load a tracing session."""
|
||||||
|
return TracerSessionV1(
|
||||||
|
id=TEST_SESSION_ID, name="default", start_time=datetime.utcnow()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def lang_chain_tracer_v1(monkeypatch: pytest.MonkeyPatch) -> LangChainTracerV1:
|
||||||
|
monkeypatch.setenv("LANGCHAIN_TENANT_ID", "test-tenant-id")
|
||||||
|
monkeypatch.setenv("LANGCHAIN_ENDPOINT", "http://test-endpoint.com")
|
||||||
|
monkeypatch.setenv("LANGCHAIN_API_KEY", "foo")
|
||||||
|
tracer = LangChainTracerV1()
|
||||||
|
return tracer
|
||||||
|
|
||||||
|
|
||||||
|
class FakeTracer(BaseTracer):
|
||||||
|
"""Fake tracer that records LangChain execution."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""Initialize the tracer."""
|
||||||
|
super().__init__()
|
||||||
|
self.runs: List[Union[LLMRun, ChainRun, ToolRun]] = []
|
||||||
|
|
||||||
|
def _persist_run(self, run: Union[Run, LLMRun, ChainRun, ToolRun]) -> None:
|
||||||
|
"""Persist a run."""
|
||||||
|
if isinstance(run, Run):
|
||||||
|
with pytest.MonkeyPatch().context() as m:
|
||||||
|
m.setenv("LANGCHAIN_TENANT_ID", "test-tenant-id")
|
||||||
|
m.setenv("LANGCHAIN_ENDPOINT", "http://test-endpoint.com")
|
||||||
|
m.setenv("LANGCHAIN_API_KEY", "foo")
|
||||||
|
tracer = LangChainTracerV1()
|
||||||
|
tracer.load_default_session = load_default_session # type: ignore
|
||||||
|
run = tracer._convert_to_v1_run(run)
|
||||||
|
self.runs.append(run)
|
||||||
|
|
||||||
|
def _persist_session(self, session: TracerSessionV1Base) -> TracerSessionV1:
|
||||||
|
"""Persist a tracing session."""
|
||||||
|
return _persist_session(session)
|
||||||
|
|
||||||
|
def new_session(self, name: Optional[str] = None) -> TracerSessionV1:
|
||||||
|
"""Create a new tracing session."""
|
||||||
|
return new_session(name)
|
||||||
|
|
||||||
|
def load_session(self, session_name: str) -> TracerSessionV1:
|
||||||
|
"""Load a tracing session."""
|
||||||
|
return load_session(session_name)
|
||||||
|
|
||||||
|
def load_default_session(self) -> TracerSessionV1:
|
||||||
|
"""Load a tracing session."""
|
||||||
|
return load_default_session()
|
||||||
|
|
||||||
|
|
||||||
|
@freeze_time("2023-01-01")
|
||||||
|
def test_tracer_llm_run() -> None:
|
||||||
|
"""Test tracer on an LLM run."""
|
||||||
|
uuid = uuid4()
|
||||||
|
compare_run = LLMRun(
|
||||||
|
uuid=str(uuid),
|
||||||
|
parent_uuid=None,
|
||||||
|
start_time=datetime.utcnow(),
|
||||||
|
end_time=datetime.utcnow(),
|
||||||
|
extra={},
|
||||||
|
execution_order=1,
|
||||||
|
child_execution_order=1,
|
||||||
|
serialized={"name": "llm"},
|
||||||
|
prompts=[],
|
||||||
|
response=LLMResult(generations=[[]]),
|
||||||
|
session_id=TEST_SESSION_ID,
|
||||||
|
error=None,
|
||||||
|
)
|
||||||
|
tracer = FakeTracer()
|
||||||
|
|
||||||
|
tracer.new_session()
|
||||||
|
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid)
|
||||||
|
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
|
||||||
|
assert tracer.runs == [compare_run]
|
||||||
|
|
||||||
|
|
||||||
|
@freeze_time("2023-01-01")
|
||||||
|
def test_tracer_chat_model_run() -> None:
|
||||||
|
"""Test tracer on a Chat Model run."""
|
||||||
|
uuid = uuid4()
|
||||||
|
compare_run = LLMRun(
|
||||||
|
uuid=str(uuid),
|
||||||
|
parent_uuid=None,
|
||||||
|
start_time=datetime.utcnow(),
|
||||||
|
end_time=datetime.utcnow(),
|
||||||
|
extra={},
|
||||||
|
execution_order=1,
|
||||||
|
child_execution_order=1,
|
||||||
|
serialized={"name": "chat_model"},
|
||||||
|
prompts=[""],
|
||||||
|
response=LLMResult(generations=[[]]),
|
||||||
|
session_id=TEST_SESSION_ID,
|
||||||
|
error=None,
|
||||||
|
)
|
||||||
|
tracer = FakeTracer()
|
||||||
|
|
||||||
|
tracer.new_session()
|
||||||
|
manager = CallbackManager(handlers=[tracer])
|
||||||
|
run_manager = manager.on_chat_model_start(
|
||||||
|
serialized={"name": "chat_model"}, messages=[[]], run_id=uuid
|
||||||
|
)
|
||||||
|
run_manager.on_llm_end(response=LLMResult(generations=[[]]))
|
||||||
|
assert tracer.runs == [compare_run]
|
||||||
|
|
||||||
|
|
||||||
|
@freeze_time("2023-01-01")
|
||||||
|
def test_tracer_llm_run_errors_no_start() -> None:
|
||||||
|
"""Test tracer on an LLM run without a start."""
|
||||||
|
tracer = FakeTracer()
|
||||||
|
|
||||||
|
tracer.new_session()
|
||||||
|
with pytest.raises(TracerException):
|
||||||
|
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid4())
|
||||||
|
|
||||||
|
|
||||||
|
@freeze_time("2023-01-01")
|
||||||
|
def test_tracer_multiple_llm_runs() -> None:
|
||||||
|
"""Test the tracer with multiple runs."""
|
||||||
|
uuid = uuid4()
|
||||||
|
compare_run = LLMRun(
|
||||||
|
uuid=str(uuid),
|
||||||
|
parent_uuid=None,
|
||||||
|
start_time=datetime.utcnow(),
|
||||||
|
end_time=datetime.utcnow(),
|
||||||
|
extra={},
|
||||||
|
execution_order=1,
|
||||||
|
child_execution_order=1,
|
||||||
|
serialized={"name": "llm"},
|
||||||
|
prompts=[],
|
||||||
|
response=LLMResult(generations=[[]]),
|
||||||
|
session_id=TEST_SESSION_ID,
|
||||||
|
error=None,
|
||||||
|
)
|
||||||
|
tracer = FakeTracer()
|
||||||
|
|
||||||
|
tracer.new_session()
|
||||||
|
num_runs = 10
|
||||||
|
for _ in range(num_runs):
|
||||||
|
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid)
|
||||||
|
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
|
||||||
|
|
||||||
|
assert tracer.runs == [compare_run] * num_runs
|
||||||
|
|
||||||
|
|
||||||
|
@freeze_time("2023-01-01")
|
||||||
|
def test_tracer_chain_run() -> None:
|
||||||
|
"""Test tracer on a Chain run."""
|
||||||
|
uuid = uuid4()
|
||||||
|
compare_run = ChainRun(
|
||||||
|
uuid=str(uuid),
|
||||||
|
parent_uuid=None,
|
||||||
|
start_time=datetime.utcnow(),
|
||||||
|
end_time=datetime.utcnow(),
|
||||||
|
extra={},
|
||||||
|
execution_order=1,
|
||||||
|
child_execution_order=1,
|
||||||
|
serialized={"name": "chain"},
|
||||||
|
inputs={},
|
||||||
|
outputs={},
|
||||||
|
session_id=TEST_SESSION_ID,
|
||||||
|
error=None,
|
||||||
|
)
|
||||||
|
tracer = FakeTracer()
|
||||||
|
|
||||||
|
tracer.new_session()
|
||||||
|
tracer.on_chain_start(serialized={"name": "chain"}, inputs={}, run_id=uuid)
|
||||||
|
tracer.on_chain_end(outputs={}, run_id=uuid)
|
||||||
|
assert tracer.runs == [compare_run]
|
||||||
|
|
||||||
|
|
||||||
|
@freeze_time("2023-01-01")
|
||||||
|
def test_tracer_tool_run() -> None:
|
||||||
|
"""Test tracer on a Tool run."""
|
||||||
|
uuid = uuid4()
|
||||||
|
compare_run = ToolRun(
|
||||||
|
uuid=str(uuid),
|
||||||
|
parent_uuid=None,
|
||||||
|
start_time=datetime.utcnow(),
|
||||||
|
end_time=datetime.utcnow(),
|
||||||
|
extra={},
|
||||||
|
execution_order=1,
|
||||||
|
child_execution_order=1,
|
||||||
|
serialized={"name": "tool"},
|
||||||
|
tool_input="test",
|
||||||
|
output="test",
|
||||||
|
action="{'name': 'tool'}",
|
||||||
|
session_id=TEST_SESSION_ID,
|
||||||
|
error=None,
|
||||||
|
)
|
||||||
|
tracer = FakeTracer()
|
||||||
|
|
||||||
|
tracer.new_session()
|
||||||
|
tracer.on_tool_start(serialized={"name": "tool"}, input_str="test", run_id=uuid)
|
||||||
|
tracer.on_tool_end("test", run_id=uuid)
|
||||||
|
assert tracer.runs == [compare_run]
|
||||||
|
|
||||||
|
|
||||||
|
@freeze_time("2023-01-01")
|
||||||
|
def test_tracer_nested_run() -> None:
|
||||||
|
"""Test tracer on a nested run."""
|
||||||
|
tracer = FakeTracer()
|
||||||
|
tracer.new_session()
|
||||||
|
|
||||||
|
chain_uuid = uuid4()
|
||||||
|
tool_uuid = uuid4()
|
||||||
|
llm_uuid1 = uuid4()
|
||||||
|
llm_uuid2 = uuid4()
|
||||||
|
for _ in range(10):
|
||||||
|
tracer.on_chain_start(
|
||||||
|
serialized={"name": "chain"}, inputs={}, run_id=chain_uuid
|
||||||
|
)
|
||||||
|
tracer.on_tool_start(
|
||||||
|
serialized={"name": "tool"},
|
||||||
|
input_str="test",
|
||||||
|
run_id=tool_uuid,
|
||||||
|
parent_run_id=chain_uuid,
|
||||||
|
)
|
||||||
|
tracer.on_llm_start(
|
||||||
|
serialized={"name": "llm"},
|
||||||
|
prompts=[],
|
||||||
|
run_id=llm_uuid1,
|
||||||
|
parent_run_id=tool_uuid,
|
||||||
|
)
|
||||||
|
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
|
||||||
|
tracer.on_tool_end("test", run_id=tool_uuid)
|
||||||
|
tracer.on_llm_start(
|
||||||
|
serialized={"name": "llm"},
|
||||||
|
prompts=[],
|
||||||
|
run_id=llm_uuid2,
|
||||||
|
parent_run_id=chain_uuid,
|
||||||
|
)
|
||||||
|
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid2)
|
||||||
|
tracer.on_chain_end(outputs={}, run_id=chain_uuid)
|
||||||
|
|
||||||
|
compare_run = ChainRun(
|
||||||
|
uuid=str(chain_uuid),
|
||||||
|
error=None,
|
||||||
|
start_time=datetime.utcnow(),
|
||||||
|
end_time=datetime.utcnow(),
|
||||||
|
extra={},
|
||||||
|
execution_order=1,
|
||||||
|
child_execution_order=4,
|
||||||
|
serialized={"name": "chain"},
|
||||||
|
inputs={},
|
||||||
|
outputs={},
|
||||||
|
session_id=TEST_SESSION_ID,
|
||||||
|
child_chain_runs=[],
|
||||||
|
child_tool_runs=[
|
||||||
|
ToolRun(
|
||||||
|
uuid=str(tool_uuid),
|
||||||
|
parent_uuid=str(chain_uuid),
|
||||||
|
start_time=datetime.utcnow(),
|
||||||
|
end_time=datetime.utcnow(),
|
||||||
|
extra={},
|
||||||
|
execution_order=2,
|
||||||
|
child_execution_order=3,
|
||||||
|
serialized={"name": "tool"},
|
||||||
|
tool_input="test",
|
||||||
|
output="test",
|
||||||
|
action="{'name': 'tool'}",
|
||||||
|
session_id=TEST_SESSION_ID,
|
||||||
|
error=None,
|
||||||
|
child_chain_runs=[],
|
||||||
|
child_tool_runs=[],
|
||||||
|
child_llm_runs=[
|
||||||
|
LLMRun(
|
||||||
|
uuid=str(llm_uuid1),
|
||||||
|
parent_uuid=str(tool_uuid),
|
||||||
|
error=None,
|
||||||
|
start_time=datetime.utcnow(),
|
||||||
|
end_time=datetime.utcnow(),
|
||||||
|
extra={},
|
||||||
|
execution_order=3,
|
||||||
|
child_execution_order=3,
|
||||||
|
serialized={"name": "llm"},
|
||||||
|
prompts=[],
|
||||||
|
response=LLMResult(generations=[[]]),
|
||||||
|
session_id=TEST_SESSION_ID,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
child_llm_runs=[
|
||||||
|
LLMRun(
|
||||||
|
uuid=str(llm_uuid2),
|
||||||
|
parent_uuid=str(chain_uuid),
|
||||||
|
error=None,
|
||||||
|
start_time=datetime.utcnow(),
|
||||||
|
end_time=datetime.utcnow(),
|
||||||
|
extra={},
|
||||||
|
execution_order=4,
|
||||||
|
child_execution_order=4,
|
||||||
|
serialized={"name": "llm"},
|
||||||
|
prompts=[],
|
||||||
|
response=LLMResult(generations=[[]]),
|
||||||
|
session_id=TEST_SESSION_ID,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert tracer.runs[0] == compare_run
|
||||||
|
assert tracer.runs == [compare_run] * 10
|
||||||
|
|
||||||
|
|
||||||
|
@freeze_time("2023-01-01")
|
||||||
|
def test_tracer_llm_run_on_error() -> None:
|
||||||
|
"""Test tracer on an LLM run with an error."""
|
||||||
|
exception = Exception("test")
|
||||||
|
uuid = uuid4()
|
||||||
|
|
||||||
|
compare_run = LLMRun(
|
||||||
|
uuid=str(uuid),
|
||||||
|
parent_uuid=None,
|
||||||
|
start_time=datetime.utcnow(),
|
||||||
|
end_time=datetime.utcnow(),
|
||||||
|
extra={},
|
||||||
|
execution_order=1,
|
||||||
|
child_execution_order=1,
|
||||||
|
serialized={"name": "llm"},
|
||||||
|
prompts=[],
|
||||||
|
response=None,
|
||||||
|
session_id=TEST_SESSION_ID,
|
||||||
|
error=repr(exception),
|
||||||
|
)
|
||||||
|
tracer = FakeTracer()
|
||||||
|
|
||||||
|
tracer.new_session()
|
||||||
|
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid)
|
||||||
|
tracer.on_llm_error(exception, run_id=uuid)
|
||||||
|
assert tracer.runs == [compare_run]
|
||||||
|
|
||||||
|
|
||||||
|
@freeze_time("2023-01-01")
|
||||||
|
def test_tracer_chain_run_on_error() -> None:
|
||||||
|
"""Test tracer on a Chain run with an error."""
|
||||||
|
exception = Exception("test")
|
||||||
|
uuid = uuid4()
|
||||||
|
|
||||||
|
compare_run = ChainRun(
|
||||||
|
uuid=str(uuid),
|
||||||
|
parent_uuid=None,
|
||||||
|
start_time=datetime.utcnow(),
|
||||||
|
end_time=datetime.utcnow(),
|
||||||
|
extra={},
|
||||||
|
execution_order=1,
|
||||||
|
child_execution_order=1,
|
||||||
|
serialized={"name": "chain"},
|
||||||
|
inputs={},
|
||||||
|
outputs=None,
|
||||||
|
session_id=TEST_SESSION_ID,
|
||||||
|
error=repr(exception),
|
||||||
|
)
|
||||||
|
tracer = FakeTracer()
|
||||||
|
|
||||||
|
tracer.new_session()
|
||||||
|
tracer.on_chain_start(serialized={"name": "chain"}, inputs={}, run_id=uuid)
|
||||||
|
tracer.on_chain_error(exception, run_id=uuid)
|
||||||
|
assert tracer.runs == [compare_run]
|
||||||
|
|
||||||
|
|
||||||
|
@freeze_time("2023-01-01")
|
||||||
|
def test_tracer_tool_run_on_error() -> None:
|
||||||
|
"""Test tracer on a Tool run with an error."""
|
||||||
|
exception = Exception("test")
|
||||||
|
uuid = uuid4()
|
||||||
|
|
||||||
|
compare_run = ToolRun(
|
||||||
|
uuid=str(uuid),
|
||||||
|
parent_uuid=None,
|
||||||
|
start_time=datetime.utcnow(),
|
||||||
|
end_time=datetime.utcnow(),
|
||||||
|
extra={},
|
||||||
|
execution_order=1,
|
||||||
|
child_execution_order=1,
|
||||||
|
serialized={"name": "tool"},
|
||||||
|
tool_input="test",
|
||||||
|
output=None,
|
||||||
|
action="{'name': 'tool'}",
|
||||||
|
session_id=TEST_SESSION_ID,
|
||||||
|
error=repr(exception),
|
||||||
|
)
|
||||||
|
tracer = FakeTracer()
|
||||||
|
|
||||||
|
tracer.new_session()
|
||||||
|
tracer.on_tool_start(serialized={"name": "tool"}, input_str="test", run_id=uuid)
|
||||||
|
tracer.on_tool_error(exception, run_id=uuid)
|
||||||
|
assert tracer.runs == [compare_run]
|
||||||
|
|
||||||
|
|
||||||
|
@freeze_time("2023-01-01")
|
||||||
|
def test_tracer_nested_runs_on_error() -> None:
|
||||||
|
"""Test tracer on a nested run with an error."""
|
||||||
|
exception = Exception("test")
|
||||||
|
|
||||||
|
tracer = FakeTracer()
|
||||||
|
tracer.new_session()
|
||||||
|
chain_uuid = uuid4()
|
||||||
|
tool_uuid = uuid4()
|
||||||
|
llm_uuid1 = uuid4()
|
||||||
|
llm_uuid2 = uuid4()
|
||||||
|
llm_uuid3 = uuid4()
|
||||||
|
|
||||||
|
for _ in range(3):
|
||||||
|
tracer.on_chain_start(
|
||||||
|
serialized={"name": "chain"}, inputs={}, run_id=chain_uuid
|
||||||
|
)
|
||||||
|
tracer.on_llm_start(
|
||||||
|
serialized={"name": "llm"},
|
||||||
|
prompts=[],
|
||||||
|
run_id=llm_uuid1,
|
||||||
|
parent_run_id=chain_uuid,
|
||||||
|
)
|
||||||
|
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
|
||||||
|
tracer.on_llm_start(
|
||||||
|
serialized={"name": "llm"},
|
||||||
|
prompts=[],
|
||||||
|
run_id=llm_uuid2,
|
||||||
|
parent_run_id=chain_uuid,
|
||||||
|
)
|
||||||
|
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid2)
|
||||||
|
tracer.on_tool_start(
|
||||||
|
serialized={"name": "tool"},
|
||||||
|
input_str="test",
|
||||||
|
run_id=tool_uuid,
|
||||||
|
parent_run_id=chain_uuid,
|
||||||
|
)
|
||||||
|
tracer.on_llm_start(
|
||||||
|
serialized={"name": "llm"},
|
||||||
|
prompts=[],
|
||||||
|
run_id=llm_uuid3,
|
||||||
|
parent_run_id=tool_uuid,
|
||||||
|
)
|
||||||
|
tracer.on_llm_error(exception, run_id=llm_uuid3)
|
||||||
|
tracer.on_tool_error(exception, run_id=tool_uuid)
|
||||||
|
tracer.on_chain_error(exception, run_id=chain_uuid)
|
||||||
|
|
||||||
|
compare_run = ChainRun(
|
||||||
|
uuid=str(chain_uuid),
|
||||||
|
start_time=datetime.utcnow(),
|
||||||
|
end_time=datetime.utcnow(),
|
||||||
|
extra={},
|
||||||
|
execution_order=1,
|
||||||
|
child_execution_order=5,
|
||||||
|
serialized={"name": "chain"},
|
||||||
|
session_id=TEST_SESSION_ID,
|
||||||
|
error=repr(exception),
|
||||||
|
inputs={},
|
||||||
|
outputs=None,
|
||||||
|
child_llm_runs=[
|
||||||
|
LLMRun(
|
||||||
|
uuid=str(llm_uuid1),
|
||||||
|
parent_uuid=str(chain_uuid),
|
||||||
|
start_time=datetime.utcnow(),
|
||||||
|
end_time=datetime.utcnow(),
|
||||||
|
extra={},
|
||||||
|
execution_order=2,
|
||||||
|
child_execution_order=2,
|
||||||
|
serialized={"name": "llm"},
|
||||||
|
session_id=TEST_SESSION_ID,
|
||||||
|
error=None,
|
||||||
|
prompts=[],
|
||||||
|
response=LLMResult(generations=[[]], llm_output=None),
|
||||||
|
),
|
||||||
|
LLMRun(
|
||||||
|
uuid=str(llm_uuid2),
|
||||||
|
parent_uuid=str(chain_uuid),
|
||||||
|
start_time=datetime.utcnow(),
|
||||||
|
end_time=datetime.utcnow(),
|
||||||
|
extra={},
|
||||||
|
execution_order=3,
|
||||||
|
child_execution_order=3,
|
||||||
|
serialized={"name": "llm"},
|
||||||
|
session_id=TEST_SESSION_ID,
|
||||||
|
error=None,
|
||||||
|
prompts=[],
|
||||||
|
response=LLMResult(generations=[[]], llm_output=None),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
child_chain_runs=[],
|
||||||
|
child_tool_runs=[
|
||||||
|
ToolRun(
|
||||||
|
uuid=str(tool_uuid),
|
||||||
|
parent_uuid=str(chain_uuid),
|
||||||
|
start_time=datetime.utcnow(),
|
||||||
|
end_time=datetime.utcnow(),
|
||||||
|
extra={},
|
||||||
|
execution_order=4,
|
||||||
|
child_execution_order=5,
|
||||||
|
serialized={"name": "tool"},
|
||||||
|
session_id=TEST_SESSION_ID,
|
||||||
|
error=repr(exception),
|
||||||
|
tool_input="test",
|
||||||
|
output=None,
|
||||||
|
action="{'name': 'tool'}",
|
||||||
|
child_llm_runs=[
|
||||||
|
LLMRun(
|
||||||
|
uuid=str(llm_uuid3),
|
||||||
|
parent_uuid=str(tool_uuid),
|
||||||
|
start_time=datetime.utcnow(),
|
||||||
|
end_time=datetime.utcnow(),
|
||||||
|
extra={},
|
||||||
|
execution_order=5,
|
||||||
|
child_execution_order=5,
|
||||||
|
serialized={"name": "llm"},
|
||||||
|
session_id=TEST_SESSION_ID,
|
||||||
|
error=repr(exception),
|
||||||
|
prompts=[],
|
||||||
|
response=None,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
child_chain_runs=[],
|
||||||
|
child_tool_runs=[],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert tracer.runs == [compare_run] * 3
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_tracer_session_v1() -> TracerSessionV1:
|
||||||
|
return TracerSessionV1(id=2, name="Sample session")
|
||||||
|
|
||||||
|
|
||||||
|
@freeze_time("2023-01-01")
|
||||||
|
def test_convert_run(
|
||||||
|
lang_chain_tracer_v1: LangChainTracerV1,
|
||||||
|
sample_tracer_session_v1: TracerSessionV1,
|
||||||
|
) -> None:
|
||||||
|
"""Test converting a run to a V1 run."""
|
||||||
|
llm_run = Run(
|
||||||
|
id="57a08cc4-73d2-4236-8370-549099d07fad",
|
||||||
|
name="llm_run",
|
||||||
|
execution_order=1,
|
||||||
|
child_execution_order=1,
|
||||||
|
start_time=datetime.utcnow(),
|
||||||
|
end_time=datetime.utcnow(),
|
||||||
|
session_id=TEST_SESSION_ID,
|
||||||
|
inputs={"prompts": []},
|
||||||
|
outputs=LLMResult(generations=[[]]).dict(),
|
||||||
|
serialized={},
|
||||||
|
extra={},
|
||||||
|
run_type=RunTypeEnum.llm,
|
||||||
|
)
|
||||||
|
chain_run = Run(
|
||||||
|
id="57a08cc4-73d2-4236-8371-549099d07fad",
|
||||||
|
name="chain_run",
|
||||||
|
execution_order=1,
|
||||||
|
start_time=datetime.utcnow(),
|
||||||
|
end_time=datetime.utcnow(),
|
||||||
|
child_execution_order=1,
|
||||||
|
serialized={},
|
||||||
|
inputs={},
|
||||||
|
outputs={},
|
||||||
|
child_runs=[llm_run],
|
||||||
|
extra={},
|
||||||
|
run_type=RunTypeEnum.chain,
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_run = Run(
|
||||||
|
id="57a08cc4-73d2-4236-8372-549099d07fad",
|
||||||
|
name="tool_run",
|
||||||
|
execution_order=1,
|
||||||
|
child_execution_order=1,
|
||||||
|
inputs={"input": "test"},
|
||||||
|
start_time=datetime.utcnow(),
|
||||||
|
end_time=datetime.utcnow(),
|
||||||
|
outputs=None,
|
||||||
|
serialized={},
|
||||||
|
child_runs=[],
|
||||||
|
extra={},
|
||||||
|
run_type=RunTypeEnum.tool,
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_llm_run = LLMRun(
|
||||||
|
uuid="57a08cc4-73d2-4236-8370-549099d07fad",
|
||||||
|
name="llm_run",
|
||||||
|
execution_order=1,
|
||||||
|
child_execution_order=1,
|
||||||
|
start_time=datetime.utcnow(),
|
||||||
|
end_time=datetime.utcnow(),
|
||||||
|
session_id=2,
|
||||||
|
prompts=[],
|
||||||
|
response=LLMResult(generations=[[]]),
|
||||||
|
serialized={},
|
||||||
|
extra={},
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_chain_run = ChainRun(
|
||||||
|
uuid="57a08cc4-73d2-4236-8371-549099d07fad",
|
||||||
|
name="chain_run",
|
||||||
|
execution_order=1,
|
||||||
|
child_execution_order=1,
|
||||||
|
start_time=datetime.utcnow(),
|
||||||
|
end_time=datetime.utcnow(),
|
||||||
|
session_id=2,
|
||||||
|
serialized={},
|
||||||
|
inputs={},
|
||||||
|
outputs={},
|
||||||
|
child_llm_runs=[expected_llm_run],
|
||||||
|
child_chain_runs=[],
|
||||||
|
child_tool_runs=[],
|
||||||
|
extra={},
|
||||||
|
)
|
||||||
|
expected_tool_run = ToolRun(
|
||||||
|
uuid="57a08cc4-73d2-4236-8372-549099d07fad",
|
||||||
|
name="tool_run",
|
||||||
|
execution_order=1,
|
||||||
|
child_execution_order=1,
|
||||||
|
session_id=2,
|
||||||
|
start_time=datetime.utcnow(),
|
||||||
|
end_time=datetime.utcnow(),
|
||||||
|
tool_input="test",
|
||||||
|
action="{}",
|
||||||
|
serialized={},
|
||||||
|
child_llm_runs=[],
|
||||||
|
child_chain_runs=[],
|
||||||
|
child_tool_runs=[],
|
||||||
|
extra={},
|
||||||
|
)
|
||||||
|
lang_chain_tracer_v1.session = sample_tracer_session_v1
|
||||||
|
converted_llm_run = lang_chain_tracer_v1._convert_to_v1_run(llm_run)
|
||||||
|
converted_chain_run = lang_chain_tracer_v1._convert_to_v1_run(chain_run)
|
||||||
|
converted_tool_run = lang_chain_tracer_v1._convert_to_v1_run(tool_run)
|
||||||
|
|
||||||
|
assert isinstance(converted_llm_run, LLMRun)
|
||||||
|
assert isinstance(converted_chain_run, ChainRun)
|
||||||
|
assert isinstance(converted_tool_run, ToolRun)
|
||||||
|
assert converted_llm_run == expected_llm_run
|
||||||
|
assert converted_tool_run == expected_tool_run
|
||||||
|
assert converted_chain_run == expected_chain_run
|
@ -3,621 +3,90 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import List, Tuple, Union
|
from typing import Tuple
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import patch
|
||||||
from uuid import UUID, uuid4
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from freezegun import freeze_time
|
from freezegun import freeze_time
|
||||||
|
|
||||||
from langchain.callbacks.manager import CallbackManager
|
from langchain.callbacks.tracers.langchain import LangChainTracer
|
||||||
from langchain.callbacks.tracers.base import (
|
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum, TracerSession
|
||||||
BaseTracer,
|
|
||||||
ChainRun,
|
|
||||||
LLMRun,
|
|
||||||
ToolRun,
|
|
||||||
TracerException,
|
|
||||||
TracerSession,
|
|
||||||
)
|
|
||||||
from langchain.callbacks.tracers.langchain import LangChainTracerV2
|
|
||||||
from langchain.callbacks.tracers.schemas import (
|
|
||||||
RunCreate,
|
|
||||||
TracerSessionBase,
|
|
||||||
TracerSessionV2,
|
|
||||||
TracerSessionV2Create,
|
|
||||||
)
|
|
||||||
from langchain.schema import LLMResult
|
from langchain.schema import LLMResult
|
||||||
|
|
||||||
TEST_SESSION_ID = 2023
|
|
||||||
|
|
||||||
|
|
||||||
def load_session(session_name: str) -> TracerSession:
|
|
||||||
"""Load a tracing session."""
|
|
||||||
return TracerSession(id=1, name=session_name, start_time=datetime.utcnow())
|
|
||||||
|
|
||||||
|
|
||||||
def _persist_session(session: TracerSessionBase) -> TracerSession:
|
|
||||||
"""Persist a tracing session."""
|
|
||||||
return TracerSession(id=TEST_SESSION_ID, **session.dict())
|
|
||||||
|
|
||||||
|
|
||||||
def load_default_session() -> TracerSession:
|
|
||||||
"""Load a tracing session."""
|
|
||||||
return TracerSession(id=1, name="default", start_time=datetime.utcnow())
|
|
||||||
|
|
||||||
|
|
||||||
class FakeTracer(BaseTracer):
|
|
||||||
"""Fake tracer that records LangChain execution."""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
"""Initialize the tracer."""
|
|
||||||
super().__init__()
|
|
||||||
self.runs: List[Union[LLMRun, ChainRun, ToolRun]] = []
|
|
||||||
|
|
||||||
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
|
|
||||||
"""Persist a run."""
|
|
||||||
self.runs.append(run)
|
|
||||||
|
|
||||||
def _persist_session(self, session: TracerSessionBase) -> TracerSession:
|
|
||||||
"""Persist a tracing session."""
|
|
||||||
return _persist_session(session)
|
|
||||||
|
|
||||||
def load_session(self, session_name: str) -> TracerSession:
|
|
||||||
"""Load a tracing session."""
|
|
||||||
return load_session(session_name)
|
|
||||||
|
|
||||||
def load_default_session(self) -> TracerSession:
|
|
||||||
"""Load a tracing session."""
|
|
||||||
return load_default_session()
|
|
||||||
|
|
||||||
|
|
||||||
@freeze_time("2023-01-01")
|
|
||||||
def test_tracer_llm_run() -> None:
|
|
||||||
"""Test tracer on an LLM run."""
|
|
||||||
uuid = uuid4()
|
|
||||||
compare_run = LLMRun(
|
|
||||||
uuid=str(uuid),
|
|
||||||
parent_uuid=None,
|
|
||||||
start_time=datetime.utcnow(),
|
|
||||||
end_time=datetime.utcnow(),
|
|
||||||
extra={},
|
|
||||||
execution_order=1,
|
|
||||||
child_execution_order=1,
|
|
||||||
serialized={},
|
|
||||||
prompts=[],
|
|
||||||
response=LLMResult(generations=[[]]),
|
|
||||||
session_id=TEST_SESSION_ID,
|
|
||||||
error=None,
|
|
||||||
)
|
|
||||||
tracer = FakeTracer()
|
|
||||||
|
|
||||||
tracer.new_session()
|
|
||||||
tracer.on_llm_start(serialized={}, prompts=[], run_id=uuid)
|
|
||||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
|
|
||||||
assert tracer.runs == [compare_run]
|
|
||||||
|
|
||||||
|
|
||||||
@freeze_time("2023-01-01")
|
|
||||||
def test_tracer_chat_model_run() -> None:
|
|
||||||
"""Test tracer on a Chat Model run."""
|
|
||||||
uuid = uuid4()
|
|
||||||
compare_run = LLMRun(
|
|
||||||
uuid=str(uuid),
|
|
||||||
parent_uuid=None,
|
|
||||||
start_time=datetime.utcnow(),
|
|
||||||
end_time=datetime.utcnow(),
|
|
||||||
extra={},
|
|
||||||
execution_order=1,
|
|
||||||
child_execution_order=1,
|
|
||||||
serialized={},
|
|
||||||
prompts=[""],
|
|
||||||
response=LLMResult(generations=[[]]),
|
|
||||||
session_id=TEST_SESSION_ID,
|
|
||||||
error=None,
|
|
||||||
)
|
|
||||||
tracer = FakeTracer()
|
|
||||||
|
|
||||||
tracer.new_session()
|
|
||||||
manager = CallbackManager(handlers=[tracer])
|
|
||||||
run_manager = manager.on_chat_model_start(serialized={}, messages=[[]], run_id=uuid)
|
|
||||||
run_manager.on_llm_end(response=LLMResult(generations=[[]]))
|
|
||||||
assert tracer.runs == [compare_run]
|
|
||||||
|
|
||||||
|
|
||||||
@freeze_time("2023-01-01")
|
|
||||||
def test_tracer_llm_run_errors_no_start() -> None:
|
|
||||||
"""Test tracer on an LLM run without a start."""
|
|
||||||
tracer = FakeTracer()
|
|
||||||
|
|
||||||
tracer.new_session()
|
|
||||||
with pytest.raises(TracerException):
|
|
||||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid4())
|
|
||||||
|
|
||||||
|
|
||||||
@freeze_time("2023-01-01")
|
|
||||||
def test_tracer_multiple_llm_runs() -> None:
|
|
||||||
"""Test the tracer with multiple runs."""
|
|
||||||
uuid = uuid4()
|
|
||||||
compare_run = LLMRun(
|
|
||||||
uuid=str(uuid),
|
|
||||||
parent_uuid=None,
|
|
||||||
start_time=datetime.utcnow(),
|
|
||||||
end_time=datetime.utcnow(),
|
|
||||||
extra={},
|
|
||||||
execution_order=1,
|
|
||||||
child_execution_order=1,
|
|
||||||
serialized={},
|
|
||||||
prompts=[],
|
|
||||||
response=LLMResult(generations=[[]]),
|
|
||||||
session_id=TEST_SESSION_ID,
|
|
||||||
error=None,
|
|
||||||
)
|
|
||||||
tracer = FakeTracer()
|
|
||||||
|
|
||||||
tracer.new_session()
|
|
||||||
num_runs = 10
|
|
||||||
for _ in range(num_runs):
|
|
||||||
tracer.on_llm_start(serialized={}, prompts=[], run_id=uuid)
|
|
||||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
|
|
||||||
|
|
||||||
assert tracer.runs == [compare_run] * num_runs
|
|
||||||
|
|
||||||
|
|
||||||
@freeze_time("2023-01-01")
|
|
||||||
def test_tracer_chain_run() -> None:
|
|
||||||
"""Test tracer on a Chain run."""
|
|
||||||
uuid = uuid4()
|
|
||||||
compare_run = ChainRun(
|
|
||||||
uuid=str(uuid),
|
|
||||||
parent_uuid=None,
|
|
||||||
start_time=datetime.utcnow(),
|
|
||||||
end_time=datetime.utcnow(),
|
|
||||||
extra={},
|
|
||||||
execution_order=1,
|
|
||||||
child_execution_order=1,
|
|
||||||
serialized={},
|
|
||||||
inputs={},
|
|
||||||
outputs={},
|
|
||||||
session_id=TEST_SESSION_ID,
|
|
||||||
error=None,
|
|
||||||
)
|
|
||||||
tracer = FakeTracer()
|
|
||||||
|
|
||||||
tracer.new_session()
|
|
||||||
tracer.on_chain_start(serialized={}, inputs={}, run_id=uuid)
|
|
||||||
tracer.on_chain_end(outputs={}, run_id=uuid)
|
|
||||||
assert tracer.runs == [compare_run]
|
|
||||||
|
|
||||||
|
|
||||||
@freeze_time("2023-01-01")
|
|
||||||
def test_tracer_tool_run() -> None:
|
|
||||||
"""Test tracer on a Tool run."""
|
|
||||||
uuid = uuid4()
|
|
||||||
compare_run = ToolRun(
|
|
||||||
uuid=str(uuid),
|
|
||||||
parent_uuid=None,
|
|
||||||
start_time=datetime.utcnow(),
|
|
||||||
end_time=datetime.utcnow(),
|
|
||||||
extra={},
|
|
||||||
execution_order=1,
|
|
||||||
child_execution_order=1,
|
|
||||||
serialized={},
|
|
||||||
tool_input="test",
|
|
||||||
output="test",
|
|
||||||
action="{}",
|
|
||||||
session_id=TEST_SESSION_ID,
|
|
||||||
error=None,
|
|
||||||
)
|
|
||||||
tracer = FakeTracer()
|
|
||||||
|
|
||||||
tracer.new_session()
|
|
||||||
tracer.on_tool_start(serialized={}, input_str="test", run_id=uuid)
|
|
||||||
tracer.on_tool_end("test", run_id=uuid)
|
|
||||||
assert tracer.runs == [compare_run]
|
|
||||||
|
|
||||||
|
|
||||||
@freeze_time("2023-01-01")
|
|
||||||
def test_tracer_nested_run() -> None:
|
|
||||||
"""Test tracer on a nested run."""
|
|
||||||
tracer = FakeTracer()
|
|
||||||
tracer.new_session()
|
|
||||||
|
|
||||||
chain_uuid = uuid4()
|
|
||||||
tool_uuid = uuid4()
|
|
||||||
llm_uuid1 = uuid4()
|
|
||||||
llm_uuid2 = uuid4()
|
|
||||||
for _ in range(10):
|
|
||||||
tracer.on_chain_start(serialized={}, inputs={}, run_id=chain_uuid)
|
|
||||||
tracer.on_tool_start(
|
|
||||||
serialized={}, input_str="test", run_id=tool_uuid, parent_run_id=chain_uuid
|
|
||||||
)
|
|
||||||
tracer.on_llm_start(
|
|
||||||
serialized={}, prompts=[], run_id=llm_uuid1, parent_run_id=tool_uuid
|
|
||||||
)
|
|
||||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
|
|
||||||
tracer.on_tool_end("test", run_id=tool_uuid)
|
|
||||||
tracer.on_llm_start(
|
|
||||||
serialized={}, prompts=[], run_id=llm_uuid2, parent_run_id=chain_uuid
|
|
||||||
)
|
|
||||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid2)
|
|
||||||
tracer.on_chain_end(outputs={}, run_id=chain_uuid)
|
|
||||||
|
|
||||||
compare_run = ChainRun(
|
|
||||||
uuid=str(chain_uuid),
|
|
||||||
error=None,
|
|
||||||
start_time=datetime.utcnow(),
|
|
||||||
end_time=datetime.utcnow(),
|
|
||||||
extra={},
|
|
||||||
execution_order=1,
|
|
||||||
child_execution_order=4,
|
|
||||||
serialized={},
|
|
||||||
inputs={},
|
|
||||||
outputs={},
|
|
||||||
session_id=TEST_SESSION_ID,
|
|
||||||
child_chain_runs=[],
|
|
||||||
child_tool_runs=[
|
|
||||||
ToolRun(
|
|
||||||
uuid=str(tool_uuid),
|
|
||||||
parent_uuid=str(chain_uuid),
|
|
||||||
start_time=datetime.utcnow(),
|
|
||||||
end_time=datetime.utcnow(),
|
|
||||||
extra={},
|
|
||||||
execution_order=2,
|
|
||||||
child_execution_order=3,
|
|
||||||
serialized={},
|
|
||||||
tool_input="test",
|
|
||||||
output="test",
|
|
||||||
action="{}",
|
|
||||||
session_id=TEST_SESSION_ID,
|
|
||||||
error=None,
|
|
||||||
child_chain_runs=[],
|
|
||||||
child_tool_runs=[],
|
|
||||||
child_llm_runs=[
|
|
||||||
LLMRun(
|
|
||||||
uuid=str(llm_uuid1),
|
|
||||||
parent_uuid=str(tool_uuid),
|
|
||||||
error=None,
|
|
||||||
start_time=datetime.utcnow(),
|
|
||||||
end_time=datetime.utcnow(),
|
|
||||||
extra={},
|
|
||||||
execution_order=3,
|
|
||||||
child_execution_order=3,
|
|
||||||
serialized={},
|
|
||||||
prompts=[],
|
|
||||||
response=LLMResult(generations=[[]]),
|
|
||||||
session_id=TEST_SESSION_ID,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
),
|
|
||||||
],
|
|
||||||
child_llm_runs=[
|
|
||||||
LLMRun(
|
|
||||||
uuid=str(llm_uuid2),
|
|
||||||
parent_uuid=str(chain_uuid),
|
|
||||||
error=None,
|
|
||||||
start_time=datetime.utcnow(),
|
|
||||||
end_time=datetime.utcnow(),
|
|
||||||
extra={},
|
|
||||||
execution_order=4,
|
|
||||||
child_execution_order=4,
|
|
||||||
serialized={},
|
|
||||||
prompts=[],
|
|
||||||
response=LLMResult(generations=[[]]),
|
|
||||||
session_id=TEST_SESSION_ID,
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
assert tracer.runs == [compare_run] * 10
|
|
||||||
|
|
||||||
|
|
||||||
@freeze_time("2023-01-01")
|
|
||||||
def test_tracer_llm_run_on_error() -> None:
|
|
||||||
"""Test tracer on an LLM run with an error."""
|
|
||||||
exception = Exception("test")
|
|
||||||
uuid = uuid4()
|
|
||||||
|
|
||||||
compare_run = LLMRun(
|
|
||||||
uuid=str(uuid),
|
|
||||||
parent_uuid=None,
|
|
||||||
start_time=datetime.utcnow(),
|
|
||||||
end_time=datetime.utcnow(),
|
|
||||||
extra={},
|
|
||||||
execution_order=1,
|
|
||||||
child_execution_order=1,
|
|
||||||
serialized={},
|
|
||||||
prompts=[],
|
|
||||||
response=None,
|
|
||||||
session_id=TEST_SESSION_ID,
|
|
||||||
error=repr(exception),
|
|
||||||
)
|
|
||||||
tracer = FakeTracer()
|
|
||||||
|
|
||||||
tracer.new_session()
|
|
||||||
tracer.on_llm_start(serialized={}, prompts=[], run_id=uuid)
|
|
||||||
tracer.on_llm_error(exception, run_id=uuid)
|
|
||||||
assert tracer.runs == [compare_run]
|
|
||||||
|
|
||||||
|
|
||||||
@freeze_time("2023-01-01")
|
|
||||||
def test_tracer_chain_run_on_error() -> None:
|
|
||||||
"""Test tracer on a Chain run with an error."""
|
|
||||||
exception = Exception("test")
|
|
||||||
uuid = uuid4()
|
|
||||||
|
|
||||||
compare_run = ChainRun(
|
|
||||||
uuid=str(uuid),
|
|
||||||
parent_uuid=None,
|
|
||||||
start_time=datetime.utcnow(),
|
|
||||||
end_time=datetime.utcnow(),
|
|
||||||
extra={},
|
|
||||||
execution_order=1,
|
|
||||||
child_execution_order=1,
|
|
||||||
serialized={},
|
|
||||||
inputs={},
|
|
||||||
outputs=None,
|
|
||||||
session_id=TEST_SESSION_ID,
|
|
||||||
error=repr(exception),
|
|
||||||
)
|
|
||||||
tracer = FakeTracer()
|
|
||||||
|
|
||||||
tracer.new_session()
|
|
||||||
tracer.on_chain_start(serialized={}, inputs={}, run_id=uuid)
|
|
||||||
tracer.on_chain_error(exception, run_id=uuid)
|
|
||||||
assert tracer.runs == [compare_run]
|
|
||||||
|
|
||||||
|
|
||||||
@freeze_time("2023-01-01")
|
|
||||||
def test_tracer_tool_run_on_error() -> None:
|
|
||||||
"""Test tracer on a Tool run with an error."""
|
|
||||||
exception = Exception("test")
|
|
||||||
uuid = uuid4()
|
|
||||||
|
|
||||||
compare_run = ToolRun(
|
|
||||||
uuid=str(uuid),
|
|
||||||
parent_uuid=None,
|
|
||||||
start_time=datetime.utcnow(),
|
|
||||||
end_time=datetime.utcnow(),
|
|
||||||
extra={},
|
|
||||||
execution_order=1,
|
|
||||||
child_execution_order=1,
|
|
||||||
serialized={},
|
|
||||||
tool_input="test",
|
|
||||||
output=None,
|
|
||||||
action="{}",
|
|
||||||
session_id=TEST_SESSION_ID,
|
|
||||||
error=repr(exception),
|
|
||||||
)
|
|
||||||
tracer = FakeTracer()
|
|
||||||
|
|
||||||
tracer.new_session()
|
|
||||||
tracer.on_tool_start(serialized={}, input_str="test", run_id=uuid)
|
|
||||||
tracer.on_tool_error(exception, run_id=uuid)
|
|
||||||
assert tracer.runs == [compare_run]
|
|
||||||
|
|
||||||
|
|
||||||
@freeze_time("2023-01-01")
|
|
||||||
def test_tracer_nested_runs_on_error() -> None:
|
|
||||||
"""Test tracer on a nested run with an error."""
|
|
||||||
exception = Exception("test")
|
|
||||||
|
|
||||||
tracer = FakeTracer()
|
|
||||||
tracer.new_session()
|
|
||||||
chain_uuid = uuid4()
|
|
||||||
tool_uuid = uuid4()
|
|
||||||
llm_uuid1 = uuid4()
|
|
||||||
llm_uuid2 = uuid4()
|
|
||||||
llm_uuid3 = uuid4()
|
|
||||||
|
|
||||||
for _ in range(3):
|
|
||||||
tracer.on_chain_start(serialized={}, inputs={}, run_id=chain_uuid)
|
|
||||||
tracer.on_llm_start(
|
|
||||||
serialized={}, prompts=[], run_id=llm_uuid1, parent_run_id=chain_uuid
|
|
||||||
)
|
|
||||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
|
|
||||||
tracer.on_llm_start(
|
|
||||||
serialized={}, prompts=[], run_id=llm_uuid2, parent_run_id=chain_uuid
|
|
||||||
)
|
|
||||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid2)
|
|
||||||
tracer.on_tool_start(
|
|
||||||
serialized={}, input_str="test", run_id=tool_uuid, parent_run_id=chain_uuid
|
|
||||||
)
|
|
||||||
tracer.on_llm_start(
|
|
||||||
serialized={}, prompts=[], run_id=llm_uuid3, parent_run_id=tool_uuid
|
|
||||||
)
|
|
||||||
tracer.on_llm_error(exception, run_id=llm_uuid3)
|
|
||||||
tracer.on_tool_error(exception, run_id=tool_uuid)
|
|
||||||
tracer.on_chain_error(exception, run_id=chain_uuid)
|
|
||||||
|
|
||||||
compare_run = ChainRun(
|
|
||||||
uuid=str(chain_uuid),
|
|
||||||
start_time=datetime.utcnow(),
|
|
||||||
end_time=datetime.utcnow(),
|
|
||||||
extra={},
|
|
||||||
execution_order=1,
|
|
||||||
child_execution_order=5,
|
|
||||||
serialized={},
|
|
||||||
session_id=TEST_SESSION_ID,
|
|
||||||
error=repr(exception),
|
|
||||||
inputs={},
|
|
||||||
outputs=None,
|
|
||||||
child_llm_runs=[
|
|
||||||
LLMRun(
|
|
||||||
uuid=str(llm_uuid1),
|
|
||||||
parent_uuid=str(chain_uuid),
|
|
||||||
start_time=datetime.utcnow(),
|
|
||||||
end_time=datetime.utcnow(),
|
|
||||||
extra={},
|
|
||||||
execution_order=2,
|
|
||||||
child_execution_order=2,
|
|
||||||
serialized={},
|
|
||||||
session_id=TEST_SESSION_ID,
|
|
||||||
error=None,
|
|
||||||
prompts=[],
|
|
||||||
response=LLMResult(generations=[[]], llm_output=None),
|
|
||||||
),
|
|
||||||
LLMRun(
|
|
||||||
uuid=str(llm_uuid2),
|
|
||||||
parent_uuid=str(chain_uuid),
|
|
||||||
start_time=datetime.utcnow(),
|
|
||||||
end_time=datetime.utcnow(),
|
|
||||||
extra={},
|
|
||||||
execution_order=3,
|
|
||||||
child_execution_order=3,
|
|
||||||
serialized={},
|
|
||||||
session_id=TEST_SESSION_ID,
|
|
||||||
error=None,
|
|
||||||
prompts=[],
|
|
||||||
response=LLMResult(generations=[[]], llm_output=None),
|
|
||||||
),
|
|
||||||
],
|
|
||||||
child_chain_runs=[],
|
|
||||||
child_tool_runs=[
|
|
||||||
ToolRun(
|
|
||||||
uuid=str(tool_uuid),
|
|
||||||
parent_uuid=str(chain_uuid),
|
|
||||||
start_time=datetime.utcnow(),
|
|
||||||
end_time=datetime.utcnow(),
|
|
||||||
extra={},
|
|
||||||
execution_order=4,
|
|
||||||
child_execution_order=5,
|
|
||||||
serialized={},
|
|
||||||
session_id=TEST_SESSION_ID,
|
|
||||||
error=repr(exception),
|
|
||||||
tool_input="test",
|
|
||||||
output=None,
|
|
||||||
action="{}",
|
|
||||||
child_llm_runs=[
|
|
||||||
LLMRun(
|
|
||||||
uuid=str(llm_uuid3),
|
|
||||||
parent_uuid=str(tool_uuid),
|
|
||||||
start_time=datetime.utcnow(),
|
|
||||||
end_time=datetime.utcnow(),
|
|
||||||
extra={},
|
|
||||||
execution_order=5,
|
|
||||||
child_execution_order=5,
|
|
||||||
serialized={},
|
|
||||||
session_id=TEST_SESSION_ID,
|
|
||||||
error=repr(exception),
|
|
||||||
prompts=[],
|
|
||||||
response=None,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
child_chain_runs=[],
|
|
||||||
child_tool_runs=[],
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
assert tracer.runs == [compare_run] * 3
|
|
||||||
|
|
||||||
|
|
||||||
_SESSION_ID = UUID("4fbf7c55-2727-4711-8964-d821ed4d4e2a")
|
_SESSION_ID = UUID("4fbf7c55-2727-4711-8964-d821ed4d4e2a")
|
||||||
_TENANT_ID = UUID("57a08cc4-73d2-4236-8378-549099d07fad")
|
_TENANT_ID = UUID("57a08cc4-73d2-4236-8378-549099d07fad")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def lang_chain_tracer_v2(monkeypatch: pytest.MonkeyPatch) -> LangChainTracerV2:
|
def lang_chain_tracer_v2(monkeypatch: pytest.MonkeyPatch) -> LangChainTracer:
|
||||||
monkeypatch.setenv("LANGCHAIN_TENANT_ID", "test-tenant-id")
|
monkeypatch.setenv("LANGCHAIN_TENANT_ID", "test-tenant-id")
|
||||||
monkeypatch.setenv("LANGCHAIN_ENDPOINT", "http://test-endpoint.com")
|
monkeypatch.setenv("LANGCHAIN_ENDPOINT", "http://test-endpoint.com")
|
||||||
monkeypatch.setenv("LANGCHAIN_API_KEY", "foo")
|
monkeypatch.setenv("LANGCHAIN_API_KEY", "foo")
|
||||||
tracer = LangChainTracerV2()
|
tracer = LangChainTracer()
|
||||||
return tracer
|
return tracer
|
||||||
|
|
||||||
|
|
||||||
# Mock a sample TracerSessionV2 object
|
# Mock a sample TracerSession object
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_tracer_session_v2() -> TracerSessionV2:
|
def sample_tracer_session_v2() -> TracerSession:
|
||||||
return TracerSessionV2(id=_SESSION_ID, name="Sample session", tenant_id=_TENANT_ID)
|
return TracerSession(id=_SESSION_ID, name="Sample session", tenant_id=_TENANT_ID)
|
||||||
|
|
||||||
|
|
||||||
# Mock a sample LLMRun, ChainRun, and ToolRun objects
|
@freeze_time("2023-01-01")
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_runs() -> Tuple[LLMRun, ChainRun, ToolRun]:
|
def sample_runs() -> Tuple[Run, Run, Run]:
|
||||||
llm_run = LLMRun(
|
llm_run = Run(
|
||||||
uuid="57a08cc4-73d2-4236-8370-549099d07fad",
|
id="57a08cc4-73d2-4236-8370-549099d07fad",
|
||||||
name="llm_run",
|
name="llm_run",
|
||||||
execution_order=1,
|
execution_order=1,
|
||||||
child_execution_order=1,
|
child_execution_order=1,
|
||||||
|
parent_run_id="57a08cc4-73d2-4236-8371-549099d07fad",
|
||||||
|
start_time=datetime.utcnow(),
|
||||||
|
end_time=datetime.utcnow(),
|
||||||
session_id=1,
|
session_id=1,
|
||||||
prompts=[],
|
inputs={"prompts": []},
|
||||||
response=LLMResult(generations=[[]]),
|
outputs=LLMResult(generations=[[]]).dict(),
|
||||||
serialized={},
|
serialized={},
|
||||||
extra={},
|
extra={},
|
||||||
|
run_type=RunTypeEnum.llm,
|
||||||
)
|
)
|
||||||
chain_run = ChainRun(
|
chain_run = Run(
|
||||||
uuid="57a08cc4-73d2-4236-8371-549099d07fad",
|
id="57a08cc4-73d2-4236-8371-549099d07fad",
|
||||||
name="chain_run",
|
name="chain_run",
|
||||||
execution_order=1,
|
execution_order=1,
|
||||||
|
start_time=datetime.utcnow(),
|
||||||
|
end_time=datetime.utcnow(),
|
||||||
child_execution_order=1,
|
child_execution_order=1,
|
||||||
session_id=1,
|
|
||||||
serialized={},
|
serialized={},
|
||||||
inputs={},
|
inputs={},
|
||||||
outputs=None,
|
outputs={},
|
||||||
child_llm_runs=[llm_run],
|
child_runs=[llm_run],
|
||||||
child_chain_runs=[],
|
|
||||||
child_tool_runs=[],
|
|
||||||
extra={},
|
extra={},
|
||||||
|
run_type=RunTypeEnum.chain,
|
||||||
)
|
)
|
||||||
tool_run = ToolRun(
|
|
||||||
uuid="57a08cc4-73d2-4236-8372-549099d07fad",
|
tool_run = Run(
|
||||||
|
id="57a08cc4-73d2-4236-8372-549099d07fad",
|
||||||
name="tool_run",
|
name="tool_run",
|
||||||
execution_order=1,
|
execution_order=1,
|
||||||
child_execution_order=1,
|
child_execution_order=1,
|
||||||
session_id=1,
|
inputs={"input": "test"},
|
||||||
tool_input="test",
|
start_time=datetime.utcnow(),
|
||||||
action="{}",
|
end_time=datetime.utcnow(),
|
||||||
|
outputs=None,
|
||||||
serialized={},
|
serialized={},
|
||||||
child_llm_runs=[],
|
child_runs=[],
|
||||||
child_chain_runs=[],
|
|
||||||
child_tool_runs=[],
|
|
||||||
extra={},
|
extra={},
|
||||||
|
run_type=RunTypeEnum.tool,
|
||||||
)
|
)
|
||||||
return llm_run, chain_run, tool_run
|
return llm_run, chain_run, tool_run
|
||||||
|
|
||||||
|
|
||||||
def test_get_default_query_params(lang_chain_tracer_v2: LangChainTracerV2) -> None:
|
|
||||||
expected = {"tenant_id": "test-tenant-id"}
|
|
||||||
result = lang_chain_tracer_v2._get_default_query_params()
|
|
||||||
assert result == expected
|
|
||||||
|
|
||||||
|
|
||||||
@patch("langchain.callbacks.tracers.langchain.requests.get")
|
|
||||||
def test_load_session(
|
|
||||||
mock_requests_get: Mock,
|
|
||||||
lang_chain_tracer_v2: LangChainTracerV2,
|
|
||||||
sample_tracer_session_v2: TracerSessionV2,
|
|
||||||
) -> None:
|
|
||||||
"""Test that load_session method returns a TracerSessionV2 object."""
|
|
||||||
mock_requests_get.return_value.json.return_value = [sample_tracer_session_v2.dict()]
|
|
||||||
result = lang_chain_tracer_v2.load_session("test-session-name")
|
|
||||||
mock_requests_get.assert_called_with(
|
|
||||||
"http://test-endpoint.com/sessions",
|
|
||||||
headers={"Content-Type": "application/json", "x-api-key": "foo"},
|
|
||||||
params={"tenant_id": "test-tenant-id", "name": "test-session-name"},
|
|
||||||
)
|
|
||||||
assert result == sample_tracer_session_v2
|
|
||||||
|
|
||||||
|
|
||||||
def test_convert_run(
|
|
||||||
lang_chain_tracer_v2: LangChainTracerV2,
|
|
||||||
sample_tracer_session_v2: TracerSessionV2,
|
|
||||||
sample_runs: Tuple[LLMRun, ChainRun, ToolRun],
|
|
||||||
) -> None:
|
|
||||||
llm_run, chain_run, tool_run = sample_runs
|
|
||||||
lang_chain_tracer_v2.session = sample_tracer_session_v2
|
|
||||||
converted_llm_run = lang_chain_tracer_v2._convert_run(llm_run)
|
|
||||||
converted_chain_run = lang_chain_tracer_v2._convert_run(chain_run)
|
|
||||||
converted_tool_run = lang_chain_tracer_v2._convert_run(tool_run)
|
|
||||||
|
|
||||||
assert isinstance(converted_llm_run, RunCreate)
|
|
||||||
assert isinstance(converted_chain_run, RunCreate)
|
|
||||||
assert isinstance(converted_tool_run, RunCreate)
|
|
||||||
|
|
||||||
|
|
||||||
def test_persist_run(
|
def test_persist_run(
|
||||||
lang_chain_tracer_v2: LangChainTracerV2,
|
lang_chain_tracer_v2: LangChainTracer,
|
||||||
sample_tracer_session_v2: TracerSessionV2,
|
sample_tracer_session_v2: TracerSession,
|
||||||
sample_runs: Tuple[LLMRun, ChainRun, ToolRun],
|
sample_runs: Tuple[Run, Run, Run],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test that persist_run method calls requests.post once per method call."""
|
"""Test that persist_run method calls requests.post once per method call."""
|
||||||
with patch("langchain.callbacks.tracers.langchain.requests.post") as post, patch(
|
with patch("langchain.callbacks.tracers.langchain.requests.post") as post, patch(
|
||||||
@ -625,25 +94,25 @@ def test_persist_run(
|
|||||||
) as get:
|
) as get:
|
||||||
post.return_value.raise_for_status.return_value = None
|
post.return_value.raise_for_status.return_value = None
|
||||||
lang_chain_tracer_v2.session = sample_tracer_session_v2
|
lang_chain_tracer_v2.session = sample_tracer_session_v2
|
||||||
llm_run, chain_run, tool_run = sample_runs
|
for run in sample_runs:
|
||||||
lang_chain_tracer_v2._persist_run(llm_run)
|
lang_chain_tracer_v2.run_map[str(run.id)] = run
|
||||||
lang_chain_tracer_v2._persist_run(chain_run)
|
for run in sample_runs:
|
||||||
lang_chain_tracer_v2._persist_run(tool_run)
|
lang_chain_tracer_v2._end_trace(run)
|
||||||
|
|
||||||
assert post.call_count == 3
|
assert post.call_count == 3
|
||||||
assert get.call_count == 0
|
assert get.call_count == 0
|
||||||
|
|
||||||
|
|
||||||
def test_persist_run_with_example_id(
|
def test_persist_run_with_example_id(
|
||||||
lang_chain_tracer_v2: LangChainTracerV2,
|
lang_chain_tracer_v2: LangChainTracer,
|
||||||
sample_tracer_session_v2: TracerSessionV2,
|
sample_tracer_session_v2: TracerSession,
|
||||||
sample_runs: Tuple[LLMRun, ChainRun, ToolRun],
|
sample_runs: Tuple[Run, Run, Run],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test the example ID is assigned only to the parent run and not the children."""
|
"""Test the example ID is assigned only to the parent run and not the children."""
|
||||||
example_id = uuid4()
|
example_id = uuid4()
|
||||||
llm_run, chain_run, tool_run = sample_runs
|
llm_run, chain_run, tool_run = sample_runs
|
||||||
chain_run.child_tool_runs = [tool_run]
|
chain_run.child_runs = [tool_run]
|
||||||
tool_run.child_llm_runs = [llm_run]
|
tool_run.child_runs = [llm_run]
|
||||||
with patch("langchain.callbacks.tracers.langchain.requests.post") as post, patch(
|
with patch("langchain.callbacks.tracers.langchain.requests.post") as post, patch(
|
||||||
"langchain.callbacks.tracers.langchain.requests.get"
|
"langchain.callbacks.tracers.langchain.requests.get"
|
||||||
) as get:
|
) as get:
|
||||||
@ -652,55 +121,14 @@ def test_persist_run_with_example_id(
|
|||||||
lang_chain_tracer_v2.example_id = example_id
|
lang_chain_tracer_v2.example_id = example_id
|
||||||
lang_chain_tracer_v2._persist_run(chain_run)
|
lang_chain_tracer_v2._persist_run(chain_run)
|
||||||
|
|
||||||
assert post.call_count == 1
|
assert post.call_count == 3
|
||||||
assert get.call_count == 0
|
assert get.call_count == 0
|
||||||
posted_data = json.loads(post.call_args[1]["data"])
|
posted_data = [
|
||||||
assert posted_data["id"] == chain_run.uuid
|
json.loads(call_args[1]["data"]) for call_args in post.call_args_list
|
||||||
assert posted_data["reference_example_id"] == str(example_id)
|
]
|
||||||
|
assert posted_data[0]["id"] == str(chain_run.id)
|
||||||
def assert_child_run_no_example_id(run: dict) -> None:
|
assert posted_data[0]["reference_example_id"] == str(example_id)
|
||||||
assert not run.get("reference_example_id")
|
assert posted_data[1]["id"] == str(tool_run.id)
|
||||||
for child_run in run.get("child_runs", []):
|
assert not posted_data[1].get("reference_example_id")
|
||||||
assert_child_run_no_example_id(child_run)
|
assert posted_data[2]["id"] == str(llm_run.id)
|
||||||
|
assert not posted_data[2].get("reference_example_id")
|
||||||
for child_run in posted_data["child_runs"]:
|
|
||||||
assert_child_run_no_example_id(child_run)
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_session_create(lang_chain_tracer_v2: LangChainTracerV2) -> None:
|
|
||||||
"""Test creating the 'SessionCreate' object."""
|
|
||||||
lang_chain_tracer_v2.tenant_id = str(_TENANT_ID)
|
|
||||||
session_create = lang_chain_tracer_v2._get_session_create(name="test")
|
|
||||||
assert isinstance(session_create, TracerSessionV2Create)
|
|
||||||
assert session_create.name == "test"
|
|
||||||
assert session_create.tenant_id == _TENANT_ID
|
|
||||||
|
|
||||||
|
|
||||||
@patch("langchain.callbacks.tracers.langchain.requests.post")
|
|
||||||
def test_persist_session(
|
|
||||||
mock_requests_post: Mock,
|
|
||||||
lang_chain_tracer_v2: LangChainTracerV2,
|
|
||||||
sample_tracer_session_v2: TracerSessionV2,
|
|
||||||
) -> None:
|
|
||||||
"""Test persist_session returns a TracerSessionV2 with the updated ID."""
|
|
||||||
session_create = TracerSessionV2Create(**sample_tracer_session_v2.dict())
|
|
||||||
new_id = str(uuid4())
|
|
||||||
mock_requests_post.return_value.json.return_value = {"id": new_id}
|
|
||||||
result = lang_chain_tracer_v2._persist_session(session_create)
|
|
||||||
assert isinstance(result, TracerSessionV2)
|
|
||||||
res = sample_tracer_session_v2.dict()
|
|
||||||
res["id"] = UUID(new_id)
|
|
||||||
assert result.dict() == res
|
|
||||||
|
|
||||||
|
|
||||||
@patch("langchain.callbacks.tracers.langchain.LangChainTracerV2.load_session")
|
|
||||||
def test_load_default_session(
|
|
||||||
mock_load_session: Mock,
|
|
||||||
lang_chain_tracer_v2: LangChainTracerV2,
|
|
||||||
sample_tracer_session_v2: TracerSessionV2,
|
|
||||||
) -> None:
|
|
||||||
"""Test load_default_session attempts to load with the default name."""
|
|
||||||
mock_load_session.return_value = sample_tracer_session_v2
|
|
||||||
result = lang_chain_tracer_v2.load_default_session()
|
|
||||||
assert result == sample_tracer_session_v2
|
|
||||||
mock_load_session.assert_called_with("default")
|
|
||||||
|
@ -8,8 +8,8 @@ from unittest import mock
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from langchain.base_language import BaseLanguageModel
|
from langchain.base_language import BaseLanguageModel
|
||||||
from langchain.callbacks.tracers.langchain import LangChainTracerV2
|
from langchain.callbacks.tracers.langchain import LangChainTracer
|
||||||
from langchain.callbacks.tracers.schemas import TracerSessionV2
|
from langchain.callbacks.tracers.schemas import TracerSession
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.client.langchain import (
|
from langchain.client.langchain import (
|
||||||
LangChainPlusClient,
|
LangChainPlusClient,
|
||||||
@ -196,10 +196,8 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||||||
{"result": f"Result for example {example.id}"} for _ in range(n_repetitions)
|
{"result": f"Result for example {example.id}"} for _ in range(n_repetitions)
|
||||||
]
|
]
|
||||||
|
|
||||||
def mock_load_session(
|
def mock_ensure_session(self: Any, *args: Any, **kwargs: Any) -> TracerSession:
|
||||||
self: Any, name: str, *args: Any, **kwargs: Any
|
return TracerSession(name="test_session", tenant_id=_TENANT_ID, id=uuid.uuid4())
|
||||||
) -> TracerSessionV2:
|
|
||||||
return TracerSessionV2(name=name, tenant_id=_TENANT_ID, id=uuid.uuid4())
|
|
||||||
|
|
||||||
with mock.patch.object(
|
with mock.patch.object(
|
||||||
LangChainPlusClient, "read_dataset", new=mock_read_dataset
|
LangChainPlusClient, "read_dataset", new=mock_read_dataset
|
||||||
@ -208,7 +206,7 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||||||
), mock.patch.object(
|
), mock.patch.object(
|
||||||
LangChainPlusClient, "_arun_llm_or_chain", new=mock_arun_chain
|
LangChainPlusClient, "_arun_llm_or_chain", new=mock_arun_chain
|
||||||
), mock.patch.object(
|
), mock.patch.object(
|
||||||
LangChainTracerV2, "load_session", new=mock_load_session
|
LangChainTracer, "ensure_session", new=mock_ensure_session
|
||||||
):
|
):
|
||||||
monkeypatch.setenv("LANGCHAIN_TENANT_ID", _TENANT_ID)
|
monkeypatch.setenv("LANGCHAIN_TENANT_ID", _TENANT_ID)
|
||||||
client = LangChainPlusClient(
|
client = LangChainPlusClient(
|
||||||
|
Loading…
Reference in New Issue
Block a user