mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-27 17:08:47 +00:00
Update V2 Tracer (#4193)
- Update the RunCreate object to work with recent changes - Add optional Example ID to the tracer - Adjust default persist_session behavior to attempt to load the session if it exists - Raise more useful HTTP errors for logging - Add unit testing - Fix the default ID to be a UUID for v2 tracer sessions Broken out from the big draft here: https://github.com/hwchase17/langchain/pull/4061
This commit is contained in:
parent
c3044b1bf0
commit
a30f42da4e
@ -58,6 +58,7 @@ def tracing_enabled(
|
|||||||
@contextmanager
|
@contextmanager
|
||||||
def tracing_v2_enabled(
|
def tracing_v2_enabled(
|
||||||
session_name: str = "default",
|
session_name: str = "default",
|
||||||
|
example_id: Optional[Union[str, UUID]] = None,
|
||||||
) -> Generator[TracerSessionV2, None, None]:
|
) -> Generator[TracerSessionV2, None, None]:
|
||||||
"""Get the experimental tracer handler in a context manager."""
|
"""Get the experimental tracer handler in a context manager."""
|
||||||
# Issue a warning that this is experimental
|
# Issue a warning that this is experimental
|
||||||
@ -65,8 +66,10 @@ def tracing_v2_enabled(
|
|||||||
"The experimental tracing v2 is in development. "
|
"The experimental tracing v2 is in development. "
|
||||||
"This is not yet stable and may change in the future."
|
"This is not yet stable and may change in the future."
|
||||||
)
|
)
|
||||||
cb = LangChainTracerV2()
|
if isinstance(example_id, str):
|
||||||
session = cb.load_session(session_name)
|
example_id = UUID(example_id)
|
||||||
|
cb = LangChainTracerV2(example_id=example_id)
|
||||||
|
session = cast(TracerSessionV2, cb.new_session(session_name))
|
||||||
tracing_callback_var.set(cb)
|
tracing_callback_var.set(cb)
|
||||||
yield session
|
yield session
|
||||||
tracing_callback_var.set(None)
|
tracing_callback_var.set(None)
|
||||||
|
@ -29,7 +29,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
def __init__(self, **kwargs: Any) -> None:
|
def __init__(self, **kwargs: Any) -> None:
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.run_map: Dict[str, Union[LLMRun, ChainRun, ToolRun]] = {}
|
self.run_map: Dict[str, Union[LLMRun, ChainRun, ToolRun]] = {}
|
||||||
self.session: Optional[Union[TracerSessionV2, TracerSession]] = None
|
self.session: Optional[Union[TracerSession, TracerSessionV2]] = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _add_child_run(
|
def _add_child_run(
|
||||||
@ -165,7 +165,6 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
llm_run = self.run_map.get(run_id_)
|
llm_run = self.run_map.get(run_id_)
|
||||||
if llm_run is None or not isinstance(llm_run, LLMRun):
|
if llm_run is None or not isinstance(llm_run, LLMRun):
|
||||||
raise TracerException("No LLMRun found to be traced")
|
raise TracerException("No LLMRun found to be traced")
|
||||||
|
|
||||||
llm_run.response = response
|
llm_run.response = response
|
||||||
llm_run.end_time = datetime.utcnow()
|
llm_run.end_time = datetime.utcnow()
|
||||||
self._end_trace(llm_run)
|
self._end_trace(llm_run)
|
||||||
|
@ -4,6 +4,7 @@ from __future__ import annotations
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
@ -11,13 +12,14 @@ from langchain.callbacks.tracers.base import BaseTracer
|
|||||||
from langchain.callbacks.tracers.schemas import (
|
from langchain.callbacks.tracers.schemas import (
|
||||||
ChainRun,
|
ChainRun,
|
||||||
LLMRun,
|
LLMRun,
|
||||||
Run,
|
RunCreate,
|
||||||
ToolRun,
|
ToolRun,
|
||||||
TracerSession,
|
TracerSession,
|
||||||
TracerSessionBase,
|
TracerSessionBase,
|
||||||
TracerSessionV2,
|
TracerSessionV2,
|
||||||
TracerSessionV2Create,
|
TracerSessionV2Create,
|
||||||
)
|
)
|
||||||
|
from langchain.utils import raise_for_status_with_text
|
||||||
|
|
||||||
|
|
||||||
def _get_headers() -> Dict[str, Any]:
|
def _get_headers() -> Dict[str, Any]:
|
||||||
@ -51,11 +53,12 @@ class LangChainTracer(BaseTracer):
|
|||||||
endpoint = f"{self._endpoint}/tool-runs"
|
endpoint = f"{self._endpoint}/tool-runs"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
requests.post(
|
response = requests.post(
|
||||||
endpoint,
|
endpoint,
|
||||||
data=run.json(),
|
data=run.json(),
|
||||||
headers=self._headers,
|
headers=self._headers,
|
||||||
)
|
)
|
||||||
|
raise_for_status_with_text(response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning(f"Failed to persist run: {e}")
|
logging.warning(f"Failed to persist run: {e}")
|
||||||
|
|
||||||
@ -111,7 +114,7 @@ def _get_tenant_id() -> Optional[str]:
|
|||||||
endpoint = _get_endpoint()
|
endpoint = _get_endpoint()
|
||||||
headers = _get_headers()
|
headers = _get_headers()
|
||||||
response = requests.get(endpoint + "/tenants", headers=headers)
|
response = requests.get(endpoint + "/tenants", headers=headers)
|
||||||
response.raise_for_status()
|
raise_for_status_with_text(response)
|
||||||
tenants: List[Dict[str, Any]] = response.json()
|
tenants: List[Dict[str, Any]] = response.json()
|
||||||
if not tenants:
|
if not tenants:
|
||||||
raise ValueError(f"No tenants found for URL {endpoint}")
|
raise ValueError(f"No tenants found for URL {endpoint}")
|
||||||
@ -121,12 +124,13 @@ def _get_tenant_id() -> Optional[str]:
|
|||||||
class LangChainTracerV2(LangChainTracer):
|
class LangChainTracerV2(LangChainTracer):
|
||||||
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
|
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
|
||||||
|
|
||||||
def __init__(self, **kwargs: Any) -> None:
|
def __init__(self, example_id: Optional[UUID] = None, **kwargs: Any) -> None:
|
||||||
"""Initialize the LangChain tracer."""
|
"""Initialize the LangChain tracer."""
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self._endpoint = _get_endpoint()
|
self._endpoint = _get_endpoint()
|
||||||
self._headers = _get_headers()
|
self._headers = _get_headers()
|
||||||
self.tenant_id = _get_tenant_id()
|
self.tenant_id = _get_tenant_id()
|
||||||
|
self.example_id = example_id
|
||||||
|
|
||||||
def _get_session_create(
|
def _get_session_create(
|
||||||
self, name: Optional[str] = None, **kwargs: Any
|
self, name: Optional[str] = None, **kwargs: Any
|
||||||
@ -135,16 +139,30 @@ class LangChainTracerV2(LangChainTracer):
|
|||||||
|
|
||||||
def _persist_session(self, session_create: TracerSessionBase) -> TracerSessionV2:
|
def _persist_session(self, session_create: TracerSessionBase) -> TracerSessionV2:
|
||||||
"""Persist a session."""
|
"""Persist a session."""
|
||||||
|
session: Optional[TracerSessionV2] = None
|
||||||
try:
|
try:
|
||||||
r = requests.post(
|
r = requests.post(
|
||||||
f"{self._endpoint}/sessions",
|
f"{self._endpoint}/sessions",
|
||||||
data=session_create.json(),
|
data=session_create.json(),
|
||||||
headers=self._headers,
|
headers=self._headers,
|
||||||
)
|
)
|
||||||
session = TracerSessionV2(id=r.json()["id"], **session_create.dict())
|
raise_for_status_with_text(r)
|
||||||
|
creation_args = session_create.dict()
|
||||||
|
if "id" in creation_args:
|
||||||
|
del creation_args["id"]
|
||||||
|
return TracerSessionV2(id=r.json()["id"], **creation_args)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning(f"Failed to create session, using default session: {e}")
|
if session_create.name is not None:
|
||||||
session = self.load_session("default")
|
try:
|
||||||
|
return self.load_session(session_create.name)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
logging.warning(
|
||||||
|
f"Failed to create session {session_create.name},"
|
||||||
|
f" using empty session: {e}"
|
||||||
|
)
|
||||||
|
session = TracerSessionV2(id=uuid4(), **session_create.dict())
|
||||||
|
|
||||||
return session
|
return session
|
||||||
|
|
||||||
def _get_default_query_params(self) -> Dict[str, Any]:
|
def _get_default_query_params(self) -> Dict[str, Any]:
|
||||||
@ -159,13 +177,14 @@ class LangChainTracerV2(LangChainTracer):
|
|||||||
if session_name:
|
if session_name:
|
||||||
params["name"] = session_name
|
params["name"] = session_name
|
||||||
r = requests.get(url, headers=self._headers, params=params)
|
r = requests.get(url, headers=self._headers, params=params)
|
||||||
|
raise_for_status_with_text(r)
|
||||||
tracer_session = TracerSessionV2(**r.json()[0])
|
tracer_session = TracerSessionV2(**r.json()[0])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
session_type = "default" if not session_name else session_name
|
session_type = "default" if not session_name else session_name
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"Failed to load {session_type} session, using empty session: {e}"
|
f"Failed to load {session_type} session, using empty session: {e}"
|
||||||
)
|
)
|
||||||
tracer_session = TracerSessionV2(id=1, tenant_id=self.tenant_id)
|
tracer_session = TracerSessionV2(id=uuid4(), tenant_id=self.tenant_id)
|
||||||
|
|
||||||
self.session = tracer_session
|
self.session = tracer_session
|
||||||
return tracer_session
|
return tracer_session
|
||||||
@ -174,7 +193,7 @@ class LangChainTracerV2(LangChainTracer):
|
|||||||
"""Load the default tracing session and set it as the Tracer's session."""
|
"""Load the default tracing session and set it as the Tracer's session."""
|
||||||
return self.load_session("default")
|
return self.load_session("default")
|
||||||
|
|
||||||
def _convert_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> Run:
|
def _convert_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> RunCreate:
|
||||||
"""Convert a run to a Run."""
|
"""Convert a run to a Run."""
|
||||||
session = self.session or self.load_default_session()
|
session = self.session or self.load_default_session()
|
||||||
inputs: Dict[str, Any] = {}
|
inputs: Dict[str, Any] = {}
|
||||||
@ -204,9 +223,9 @@ class LangChainTracerV2(LangChainTracer):
|
|||||||
*run.child_tool_runs,
|
*run.child_tool_runs,
|
||||||
]
|
]
|
||||||
|
|
||||||
return Run(
|
return RunCreate(
|
||||||
id=run.uuid,
|
id=run.uuid,
|
||||||
name=run.serialized.get("name", f"{run_type}-{run.uuid}"),
|
name=run.serialized.get("name"),
|
||||||
start_time=run.start_time,
|
start_time=run.start_time,
|
||||||
end_time=run.end_time,
|
end_time=run.end_time,
|
||||||
extra=run.extra or {},
|
extra=run.extra or {},
|
||||||
@ -217,7 +236,7 @@ class LangChainTracerV2(LangChainTracer):
|
|||||||
outputs=outputs,
|
outputs=outputs,
|
||||||
session_id=session.id,
|
session_id=session.id,
|
||||||
run_type=run_type,
|
run_type=run_type,
|
||||||
parent_run_id=run.parent_uuid,
|
reference_example_id=self.example_id,
|
||||||
child_runs=[self._convert_run(child) for child in child_runs],
|
child_runs=[self._convert_run(child) for child in child_runs],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -225,11 +244,11 @@ class LangChainTracerV2(LangChainTracer):
|
|||||||
"""Persist a run."""
|
"""Persist a run."""
|
||||||
run_create = self._convert_run(run)
|
run_create = self._convert_run(run)
|
||||||
try:
|
try:
|
||||||
result = requests.post(
|
response = requests.post(
|
||||||
f"{self._endpoint}/runs",
|
f"{self._endpoint}/runs",
|
||||||
data=run_create.json(),
|
data=run_create.json(),
|
||||||
headers=self._headers,
|
headers=self._headers,
|
||||||
)
|
)
|
||||||
result.raise_for_status()
|
raise_for_status_with_text(response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning(f"Failed to persist run: {e}")
|
logging.warning(f"Failed to persist run: {e}")
|
||||||
|
@ -37,9 +37,11 @@ class TracerSessionV2Base(TracerSessionBase):
|
|||||||
tenant_id: UUID
|
tenant_id: UUID
|
||||||
|
|
||||||
|
|
||||||
class TracerSessionV2Create(TracerSessionBase):
|
class TracerSessionV2Create(TracerSessionV2Base):
|
||||||
"""A creation class for TracerSessionV2."""
|
"""A creation class for TracerSessionV2."""
|
||||||
|
|
||||||
|
id: Optional[UUID]
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ -100,9 +102,10 @@ class RunTypeEnum(str, Enum):
|
|||||||
llm = "llm"
|
llm = "llm"
|
||||||
|
|
||||||
|
|
||||||
class Run(BaseModel):
|
class RunBase(BaseModel):
|
||||||
|
"""Base Run schema."""
|
||||||
|
|
||||||
id: Optional[UUID]
|
id: Optional[UUID]
|
||||||
name: str
|
|
||||||
start_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
|
start_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
|
||||||
end_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
|
end_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
|
||||||
extra: dict
|
extra: dict
|
||||||
@ -112,10 +115,22 @@ class Run(BaseModel):
|
|||||||
inputs: dict
|
inputs: dict
|
||||||
outputs: Optional[dict]
|
outputs: Optional[dict]
|
||||||
session_id: UUID
|
session_id: UUID
|
||||||
parent_run_id: Optional[UUID]
|
|
||||||
reference_example_id: Optional[UUID]
|
reference_example_id: Optional[UUID]
|
||||||
run_type: RunTypeEnum
|
run_type: RunTypeEnum
|
||||||
child_runs: List[Run] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
class RunCreate(RunBase):
|
||||||
|
"""Schema to create a run in the DB."""
|
||||||
|
|
||||||
|
name: Optional[str]
|
||||||
|
child_runs: List[RunCreate] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class Run(RunBase):
|
||||||
|
"""Run schema when loading from the DB."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
parent_run_id: Optional[UUID]
|
||||||
|
|
||||||
|
|
||||||
ChainRun.update_forward_refs()
|
ChainRun.update_forward_refs()
|
||||||
|
@ -2,6 +2,8 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Any, Callable, Dict, Optional, Tuple
|
from typing import Any, Callable, Dict, Optional, Tuple
|
||||||
|
|
||||||
|
from requests import HTTPError, Response
|
||||||
|
|
||||||
|
|
||||||
def get_from_dict_or_env(
|
def get_from_dict_or_env(
|
||||||
data: Dict[str, Any], key: str, env_key: str, default: Optional[str] = None
|
data: Dict[str, Any], key: str, env_key: str, default: Optional[str] = None
|
||||||
@ -52,6 +54,14 @@ def xor_args(*arg_groups: Tuple[str, ...]) -> Callable:
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def raise_for_status_with_text(response: Response) -> None:
|
||||||
|
"""Raise an error with the response text."""
|
||||||
|
try:
|
||||||
|
response.raise_for_status()
|
||||||
|
except HTTPError as e:
|
||||||
|
raise ValueError(response.text) from e
|
||||||
|
|
||||||
|
|
||||||
def stringify_value(val: Any) -> str:
|
def stringify_value(val: Any) -> str:
|
||||||
if isinstance(val, str):
|
if isinstance(val, str):
|
||||||
return val
|
return val
|
||||||
|
@ -18,7 +18,12 @@ from langchain.callbacks.tracers.base import (
|
|||||||
TracerSession,
|
TracerSession,
|
||||||
)
|
)
|
||||||
from langchain.callbacks.tracers.langchain import LangChainTracerV2
|
from langchain.callbacks.tracers.langchain import LangChainTracerV2
|
||||||
from langchain.callbacks.tracers.schemas import Run, TracerSessionBase, TracerSessionV2
|
from langchain.callbacks.tracers.schemas import (
|
||||||
|
RunCreate,
|
||||||
|
TracerSessionBase,
|
||||||
|
TracerSessionV2,
|
||||||
|
TracerSessionV2Create,
|
||||||
|
)
|
||||||
from langchain.schema import LLMResult
|
from langchain.schema import LLMResult
|
||||||
|
|
||||||
TEST_SESSION_ID = 2023
|
TEST_SESSION_ID = 2023
|
||||||
@ -541,14 +546,12 @@ def sample_runs() -> Tuple[LLMRun, ChainRun, ToolRun]:
|
|||||||
return llm_run, chain_run, tool_run
|
return llm_run, chain_run, tool_run
|
||||||
|
|
||||||
|
|
||||||
# Test _get_default_query_params method
|
|
||||||
def test_get_default_query_params(lang_chain_tracer_v2: LangChainTracerV2) -> None:
|
def test_get_default_query_params(lang_chain_tracer_v2: LangChainTracerV2) -> None:
|
||||||
expected = {"tenant_id": "test-tenant-id"}
|
expected = {"tenant_id": "test-tenant-id"}
|
||||||
result = lang_chain_tracer_v2._get_default_query_params()
|
result = lang_chain_tracer_v2._get_default_query_params()
|
||||||
assert result == expected
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
# Test load_session method
|
|
||||||
@patch("langchain.callbacks.tracers.langchain.requests.get")
|
@patch("langchain.callbacks.tracers.langchain.requests.get")
|
||||||
def test_load_session(
|
def test_load_session(
|
||||||
mock_requests_get: Mock,
|
mock_requests_get: Mock,
|
||||||
@ -577,23 +580,65 @@ def test_convert_run(
|
|||||||
converted_chain_run = lang_chain_tracer_v2._convert_run(chain_run)
|
converted_chain_run = lang_chain_tracer_v2._convert_run(chain_run)
|
||||||
converted_tool_run = lang_chain_tracer_v2._convert_run(tool_run)
|
converted_tool_run = lang_chain_tracer_v2._convert_run(tool_run)
|
||||||
|
|
||||||
assert isinstance(converted_llm_run, Run)
|
assert isinstance(converted_llm_run, RunCreate)
|
||||||
assert isinstance(converted_chain_run, Run)
|
assert isinstance(converted_chain_run, RunCreate)
|
||||||
assert isinstance(converted_tool_run, Run)
|
assert isinstance(converted_tool_run, RunCreate)
|
||||||
|
|
||||||
|
|
||||||
@patch("langchain.callbacks.tracers.langchain.requests.post")
|
|
||||||
def test_persist_run(
|
def test_persist_run(
|
||||||
mock_requests_post: Mock,
|
|
||||||
lang_chain_tracer_v2: LangChainTracerV2,
|
lang_chain_tracer_v2: LangChainTracerV2,
|
||||||
sample_tracer_session_v2: TracerSessionV2,
|
sample_tracer_session_v2: TracerSessionV2,
|
||||||
sample_runs: Tuple[LLMRun, ChainRun, ToolRun],
|
sample_runs: Tuple[LLMRun, ChainRun, ToolRun],
|
||||||
) -> None:
|
) -> None:
|
||||||
mock_requests_post.return_value.raise_for_status.return_value = None
|
"""Test that persist_run method calls requests.post once per method call."""
|
||||||
lang_chain_tracer_v2.session = sample_tracer_session_v2
|
with patch("langchain.callbacks.tracers.langchain.requests.post") as post, patch(
|
||||||
llm_run, chain_run, tool_run = sample_runs
|
"langchain.callbacks.tracers.langchain.requests.get"
|
||||||
lang_chain_tracer_v2._persist_run(llm_run)
|
) as get:
|
||||||
lang_chain_tracer_v2._persist_run(chain_run)
|
post.return_value.raise_for_status.return_value = None
|
||||||
lang_chain_tracer_v2._persist_run(tool_run)
|
lang_chain_tracer_v2.session = sample_tracer_session_v2
|
||||||
|
llm_run, chain_run, tool_run = sample_runs
|
||||||
|
lang_chain_tracer_v2._persist_run(llm_run)
|
||||||
|
lang_chain_tracer_v2._persist_run(chain_run)
|
||||||
|
lang_chain_tracer_v2._persist_run(tool_run)
|
||||||
|
|
||||||
assert mock_requests_post.call_count == 3
|
assert post.call_count == 3
|
||||||
|
assert get.call_count == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_session_create(lang_chain_tracer_v2: LangChainTracerV2) -> None:
|
||||||
|
"""Test creating the 'SessionCreate' object."""
|
||||||
|
lang_chain_tracer_v2.tenant_id = str(_TENANT_ID)
|
||||||
|
session_create = lang_chain_tracer_v2._get_session_create(name="test")
|
||||||
|
assert isinstance(session_create, TracerSessionV2Create)
|
||||||
|
assert session_create.name == "test"
|
||||||
|
assert session_create.tenant_id == _TENANT_ID
|
||||||
|
|
||||||
|
|
||||||
|
@patch("langchain.callbacks.tracers.langchain.requests.post")
|
||||||
|
def test_persist_session(
|
||||||
|
mock_requests_post: Mock,
|
||||||
|
lang_chain_tracer_v2: LangChainTracerV2,
|
||||||
|
sample_tracer_session_v2: TracerSessionV2,
|
||||||
|
) -> None:
|
||||||
|
"""Test persist_session returns a TracerSessionV2 with the updated ID."""
|
||||||
|
session_create = TracerSessionV2Create(**sample_tracer_session_v2.dict())
|
||||||
|
new_id = str(uuid4())
|
||||||
|
mock_requests_post.return_value.json.return_value = {"id": new_id}
|
||||||
|
result = lang_chain_tracer_v2._persist_session(session_create)
|
||||||
|
assert isinstance(result, TracerSessionV2)
|
||||||
|
res = sample_tracer_session_v2.dict()
|
||||||
|
res["id"] = UUID(new_id)
|
||||||
|
assert result.dict() == res
|
||||||
|
|
||||||
|
|
||||||
|
@patch("langchain.callbacks.tracers.langchain.LangChainTracerV2.load_session")
|
||||||
|
def test_load_default_session(
|
||||||
|
mock_load_session: Mock,
|
||||||
|
lang_chain_tracer_v2: LangChainTracerV2,
|
||||||
|
sample_tracer_session_v2: TracerSessionV2,
|
||||||
|
) -> None:
|
||||||
|
"""Test load_default_session attempts to load with the default name."""
|
||||||
|
mock_load_session.return_value = sample_tracer_session_v2
|
||||||
|
result = lang_chain_tracer_v2.load_default_session()
|
||||||
|
assert result == sample_tracer_session_v2
|
||||||
|
mock_load_session.assert_called_with("default")
|
||||||
|
Loading…
Reference in New Issue
Block a user