mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-13 21:47:12 +00:00
[Breaking] Refactor Base Tracer(#4549)
### Refactor the BaseTracer - Remove the 'session' abstraction from the BaseTracer - Rename 'RunV2' object(s) to be called 'Run' objects (Rename previous Run objects to be RunV1 objects) - Ditto for sessions: TracerSession*V2 -> TracerSession* - Remove now deprecated conversion from v1 run objects to v2 run objects in LangChainTracerV2 - Add conversion from v2 run objects to v1 run objects in V1 tracer
This commit is contained in:
464
tests/unit_tests/callbacks/tracers/test_base_tracer.py
Normal file
464
tests/unit_tests/callbacks/tracers/test_base_tracer.py
Normal file
@@ -0,0 +1,464 @@
|
||||
"""Test Tracer classes."""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from freezegun import freeze_time
|
||||
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
from langchain.callbacks.tracers.base import BaseTracer, TracerException
|
||||
from langchain.callbacks.tracers.schemas import Run
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
|
||||
class FakeTracer(BaseTracer):
|
||||
"""Fake tracer that records LangChain execution."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the tracer."""
|
||||
super().__init__()
|
||||
self.runs: List[Run] = []
|
||||
|
||||
def _persist_run(self, run: Run) -> None:
|
||||
"""Persist a run."""
|
||||
self.runs.append(run)
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_llm_run() -> None:
|
||||
"""Test tracer on an LLM run."""
|
||||
uuid = uuid4()
|
||||
compare_run = Run(
|
||||
id=uuid,
|
||||
parent_run_id=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={"name": "llm"},
|
||||
inputs={"prompts": []},
|
||||
outputs=LLMResult(generations=[[]]),
|
||||
error=None,
|
||||
run_type="llm",
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid)
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_chat_model_run() -> None:
|
||||
"""Test tracer on a Chat Model run."""
|
||||
uuid = uuid4()
|
||||
compare_run = Run(
|
||||
id=str(uuid),
|
||||
name="chat_model",
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={"name": "chat_model"},
|
||||
inputs=dict(prompts=[""]),
|
||||
outputs=LLMResult(generations=[[]]),
|
||||
error=None,
|
||||
run_type="llm",
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
manager = CallbackManager(handlers=[tracer])
|
||||
run_manager = manager.on_chat_model_start(
|
||||
serialized={"name": "chat_model"}, messages=[[]], run_id=uuid
|
||||
)
|
||||
run_manager.on_llm_end(response=LLMResult(generations=[[]]))
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_llm_run_errors_no_start() -> None:
|
||||
"""Test tracer on an LLM run without a start."""
|
||||
tracer = FakeTracer()
|
||||
|
||||
with pytest.raises(TracerException):
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid4())
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_multiple_llm_runs() -> None:
|
||||
"""Test the tracer with multiple runs."""
|
||||
uuid = uuid4()
|
||||
compare_run = Run(
|
||||
id=uuid,
|
||||
name="llm",
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={"name": "llm"},
|
||||
inputs=dict(prompts=[]),
|
||||
outputs=LLMResult(generations=[[]]),
|
||||
error=None,
|
||||
run_type="llm",
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
num_runs = 10
|
||||
for _ in range(num_runs):
|
||||
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid)
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
|
||||
|
||||
assert tracer.runs == [compare_run] * num_runs
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_chain_run() -> None:
|
||||
"""Test tracer on a Chain run."""
|
||||
uuid = uuid4()
|
||||
compare_run = Run(
|
||||
id=str(uuid),
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={"name": "chain"},
|
||||
inputs={},
|
||||
outputs={},
|
||||
error=None,
|
||||
run_type="chain",
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.on_chain_start(serialized={"name": "chain"}, inputs={}, run_id=uuid)
|
||||
tracer.on_chain_end(outputs={}, run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_tool_run() -> None:
|
||||
"""Test tracer on a Tool run."""
|
||||
uuid = uuid4()
|
||||
compare_run = Run(
|
||||
id=str(uuid),
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={"name": "tool"},
|
||||
inputs={"input": "test"},
|
||||
outputs={"output": "test"},
|
||||
error=None,
|
||||
run_type="tool",
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
tracer.on_tool_start(serialized={"name": "tool"}, input_str="test", run_id=uuid)
|
||||
tracer.on_tool_end("test", run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_nested_run() -> None:
|
||||
"""Test tracer on a nested run."""
|
||||
tracer = FakeTracer()
|
||||
|
||||
chain_uuid = uuid4()
|
||||
tool_uuid = uuid4()
|
||||
llm_uuid1 = uuid4()
|
||||
llm_uuid2 = uuid4()
|
||||
for _ in range(10):
|
||||
tracer.on_chain_start(
|
||||
serialized={"name": "chain"}, inputs={}, run_id=chain_uuid
|
||||
)
|
||||
tracer.on_tool_start(
|
||||
serialized={"name": "tool"},
|
||||
input_str="test",
|
||||
run_id=tool_uuid,
|
||||
parent_run_id=chain_uuid,
|
||||
)
|
||||
tracer.on_llm_start(
|
||||
serialized={"name": "llm"},
|
||||
prompts=[],
|
||||
run_id=llm_uuid1,
|
||||
parent_run_id=tool_uuid,
|
||||
)
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
|
||||
tracer.on_tool_end("test", run_id=tool_uuid)
|
||||
tracer.on_llm_start(
|
||||
serialized={"name": "llm"},
|
||||
prompts=[],
|
||||
run_id=llm_uuid2,
|
||||
parent_run_id=chain_uuid,
|
||||
)
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid2)
|
||||
tracer.on_chain_end(outputs={}, run_id=chain_uuid)
|
||||
|
||||
compare_run = Run(
|
||||
id=str(chain_uuid),
|
||||
error=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=4,
|
||||
serialized={"name": "chain"},
|
||||
inputs={},
|
||||
outputs={},
|
||||
run_type="chain",
|
||||
child_runs=[
|
||||
Run(
|
||||
id=tool_uuid,
|
||||
parent_run_id=chain_uuid,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=2,
|
||||
child_execution_order=3,
|
||||
serialized={"name": "tool"},
|
||||
inputs=dict(input="test"),
|
||||
outputs=dict(output="test"),
|
||||
error=None,
|
||||
run_type="tool",
|
||||
child_runs=[
|
||||
Run(
|
||||
id=str(llm_uuid1),
|
||||
parent_run_id=str(tool_uuid),
|
||||
error=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=3,
|
||||
child_execution_order=3,
|
||||
serialized={"name": "llm"},
|
||||
inputs=dict(prompts=[]),
|
||||
outputs=LLMResult(generations=[[]]),
|
||||
run_type="llm",
|
||||
)
|
||||
],
|
||||
),
|
||||
Run(
|
||||
id=str(llm_uuid2),
|
||||
parent_run_id=str(chain_uuid),
|
||||
error=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=4,
|
||||
child_execution_order=4,
|
||||
serialized={"name": "llm"},
|
||||
inputs=dict(prompts=[]),
|
||||
outputs=LLMResult(generations=[[]]),
|
||||
run_type="llm",
|
||||
),
|
||||
],
|
||||
)
|
||||
assert tracer.runs[0] == compare_run
|
||||
assert tracer.runs == [compare_run] * 10
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_llm_run_on_error() -> None:
|
||||
"""Test tracer on an LLM run with an error."""
|
||||
exception = Exception("test")
|
||||
uuid = uuid4()
|
||||
|
||||
compare_run = Run(
|
||||
id=str(uuid),
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={"name": "llm"},
|
||||
inputs=dict(prompts=[]),
|
||||
outputs=None,
|
||||
error=repr(exception),
|
||||
run_type="llm",
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid)
|
||||
tracer.on_llm_error(exception, run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_chain_run_on_error() -> None:
|
||||
"""Test tracer on a Chain run with an error."""
|
||||
exception = Exception("test")
|
||||
uuid = uuid4()
|
||||
|
||||
compare_run = Run(
|
||||
id=str(uuid),
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={"name": "chain"},
|
||||
inputs={},
|
||||
outputs=None,
|
||||
error=repr(exception),
|
||||
run_type="chain",
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.on_chain_start(serialized={"name": "chain"}, inputs={}, run_id=uuid)
|
||||
tracer.on_chain_error(exception, run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_tool_run_on_error() -> None:
|
||||
"""Test tracer on a Tool run with an error."""
|
||||
exception = Exception("test")
|
||||
uuid = uuid4()
|
||||
|
||||
compare_run = Run(
|
||||
id=str(uuid),
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={"name": "tool"},
|
||||
inputs=dict(input="test"),
|
||||
outputs=None,
|
||||
action="{'name': 'tool'}",
|
||||
error=repr(exception),
|
||||
run_type="tool",
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.on_tool_start(serialized={"name": "tool"}, input_str="test", run_id=uuid)
|
||||
tracer.on_tool_error(exception, run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_nested_runs_on_error() -> None:
|
||||
"""Test tracer on a nested run with an error."""
|
||||
exception = Exception("test")
|
||||
|
||||
tracer = FakeTracer()
|
||||
chain_uuid = uuid4()
|
||||
tool_uuid = uuid4()
|
||||
llm_uuid1 = uuid4()
|
||||
llm_uuid2 = uuid4()
|
||||
llm_uuid3 = uuid4()
|
||||
|
||||
for _ in range(3):
|
||||
tracer.on_chain_start(
|
||||
serialized={"name": "chain"}, inputs={}, run_id=chain_uuid
|
||||
)
|
||||
tracer.on_llm_start(
|
||||
serialized={"name": "llm"},
|
||||
prompts=[],
|
||||
run_id=llm_uuid1,
|
||||
parent_run_id=chain_uuid,
|
||||
)
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
|
||||
tracer.on_llm_start(
|
||||
serialized={"name": "llm"},
|
||||
prompts=[],
|
||||
run_id=llm_uuid2,
|
||||
parent_run_id=chain_uuid,
|
||||
)
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid2)
|
||||
tracer.on_tool_start(
|
||||
serialized={"name": "tool"},
|
||||
input_str="test",
|
||||
run_id=tool_uuid,
|
||||
parent_run_id=chain_uuid,
|
||||
)
|
||||
tracer.on_llm_start(
|
||||
serialized={"name": "llm"},
|
||||
prompts=[],
|
||||
run_id=llm_uuid3,
|
||||
parent_run_id=tool_uuid,
|
||||
)
|
||||
tracer.on_llm_error(exception, run_id=llm_uuid3)
|
||||
tracer.on_tool_error(exception, run_id=tool_uuid)
|
||||
tracer.on_chain_error(exception, run_id=chain_uuid)
|
||||
|
||||
compare_run = Run(
|
||||
id=str(chain_uuid),
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=5,
|
||||
serialized={"name": "chain"},
|
||||
error=repr(exception),
|
||||
inputs={},
|
||||
outputs=None,
|
||||
run_type="chain",
|
||||
child_runs=[
|
||||
Run(
|
||||
id=str(llm_uuid1),
|
||||
parent_run_id=str(chain_uuid),
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=2,
|
||||
child_execution_order=2,
|
||||
serialized={"name": "llm"},
|
||||
error=None,
|
||||
inputs=dict(prompts=[]),
|
||||
outputs=LLMResult(generations=[[]], llm_output=None),
|
||||
run_type="llm",
|
||||
),
|
||||
Run(
|
||||
id=str(llm_uuid2),
|
||||
parent_run_id=str(chain_uuid),
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=3,
|
||||
child_execution_order=3,
|
||||
serialized={"name": "llm"},
|
||||
error=None,
|
||||
inputs=dict(prompts=[]),
|
||||
outputs=LLMResult(generations=[[]], llm_output=None),
|
||||
run_type="llm",
|
||||
),
|
||||
Run(
|
||||
id=str(tool_uuid),
|
||||
parent_run_id=str(chain_uuid),
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=4,
|
||||
child_execution_order=5,
|
||||
serialized={"name": "tool"},
|
||||
error=repr(exception),
|
||||
inputs=dict(input="test"),
|
||||
outputs=None,
|
||||
action="{'name': 'tool'}",
|
||||
child_runs=[
|
||||
Run(
|
||||
id=str(llm_uuid3),
|
||||
parent_run_id=str(tool_uuid),
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=5,
|
||||
child_execution_order=5,
|
||||
serialized={"name": "llm"},
|
||||
error=repr(exception),
|
||||
inputs=dict(prompts=[]),
|
||||
outputs=None,
|
||||
run_type="llm",
|
||||
)
|
||||
],
|
||||
run_type="tool",
|
||||
),
|
||||
],
|
||||
)
|
||||
assert tracer.runs == [compare_run] * 3
|
676
tests/unit_tests/callbacks/tracers/test_langchain_v1.py
Normal file
676
tests/unit_tests/callbacks/tracers/test_langchain_v1.py
Normal file
@@ -0,0 +1,676 @@
|
||||
"""Test Tracer classes."""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from freezegun import freeze_time
|
||||
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
from langchain.callbacks.tracers.base import BaseTracer, TracerException
|
||||
from langchain.callbacks.tracers.langchain_v1 import (
|
||||
ChainRun,
|
||||
LangChainTracerV1,
|
||||
LLMRun,
|
||||
ToolRun,
|
||||
TracerSessionV1,
|
||||
)
|
||||
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum, TracerSessionV1Base
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
TEST_SESSION_ID = 2023
|
||||
|
||||
|
||||
def load_session(session_name: str) -> TracerSessionV1:
|
||||
"""Load a tracing session."""
|
||||
return TracerSessionV1(
|
||||
id=TEST_SESSION_ID, name=session_name, start_time=datetime.utcnow()
|
||||
)
|
||||
|
||||
|
||||
def new_session(name: Optional[str] = None) -> TracerSessionV1:
|
||||
"""Create a new tracing session."""
|
||||
return TracerSessionV1(
|
||||
id=TEST_SESSION_ID, name=name or "default", start_time=datetime.utcnow()
|
||||
)
|
||||
|
||||
|
||||
def _persist_session(session: TracerSessionV1Base) -> TracerSessionV1:
|
||||
"""Persist a tracing session."""
|
||||
return TracerSessionV1(**{**session.dict(), "id": TEST_SESSION_ID})
|
||||
|
||||
|
||||
def load_default_session() -> TracerSessionV1:
|
||||
"""Load a tracing session."""
|
||||
return TracerSessionV1(
|
||||
id=TEST_SESSION_ID, name="default", start_time=datetime.utcnow()
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def lang_chain_tracer_v1(monkeypatch: pytest.MonkeyPatch) -> LangChainTracerV1:
|
||||
monkeypatch.setenv("LANGCHAIN_TENANT_ID", "test-tenant-id")
|
||||
monkeypatch.setenv("LANGCHAIN_ENDPOINT", "http://test-endpoint.com")
|
||||
monkeypatch.setenv("LANGCHAIN_API_KEY", "foo")
|
||||
tracer = LangChainTracerV1()
|
||||
return tracer
|
||||
|
||||
|
||||
class FakeTracer(BaseTracer):
|
||||
"""Fake tracer that records LangChain execution."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the tracer."""
|
||||
super().__init__()
|
||||
self.runs: List[Union[LLMRun, ChainRun, ToolRun]] = []
|
||||
|
||||
def _persist_run(self, run: Union[Run, LLMRun, ChainRun, ToolRun]) -> None:
|
||||
"""Persist a run."""
|
||||
if isinstance(run, Run):
|
||||
with pytest.MonkeyPatch().context() as m:
|
||||
m.setenv("LANGCHAIN_TENANT_ID", "test-tenant-id")
|
||||
m.setenv("LANGCHAIN_ENDPOINT", "http://test-endpoint.com")
|
||||
m.setenv("LANGCHAIN_API_KEY", "foo")
|
||||
tracer = LangChainTracerV1()
|
||||
tracer.load_default_session = load_default_session # type: ignore
|
||||
run = tracer._convert_to_v1_run(run)
|
||||
self.runs.append(run)
|
||||
|
||||
def _persist_session(self, session: TracerSessionV1Base) -> TracerSessionV1:
|
||||
"""Persist a tracing session."""
|
||||
return _persist_session(session)
|
||||
|
||||
def new_session(self, name: Optional[str] = None) -> TracerSessionV1:
|
||||
"""Create a new tracing session."""
|
||||
return new_session(name)
|
||||
|
||||
def load_session(self, session_name: str) -> TracerSessionV1:
|
||||
"""Load a tracing session."""
|
||||
return load_session(session_name)
|
||||
|
||||
def load_default_session(self) -> TracerSessionV1:
|
||||
"""Load a tracing session."""
|
||||
return load_default_session()
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_llm_run() -> None:
|
||||
"""Test tracer on an LLM run."""
|
||||
uuid = uuid4()
|
||||
compare_run = LLMRun(
|
||||
uuid=str(uuid),
|
||||
parent_uuid=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={"name": "llm"},
|
||||
prompts=[],
|
||||
response=LLMResult(generations=[[]]),
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid)
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_chat_model_run() -> None:
|
||||
"""Test tracer on a Chat Model run."""
|
||||
uuid = uuid4()
|
||||
compare_run = LLMRun(
|
||||
uuid=str(uuid),
|
||||
parent_uuid=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={"name": "chat_model"},
|
||||
prompts=[""],
|
||||
response=LLMResult(generations=[[]]),
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
manager = CallbackManager(handlers=[tracer])
|
||||
run_manager = manager.on_chat_model_start(
|
||||
serialized={"name": "chat_model"}, messages=[[]], run_id=uuid
|
||||
)
|
||||
run_manager.on_llm_end(response=LLMResult(generations=[[]]))
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_llm_run_errors_no_start() -> None:
|
||||
"""Test tracer on an LLM run without a start."""
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
with pytest.raises(TracerException):
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid4())
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_multiple_llm_runs() -> None:
|
||||
"""Test the tracer with multiple runs."""
|
||||
uuid = uuid4()
|
||||
compare_run = LLMRun(
|
||||
uuid=str(uuid),
|
||||
parent_uuid=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={"name": "llm"},
|
||||
prompts=[],
|
||||
response=LLMResult(generations=[[]]),
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
num_runs = 10
|
||||
for _ in range(num_runs):
|
||||
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid)
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
|
||||
|
||||
assert tracer.runs == [compare_run] * num_runs
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_chain_run() -> None:
|
||||
"""Test tracer on a Chain run."""
|
||||
uuid = uuid4()
|
||||
compare_run = ChainRun(
|
||||
uuid=str(uuid),
|
||||
parent_uuid=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={"name": "chain"},
|
||||
inputs={},
|
||||
outputs={},
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
tracer.on_chain_start(serialized={"name": "chain"}, inputs={}, run_id=uuid)
|
||||
tracer.on_chain_end(outputs={}, run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_tool_run() -> None:
|
||||
"""Test tracer on a Tool run."""
|
||||
uuid = uuid4()
|
||||
compare_run = ToolRun(
|
||||
uuid=str(uuid),
|
||||
parent_uuid=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={"name": "tool"},
|
||||
tool_input="test",
|
||||
output="test",
|
||||
action="{'name': 'tool'}",
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
tracer.on_tool_start(serialized={"name": "tool"}, input_str="test", run_id=uuid)
|
||||
tracer.on_tool_end("test", run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_nested_run() -> None:
|
||||
"""Test tracer on a nested run."""
|
||||
tracer = FakeTracer()
|
||||
tracer.new_session()
|
||||
|
||||
chain_uuid = uuid4()
|
||||
tool_uuid = uuid4()
|
||||
llm_uuid1 = uuid4()
|
||||
llm_uuid2 = uuid4()
|
||||
for _ in range(10):
|
||||
tracer.on_chain_start(
|
||||
serialized={"name": "chain"}, inputs={}, run_id=chain_uuid
|
||||
)
|
||||
tracer.on_tool_start(
|
||||
serialized={"name": "tool"},
|
||||
input_str="test",
|
||||
run_id=tool_uuid,
|
||||
parent_run_id=chain_uuid,
|
||||
)
|
||||
tracer.on_llm_start(
|
||||
serialized={"name": "llm"},
|
||||
prompts=[],
|
||||
run_id=llm_uuid1,
|
||||
parent_run_id=tool_uuid,
|
||||
)
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
|
||||
tracer.on_tool_end("test", run_id=tool_uuid)
|
||||
tracer.on_llm_start(
|
||||
serialized={"name": "llm"},
|
||||
prompts=[],
|
||||
run_id=llm_uuid2,
|
||||
parent_run_id=chain_uuid,
|
||||
)
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid2)
|
||||
tracer.on_chain_end(outputs={}, run_id=chain_uuid)
|
||||
|
||||
compare_run = ChainRun(
|
||||
uuid=str(chain_uuid),
|
||||
error=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=4,
|
||||
serialized={"name": "chain"},
|
||||
inputs={},
|
||||
outputs={},
|
||||
session_id=TEST_SESSION_ID,
|
||||
child_chain_runs=[],
|
||||
child_tool_runs=[
|
||||
ToolRun(
|
||||
uuid=str(tool_uuid),
|
||||
parent_uuid=str(chain_uuid),
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=2,
|
||||
child_execution_order=3,
|
||||
serialized={"name": "tool"},
|
||||
tool_input="test",
|
||||
output="test",
|
||||
action="{'name': 'tool'}",
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
child_chain_runs=[],
|
||||
child_tool_runs=[],
|
||||
child_llm_runs=[
|
||||
LLMRun(
|
||||
uuid=str(llm_uuid1),
|
||||
parent_uuid=str(tool_uuid),
|
||||
error=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=3,
|
||||
child_execution_order=3,
|
||||
serialized={"name": "llm"},
|
||||
prompts=[],
|
||||
response=LLMResult(generations=[[]]),
|
||||
session_id=TEST_SESSION_ID,
|
||||
)
|
||||
],
|
||||
),
|
||||
],
|
||||
child_llm_runs=[
|
||||
LLMRun(
|
||||
uuid=str(llm_uuid2),
|
||||
parent_uuid=str(chain_uuid),
|
||||
error=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=4,
|
||||
child_execution_order=4,
|
||||
serialized={"name": "llm"},
|
||||
prompts=[],
|
||||
response=LLMResult(generations=[[]]),
|
||||
session_id=TEST_SESSION_ID,
|
||||
),
|
||||
],
|
||||
)
|
||||
assert tracer.runs[0] == compare_run
|
||||
assert tracer.runs == [compare_run] * 10
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_llm_run_on_error() -> None:
|
||||
"""Test tracer on an LLM run with an error."""
|
||||
exception = Exception("test")
|
||||
uuid = uuid4()
|
||||
|
||||
compare_run = LLMRun(
|
||||
uuid=str(uuid),
|
||||
parent_uuid=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={"name": "llm"},
|
||||
prompts=[],
|
||||
response=None,
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=repr(exception),
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid)
|
||||
tracer.on_llm_error(exception, run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_chain_run_on_error() -> None:
|
||||
"""Test tracer on a Chain run with an error."""
|
||||
exception = Exception("test")
|
||||
uuid = uuid4()
|
||||
|
||||
compare_run = ChainRun(
|
||||
uuid=str(uuid),
|
||||
parent_uuid=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={"name": "chain"},
|
||||
inputs={},
|
||||
outputs=None,
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=repr(exception),
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
tracer.on_chain_start(serialized={"name": "chain"}, inputs={}, run_id=uuid)
|
||||
tracer.on_chain_error(exception, run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_tool_run_on_error() -> None:
|
||||
"""Test tracer on a Tool run with an error."""
|
||||
exception = Exception("test")
|
||||
uuid = uuid4()
|
||||
|
||||
compare_run = ToolRun(
|
||||
uuid=str(uuid),
|
||||
parent_uuid=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={"name": "tool"},
|
||||
tool_input="test",
|
||||
output=None,
|
||||
action="{'name': 'tool'}",
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=repr(exception),
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
tracer.on_tool_start(serialized={"name": "tool"}, input_str="test", run_id=uuid)
|
||||
tracer.on_tool_error(exception, run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_nested_runs_on_error() -> None:
|
||||
"""Test tracer on a nested run with an error."""
|
||||
exception = Exception("test")
|
||||
|
||||
tracer = FakeTracer()
|
||||
tracer.new_session()
|
||||
chain_uuid = uuid4()
|
||||
tool_uuid = uuid4()
|
||||
llm_uuid1 = uuid4()
|
||||
llm_uuid2 = uuid4()
|
||||
llm_uuid3 = uuid4()
|
||||
|
||||
for _ in range(3):
|
||||
tracer.on_chain_start(
|
||||
serialized={"name": "chain"}, inputs={}, run_id=chain_uuid
|
||||
)
|
||||
tracer.on_llm_start(
|
||||
serialized={"name": "llm"},
|
||||
prompts=[],
|
||||
run_id=llm_uuid1,
|
||||
parent_run_id=chain_uuid,
|
||||
)
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
|
||||
tracer.on_llm_start(
|
||||
serialized={"name": "llm"},
|
||||
prompts=[],
|
||||
run_id=llm_uuid2,
|
||||
parent_run_id=chain_uuid,
|
||||
)
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid2)
|
||||
tracer.on_tool_start(
|
||||
serialized={"name": "tool"},
|
||||
input_str="test",
|
||||
run_id=tool_uuid,
|
||||
parent_run_id=chain_uuid,
|
||||
)
|
||||
tracer.on_llm_start(
|
||||
serialized={"name": "llm"},
|
||||
prompts=[],
|
||||
run_id=llm_uuid3,
|
||||
parent_run_id=tool_uuid,
|
||||
)
|
||||
tracer.on_llm_error(exception, run_id=llm_uuid3)
|
||||
tracer.on_tool_error(exception, run_id=tool_uuid)
|
||||
tracer.on_chain_error(exception, run_id=chain_uuid)
|
||||
|
||||
compare_run = ChainRun(
|
||||
uuid=str(chain_uuid),
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=5,
|
||||
serialized={"name": "chain"},
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=repr(exception),
|
||||
inputs={},
|
||||
outputs=None,
|
||||
child_llm_runs=[
|
||||
LLMRun(
|
||||
uuid=str(llm_uuid1),
|
||||
parent_uuid=str(chain_uuid),
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=2,
|
||||
child_execution_order=2,
|
||||
serialized={"name": "llm"},
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
prompts=[],
|
||||
response=LLMResult(generations=[[]], llm_output=None),
|
||||
),
|
||||
LLMRun(
|
||||
uuid=str(llm_uuid2),
|
||||
parent_uuid=str(chain_uuid),
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=3,
|
||||
child_execution_order=3,
|
||||
serialized={"name": "llm"},
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
prompts=[],
|
||||
response=LLMResult(generations=[[]], llm_output=None),
|
||||
),
|
||||
],
|
||||
child_chain_runs=[],
|
||||
child_tool_runs=[
|
||||
ToolRun(
|
||||
uuid=str(tool_uuid),
|
||||
parent_uuid=str(chain_uuid),
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=4,
|
||||
child_execution_order=5,
|
||||
serialized={"name": "tool"},
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=repr(exception),
|
||||
tool_input="test",
|
||||
output=None,
|
||||
action="{'name': 'tool'}",
|
||||
child_llm_runs=[
|
||||
LLMRun(
|
||||
uuid=str(llm_uuid3),
|
||||
parent_uuid=str(tool_uuid),
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=5,
|
||||
child_execution_order=5,
|
||||
serialized={"name": "llm"},
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=repr(exception),
|
||||
prompts=[],
|
||||
response=None,
|
||||
)
|
||||
],
|
||||
child_chain_runs=[],
|
||||
child_tool_runs=[],
|
||||
),
|
||||
],
|
||||
)
|
||||
assert tracer.runs == [compare_run] * 3
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_tracer_session_v1() -> TracerSessionV1:
|
||||
return TracerSessionV1(id=2, name="Sample session")
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_convert_run(
|
||||
lang_chain_tracer_v1: LangChainTracerV1,
|
||||
sample_tracer_session_v1: TracerSessionV1,
|
||||
) -> None:
|
||||
"""Test converting a run to a V1 run."""
|
||||
llm_run = Run(
|
||||
id="57a08cc4-73d2-4236-8370-549099d07fad",
|
||||
name="llm_run",
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
session_id=TEST_SESSION_ID,
|
||||
inputs={"prompts": []},
|
||||
outputs=LLMResult(generations=[[]]).dict(),
|
||||
serialized={},
|
||||
extra={},
|
||||
run_type=RunTypeEnum.llm,
|
||||
)
|
||||
chain_run = Run(
|
||||
id="57a08cc4-73d2-4236-8371-549099d07fad",
|
||||
name="chain_run",
|
||||
execution_order=1,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
child_execution_order=1,
|
||||
serialized={},
|
||||
inputs={},
|
||||
outputs={},
|
||||
child_runs=[llm_run],
|
||||
extra={},
|
||||
run_type=RunTypeEnum.chain,
|
||||
)
|
||||
|
||||
tool_run = Run(
|
||||
id="57a08cc4-73d2-4236-8372-549099d07fad",
|
||||
name="tool_run",
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
inputs={"input": "test"},
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
outputs=None,
|
||||
serialized={},
|
||||
child_runs=[],
|
||||
extra={},
|
||||
run_type=RunTypeEnum.tool,
|
||||
)
|
||||
|
||||
expected_llm_run = LLMRun(
|
||||
uuid="57a08cc4-73d2-4236-8370-549099d07fad",
|
||||
name="llm_run",
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
session_id=2,
|
||||
prompts=[],
|
||||
response=LLMResult(generations=[[]]),
|
||||
serialized={},
|
||||
extra={},
|
||||
)
|
||||
|
||||
expected_chain_run = ChainRun(
|
||||
uuid="57a08cc4-73d2-4236-8371-549099d07fad",
|
||||
name="chain_run",
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
session_id=2,
|
||||
serialized={},
|
||||
inputs={},
|
||||
outputs={},
|
||||
child_llm_runs=[expected_llm_run],
|
||||
child_chain_runs=[],
|
||||
child_tool_runs=[],
|
||||
extra={},
|
||||
)
|
||||
expected_tool_run = ToolRun(
|
||||
uuid="57a08cc4-73d2-4236-8372-549099d07fad",
|
||||
name="tool_run",
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
session_id=2,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
tool_input="test",
|
||||
action="{}",
|
||||
serialized={},
|
||||
child_llm_runs=[],
|
||||
child_chain_runs=[],
|
||||
child_tool_runs=[],
|
||||
extra={},
|
||||
)
|
||||
lang_chain_tracer_v1.session = sample_tracer_session_v1
|
||||
converted_llm_run = lang_chain_tracer_v1._convert_to_v1_run(llm_run)
|
||||
converted_chain_run = lang_chain_tracer_v1._convert_to_v1_run(chain_run)
|
||||
converted_tool_run = lang_chain_tracer_v1._convert_to_v1_run(tool_run)
|
||||
|
||||
assert isinstance(converted_llm_run, LLMRun)
|
||||
assert isinstance(converted_chain_run, ChainRun)
|
||||
assert isinstance(converted_tool_run, ToolRun)
|
||||
assert converted_llm_run == expected_llm_run
|
||||
assert converted_tool_run == expected_tool_run
|
||||
assert converted_chain_run == expected_chain_run
|
@@ -3,621 +3,90 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import List, Tuple, Union
|
||||
from unittest.mock import Mock, patch
|
||||
from typing import Tuple
|
||||
from unittest.mock import patch
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
from freezegun import freeze_time
|
||||
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
from langchain.callbacks.tracers.base import (
|
||||
BaseTracer,
|
||||
ChainRun,
|
||||
LLMRun,
|
||||
ToolRun,
|
||||
TracerException,
|
||||
TracerSession,
|
||||
)
|
||||
from langchain.callbacks.tracers.langchain import LangChainTracerV2
|
||||
from langchain.callbacks.tracers.schemas import (
|
||||
RunCreate,
|
||||
TracerSessionBase,
|
||||
TracerSessionV2,
|
||||
TracerSessionV2Create,
|
||||
)
|
||||
from langchain.callbacks.tracers.langchain import LangChainTracer
|
||||
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum, TracerSession
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
TEST_SESSION_ID = 2023
|
||||
|
||||
|
||||
def load_session(session_name: str) -> TracerSession:
|
||||
"""Load a tracing session."""
|
||||
return TracerSession(id=1, name=session_name, start_time=datetime.utcnow())
|
||||
|
||||
|
||||
def _persist_session(session: TracerSessionBase) -> TracerSession:
|
||||
"""Persist a tracing session."""
|
||||
return TracerSession(id=TEST_SESSION_ID, **session.dict())
|
||||
|
||||
|
||||
def load_default_session() -> TracerSession:
|
||||
"""Load a tracing session."""
|
||||
return TracerSession(id=1, name="default", start_time=datetime.utcnow())
|
||||
|
||||
|
||||
class FakeTracer(BaseTracer):
|
||||
"""Fake tracer that records LangChain execution."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the tracer."""
|
||||
super().__init__()
|
||||
self.runs: List[Union[LLMRun, ChainRun, ToolRun]] = []
|
||||
|
||||
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
|
||||
"""Persist a run."""
|
||||
self.runs.append(run)
|
||||
|
||||
def _persist_session(self, session: TracerSessionBase) -> TracerSession:
|
||||
"""Persist a tracing session."""
|
||||
return _persist_session(session)
|
||||
|
||||
def load_session(self, session_name: str) -> TracerSession:
|
||||
"""Load a tracing session."""
|
||||
return load_session(session_name)
|
||||
|
||||
def load_default_session(self) -> TracerSession:
|
||||
"""Load a tracing session."""
|
||||
return load_default_session()
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_llm_run() -> None:
|
||||
"""Test tracer on an LLM run."""
|
||||
uuid = uuid4()
|
||||
compare_run = LLMRun(
|
||||
uuid=str(uuid),
|
||||
parent_uuid=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={},
|
||||
prompts=[],
|
||||
response=LLMResult(generations=[[]]),
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
tracer.on_llm_start(serialized={}, prompts=[], run_id=uuid)
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_chat_model_run() -> None:
|
||||
"""Test tracer on a Chat Model run."""
|
||||
uuid = uuid4()
|
||||
compare_run = LLMRun(
|
||||
uuid=str(uuid),
|
||||
parent_uuid=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={},
|
||||
prompts=[""],
|
||||
response=LLMResult(generations=[[]]),
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
manager = CallbackManager(handlers=[tracer])
|
||||
run_manager = manager.on_chat_model_start(serialized={}, messages=[[]], run_id=uuid)
|
||||
run_manager.on_llm_end(response=LLMResult(generations=[[]]))
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_llm_run_errors_no_start() -> None:
|
||||
"""Test tracer on an LLM run without a start."""
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
with pytest.raises(TracerException):
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid4())
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_multiple_llm_runs() -> None:
|
||||
"""Test the tracer with multiple runs."""
|
||||
uuid = uuid4()
|
||||
compare_run = LLMRun(
|
||||
uuid=str(uuid),
|
||||
parent_uuid=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={},
|
||||
prompts=[],
|
||||
response=LLMResult(generations=[[]]),
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
num_runs = 10
|
||||
for _ in range(num_runs):
|
||||
tracer.on_llm_start(serialized={}, prompts=[], run_id=uuid)
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
|
||||
|
||||
assert tracer.runs == [compare_run] * num_runs
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_chain_run() -> None:
|
||||
"""Test tracer on a Chain run."""
|
||||
uuid = uuid4()
|
||||
compare_run = ChainRun(
|
||||
uuid=str(uuid),
|
||||
parent_uuid=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={},
|
||||
inputs={},
|
||||
outputs={},
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
tracer.on_chain_start(serialized={}, inputs={}, run_id=uuid)
|
||||
tracer.on_chain_end(outputs={}, run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_tool_run() -> None:
|
||||
"""Test tracer on a Tool run."""
|
||||
uuid = uuid4()
|
||||
compare_run = ToolRun(
|
||||
uuid=str(uuid),
|
||||
parent_uuid=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={},
|
||||
tool_input="test",
|
||||
output="test",
|
||||
action="{}",
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
tracer.on_tool_start(serialized={}, input_str="test", run_id=uuid)
|
||||
tracer.on_tool_end("test", run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_nested_run() -> None:
|
||||
"""Test tracer on a nested run."""
|
||||
tracer = FakeTracer()
|
||||
tracer.new_session()
|
||||
|
||||
chain_uuid = uuid4()
|
||||
tool_uuid = uuid4()
|
||||
llm_uuid1 = uuid4()
|
||||
llm_uuid2 = uuid4()
|
||||
for _ in range(10):
|
||||
tracer.on_chain_start(serialized={}, inputs={}, run_id=chain_uuid)
|
||||
tracer.on_tool_start(
|
||||
serialized={}, input_str="test", run_id=tool_uuid, parent_run_id=chain_uuid
|
||||
)
|
||||
tracer.on_llm_start(
|
||||
serialized={}, prompts=[], run_id=llm_uuid1, parent_run_id=tool_uuid
|
||||
)
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
|
||||
tracer.on_tool_end("test", run_id=tool_uuid)
|
||||
tracer.on_llm_start(
|
||||
serialized={}, prompts=[], run_id=llm_uuid2, parent_run_id=chain_uuid
|
||||
)
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid2)
|
||||
tracer.on_chain_end(outputs={}, run_id=chain_uuid)
|
||||
|
||||
compare_run = ChainRun(
|
||||
uuid=str(chain_uuid),
|
||||
error=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=4,
|
||||
serialized={},
|
||||
inputs={},
|
||||
outputs={},
|
||||
session_id=TEST_SESSION_ID,
|
||||
child_chain_runs=[],
|
||||
child_tool_runs=[
|
||||
ToolRun(
|
||||
uuid=str(tool_uuid),
|
||||
parent_uuid=str(chain_uuid),
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=2,
|
||||
child_execution_order=3,
|
||||
serialized={},
|
||||
tool_input="test",
|
||||
output="test",
|
||||
action="{}",
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
child_chain_runs=[],
|
||||
child_tool_runs=[],
|
||||
child_llm_runs=[
|
||||
LLMRun(
|
||||
uuid=str(llm_uuid1),
|
||||
parent_uuid=str(tool_uuid),
|
||||
error=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=3,
|
||||
child_execution_order=3,
|
||||
serialized={},
|
||||
prompts=[],
|
||||
response=LLMResult(generations=[[]]),
|
||||
session_id=TEST_SESSION_ID,
|
||||
)
|
||||
],
|
||||
),
|
||||
],
|
||||
child_llm_runs=[
|
||||
LLMRun(
|
||||
uuid=str(llm_uuid2),
|
||||
parent_uuid=str(chain_uuid),
|
||||
error=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=4,
|
||||
child_execution_order=4,
|
||||
serialized={},
|
||||
prompts=[],
|
||||
response=LLMResult(generations=[[]]),
|
||||
session_id=TEST_SESSION_ID,
|
||||
),
|
||||
],
|
||||
)
|
||||
assert tracer.runs == [compare_run] * 10
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_llm_run_on_error() -> None:
|
||||
"""Test tracer on an LLM run with an error."""
|
||||
exception = Exception("test")
|
||||
uuid = uuid4()
|
||||
|
||||
compare_run = LLMRun(
|
||||
uuid=str(uuid),
|
||||
parent_uuid=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={},
|
||||
prompts=[],
|
||||
response=None,
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=repr(exception),
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
tracer.on_llm_start(serialized={}, prompts=[], run_id=uuid)
|
||||
tracer.on_llm_error(exception, run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_chain_run_on_error() -> None:
|
||||
"""Test tracer on a Chain run with an error."""
|
||||
exception = Exception("test")
|
||||
uuid = uuid4()
|
||||
|
||||
compare_run = ChainRun(
|
||||
uuid=str(uuid),
|
||||
parent_uuid=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={},
|
||||
inputs={},
|
||||
outputs=None,
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=repr(exception),
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
tracer.on_chain_start(serialized={}, inputs={}, run_id=uuid)
|
||||
tracer.on_chain_error(exception, run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_tool_run_on_error() -> None:
|
||||
"""Test tracer on a Tool run with an error."""
|
||||
exception = Exception("test")
|
||||
uuid = uuid4()
|
||||
|
||||
compare_run = ToolRun(
|
||||
uuid=str(uuid),
|
||||
parent_uuid=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={},
|
||||
tool_input="test",
|
||||
output=None,
|
||||
action="{}",
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=repr(exception),
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
tracer.on_tool_start(serialized={}, input_str="test", run_id=uuid)
|
||||
tracer.on_tool_error(exception, run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_nested_runs_on_error() -> None:
|
||||
"""Test tracer on a nested run with an error."""
|
||||
exception = Exception("test")
|
||||
|
||||
tracer = FakeTracer()
|
||||
tracer.new_session()
|
||||
chain_uuid = uuid4()
|
||||
tool_uuid = uuid4()
|
||||
llm_uuid1 = uuid4()
|
||||
llm_uuid2 = uuid4()
|
||||
llm_uuid3 = uuid4()
|
||||
|
||||
for _ in range(3):
|
||||
tracer.on_chain_start(serialized={}, inputs={}, run_id=chain_uuid)
|
||||
tracer.on_llm_start(
|
||||
serialized={}, prompts=[], run_id=llm_uuid1, parent_run_id=chain_uuid
|
||||
)
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
|
||||
tracer.on_llm_start(
|
||||
serialized={}, prompts=[], run_id=llm_uuid2, parent_run_id=chain_uuid
|
||||
)
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid2)
|
||||
tracer.on_tool_start(
|
||||
serialized={}, input_str="test", run_id=tool_uuid, parent_run_id=chain_uuid
|
||||
)
|
||||
tracer.on_llm_start(
|
||||
serialized={}, prompts=[], run_id=llm_uuid3, parent_run_id=tool_uuid
|
||||
)
|
||||
tracer.on_llm_error(exception, run_id=llm_uuid3)
|
||||
tracer.on_tool_error(exception, run_id=tool_uuid)
|
||||
tracer.on_chain_error(exception, run_id=chain_uuid)
|
||||
|
||||
compare_run = ChainRun(
|
||||
uuid=str(chain_uuid),
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=5,
|
||||
serialized={},
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=repr(exception),
|
||||
inputs={},
|
||||
outputs=None,
|
||||
child_llm_runs=[
|
||||
LLMRun(
|
||||
uuid=str(llm_uuid1),
|
||||
parent_uuid=str(chain_uuid),
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=2,
|
||||
child_execution_order=2,
|
||||
serialized={},
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
prompts=[],
|
||||
response=LLMResult(generations=[[]], llm_output=None),
|
||||
),
|
||||
LLMRun(
|
||||
uuid=str(llm_uuid2),
|
||||
parent_uuid=str(chain_uuid),
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=3,
|
||||
child_execution_order=3,
|
||||
serialized={},
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
prompts=[],
|
||||
response=LLMResult(generations=[[]], llm_output=None),
|
||||
),
|
||||
],
|
||||
child_chain_runs=[],
|
||||
child_tool_runs=[
|
||||
ToolRun(
|
||||
uuid=str(tool_uuid),
|
||||
parent_uuid=str(chain_uuid),
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=4,
|
||||
child_execution_order=5,
|
||||
serialized={},
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=repr(exception),
|
||||
tool_input="test",
|
||||
output=None,
|
||||
action="{}",
|
||||
child_llm_runs=[
|
||||
LLMRun(
|
||||
uuid=str(llm_uuid3),
|
||||
parent_uuid=str(tool_uuid),
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=5,
|
||||
child_execution_order=5,
|
||||
serialized={},
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=repr(exception),
|
||||
prompts=[],
|
||||
response=None,
|
||||
)
|
||||
],
|
||||
child_chain_runs=[],
|
||||
child_tool_runs=[],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
assert tracer.runs == [compare_run] * 3
|
||||
|
||||
|
||||
_SESSION_ID = UUID("4fbf7c55-2727-4711-8964-d821ed4d4e2a")
|
||||
_TENANT_ID = UUID("57a08cc4-73d2-4236-8378-549099d07fad")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def lang_chain_tracer_v2(monkeypatch: pytest.MonkeyPatch) -> LangChainTracerV2:
|
||||
def lang_chain_tracer_v2(monkeypatch: pytest.MonkeyPatch) -> LangChainTracer:
|
||||
monkeypatch.setenv("LANGCHAIN_TENANT_ID", "test-tenant-id")
|
||||
monkeypatch.setenv("LANGCHAIN_ENDPOINT", "http://test-endpoint.com")
|
||||
monkeypatch.setenv("LANGCHAIN_API_KEY", "foo")
|
||||
tracer = LangChainTracerV2()
|
||||
tracer = LangChainTracer()
|
||||
return tracer
|
||||
|
||||
|
||||
# Mock a sample TracerSessionV2 object
|
||||
# Mock a sample TracerSession object
|
||||
@pytest.fixture
|
||||
def sample_tracer_session_v2() -> TracerSessionV2:
|
||||
return TracerSessionV2(id=_SESSION_ID, name="Sample session", tenant_id=_TENANT_ID)
|
||||
def sample_tracer_session_v2() -> TracerSession:
|
||||
return TracerSession(id=_SESSION_ID, name="Sample session", tenant_id=_TENANT_ID)
|
||||
|
||||
|
||||
# Mock a sample LLMRun, ChainRun, and ToolRun objects
|
||||
@freeze_time("2023-01-01")
|
||||
@pytest.fixture
|
||||
def sample_runs() -> Tuple[LLMRun, ChainRun, ToolRun]:
|
||||
llm_run = LLMRun(
|
||||
uuid="57a08cc4-73d2-4236-8370-549099d07fad",
|
||||
def sample_runs() -> Tuple[Run, Run, Run]:
|
||||
llm_run = Run(
|
||||
id="57a08cc4-73d2-4236-8370-549099d07fad",
|
||||
name="llm_run",
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
parent_run_id="57a08cc4-73d2-4236-8371-549099d07fad",
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
session_id=1,
|
||||
prompts=[],
|
||||
response=LLMResult(generations=[[]]),
|
||||
inputs={"prompts": []},
|
||||
outputs=LLMResult(generations=[[]]).dict(),
|
||||
serialized={},
|
||||
extra={},
|
||||
run_type=RunTypeEnum.llm,
|
||||
)
|
||||
chain_run = ChainRun(
|
||||
uuid="57a08cc4-73d2-4236-8371-549099d07fad",
|
||||
chain_run = Run(
|
||||
id="57a08cc4-73d2-4236-8371-549099d07fad",
|
||||
name="chain_run",
|
||||
execution_order=1,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
child_execution_order=1,
|
||||
session_id=1,
|
||||
serialized={},
|
||||
inputs={},
|
||||
outputs=None,
|
||||
child_llm_runs=[llm_run],
|
||||
child_chain_runs=[],
|
||||
child_tool_runs=[],
|
||||
outputs={},
|
||||
child_runs=[llm_run],
|
||||
extra={},
|
||||
run_type=RunTypeEnum.chain,
|
||||
)
|
||||
tool_run = ToolRun(
|
||||
uuid="57a08cc4-73d2-4236-8372-549099d07fad",
|
||||
|
||||
tool_run = Run(
|
||||
id="57a08cc4-73d2-4236-8372-549099d07fad",
|
||||
name="tool_run",
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
session_id=1,
|
||||
tool_input="test",
|
||||
action="{}",
|
||||
inputs={"input": "test"},
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
outputs=None,
|
||||
serialized={},
|
||||
child_llm_runs=[],
|
||||
child_chain_runs=[],
|
||||
child_tool_runs=[],
|
||||
child_runs=[],
|
||||
extra={},
|
||||
run_type=RunTypeEnum.tool,
|
||||
)
|
||||
return llm_run, chain_run, tool_run
|
||||
|
||||
|
||||
def test_get_default_query_params(lang_chain_tracer_v2: LangChainTracerV2) -> None:
|
||||
expected = {"tenant_id": "test-tenant-id"}
|
||||
result = lang_chain_tracer_v2._get_default_query_params()
|
||||
assert result == expected
|
||||
|
||||
|
||||
@patch("langchain.callbacks.tracers.langchain.requests.get")
|
||||
def test_load_session(
|
||||
mock_requests_get: Mock,
|
||||
lang_chain_tracer_v2: LangChainTracerV2,
|
||||
sample_tracer_session_v2: TracerSessionV2,
|
||||
) -> None:
|
||||
"""Test that load_session method returns a TracerSessionV2 object."""
|
||||
mock_requests_get.return_value.json.return_value = [sample_tracer_session_v2.dict()]
|
||||
result = lang_chain_tracer_v2.load_session("test-session-name")
|
||||
mock_requests_get.assert_called_with(
|
||||
"http://test-endpoint.com/sessions",
|
||||
headers={"Content-Type": "application/json", "x-api-key": "foo"},
|
||||
params={"tenant_id": "test-tenant-id", "name": "test-session-name"},
|
||||
)
|
||||
assert result == sample_tracer_session_v2
|
||||
|
||||
|
||||
def test_convert_run(
|
||||
lang_chain_tracer_v2: LangChainTracerV2,
|
||||
sample_tracer_session_v2: TracerSessionV2,
|
||||
sample_runs: Tuple[LLMRun, ChainRun, ToolRun],
|
||||
) -> None:
|
||||
llm_run, chain_run, tool_run = sample_runs
|
||||
lang_chain_tracer_v2.session = sample_tracer_session_v2
|
||||
converted_llm_run = lang_chain_tracer_v2._convert_run(llm_run)
|
||||
converted_chain_run = lang_chain_tracer_v2._convert_run(chain_run)
|
||||
converted_tool_run = lang_chain_tracer_v2._convert_run(tool_run)
|
||||
|
||||
assert isinstance(converted_llm_run, RunCreate)
|
||||
assert isinstance(converted_chain_run, RunCreate)
|
||||
assert isinstance(converted_tool_run, RunCreate)
|
||||
|
||||
|
||||
def test_persist_run(
|
||||
lang_chain_tracer_v2: LangChainTracerV2,
|
||||
sample_tracer_session_v2: TracerSessionV2,
|
||||
sample_runs: Tuple[LLMRun, ChainRun, ToolRun],
|
||||
lang_chain_tracer_v2: LangChainTracer,
|
||||
sample_tracer_session_v2: TracerSession,
|
||||
sample_runs: Tuple[Run, Run, Run],
|
||||
) -> None:
|
||||
"""Test that persist_run method calls requests.post once per method call."""
|
||||
with patch("langchain.callbacks.tracers.langchain.requests.post") as post, patch(
|
||||
@@ -625,25 +94,25 @@ def test_persist_run(
|
||||
) as get:
|
||||
post.return_value.raise_for_status.return_value = None
|
||||
lang_chain_tracer_v2.session = sample_tracer_session_v2
|
||||
llm_run, chain_run, tool_run = sample_runs
|
||||
lang_chain_tracer_v2._persist_run(llm_run)
|
||||
lang_chain_tracer_v2._persist_run(chain_run)
|
||||
lang_chain_tracer_v2._persist_run(tool_run)
|
||||
for run in sample_runs:
|
||||
lang_chain_tracer_v2.run_map[str(run.id)] = run
|
||||
for run in sample_runs:
|
||||
lang_chain_tracer_v2._end_trace(run)
|
||||
|
||||
assert post.call_count == 3
|
||||
assert get.call_count == 0
|
||||
|
||||
|
||||
def test_persist_run_with_example_id(
|
||||
lang_chain_tracer_v2: LangChainTracerV2,
|
||||
sample_tracer_session_v2: TracerSessionV2,
|
||||
sample_runs: Tuple[LLMRun, ChainRun, ToolRun],
|
||||
lang_chain_tracer_v2: LangChainTracer,
|
||||
sample_tracer_session_v2: TracerSession,
|
||||
sample_runs: Tuple[Run, Run, Run],
|
||||
) -> None:
|
||||
"""Test the example ID is assigned only to the parent run and not the children."""
|
||||
example_id = uuid4()
|
||||
llm_run, chain_run, tool_run = sample_runs
|
||||
chain_run.child_tool_runs = [tool_run]
|
||||
tool_run.child_llm_runs = [llm_run]
|
||||
chain_run.child_runs = [tool_run]
|
||||
tool_run.child_runs = [llm_run]
|
||||
with patch("langchain.callbacks.tracers.langchain.requests.post") as post, patch(
|
||||
"langchain.callbacks.tracers.langchain.requests.get"
|
||||
) as get:
|
||||
@@ -652,55 +121,14 @@ def test_persist_run_with_example_id(
|
||||
lang_chain_tracer_v2.example_id = example_id
|
||||
lang_chain_tracer_v2._persist_run(chain_run)
|
||||
|
||||
assert post.call_count == 1
|
||||
assert post.call_count == 3
|
||||
assert get.call_count == 0
|
||||
posted_data = json.loads(post.call_args[1]["data"])
|
||||
assert posted_data["id"] == chain_run.uuid
|
||||
assert posted_data["reference_example_id"] == str(example_id)
|
||||
|
||||
def assert_child_run_no_example_id(run: dict) -> None:
|
||||
assert not run.get("reference_example_id")
|
||||
for child_run in run.get("child_runs", []):
|
||||
assert_child_run_no_example_id(child_run)
|
||||
|
||||
for child_run in posted_data["child_runs"]:
|
||||
assert_child_run_no_example_id(child_run)
|
||||
|
||||
|
||||
def test_get_session_create(lang_chain_tracer_v2: LangChainTracerV2) -> None:
|
||||
"""Test creating the 'SessionCreate' object."""
|
||||
lang_chain_tracer_v2.tenant_id = str(_TENANT_ID)
|
||||
session_create = lang_chain_tracer_v2._get_session_create(name="test")
|
||||
assert isinstance(session_create, TracerSessionV2Create)
|
||||
assert session_create.name == "test"
|
||||
assert session_create.tenant_id == _TENANT_ID
|
||||
|
||||
|
||||
@patch("langchain.callbacks.tracers.langchain.requests.post")
|
||||
def test_persist_session(
|
||||
mock_requests_post: Mock,
|
||||
lang_chain_tracer_v2: LangChainTracerV2,
|
||||
sample_tracer_session_v2: TracerSessionV2,
|
||||
) -> None:
|
||||
"""Test persist_session returns a TracerSessionV2 with the updated ID."""
|
||||
session_create = TracerSessionV2Create(**sample_tracer_session_v2.dict())
|
||||
new_id = str(uuid4())
|
||||
mock_requests_post.return_value.json.return_value = {"id": new_id}
|
||||
result = lang_chain_tracer_v2._persist_session(session_create)
|
||||
assert isinstance(result, TracerSessionV2)
|
||||
res = sample_tracer_session_v2.dict()
|
||||
res["id"] = UUID(new_id)
|
||||
assert result.dict() == res
|
||||
|
||||
|
||||
@patch("langchain.callbacks.tracers.langchain.LangChainTracerV2.load_session")
|
||||
def test_load_default_session(
|
||||
mock_load_session: Mock,
|
||||
lang_chain_tracer_v2: LangChainTracerV2,
|
||||
sample_tracer_session_v2: TracerSessionV2,
|
||||
) -> None:
|
||||
"""Test load_default_session attempts to load with the default name."""
|
||||
mock_load_session.return_value = sample_tracer_session_v2
|
||||
result = lang_chain_tracer_v2.load_default_session()
|
||||
assert result == sample_tracer_session_v2
|
||||
mock_load_session.assert_called_with("default")
|
||||
posted_data = [
|
||||
json.loads(call_args[1]["data"]) for call_args in post.call_args_list
|
||||
]
|
||||
assert posted_data[0]["id"] == str(chain_run.id)
|
||||
assert posted_data[0]["reference_example_id"] == str(example_id)
|
||||
assert posted_data[1]["id"] == str(tool_run.id)
|
||||
assert not posted_data[1].get("reference_example_id")
|
||||
assert posted_data[2]["id"] == str(llm_run.id)
|
||||
assert not posted_data[2].get("reference_example_id")
|
||||
|
@@ -8,8 +8,8 @@ from unittest import mock
|
||||
import pytest
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.tracers.langchain import LangChainTracerV2
|
||||
from langchain.callbacks.tracers.schemas import TracerSessionV2
|
||||
from langchain.callbacks.tracers.langchain import LangChainTracer
|
||||
from langchain.callbacks.tracers.schemas import TracerSession
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.client.langchain import (
|
||||
LangChainPlusClient,
|
||||
@@ -196,10 +196,8 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
{"result": f"Result for example {example.id}"} for _ in range(n_repetitions)
|
||||
]
|
||||
|
||||
def mock_load_session(
|
||||
self: Any, name: str, *args: Any, **kwargs: Any
|
||||
) -> TracerSessionV2:
|
||||
return TracerSessionV2(name=name, tenant_id=_TENANT_ID, id=uuid.uuid4())
|
||||
def mock_ensure_session(self: Any, *args: Any, **kwargs: Any) -> TracerSession:
|
||||
return TracerSession(name="test_session", tenant_id=_TENANT_ID, id=uuid.uuid4())
|
||||
|
||||
with mock.patch.object(
|
||||
LangChainPlusClient, "read_dataset", new=mock_read_dataset
|
||||
@@ -208,7 +206,7 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
), mock.patch.object(
|
||||
LangChainPlusClient, "_arun_llm_or_chain", new=mock_arun_chain
|
||||
), mock.patch.object(
|
||||
LangChainTracerV2, "load_session", new=mock_load_session
|
||||
LangChainTracer, "ensure_session", new=mock_ensure_session
|
||||
):
|
||||
monkeypatch.setenv("LANGCHAIN_TENANT_ID", _TENANT_ID)
|
||||
client = LangChainPlusClient(
|
||||
|
Reference in New Issue
Block a user