mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-09 13:00:34 +00:00
core: Use parametrized test in test_correct_get_tracer_project (#31513)
This commit is contained in:
parent
8a0782c46c
commit
c7e82ad95d
@ -94,44 +94,38 @@ def test_log_lock() -> None:
|
||||
tracer.wait_for_futures()
|
||||
|
||||
|
||||
class LangChainProjectNameTest(unittest.TestCase):
|
||||
"""Test that the project name is set correctly for runs."""
|
||||
|
||||
class SetProperTracerProjectTestCase:
|
||||
def __init__(
|
||||
self, test_name: str, envvars: dict[str, str], expected_project_name: str
|
||||
):
|
||||
self.test_name = test_name
|
||||
self.envvars = envvars
|
||||
self.expected_project_name = expected_project_name
|
||||
|
||||
def test_correct_get_tracer_project(self) -> None:
|
||||
cases = [
|
||||
self.SetProperTracerProjectTestCase(
|
||||
test_name="default to 'default' when no project provided",
|
||||
envvars={},
|
||||
expected_project_name="default",
|
||||
@pytest.mark.parametrize(
|
||||
("envvars", "expected_project_name"),
|
||||
[
|
||||
(
|
||||
{},
|
||||
"default",
|
||||
),
|
||||
self.SetProperTracerProjectTestCase(
|
||||
test_name="use session_name for legacy tracers",
|
||||
envvars={"LANGCHAIN_SESSION": "old_timey_session"},
|
||||
expected_project_name="old_timey_session",
|
||||
(
|
||||
{"LANGCHAIN_SESSION": "old_timey_session"},
|
||||
"old_timey_session",
|
||||
),
|
||||
self.SetProperTracerProjectTestCase(
|
||||
test_name="use LANGCHAIN_PROJECT over SESSION_NAME",
|
||||
envvars={
|
||||
(
|
||||
{
|
||||
"LANGCHAIN_SESSION": "old_timey_session",
|
||||
"LANGCHAIN_PROJECT": "modern_session",
|
||||
},
|
||||
expected_project_name="modern_session",
|
||||
"modern_session",
|
||||
),
|
||||
]
|
||||
|
||||
for case in cases:
|
||||
],
|
||||
ids=[
|
||||
"default to 'default' when no project provided",
|
||||
"use session_name for legacy tracers",
|
||||
"use LANGCHAIN_PROJECT over SESSION_NAME",
|
||||
],
|
||||
)
|
||||
def test_correct_get_tracer_project(
|
||||
envvars: dict[str, str], expected_project_name: str
|
||||
) -> None:
|
||||
get_env_var.cache_clear()
|
||||
get_tracer_project.cache_clear()
|
||||
with self.subTest(msg=case.test_name), pytest.MonkeyPatch.context() as mp:
|
||||
for k, v in case.envvars.items():
|
||||
with pytest.MonkeyPatch.context() as mp:
|
||||
for k, v in envvars.items():
|
||||
mp.setenv(k, v)
|
||||
|
||||
client = unittest.mock.MagicMock(spec=Client)
|
||||
@ -139,7 +133,7 @@ class LangChainProjectNameTest(unittest.TestCase):
|
||||
projects = []
|
||||
|
||||
def mock_create_run(**kwargs: Any) -> Any:
|
||||
projects.append(kwargs.get("session_name")) # noqa: B023
|
||||
projects.append(kwargs.get("session_name"))
|
||||
return unittest.mock.MagicMock()
|
||||
|
||||
client.create_run = mock_create_run
|
||||
@ -150,4 +144,4 @@ class LangChainProjectNameTest(unittest.TestCase):
|
||||
run_id=UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a"),
|
||||
)
|
||||
tracer.wait_for_futures()
|
||||
assert projects == [case.expected_project_name]
|
||||
assert projects == [expected_project_name]
|
||||
|
Loading…
Reference in New Issue
Block a user