[Core] Unified Enable/Disable Tracing (#22576)

This commit is contained in:
William FH
2024-06-06 16:54:35 -07:00
committed by GitHub
parent 57c1239643
commit be79ce9336
4 changed files with 183 additions and 169 deletions

View File

@@ -1,10 +1,10 @@
import json
import sys
import time
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch
import pytest
from langsmith import Client, traceable
from langsmith.run_helpers import tracing_context
from langchain_core.runnables.base import RunnableLambda
from langchain_core.tracers.langchain import LangChainTracer
@@ -20,14 +20,20 @@ def _get_posts(client: Client) -> list:
assert call.args[0] == "POST"
assert call.args[1].startswith("https://api.smith.langchain.com")
body = json.loads(call.kwargs["data"])
assert body["post"]
posts.extend(body["post"])
if "post" in body:
# Batch request
assert body["post"]
posts.extend(body["post"])
else:
posts.append(body)
return posts
def test_config_traceable_handoff() -> None:
mock_session = MagicMock()
mock_client_ = Client(session=mock_session, api_key="test")
mock_client_ = Client(
session=mock_session, api_key="test", auto_batch_tracing=False
)
tracer = LangChainTracer(client=mock_client_)
@traceable
@@ -56,11 +62,7 @@ def test_config_traceable_handoff() -> None:
my_parent_runnable = RunnableLambda(my_parent_function)
assert my_parent_runnable.invoke(1, {"callbacks": [tracer]}) == 6
for _ in range(15):
time.sleep(0.1)
posts = _get_posts(mock_client_)
if len(posts) == 6:
break
posts = _get_posts(mock_client_)
# There should have been 6 runs created,
# one for each function invocation
assert len(posts) == 6
@@ -101,7 +103,9 @@ def test_config_traceable_handoff() -> None:
)
async def test_config_traceable_async_handoff() -> None:
mock_session = MagicMock()
mock_client_ = Client(session=mock_session, api_key="test")
mock_client_ = Client(
session=mock_session, api_key="test", auto_batch_tracing=False
)
tracer = LangChainTracer(client=mock_client_)
@traceable
@@ -130,11 +134,7 @@ async def test_config_traceable_async_handoff() -> None:
my_parent_runnable = RunnableLambda(my_parent_function) # type: ignore
result = await my_parent_runnable.ainvoke(1, {"callbacks": [tracer]})
assert result == 6
for _ in range(15):
time.sleep(0.1)
posts = _get_posts(mock_client_)
if len(posts) == 6:
break
posts = _get_posts(mock_client_)
# There should have been 6 runs created,
# one for each function invocation
assert len(posts) == 6
@@ -168,3 +168,34 @@ async def test_config_traceable_async_handoff() -> None:
)
last_dotted_order = dotted_order
parent_run_id = id_
@patch("langchain_core.tracers.langchain.get_client")
@pytest.mark.parametrize("enabled", [None, True, False])
@pytest.mark.parametrize("env", ["", "true"])
def test_tracing_enable_disable(
mock_get_client: MagicMock, enabled: bool, env: str
) -> None:
mock_session = MagicMock()
mock_client_ = Client(
session=mock_session, api_key="test", auto_batch_tracing=False
)
mock_get_client.return_value = mock_client_
def my_func(a: int) -> int:
return a + 1
env_on = env == "true"
with patch.dict("os.environ", {"LANGSMITH_TRACING": env}):
with tracing_context(enabled=enabled):
RunnableLambda(my_func).invoke(1)
mock_posts = _get_posts(mock_client_)
if enabled is True:
assert len(mock_posts) == 1
elif enabled is False:
assert not mock_posts
elif env_on:
assert len(mock_posts) == 1
else:
assert not mock_posts