From 520ff50adca9143c80ee0f04b007871ab746cee3 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Fri, 5 Apr 2024 15:17:51 -0400 Subject: [PATCH] community[patch]: Improve import callbacks to make it IDE friendly (#20050) * declares __all__ as a list of strings (instead of dynamically computing it) * import type definitions when TYPE_CHECKING is true --- .../langchain_community/callbacks/__init__.py | 100 +++++++++++++++++- .../unit_tests/callbacks/test_imports.py | 4 +- 2 files changed, 101 insertions(+), 3 deletions(-) diff --git a/libs/community/langchain_community/callbacks/__init__.py b/libs/community/langchain_community/callbacks/__init__.py index f4cc9fc2037..adc0c9750f2 100644 --- a/libs/community/langchain_community/callbacks/__init__.py +++ b/libs/community/langchain_community/callbacks/__init__.py @@ -7,7 +7,78 @@ BaseCallbackHandler --> CallbackHandler # Example: AimCallbackHandler """ import importlib -from typing import Any +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from langchain_community.callbacks.aim_callback import ( + AimCallbackHandler, # noqa: F401 + ) + from langchain_community.callbacks.argilla_callback import ( + ArgillaCallbackHandler, # noqa: F401 + ) + from langchain_community.callbacks.arize_callback import ( + ArizeCallbackHandler, # noqa: F401 + ) + from langchain_community.callbacks.arthur_callback import ( + ArthurCallbackHandler, # noqa: F401 + ) + from langchain_community.callbacks.clearml_callback import ( + ClearMLCallbackHandler, # noqa: F401 + ) + from langchain_community.callbacks.comet_ml_callback import ( + CometCallbackHandler, # noqa: F401 + ) + from langchain_community.callbacks.context_callback import ( + ContextCallbackHandler, # noqa: F401 + ) + from langchain_community.callbacks.fiddler_callback import ( + FiddlerCallbackHandler, # noqa: F401 + ) + from langchain_community.callbacks.flyte_callback import ( + FlyteCallbackHandler, # noqa: F401 + ) + from langchain_community.callbacks.human import ( + HumanApprovalCallbackHandler, # noqa: F401 + ) + from langchain_community.callbacks.infino_callback import ( + InfinoCallbackHandler, # noqa: F401 + ) + from langchain_community.callbacks.labelstudio_callback import ( + LabelStudioCallbackHandler, # noqa: F401 + ) + from langchain_community.callbacks.llmonitor_callback import ( + LLMonitorCallbackHandler, # noqa: F401 + ) + from langchain_community.callbacks.manager import ( # noqa: F401 + get_openai_callback, + wandb_tracing_enabled, + ) + from langchain_community.callbacks.mlflow_callback import ( + MlflowCallbackHandler, # noqa: F401 + ) + from langchain_community.callbacks.openai_info import ( + OpenAICallbackHandler, # noqa: F401 + ) + from langchain_community.callbacks.promptlayer_callback import ( + PromptLayerCallbackHandler, # noqa: F401 + ) + from langchain_community.callbacks.sagemaker_callback import ( + SageMakerCallbackHandler, # noqa: F401 + ) + from langchain_community.callbacks.streamlit import ( # noqa: F401 + LLMThoughtLabeler, + StreamlitCallbackHandler, + ) + from langchain_community.callbacks.trubrics_callback import ( + TrubricsCallbackHandler, # noqa: F401 + ) + from langchain_community.callbacks.wandb_callback import ( + WandbCallbackHandler, # noqa: F401 + ) + from langchain_community.callbacks.whylabs_callback import ( + WhyLabsCallbackHandler, # noqa: F401 + ) + _module_lookup = { "AimCallbackHandler": "langchain_community.callbacks.aim_callback", @@ -44,4 +115,29 @@ def __getattr__(name: str) -> Any: raise AttributeError(f"module {__name__} has no attribute {name}") -__all__ = list(_module_lookup.keys()) +__all__ = [ + "AimCallbackHandler", + "ArgillaCallbackHandler", + "ArizeCallbackHandler", + "ArthurCallbackHandler", + "ClearMLCallbackHandler", + "CometCallbackHandler", + "ContextCallbackHandler", + "FiddlerCallbackHandler", + "FlyteCallbackHandler", + "HumanApprovalCallbackHandler", + "InfinoCallbackHandler", + "LLMThoughtLabeler", + "LLMonitorCallbackHandler", + "LabelStudioCallbackHandler", + "MlflowCallbackHandler", + "OpenAICallbackHandler", + "PromptLayerCallbackHandler", + "SageMakerCallbackHandler", + "StreamlitCallbackHandler", + "TrubricsCallbackHandler", + "WandbCallbackHandler", + "WhyLabsCallbackHandler", + "get_openai_callback", + "wandb_tracing_enabled", +] diff --git a/libs/community/tests/unit_tests/callbacks/test_imports.py b/libs/community/tests/unit_tests/callbacks/test_imports.py index 9b2a11a80a2..648e198c2c9 100644 --- a/libs/community/tests/unit_tests/callbacks/test_imports.py +++ b/libs/community/tests/unit_tests/callbacks/test_imports.py @@ -1,4 +1,4 @@ -from langchain_community.callbacks import __all__ +from langchain_community.callbacks import __all__, _module_lookup EXPECTED_ALL = [ "AimCallbackHandler", @@ -29,4 +29,6 @@ EXPECTED_ALL = [ def test_all_imports() -> None: + """Test that __all__ is correctly set.""" assert set(__all__) == set(EXPECTED_ALL) + assert set(__all__) == set(_module_lookup.keys())