mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-11 16:01:33 +00:00
Update V2 Tracer (#4193)
- Update the RunCreate object to work with recent changes - Add optional Example ID to the tracer - Adjust default persist_session behavior to attempt to load the session if it exists - Raise more useful HTTP errors for logging - Add unit testing - Fix the default ID to be a UUID for v2 tracer sessions Broken out from the big draft here: https://github.com/hwchase17/langchain/pull/4061
This commit is contained in:
@@ -18,7 +18,12 @@ from langchain.callbacks.tracers.base import (
|
||||
TracerSession,
|
||||
)
|
||||
from langchain.callbacks.tracers.langchain import LangChainTracerV2
|
||||
from langchain.callbacks.tracers.schemas import Run, TracerSessionBase, TracerSessionV2
|
||||
from langchain.callbacks.tracers.schemas import (
|
||||
RunCreate,
|
||||
TracerSessionBase,
|
||||
TracerSessionV2,
|
||||
TracerSessionV2Create,
|
||||
)
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
TEST_SESSION_ID = 2023
|
||||
@@ -541,14 +546,12 @@ def sample_runs() -> Tuple[LLMRun, ChainRun, ToolRun]:
|
||||
return llm_run, chain_run, tool_run
|
||||
|
||||
|
||||
# Test _get_default_query_params method
|
||||
def test_get_default_query_params(lang_chain_tracer_v2: LangChainTracerV2) -> None:
|
||||
expected = {"tenant_id": "test-tenant-id"}
|
||||
result = lang_chain_tracer_v2._get_default_query_params()
|
||||
assert result == expected
|
||||
|
||||
|
||||
# Test load_session method
|
||||
@patch("langchain.callbacks.tracers.langchain.requests.get")
|
||||
def test_load_session(
|
||||
mock_requests_get: Mock,
|
||||
@@ -577,23 +580,65 @@ def test_convert_run(
|
||||
converted_chain_run = lang_chain_tracer_v2._convert_run(chain_run)
|
||||
converted_tool_run = lang_chain_tracer_v2._convert_run(tool_run)
|
||||
|
||||
assert isinstance(converted_llm_run, Run)
|
||||
assert isinstance(converted_chain_run, Run)
|
||||
assert isinstance(converted_tool_run, Run)
|
||||
assert isinstance(converted_llm_run, RunCreate)
|
||||
assert isinstance(converted_chain_run, RunCreate)
|
||||
assert isinstance(converted_tool_run, RunCreate)
|
||||
|
||||
|
||||
@patch("langchain.callbacks.tracers.langchain.requests.post")
|
||||
def test_persist_run(
|
||||
mock_requests_post: Mock,
|
||||
lang_chain_tracer_v2: LangChainTracerV2,
|
||||
sample_tracer_session_v2: TracerSessionV2,
|
||||
sample_runs: Tuple[LLMRun, ChainRun, ToolRun],
|
||||
) -> None:
|
||||
mock_requests_post.return_value.raise_for_status.return_value = None
|
||||
lang_chain_tracer_v2.session = sample_tracer_session_v2
|
||||
llm_run, chain_run, tool_run = sample_runs
|
||||
lang_chain_tracer_v2._persist_run(llm_run)
|
||||
lang_chain_tracer_v2._persist_run(chain_run)
|
||||
lang_chain_tracer_v2._persist_run(tool_run)
|
||||
"""Test that persist_run method calls requests.post once per method call."""
|
||||
with patch("langchain.callbacks.tracers.langchain.requests.post") as post, patch(
|
||||
"langchain.callbacks.tracers.langchain.requests.get"
|
||||
) as get:
|
||||
post.return_value.raise_for_status.return_value = None
|
||||
lang_chain_tracer_v2.session = sample_tracer_session_v2
|
||||
llm_run, chain_run, tool_run = sample_runs
|
||||
lang_chain_tracer_v2._persist_run(llm_run)
|
||||
lang_chain_tracer_v2._persist_run(chain_run)
|
||||
lang_chain_tracer_v2._persist_run(tool_run)
|
||||
|
||||
assert mock_requests_post.call_count == 3
|
||||
assert post.call_count == 3
|
||||
assert get.call_count == 0
|
||||
|
||||
|
||||
def test_get_session_create(lang_chain_tracer_v2: LangChainTracerV2) -> None:
|
||||
"""Test creating the 'SessionCreate' object."""
|
||||
lang_chain_tracer_v2.tenant_id = str(_TENANT_ID)
|
||||
session_create = lang_chain_tracer_v2._get_session_create(name="test")
|
||||
assert isinstance(session_create, TracerSessionV2Create)
|
||||
assert session_create.name == "test"
|
||||
assert session_create.tenant_id == _TENANT_ID
|
||||
|
||||
|
||||
@patch("langchain.callbacks.tracers.langchain.requests.post")
|
||||
def test_persist_session(
|
||||
mock_requests_post: Mock,
|
||||
lang_chain_tracer_v2: LangChainTracerV2,
|
||||
sample_tracer_session_v2: TracerSessionV2,
|
||||
) -> None:
|
||||
"""Test persist_session returns a TracerSessionV2 with the updated ID."""
|
||||
session_create = TracerSessionV2Create(**sample_tracer_session_v2.dict())
|
||||
new_id = str(uuid4())
|
||||
mock_requests_post.return_value.json.return_value = {"id": new_id}
|
||||
result = lang_chain_tracer_v2._persist_session(session_create)
|
||||
assert isinstance(result, TracerSessionV2)
|
||||
res = sample_tracer_session_v2.dict()
|
||||
res["id"] = UUID(new_id)
|
||||
assert result.dict() == res
|
||||
|
||||
|
||||
@patch("langchain.callbacks.tracers.langchain.LangChainTracerV2.load_session")
|
||||
def test_load_default_session(
|
||||
mock_load_session: Mock,
|
||||
lang_chain_tracer_v2: LangChainTracerV2,
|
||||
sample_tracer_session_v2: TracerSessionV2,
|
||||
) -> None:
|
||||
"""Test load_default_session attempts to load with the default name."""
|
||||
mock_load_session.return_value = sample_tracer_session_v2
|
||||
result = lang_chain_tracer_v2.load_default_session()
|
||||
assert result == sample_tracer_session_v2
|
||||
mock_load_session.assert_called_with("default")
|
||||
|
Reference in New Issue
Block a user