[Tracer] add project name to run from tracer (#26736)

This commit is contained in:
William FH 2024-09-20 16:48:37 -07:00 committed by GitHub
parent 2d21274bf6
commit 864020e592
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 25 additions and 2 deletions

View File

@ -110,6 +110,14 @@ class LangChainTracer(BaseTracer):
self.latest_run: Optional[Run] = None
def _start_trace(self, run: Run) -> None:
if self.project_name:
run.session_name = self.project_name
if self.tags is not None:
if run.tags:
run.tags = sorted(set(run.tags + self.tags))
else:
run.tags = self.tags.copy()
super()._start_trace(run)
if run._client is None:
run._client = self.client

View File

@ -6,7 +6,7 @@ from typing import Any
from unittest.mock import MagicMock, patch
import pytest
from langsmith import Client, traceable
from langsmith import Client, get_current_run_tree, traceable
from langsmith.run_helpers import tracing_context
from langsmith.run_trees import RunTree
from langsmith.utils import get_env_var
@ -40,10 +40,15 @@ def test_config_traceable_handoff() -> None:
mock_client_ = Client(
session=mock_session, api_key="test", auto_batch_tracing=False
)
tracer = LangChainTracer(client=mock_client_)
tracer = LangChainTracer(
client=mock_client_, project_name="another-flippin-project", tags=["such-a-tag"]
)
@traceable
def my_great_great_grandchild_function(a: int) -> int:
rt = get_current_run_tree()
assert rt
assert rt.session_name == "another-flippin-project"
return a + 1
@RunnableLambda
@ -60,19 +65,28 @@ def test_config_traceable_handoff() -> None:
@traceable()
def my_function(a: int) -> int:
rt = get_current_run_tree()
assert rt
assert rt.session_name == "another-flippin-project"
assert rt.parent_run and rt.parent_run.name == "my_parent_function"
return my_child_function(a)
def my_parent_function(a: int) -> int:
rt = get_current_run_tree()
assert rt
assert rt.session_name == "another-flippin-project"
return my_function(a)
my_parent_runnable = RunnableLambda(my_parent_function)
assert my_parent_runnable.invoke(1, {"callbacks": [tracer]}) == 6
posts = _get_posts(mock_client_)
assert all(post["session_name"] == "another-flippin-project" for post in posts)
# There should have been 6 runs created,
# one for each function invocation
assert len(posts) == 6
name_to_body = {post["name"]: post for post in posts}
ordered_names = [
"my_parent_function",
"my_function",
@ -102,6 +116,7 @@ def test_config_traceable_handoff() -> None:
)
last_dotted_order = dotted_order
parent_run_id = id_
assert "such-a-tag" in name_to_body["my_parent_function"]["tags"]
@pytest.mark.skipif(