From a208abe6b7d9ca262e94bb7ebff059471bd39634 Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Tue, 21 Nov 2023 15:28:49 -0800 Subject: [PATCH] add callback import test (#13689) --- .../tests/unit_tests/callbacks/test_base.py | 18 ++++++++++ .../unit_tests/callbacks/test_manager.py | 34 +++++++++++++++++++ 2 files changed, 52 insertions(+) create mode 100644 libs/langchain/tests/unit_tests/callbacks/test_base.py create mode 100644 libs/langchain/tests/unit_tests/callbacks/test_manager.py diff --git a/libs/langchain/tests/unit_tests/callbacks/test_base.py b/libs/langchain/tests/unit_tests/callbacks/test_base.py new file mode 100644 index 00000000000..62760e032df --- /dev/null +++ b/libs/langchain/tests/unit_tests/callbacks/test_base.py @@ -0,0 +1,18 @@ +from langchain.callbacks.base import __all__ + +EXPECTED_ALL = [ + "RetrieverManagerMixin", + "LLMManagerMixin", + "ChainManagerMixin", + "ToolManagerMixin", + "CallbackManagerMixin", + "RunManagerMixin", + "BaseCallbackHandler", + "AsyncCallbackHandler", + "BaseCallbackManager", + "Callbacks", +] + + +def test_all_imports() -> None: + assert set(__all__) == set(EXPECTED_ALL) diff --git a/libs/langchain/tests/unit_tests/callbacks/test_manager.py b/libs/langchain/tests/unit_tests/callbacks/test_manager.py new file mode 100644 index 00000000000..3ddfb4f0c7b --- /dev/null +++ b/libs/langchain/tests/unit_tests/callbacks/test_manager.py @@ -0,0 +1,34 @@ +from langchain.callbacks.manager import __all__ + +EXPECTED_ALL = [ + "BaseRunManager", + "RunManager", + "ParentRunManager", + "AsyncRunManager", + "AsyncParentRunManager", + "CallbackManagerForLLMRun", + "AsyncCallbackManagerForLLMRun", + "CallbackManagerForChainRun", + "AsyncCallbackManagerForChainRun", + "CallbackManagerForToolRun", + "AsyncCallbackManagerForToolRun", + "CallbackManagerForRetrieverRun", + "AsyncCallbackManagerForRetrieverRun", + "CallbackManager", + "CallbackManagerForChainGroup", + "AsyncCallbackManager", + "AsyncCallbackManagerForChainGroup", + "tracing_enabled", + "tracing_v2_enabled", + "collect_runs", + "atrace_as_chain_group", + "trace_as_chain_group", + "handle_event", + "ahandle_event", + "env_var_is_set", + "Callbacks", +] + + +def test_all_imports() -> None: + assert set(__all__) == set(EXPECTED_ALL)