core: Use parametrized test in test_correct_get_tracer_project (#31513)

This commit is contained in:
Christophe Bornet 2025-06-24 00:55:57 +02:00 committed by GitHub
parent 8a0782c46c
commit c7e82ad95d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -94,44 +94,38 @@ def test_log_lock() -> None:
tracer.wait_for_futures() tracer.wait_for_futures()
class LangChainProjectNameTest(unittest.TestCase): @pytest.mark.parametrize(
"""Test that the project name is set correctly for runs.""" ("envvars", "expected_project_name"),
[
class SetProperTracerProjectTestCase: (
def __init__( {},
self, test_name: str, envvars: dict[str, str], expected_project_name: str "default",
):
self.test_name = test_name
self.envvars = envvars
self.expected_project_name = expected_project_name
def test_correct_get_tracer_project(self) -> None:
cases = [
self.SetProperTracerProjectTestCase(
test_name="default to 'default' when no project provided",
envvars={},
expected_project_name="default",
), ),
self.SetProperTracerProjectTestCase( (
test_name="use session_name for legacy tracers", {"LANGCHAIN_SESSION": "old_timey_session"},
envvars={"LANGCHAIN_SESSION": "old_timey_session"}, "old_timey_session",
expected_project_name="old_timey_session",
), ),
self.SetProperTracerProjectTestCase( (
test_name="use LANGCHAIN_PROJECT over SESSION_NAME", {
envvars={
"LANGCHAIN_SESSION": "old_timey_session", "LANGCHAIN_SESSION": "old_timey_session",
"LANGCHAIN_PROJECT": "modern_session", "LANGCHAIN_PROJECT": "modern_session",
}, },
expected_project_name="modern_session", "modern_session",
), ),
] ],
ids=[
for case in cases: "default to 'default' when no project provided",
"use session_name for legacy tracers",
"use LANGCHAIN_PROJECT over SESSION_NAME",
],
)
def test_correct_get_tracer_project(
envvars: dict[str, str], expected_project_name: str
) -> None:
get_env_var.cache_clear() get_env_var.cache_clear()
get_tracer_project.cache_clear() get_tracer_project.cache_clear()
with self.subTest(msg=case.test_name), pytest.MonkeyPatch.context() as mp: with pytest.MonkeyPatch.context() as mp:
for k, v in case.envvars.items(): for k, v in envvars.items():
mp.setenv(k, v) mp.setenv(k, v)
client = unittest.mock.MagicMock(spec=Client) client = unittest.mock.MagicMock(spec=Client)
@ -139,7 +133,7 @@ class LangChainProjectNameTest(unittest.TestCase):
projects = [] projects = []
def mock_create_run(**kwargs: Any) -> Any: def mock_create_run(**kwargs: Any) -> Any:
projects.append(kwargs.get("session_name")) # noqa: B023 projects.append(kwargs.get("session_name"))
return unittest.mock.MagicMock() return unittest.mock.MagicMock()
client.create_run = mock_create_run client.create_run = mock_create_run
@ -150,4 +144,4 @@ class LangChainProjectNameTest(unittest.TestCase):
run_id=UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a"), run_id=UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a"),
) )
tracer.wait_for_futures() tracer.wait_for_futures()
assert projects == [case.expected_project_name] assert projects == [expected_project_name]