diff --git a/libs/community/langchain_community/callbacks/tracers/__init__.py b/libs/community/langchain_community/callbacks/tracers/__init__.py index 6cbed4ac5db..8af691585a6 100644 --- a/libs/community/langchain_community/callbacks/tracers/__init__.py +++ b/libs/community/langchain_community/callbacks/tracers/__init__.py @@ -1,6 +1,7 @@ """Tracers that record execution of LangChain runs.""" from langchain_core.tracers.langchain import LangChainTracer +from langchain_core.tracers.langchain_v1 import LangChainTracerV1 from langchain_core.tracers.stdout import ( ConsoleCallbackHandler, FunctionCallbackHandler, @@ -12,5 +13,6 @@ __all__ = [ "ConsoleCallbackHandler", "FunctionCallbackHandler", "LangChainTracer", + "LangChainTracerV1", "WandbTracer", ] diff --git a/libs/community/tests/integration_tests/callbacks/test_langchain_tracer.py b/libs/community/tests/integration_tests/callbacks/test_langchain_tracer.py index 7f074454d5c..d4941b6a377 100644 --- a/libs/community/tests/integration_tests/callbacks/test_langchain_tracer.py +++ b/libs/community/tests/integration_tests/callbacks/test_langchain_tracer.py @@ -5,7 +5,7 @@ import os from aiohttp import ClientSession from langchain_core.callbacks.manager import atrace_as_chain_group, trace_as_chain_group from langchain_core.prompts import PromptTemplate -from langchain_core.tracers.context import tracing_v2_enabled +from langchain_core.tracers.context import tracing_enabled, tracing_v2_enabled from langchain_community.chat_models import ChatOpenAI from langchain_community.llms import OpenAI @@ -76,6 +76,63 @@ async def test_tracing_concurrent() -> None: await aiosession.close() +async def test_tracing_concurrent_bw_compat_environ() -> None: + from langchain.agents import AgentType, initialize_agent, load_tools + + os.environ["LANGCHAIN_HANDLER"] = "langchain" + if "LANGCHAIN_TRACING" in os.environ: + del os.environ["LANGCHAIN_TRACING"] + aiosession = ClientSession() + llm = OpenAI(temperature=0) + async_tools = load_tools(["llm-math", "serpapi"], llm=llm, aiosession=aiosession) + agent = initialize_agent( + async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True + ) + tasks = [agent.arun(q) for q in questions[:3]] + await asyncio.gather(*tasks) + await aiosession.close() + if "LANGCHAIN_HANDLER" in os.environ: + del os.environ["LANGCHAIN_HANDLER"] + + +def test_tracing_context_manager() -> None: + from langchain.agents import AgentType, initialize_agent, load_tools + + llm = OpenAI(temperature=0) + tools = load_tools(["llm-math", "serpapi"], llm=llm) + agent = initialize_agent( + tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True + ) + if "LANGCHAIN_TRACING" in os.environ: + del os.environ["LANGCHAIN_TRACING"] + with tracing_enabled() as session: + assert session + agent.run(questions[0]) # this should be traced + + agent.run(questions[0]) # this should not be traced + + +async def test_tracing_context_manager_async() -> None: + from langchain.agents import AgentType, initialize_agent, load_tools + + llm = OpenAI(temperature=0) + async_tools = load_tools(["llm-math", "serpapi"], llm=llm) + agent = initialize_agent( + async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True + ) + if "LANGCHAIN_TRACING" in os.environ: + del os.environ["LANGCHAIN_TRACING"] + + # start a background task + task = asyncio.create_task(agent.arun(questions[0])) # this should not be traced + with tracing_enabled() as session: + assert session + tasks = [agent.arun(q) for q in questions[1:4]] # these should be traced + await asyncio.gather(*tasks) + + await task + + async def test_tracing_v2_environment_variable() -> None: from langchain.agents import AgentType, initialize_agent, load_tools diff --git a/libs/core/langchain_core/callbacks/manager.py b/libs/core/langchain_core/callbacks/manager.py index 771edb72de2..b18f83c3aca 100644 --- a/libs/core/langchain_core/callbacks/manager.py +++ b/libs/core/langchain_core/callbacks/manager.py @@ -1847,6 +1847,7 @@ def _configure( _configure_hooks, _get_tracer_project, _tracing_v2_is_enabled, + tracing_callback_var, tracing_v2_callback_var, ) @@ -1885,12 +1886,20 @@ def _configure( callback_manager.add_metadata(inheritable_metadata or {}) callback_manager.add_metadata(local_metadata or {}, False) + tracer = tracing_callback_var.get() + tracing_enabled_ = ( + env_var_is_set("LANGCHAIN_TRACING") + or tracer is not None + or env_var_is_set("LANGCHAIN_HANDLER") + ) + tracer_v2 = tracing_v2_callback_var.get() tracing_v2_enabled_ = _tracing_v2_is_enabled() tracer_project = _get_tracer_project() debug = _get_debug() - if verbose or debug or tracing_v2_enabled_: + if verbose or debug or tracing_enabled_ or tracing_v2_enabled_: from langchain_core.tracers.langchain import LangChainTracer + from langchain_core.tracers.langchain_v1 import LangChainTracerV1 from langchain_core.tracers.stdout import ConsoleCallbackHandler if verbose and not any( @@ -1898,7 +1907,6 @@ def _configure( for handler in callback_manager.handlers ): if debug: - # We will use ConsoleCallbackHandler instead of StdOutCallbackHandler pass else: callback_manager.add_handler(StdOutCallbackHandler(), False) @@ -1907,6 +1915,16 @@ def _configure( for handler in callback_manager.handlers ): callback_manager.add_handler(ConsoleCallbackHandler(), True) + if tracing_enabled_ and not any( + isinstance(handler, LangChainTracerV1) + for handler in callback_manager.handlers + ): + if tracer: + callback_manager.add_handler(tracer, True) + else: + handler = LangChainTracerV1() + handler.load_session(tracer_project) + callback_manager.add_handler(handler, True) if tracing_v2_enabled_ and not any( isinstance(handler, LangChainTracer) for handler in callback_manager.handlers diff --git a/libs/core/langchain_core/tracers/context.py b/libs/core/langchain_core/tracers/context.py index fcd588ee354..6405fad3b95 100644 --- a/libs/core/langchain_core/tracers/context.py +++ b/libs/core/langchain_core/tracers/context.py @@ -19,7 +19,9 @@ from langsmith import utils as ls_utils from langsmith.run_helpers import get_run_tree_context from langchain_core.tracers.langchain import LangChainTracer +from langchain_core.tracers.langchain_v1 import LangChainTracerV1 from langchain_core.tracers.run_collector import RunCollectorCallbackHandler +from langchain_core.tracers.schemas import TracerSessionV1 from langchain_core.utils.env import env_var_is_set if TYPE_CHECKING: @@ -28,6 +30,10 @@ if TYPE_CHECKING: from langchain_core.callbacks.base import BaseCallbackHandler, Callbacks from langchain_core.callbacks.manager import AsyncCallbackManager, CallbackManager +tracing_callback_var: ContextVar[Optional[LangChainTracerV1]] = ContextVar( # noqa: E501 + "tracing_callback", default=None +) + tracing_v2_callback_var: ContextVar[Optional[LangChainTracer]] = ContextVar( # noqa: E501 "tracing_callback_v2", default=None ) @@ -36,6 +42,32 @@ run_collector_var: ContextVar[Optional[RunCollectorCallbackHandler]] = ContextVa ) +@contextmanager +def tracing_enabled( + session_name: str = "default", +) -> Generator[TracerSessionV1, None, None]: + """Get the Deprecated LangChainTracer in a context manager. + + Args: + session_name (str, optional): The name of the session. + Defaults to "default". + + Returns: + TracerSessionV1: The LangChainTracer session. + + Example: + >>> with tracing_enabled() as session: + ... # Use the LangChainTracer session + """ + cb = LangChainTracerV1() + session = cast(TracerSessionV1, cb.load_session(session_name)) + try: + tracing_callback_var.set(cb) + yield session + finally: + tracing_callback_var.set(None) + + @contextmanager def tracing_v2_enabled( project_name: Optional[str] = None, @@ -136,7 +168,6 @@ def _tracing_v2_is_enabled() -> bool: env_var_is_set("LANGCHAIN_TRACING_V2") or tracing_v2_callback_var.get() is not None or get_run_tree_context() is not None - or env_var_is_set("LANGCHAIN_TRACING") ) diff --git a/libs/core/langchain_core/tracers/langchain_v1.py b/libs/core/langchain_core/tracers/langchain_v1.py new file mode 100644 index 00000000000..f2178b24a61 --- /dev/null +++ b/libs/core/langchain_core/tracers/langchain_v1.py @@ -0,0 +1,185 @@ +from __future__ import annotations + +import logging +import os +from typing import Any, Dict, Optional, Union + +import requests + +from langchain_core.messages import get_buffer_string +from langchain_core.tracers.base import BaseTracer +from langchain_core.tracers.schemas import ( + ChainRun, + LLMRun, + Run, + ToolRun, + TracerSession, + TracerSessionV1, + TracerSessionV1Base, +) +from langchain_core.utils import raise_for_status_with_text + +logger = logging.getLogger(__name__) + + +def get_headers() -> Dict[str, Any]: + """Get the headers for the LangChain API.""" + headers: Dict[str, Any] = {"Content-Type": "application/json"} + if os.getenv("LANGCHAIN_API_KEY"): + headers["x-api-key"] = os.getenv("LANGCHAIN_API_KEY") + return headers + + +def _get_endpoint() -> str: + return os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:8000") + + +class LangChainTracerV1(BaseTracer): + """An implementation of the SharedTracer that POSTS to the langchain endpoint.""" + + def __init__(self, **kwargs: Any) -> None: + """Initialize the LangChain tracer.""" + super().__init__(**kwargs) + self.session: Optional[TracerSessionV1] = None + self._endpoint = _get_endpoint() + self._headers = get_headers() + + def _convert_to_v1_run(self, run: Run) -> Union[LLMRun, ChainRun, ToolRun]: + session = self.session or self.load_default_session() + if not isinstance(session, TracerSessionV1): + raise ValueError( + "LangChainTracerV1 is not compatible with" + f" session of type {type(session)}" + ) + + if run.run_type == "llm": + if "prompts" in run.inputs: + prompts = run.inputs["prompts"] + elif "messages" in run.inputs: + prompts = [get_buffer_string(batch) for batch in run.inputs["messages"]] + else: + raise ValueError("No prompts found in LLM run inputs") + return LLMRun( + uuid=str(run.id) if run.id else None, + parent_uuid=str(run.parent_run_id) if run.parent_run_id else None, + start_time=run.start_time, + end_time=run.end_time, + extra=run.extra, + execution_order=run.execution_order, + child_execution_order=run.child_execution_order, + serialized=run.serialized, + session_id=session.id, + error=run.error, + prompts=prompts, + response=run.outputs if run.outputs else None, + ) + if run.run_type == "chain": + child_runs = [self._convert_to_v1_run(run) for run in run.child_runs] + return ChainRun( + uuid=str(run.id) if run.id else None, + parent_uuid=str(run.parent_run_id) if run.parent_run_id else None, + start_time=run.start_time, + end_time=run.end_time, + execution_order=run.execution_order, + child_execution_order=run.child_execution_order, + serialized=run.serialized, + session_id=session.id, + inputs=run.inputs, + outputs=run.outputs, + error=run.error, + extra=run.extra, + child_llm_runs=[run for run in child_runs if isinstance(run, LLMRun)], + child_chain_runs=[ + run for run in child_runs if isinstance(run, ChainRun) + ], + child_tool_runs=[run for run in child_runs if isinstance(run, ToolRun)], + ) + if run.run_type == "tool": + child_runs = [self._convert_to_v1_run(run) for run in run.child_runs] + return ToolRun( + uuid=str(run.id) if run.id else None, + parent_uuid=str(run.parent_run_id) if run.parent_run_id else None, + start_time=run.start_time, + end_time=run.end_time, + execution_order=run.execution_order, + child_execution_order=run.child_execution_order, + serialized=run.serialized, + session_id=session.id, + action=str(run.serialized), + tool_input=run.inputs.get("input", ""), + output=None if run.outputs is None else run.outputs.get("output"), + error=run.error, + extra=run.extra, + child_chain_runs=[ + run for run in child_runs if isinstance(run, ChainRun) + ], + child_tool_runs=[run for run in child_runs if isinstance(run, ToolRun)], + child_llm_runs=[run for run in child_runs if isinstance(run, LLMRun)], + ) + raise ValueError(f"Unknown run type: {run.run_type}") + + def _persist_run(self, run: Union[Run, LLMRun, ChainRun, ToolRun]) -> None: + """Persist a run.""" + if isinstance(run, Run): + v1_run = self._convert_to_v1_run(run) + else: + v1_run = run + if isinstance(v1_run, LLMRun): + endpoint = f"{self._endpoint}/llm-runs" + elif isinstance(v1_run, ChainRun): + endpoint = f"{self._endpoint}/chain-runs" + else: + endpoint = f"{self._endpoint}/tool-runs" + + try: + response = requests.post( + endpoint, + data=v1_run.json(), + headers=self._headers, + ) + raise_for_status_with_text(response) + except Exception as e: + logger.warning(f"Failed to persist run: {e}") + + def _persist_session( + self, session_create: TracerSessionV1Base + ) -> Union[TracerSessionV1, TracerSession]: + """Persist a session.""" + try: + r = requests.post( + f"{self._endpoint}/sessions", + data=session_create.json(), + headers=self._headers, + ) + session = TracerSessionV1(id=r.json()["id"], **session_create.dict()) + except Exception as e: + logger.warning(f"Failed to create session, using default session: {e}") + session = TracerSessionV1(id=1, **session_create.dict()) + return session + + def _load_session(self, session_name: Optional[str] = None) -> TracerSessionV1: + """Load a session from the tracer.""" + try: + url = f"{self._endpoint}/sessions" + if session_name: + url += f"?name={session_name}" + r = requests.get(url, headers=self._headers) + + tracer_session = TracerSessionV1(**r.json()[0]) + except Exception as e: + session_type = "default" if not session_name else session_name + logger.warning( + f"Failed to load {session_type} session, using empty session: {e}" + ) + tracer_session = TracerSessionV1(id=1) + + self.session = tracer_session + return tracer_session + + def load_session(self, session_name: str) -> Union[TracerSessionV1, TracerSession]: + """Load a session with the given name from the tracer.""" + return self._load_session(session_name) + + def load_default_session(self) -> Union[TracerSessionV1, TracerSession]: + """Load the default tracing session and set it as the Tracer's session.""" + return self._load_session("default") diff --git a/libs/core/langchain_core/tracers/schemas.py b/libs/core/langchain_core/tracers/schemas.py index bf9359ccc2f..126351fdfba 100644 --- a/libs/core/langchain_core/tracers/schemas.py +++ b/libs/core/langchain_core/tracers/schemas.py @@ -1,12 +1,102 @@ """Schemas for tracers.""" from __future__ import annotations -from typing import Any, Dict, List, Optional +import datetime +import warnings +from typing import Any, Dict, List, Optional, Type from uuid import UUID from langsmith.schemas import RunBase as BaseRunV2 +from langsmith.schemas import RunTypeEnum as RunTypeEnumDep -from langchain_core.pydantic_v1 import Field, root_validator +from langchain_core.outputs import LLMResult +from langchain_core.pydantic_v1 import BaseModel, Field, root_validator + + +def RunTypeEnum() -> Type[RunTypeEnumDep]: + """RunTypeEnum.""" + warnings.warn( + "RunTypeEnum is deprecated. Please directly use a string instead" + " (e.g. 'llm', 'chain', 'tool').", + DeprecationWarning, + ) + return RunTypeEnumDep + + +class TracerSessionV1Base(BaseModel): + """Base class for TracerSessionV1.""" + + start_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow) + name: Optional[str] = None + extra: Optional[Dict[str, Any]] = None + + +class TracerSessionV1Create(TracerSessionV1Base): + """Create class for TracerSessionV1.""" + + +class TracerSessionV1(TracerSessionV1Base): + """TracerSessionV1 schema.""" + + id: int + + +class TracerSessionBase(TracerSessionV1Base): + """Base class for TracerSession.""" + + tenant_id: UUID + + +class TracerSession(TracerSessionBase): + """TracerSessionV1 schema for the V2 API.""" + + id: UUID + + +class BaseRun(BaseModel): + """Base class for Run.""" + + uuid: str + parent_uuid: Optional[str] = None + start_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow) + end_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow) + extra: Optional[Dict[str, Any]] = None + execution_order: int + child_execution_order: int + serialized: Dict[str, Any] + session_id: int + error: Optional[str] = None + + +class LLMRun(BaseRun): + """Class for LLMRun.""" + + prompts: List[str] + response: Optional[LLMResult] = None + + +class ChainRun(BaseRun): + """Class for ChainRun.""" + + inputs: Dict[str, Any] + outputs: Optional[Dict[str, Any]] = None + child_llm_runs: List[LLMRun] = Field(default_factory=list) + child_chain_runs: List[ChainRun] = Field(default_factory=list) + child_tool_runs: List[ToolRun] = Field(default_factory=list) + + +class ToolRun(BaseRun): + """Class for ToolRun.""" + + tool_input: str + output: Optional[str] = None + action: str + child_llm_runs: List[LLMRun] = Field(default_factory=list) + child_chain_runs: List[ChainRun] = Field(default_factory=list) + child_tool_runs: List[ToolRun] = Field(default_factory=list) + + +# Begin V2 API Schemas class Run(BaseRunV2): @@ -33,8 +123,20 @@ class Run(BaseRunV2): return values +ChainRun.update_forward_refs() +ToolRun.update_forward_refs() Run.update_forward_refs() __all__ = [ + "BaseRun", + "ChainRun", + "LLMRun", "Run", + "RunTypeEnum", + "ToolRun", + "TracerSession", + "TracerSessionBase", + "TracerSessionV1", + "TracerSessionV1Base", + "TracerSessionV1Create", ] diff --git a/libs/core/tests/unit_tests/callbacks/tracers/__init__.py b/libs/core/tests/unit_tests/callbacks/tracers/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/libs/core/tests/unit_tests/callbacks/tracers/test_base_tracer.py b/libs/core/tests/unit_tests/tracers/test_base_tracer.py similarity index 98% rename from libs/core/tests/unit_tests/callbacks/tracers/test_base_tracer.py rename to libs/core/tests/unit_tests/tracers/test_base_tracer.py index 121f2e4b3e4..d3657d25873 100644 --- a/libs/core/tests/unit_tests/callbacks/tracers/test_base_tracer.py +++ b/libs/core/tests/unit_tests/tracers/test_base_tracer.py @@ -2,7 +2,7 @@ from __future__ import annotations from datetime import datetime, timezone -from typing import List +from typing import Any, List from uuid import uuid4 import pytest @@ -31,17 +31,17 @@ class FakeTracer(BaseTracer): self.runs.append(run) -def _compare_run_with_error(run: Run, expected_run: Run) -> None: +def _compare_run_with_error(run: Any, expected_run: Any) -> None: if run.child_runs: assert len(expected_run.child_runs) == len(run.child_runs) for received, expected in zip(run.child_runs, expected_run.child_runs): _compare_run_with_error(received, expected) - received_dict = run.dict(exclude={"child_runs"}) - received_err = received_dict.pop("error") - expected_dict = expected_run.dict(exclude={"child_runs"}) - expected_err = expected_dict.pop("error") + received = run.dict(exclude={"child_runs"}) + received_err = received.pop("error") + expected = expected_run.dict(exclude={"child_runs"}) + expected_err = expected.pop("error") - assert received_dict == expected_dict + assert received == expected if expected_err is not None: assert received_err is not None assert expected_err in received_err @@ -406,7 +406,6 @@ def test_tracer_llm_run_on_error_callback() -> None: tracer = FakeTracerWithLlmErrorCallback() tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid) tracer.on_llm_error(exception, run_id=uuid) - assert tracer.error_run is not None _compare_run_with_error(tracer.error_run, compare_run) diff --git a/libs/core/tests/unit_tests/callbacks/tracers/test_langchain.py b/libs/core/tests/unit_tests/tracers/test_langchain.py similarity index 100% rename from libs/core/tests/unit_tests/callbacks/tracers/test_langchain.py rename to libs/core/tests/unit_tests/tracers/test_langchain.py diff --git a/libs/core/tests/unit_tests/tracers/test_langchain_v1.py b/libs/core/tests/unit_tests/tracers/test_langchain_v1.py new file mode 100644 index 00000000000..41edb4d7bdd --- /dev/null +++ b/libs/core/tests/unit_tests/tracers/test_langchain_v1.py @@ -0,0 +1,562 @@ +"""Test Tracer classes.""" +from __future__ import annotations + +from datetime import datetime, timezone +from typing import Any, List, Optional, Union +from uuid import uuid4 + +import pytest +from freezegun import freeze_time + +from langchain_core.callbacks import CallbackManager +from langchain_core.messages import HumanMessage +from langchain_core.outputs import LLMResult +from langchain_core.tracers.base import BaseTracer, TracerException +from langchain_core.tracers.langchain_v1 import ( + ChainRun, + LangChainTracerV1, + LLMRun, + ToolRun, + TracerSessionV1, +) +from langchain_core.tracers.schemas import Run, TracerSessionV1Base + +TEST_SESSION_ID = 2023 + +SERIALIZED = {"id": ["llm"]} +SERIALIZED_CHAT = {"id": ["chat_model"]} + + +def load_session(session_name: str) -> TracerSessionV1: + """Load a tracing session.""" + return TracerSessionV1( + id=TEST_SESSION_ID, name=session_name, start_time=datetime.now(timezone.utc) + ) + + +def new_session(name: Optional[str] = None) -> TracerSessionV1: + """Create a new tracing session.""" + return TracerSessionV1( + id=TEST_SESSION_ID, + name=name or "default", + start_time=datetime.now(timezone.utc), + ) + + +def _persist_session(session: TracerSessionV1Base) -> TracerSessionV1: + """Persist a tracing session.""" + return TracerSessionV1(**{**session.dict(), "id": TEST_SESSION_ID}) + + +def load_default_session() -> TracerSessionV1: + """Load a tracing session.""" + return TracerSessionV1( + id=TEST_SESSION_ID, name="default", start_time=datetime.now(timezone.utc) + ) + + +@pytest.fixture +def lang_chain_tracer_v1(monkeypatch: pytest.MonkeyPatch) -> LangChainTracerV1: + monkeypatch.setenv("LANGCHAIN_TENANT_ID", "test-tenant-id") + monkeypatch.setenv("LANGCHAIN_ENDPOINT", "http://test-endpoint.com") + monkeypatch.setenv("LANGCHAIN_API_KEY", "foo") + tracer = LangChainTracerV1() + return tracer + + +class FakeTracer(BaseTracer): + """Fake tracer that records LangChain execution.""" + + def __init__(self) -> None: + """Initialize the tracer.""" + super().__init__() + self.runs: List[Union[LLMRun, ChainRun, ToolRun]] = [] + + def _persist_run(self, run: Union[Run, LLMRun, ChainRun, ToolRun]) -> None: + """Persist a run.""" + if isinstance(run, Run): + with pytest.MonkeyPatch().context() as m: + m.setenv("LANGCHAIN_TENANT_ID", "test-tenant-id") + m.setenv("LANGCHAIN_ENDPOINT", "http://test-endpoint.com") + m.setenv("LANGCHAIN_API_KEY", "foo") + tracer = LangChainTracerV1() + tracer.load_default_session = load_default_session # type: ignore + run = tracer._convert_to_v1_run(run) + self.runs.append(run) + + def _persist_session(self, session: TracerSessionV1Base) -> TracerSessionV1: + """Persist a tracing session.""" + return _persist_session(session) + + def new_session(self, name: Optional[str] = None) -> TracerSessionV1: + """Create a new tracing session.""" + return new_session(name) + + def load_session(self, session_name: str) -> TracerSessionV1: + """Load a tracing session.""" + return load_session(session_name) + + def load_default_session(self) -> TracerSessionV1: + """Load a tracing session.""" + return load_default_session() + + +def _compare_run_with_error(run: Any, expected_run: Any) -> None: + received = run.dict() + received_err = received.pop("error") + expected = expected_run.dict() + expected_err = expected.pop("error") + assert received == expected + assert expected_err in received_err + + +@freeze_time("2023-01-01") +def test_tracer_llm_run() -> None: + """Test tracer on an LLM run.""" + uuid = uuid4() + compare_run = LLMRun( + uuid=str(uuid), + parent_uuid=None, + start_time=datetime.now(timezone.utc), + end_time=datetime.now(timezone.utc), + extra={}, + execution_order=1, + child_execution_order=1, + serialized=SERIALIZED, + prompts=[], + response=LLMResult(generations=[[]]), + session_id=TEST_SESSION_ID, + error=None, + ) + tracer = FakeTracer() + + tracer.new_session() + tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid) + tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid) + assert tracer.runs == [compare_run] + + +@freeze_time("2023-01-01") +def test_tracer_chat_model_run() -> None: + """Test tracer on a Chat Model run.""" + tracer = FakeTracer() + + tracer.new_session() + manager = CallbackManager(handlers=[tracer]) + run_managers = manager.on_chat_model_start( + serialized=SERIALIZED_CHAT, messages=[[HumanMessage(content="")]] + ) + compare_run = LLMRun( + uuid=str(run_managers[0].run_id), + parent_uuid=None, + start_time=datetime.now(timezone.utc), + end_time=datetime.now(timezone.utc), + extra={}, + execution_order=1, + child_execution_order=1, + serialized=SERIALIZED_CHAT, + prompts=["Human: "], + response=LLMResult(generations=[[]]), + session_id=TEST_SESSION_ID, + error=None, + ) + for run_manager in run_managers: + run_manager.on_llm_end(response=LLMResult(generations=[[]])) + assert tracer.runs == [compare_run] + + +@freeze_time("2023-01-01") +def test_tracer_llm_run_errors_no_start() -> None: + """Test tracer on an LLM run without a start.""" + tracer = FakeTracer() + + tracer.new_session() + with pytest.raises(TracerException): + tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid4()) + + +@freeze_time("2023-01-01") +def test_tracer_multiple_llm_runs() -> None: + """Test the tracer with multiple runs.""" + uuid = uuid4() + compare_run = LLMRun( + uuid=str(uuid), + parent_uuid=None, + start_time=datetime.now(timezone.utc), + end_time=datetime.now(timezone.utc), + extra={}, + execution_order=1, + child_execution_order=1, + serialized=SERIALIZED, + prompts=[], + response=LLMResult(generations=[[]]), + session_id=TEST_SESSION_ID, + error=None, + ) + tracer = FakeTracer() + + tracer.new_session() + num_runs = 10 + for _ in range(num_runs): + tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid) + tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid) + + assert tracer.runs == [compare_run] * num_runs + + +@freeze_time("2023-01-01") +def test_tracer_chain_run() -> None: + """Test tracer on a Chain run.""" + uuid = uuid4() + compare_run = ChainRun( + uuid=str(uuid), + parent_uuid=None, + start_time=datetime.now(timezone.utc), + end_time=datetime.now(timezone.utc), + extra={}, + execution_order=1, + child_execution_order=1, + serialized={"name": "chain"}, + inputs={}, + outputs={}, + session_id=TEST_SESSION_ID, + error=None, + ) + tracer = FakeTracer() + + tracer.new_session() + tracer.on_chain_start(serialized={"name": "chain"}, inputs={}, run_id=uuid) + tracer.on_chain_end(outputs={}, run_id=uuid) + assert tracer.runs == [compare_run] + + +@freeze_time("2023-01-01") +def test_tracer_tool_run() -> None: + """Test tracer on a Tool run.""" + uuid = uuid4() + compare_run = ToolRun( + uuid=str(uuid), + parent_uuid=None, + start_time=datetime.now(timezone.utc), + end_time=datetime.now(timezone.utc), + extra={}, + execution_order=1, + child_execution_order=1, + serialized={"name": "tool"}, + tool_input="test", + output="test", + action="{'name': 'tool'}", + session_id=TEST_SESSION_ID, + error=None, + ) + tracer = FakeTracer() + + tracer.new_session() + tracer.on_tool_start(serialized={"name": "tool"}, input_str="test", run_id=uuid) + tracer.on_tool_end("test", run_id=uuid) + assert tracer.runs == [compare_run] + + +@freeze_time("2023-01-01") +def test_tracer_nested_run() -> None: + """Test tracer on a nested run.""" + tracer = FakeTracer() + tracer.new_session() + + chain_uuid = uuid4() + tool_uuid = uuid4() + llm_uuid1 = uuid4() + llm_uuid2 = uuid4() + for _ in range(10): + tracer.on_chain_start( + serialized={"name": "chain"}, inputs={}, run_id=chain_uuid + ) + tracer.on_tool_start( + serialized={"name": "tool"}, + input_str="test", + run_id=tool_uuid, + parent_run_id=chain_uuid, + ) + tracer.on_llm_start( + serialized=SERIALIZED, + prompts=[], + run_id=llm_uuid1, + parent_run_id=tool_uuid, + ) + tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1) + tracer.on_tool_end("test", run_id=tool_uuid) + tracer.on_llm_start( + serialized=SERIALIZED, + prompts=[], + run_id=llm_uuid2, + parent_run_id=chain_uuid, + ) + tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid2) + tracer.on_chain_end(outputs={}, run_id=chain_uuid) + + compare_run = ChainRun( + uuid=str(chain_uuid), + error=None, + start_time=datetime.now(timezone.utc), + end_time=datetime.now(timezone.utc), + extra={}, + execution_order=1, + child_execution_order=4, + serialized={"name": "chain"}, + inputs={}, + outputs={}, + session_id=TEST_SESSION_ID, + child_chain_runs=[], + child_tool_runs=[ + ToolRun( + uuid=str(tool_uuid), + parent_uuid=str(chain_uuid), + start_time=datetime.now(timezone.utc), + end_time=datetime.now(timezone.utc), + extra={}, + execution_order=2, + child_execution_order=3, + serialized={"name": "tool"}, + tool_input="test", + output="test", + action="{'name': 'tool'}", + session_id=TEST_SESSION_ID, + error=None, + child_chain_runs=[], + child_tool_runs=[], + child_llm_runs=[ + LLMRun( + uuid=str(llm_uuid1), + parent_uuid=str(tool_uuid), + error=None, + start_time=datetime.now(timezone.utc), + end_time=datetime.now(timezone.utc), + extra={}, + execution_order=3, + child_execution_order=3, + serialized=SERIALIZED, + prompts=[], + response=LLMResult(generations=[[]]), + session_id=TEST_SESSION_ID, + ) + ], + ), + ], + child_llm_runs=[ + LLMRun( + uuid=str(llm_uuid2), + parent_uuid=str(chain_uuid), + error=None, + start_time=datetime.now(timezone.utc), + end_time=datetime.now(timezone.utc), + extra={}, + execution_order=4, + child_execution_order=4, + serialized=SERIALIZED, + prompts=[], + response=LLMResult(generations=[[]]), + session_id=TEST_SESSION_ID, + ), + ], + ) + assert tracer.runs[0] == compare_run + assert tracer.runs == [compare_run] * 10 + + +@freeze_time("2023-01-01") +def test_tracer_llm_run_on_error() -> None: + """Test tracer on an LLM run with an error.""" + exception = Exception("test") + uuid = uuid4() + + compare_run = LLMRun( + uuid=str(uuid), + parent_uuid=None, + start_time=datetime.now(timezone.utc), + end_time=datetime.now(timezone.utc), + extra={}, + execution_order=1, + child_execution_order=1, + serialized=SERIALIZED, + prompts=[], + response=None, + session_id=TEST_SESSION_ID, + error=repr(exception), + ) + tracer = FakeTracer() + + tracer.new_session() + tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid) + tracer.on_llm_error(exception, run_id=uuid) + _compare_run_with_error(tracer.runs[0], compare_run) + + +@freeze_time("2023-01-01") +def test_tracer_chain_run_on_error() -> None: + """Test tracer on a Chain run with an error.""" + exception = Exception("test") + uuid = uuid4() + + compare_run = ChainRun( + uuid=str(uuid), + parent_uuid=None, + start_time=datetime.now(timezone.utc), + end_time=datetime.now(timezone.utc), + extra={}, + execution_order=1, + child_execution_order=1, + serialized={"name": "chain"}, + inputs={}, + outputs=None, + session_id=TEST_SESSION_ID, + error=repr(exception), + ) + tracer = FakeTracer() + + tracer.new_session() + tracer.on_chain_start(serialized={"name": "chain"}, inputs={}, run_id=uuid) + tracer.on_chain_error(exception, run_id=uuid) + _compare_run_with_error(tracer.runs[0], compare_run) + + +@freeze_time("2023-01-01") +def test_tracer_tool_run_on_error() -> None: + """Test tracer on a Tool run with an error.""" + exception = Exception("test") + uuid = uuid4() + + compare_run = ToolRun( + uuid=str(uuid), + parent_uuid=None, + start_time=datetime.now(timezone.utc), + end_time=datetime.now(timezone.utc), + extra={}, + execution_order=1, + child_execution_order=1, + serialized={"name": "tool"}, + tool_input="test", + output=None, + action="{'name': 'tool'}", + session_id=TEST_SESSION_ID, + error=repr(exception), + ) + tracer = FakeTracer() + + tracer.new_session() + tracer.on_tool_start(serialized={"name": "tool"}, input_str="test", run_id=uuid) + tracer.on_tool_error(exception, run_id=uuid) + _compare_run_with_error(tracer.runs[0], compare_run) + + +@pytest.fixture +def sample_tracer_session_v1() -> TracerSessionV1: + return TracerSessionV1(id=2, name="Sample session") + + +@freeze_time("2023-01-01") +def test_convert_run( + lang_chain_tracer_v1: LangChainTracerV1, + sample_tracer_session_v1: TracerSessionV1, +) -> None: + """Test converting a run to a V1 run.""" + llm_run = Run( + id="57a08cc4-73d2-4236-8370-549099d07fad", + name="llm_run", + execution_order=1, + child_execution_order=1, + start_time=datetime.now(timezone.utc), + end_time=datetime.now(timezone.utc), + session_id=TEST_SESSION_ID, + inputs={"prompts": []}, + outputs=LLMResult(generations=[[]]).dict(), + serialized={}, + extra={}, + run_type="llm", + ) + chain_run = Run( + id="57a08cc4-73d2-4236-8371-549099d07fad", + name="chain_run", + execution_order=1, + start_time=datetime.now(timezone.utc), + end_time=datetime.now(timezone.utc), + child_execution_order=1, + serialized={}, + inputs={}, + outputs={}, + child_runs=[llm_run], + extra={}, + run_type="chain", + ) + + tool_run = Run( + id="57a08cc4-73d2-4236-8372-549099d07fad", + name="tool_run", + execution_order=1, + child_execution_order=1, + inputs={"input": "test"}, + start_time=datetime.now(timezone.utc), + end_time=datetime.now(timezone.utc), + outputs=None, + serialized={}, + child_runs=[], + extra={}, + run_type="tool", + ) + + expected_llm_run = LLMRun( + uuid="57a08cc4-73d2-4236-8370-549099d07fad", + name="llm_run", + execution_order=1, + child_execution_order=1, + start_time=datetime.now(timezone.utc), + end_time=datetime.now(timezone.utc), + session_id=2, + prompts=[], + response=LLMResult(generations=[[]]), + serialized={}, + extra={}, + ) + + expected_chain_run = ChainRun( + uuid="57a08cc4-73d2-4236-8371-549099d07fad", + name="chain_run", + execution_order=1, + child_execution_order=1, + start_time=datetime.now(timezone.utc), + end_time=datetime.now(timezone.utc), + session_id=2, + serialized={}, + inputs={}, + outputs={}, + child_llm_runs=[expected_llm_run], + child_chain_runs=[], + child_tool_runs=[], + extra={}, + ) + expected_tool_run = ToolRun( + uuid="57a08cc4-73d2-4236-8372-549099d07fad", + name="tool_run", + execution_order=1, + child_execution_order=1, + session_id=2, + start_time=datetime.now(timezone.utc), + end_time=datetime.now(timezone.utc), + tool_input="test", + action="{}", + serialized={}, + child_llm_runs=[], + child_chain_runs=[], + child_tool_runs=[], + extra={}, + ) + lang_chain_tracer_v1.session = sample_tracer_session_v1 + converted_llm_run = lang_chain_tracer_v1._convert_to_v1_run(llm_run) + converted_chain_run = lang_chain_tracer_v1._convert_to_v1_run(chain_run) + converted_tool_run = lang_chain_tracer_v1._convert_to_v1_run(tool_run) + + assert isinstance(converted_llm_run, LLMRun) + assert isinstance(converted_chain_run, ChainRun) + assert isinstance(converted_tool_run, ToolRun) + assert converted_llm_run == expected_llm_run + assert converted_tool_run == expected_tool_run + assert converted_chain_run == expected_chain_run diff --git a/libs/core/tests/unit_tests/callbacks/tracers/test_schemas.py b/libs/core/tests/unit_tests/tracers/test_schemas.py similarity index 67% rename from libs/core/tests/unit_tests/callbacks/tracers/test_schemas.py rename to libs/core/tests/unit_tests/tracers/test_schemas.py index 90ab8307ae8..a452b158752 100644 --- a/libs/core/tests/unit_tests/callbacks/tracers/test_schemas.py +++ b/libs/core/tests/unit_tests/tracers/test_schemas.py @@ -5,7 +5,17 @@ from langchain_core.tracers.schemas import __all__ as schemas_all def test_public_api() -> None: """Test for changes in the public API.""" expected_all = [ + "BaseRun", + "ChainRun", + "LLMRun", "Run", + "RunTypeEnum", + "ToolRun", + "TracerSession", + "TracerSessionBase", + "TracerSessionV1", + "TracerSessionV1Base", + "TracerSessionV1Create", ] assert sorted(schemas_all) == expected_all diff --git a/libs/langchain/langchain/callbacks/__init__.py b/libs/langchain/langchain/callbacks/__init__.py index fe34d21f562..66adf85a69e 100644 --- a/libs/langchain/langchain/callbacks/__init__.py +++ b/libs/langchain/langchain/callbacks/__init__.py @@ -44,6 +44,7 @@ from langchain_core.callbacks import ( ) from langchain_core.tracers.context import ( collect_runs, + tracing_enabled, tracing_v2_enabled, ) from langchain_core.tracers.langchain import LangChainTracer @@ -79,6 +80,7 @@ __all__ = [ "WandbCallbackHandler", "WhyLabsCallbackHandler", "get_openai_callback", + "tracing_enabled", "tracing_v2_enabled", "collect_runs", "wandb_tracing_enabled", diff --git a/libs/langchain/langchain/callbacks/manager.py b/libs/langchain/langchain/callbacks/manager.py index f9477684ce7..8343c29977e 100644 --- a/libs/langchain/langchain/callbacks/manager.py +++ b/libs/langchain/langchain/callbacks/manager.py @@ -30,6 +30,7 @@ from langchain_core.callbacks.manager import ( ) from langchain_core.tracers.context import ( collect_runs, + tracing_enabled, tracing_v2_enabled, ) from langchain_core.utils.env import env_var_is_set @@ -52,6 +53,7 @@ __all__ = [ "CallbackManagerForChainGroup", "AsyncCallbackManager", "AsyncCallbackManagerForChainGroup", + "tracing_enabled", "tracing_v2_enabled", "collect_runs", "atrace_as_chain_group", diff --git a/libs/langchain/langchain/callbacks/tracers/__init__.py b/libs/langchain/langchain/callbacks/tracers/__init__.py index 5cedf7b3bfd..1f470e6bb84 100644 --- a/libs/langchain/langchain/callbacks/tracers/__init__.py +++ b/libs/langchain/langchain/callbacks/tracers/__init__.py @@ -1,6 +1,7 @@ """Tracers that record execution of LangChain runs.""" from langchain_core.tracers.langchain import LangChainTracer +from langchain_core.tracers.langchain_v1 import LangChainTracerV1 from langchain_core.tracers.stdout import ( ConsoleCallbackHandler, FunctionCallbackHandler, @@ -14,5 +15,6 @@ __all__ = [ "FunctionCallbackHandler", "LoggingCallbackHandler", "LangChainTracer", + "LangChainTracerV1", "WandbTracer", ] diff --git a/libs/langchain/langchain/callbacks/tracers/langchain_v1.py b/libs/langchain/langchain/callbacks/tracers/langchain_v1.py new file mode 100644 index 00000000000..a12b47401f7 --- /dev/null +++ b/libs/langchain/langchain/callbacks/tracers/langchain_v1.py @@ -0,0 +1,3 @@ +from langchain_core.tracers.langchain_v1 import LangChainTracerV1 + +__all__ = ["LangChainTracerV1"] diff --git a/libs/langchain/langchain/callbacks/tracers/schemas.py b/libs/langchain/langchain/callbacks/tracers/schemas.py index 725755d0ee8..e8f34027d34 100644 --- a/libs/langchain/langchain/callbacks/tracers/schemas.py +++ b/libs/langchain/langchain/callbacks/tracers/schemas.py @@ -1,7 +1,27 @@ from langchain_core.tracers.schemas import ( + BaseRun, + ChainRun, + LLMRun, Run, + RunTypeEnum, + ToolRun, + TracerSession, + TracerSessionBase, + TracerSessionV1, + TracerSessionV1Base, + TracerSessionV1Create, ) __all__ = [ + "BaseRun", + "ChainRun", + "LLMRun", "Run", + "RunTypeEnum", + "ToolRun", + "TracerSession", + "TracerSessionBase", + "TracerSessionV1", + "TracerSessionV1Base", + "TracerSessionV1Create", ] diff --git a/libs/langchain/langchain/schema/callbacks/manager.py b/libs/langchain/langchain/schema/callbacks/manager.py index 5754feeb546..a459e1bb697 100644 --- a/libs/langchain/langchain/schema/callbacks/manager.py +++ b/libs/langchain/langchain/schema/callbacks/manager.py @@ -22,11 +22,13 @@ from langchain_core.callbacks.manager import ( from langchain_core.tracers.context import ( collect_runs, register_configure_hook, + tracing_enabled, tracing_v2_enabled, ) from langchain_core.utils.env import env_var_is_set __all__ = [ + "tracing_enabled", "tracing_v2_enabled", "collect_runs", "trace_as_chain_group", diff --git a/libs/langchain/langchain/schema/callbacks/tracers/langchain_v1.py b/libs/langchain/langchain/schema/callbacks/tracers/langchain_v1.py new file mode 100644 index 00000000000..fca2d7590f1 --- /dev/null +++ b/libs/langchain/langchain/schema/callbacks/tracers/langchain_v1.py @@ -0,0 +1,3 @@ +from langchain_core.tracers.langchain_v1 import LangChainTracerV1, get_headers + +__all__ = ["get_headers", "LangChainTracerV1"] diff --git a/libs/langchain/langchain/schema/callbacks/tracers/schemas.py b/libs/langchain/langchain/schema/callbacks/tracers/schemas.py index 725755d0ee8..6fb49dbf724 100644 --- a/libs/langchain/langchain/schema/callbacks/tracers/schemas.py +++ b/libs/langchain/langchain/schema/callbacks/tracers/schemas.py @@ -1,7 +1,27 @@ from langchain_core.tracers.schemas import ( + BaseRun, + ChainRun, + LLMRun, Run, + RunTypeEnum, + ToolRun, + TracerSession, + TracerSessionBase, + TracerSessionV1, + TracerSessionV1Base, + TracerSessionV1Create, ) __all__ = [ + "RunTypeEnum", + "TracerSessionV1Base", + "TracerSessionV1Create", + "TracerSessionV1", + "TracerSessionBase", + "TracerSession", + "BaseRun", + "LLMRun", + "ChainRun", + "ToolRun", "Run", ] diff --git a/libs/langchain/tests/unit_tests/callbacks/test_imports.py b/libs/langchain/tests/unit_tests/callbacks/test_imports.py index d6f1784e4dd..b7adb053e7f 100644 --- a/libs/langchain/tests/unit_tests/callbacks/test_imports.py +++ b/libs/langchain/tests/unit_tests/callbacks/test_imports.py @@ -25,6 +25,7 @@ EXPECTED_ALL = [ "WandbCallbackHandler", "WhyLabsCallbackHandler", "get_openai_callback", + "tracing_enabled", "tracing_v2_enabled", "collect_runs", "wandb_tracing_enabled", diff --git a/libs/langchain/tests/unit_tests/callbacks/test_manager.py b/libs/langchain/tests/unit_tests/callbacks/test_manager.py index e515eb61a10..8ee369e81fd 100644 --- a/libs/langchain/tests/unit_tests/callbacks/test_manager.py +++ b/libs/langchain/tests/unit_tests/callbacks/test_manager.py @@ -18,6 +18,7 @@ EXPECTED_ALL = [ "CallbackManagerForChainGroup", "AsyncCallbackManager", "AsyncCallbackManagerForChainGroup", + "tracing_enabled", "tracing_v2_enabled", "collect_runs", "atrace_as_chain_group",