core: Use parametrized test in test_correct_get_tracer_project (#31513)

This commit is contained in:
Christophe Bornet 2025-06-24 00:55:57 +02:00 committed by GitHub
parent 8a0782c46c
commit c7e82ad95d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -94,60 +94,54 @@ def test_log_lock() -> None:
tracer.wait_for_futures() tracer.wait_for_futures()
class LangChainProjectNameTest(unittest.TestCase): @pytest.mark.parametrize(
"""Test that the project name is set correctly for runs.""" ("envvars", "expected_project_name"),
[
(
{},
"default",
),
(
{"LANGCHAIN_SESSION": "old_timey_session"},
"old_timey_session",
),
(
{
"LANGCHAIN_SESSION": "old_timey_session",
"LANGCHAIN_PROJECT": "modern_session",
},
"modern_session",
),
],
ids=[
"default to 'default' when no project provided",
"use session_name for legacy tracers",
"use LANGCHAIN_PROJECT over SESSION_NAME",
],
)
def test_correct_get_tracer_project(
envvars: dict[str, str], expected_project_name: str
) -> None:
get_env_var.cache_clear()
get_tracer_project.cache_clear()
with pytest.MonkeyPatch.context() as mp:
for k, v in envvars.items():
mp.setenv(k, v)
class SetProperTracerProjectTestCase: client = unittest.mock.MagicMock(spec=Client)
def __init__( tracer = LangChainTracer(client=client)
self, test_name: str, envvars: dict[str, str], expected_project_name: str projects = []
):
self.test_name = test_name
self.envvars = envvars
self.expected_project_name = expected_project_name
def test_correct_get_tracer_project(self) -> None: def mock_create_run(**kwargs: Any) -> Any:
cases = [ projects.append(kwargs.get("session_name"))
self.SetProperTracerProjectTestCase( return unittest.mock.MagicMock()
test_name="default to 'default' when no project provided",
envvars={},
expected_project_name="default",
),
self.SetProperTracerProjectTestCase(
test_name="use session_name for legacy tracers",
envvars={"LANGCHAIN_SESSION": "old_timey_session"},
expected_project_name="old_timey_session",
),
self.SetProperTracerProjectTestCase(
test_name="use LANGCHAIN_PROJECT over SESSION_NAME",
envvars={
"LANGCHAIN_SESSION": "old_timey_session",
"LANGCHAIN_PROJECT": "modern_session",
},
expected_project_name="modern_session",
),
]
for case in cases: client.create_run = mock_create_run
get_env_var.cache_clear()
get_tracer_project.cache_clear()
with self.subTest(msg=case.test_name), pytest.MonkeyPatch.context() as mp:
for k, v in case.envvars.items():
mp.setenv(k, v)
client = unittest.mock.MagicMock(spec=Client) tracer.on_llm_start(
tracer = LangChainTracer(client=client) {"name": "example_1"},
projects = [] ["foo"],
run_id=UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a"),
def mock_create_run(**kwargs: Any) -> Any: )
projects.append(kwargs.get("session_name")) # noqa: B023 tracer.wait_for_futures()
return unittest.mock.MagicMock() assert projects == [expected_project_name]
client.create_run = mock_create_run
tracer.on_llm_start(
{"name": "example_1"},
["foo"],
run_id=UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a"),
)
tracer.wait_for_futures()
assert projects == [case.expected_project_name]