core: Use parametrized test in test_correct_get_tracer_project (#31513)

This commit is contained in:
Christophe Bornet 2025-06-24 00:55:57 +02:00 committed by GitHub
parent 8a0782c46c
commit c7e82ad95d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -94,60 +94,54 @@ def test_log_lock() -> None:
tracer.wait_for_futures()
class LangChainProjectNameTest(unittest.TestCase):
"""Test that the project name is set correctly for runs."""
@pytest.mark.parametrize(
("envvars", "expected_project_name"),
[
(
{},
"default",
),
(
{"LANGCHAIN_SESSION": "old_timey_session"},
"old_timey_session",
),
(
{
"LANGCHAIN_SESSION": "old_timey_session",
"LANGCHAIN_PROJECT": "modern_session",
},
"modern_session",
),
],
ids=[
"default to 'default' when no project provided",
"use session_name for legacy tracers",
"use LANGCHAIN_PROJECT over SESSION_NAME",
],
)
def test_correct_get_tracer_project(
envvars: dict[str, str], expected_project_name: str
) -> None:
get_env_var.cache_clear()
get_tracer_project.cache_clear()
with pytest.MonkeyPatch.context() as mp:
for k, v in envvars.items():
mp.setenv(k, v)
class SetProperTracerProjectTestCase:
def __init__(
self, test_name: str, envvars: dict[str, str], expected_project_name: str
):
self.test_name = test_name
self.envvars = envvars
self.expected_project_name = expected_project_name
client = unittest.mock.MagicMock(spec=Client)
tracer = LangChainTracer(client=client)
projects = []
def test_correct_get_tracer_project(self) -> None:
cases = [
self.SetProperTracerProjectTestCase(
test_name="default to 'default' when no project provided",
envvars={},
expected_project_name="default",
),
self.SetProperTracerProjectTestCase(
test_name="use session_name for legacy tracers",
envvars={"LANGCHAIN_SESSION": "old_timey_session"},
expected_project_name="old_timey_session",
),
self.SetProperTracerProjectTestCase(
test_name="use LANGCHAIN_PROJECT over SESSION_NAME",
envvars={
"LANGCHAIN_SESSION": "old_timey_session",
"LANGCHAIN_PROJECT": "modern_session",
},
expected_project_name="modern_session",
),
]
def mock_create_run(**kwargs: Any) -> Any:
projects.append(kwargs.get("session_name"))
return unittest.mock.MagicMock()
for case in cases:
get_env_var.cache_clear()
get_tracer_project.cache_clear()
with self.subTest(msg=case.test_name), pytest.MonkeyPatch.context() as mp:
for k, v in case.envvars.items():
mp.setenv(k, v)
client.create_run = mock_create_run
client = unittest.mock.MagicMock(spec=Client)
tracer = LangChainTracer(client=client)
projects = []
def mock_create_run(**kwargs: Any) -> Any:
projects.append(kwargs.get("session_name")) # noqa: B023
return unittest.mock.MagicMock()
client.create_run = mock_create_run
tracer.on_llm_start(
{"name": "example_1"},
["foo"],
run_id=UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a"),
)
tracer.wait_for_futures()
assert projects == [case.expected_project_name]
tracer.on_llm_start(
{"name": "example_1"},
["foo"],
run_id=UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a"),
)
tracer.wait_for_futures()
assert projects == [expected_project_name]