community[patch]: Improve import callbacks to make it IDE friendly (#20050)

* declares __all__ as a list of strings (instead of dynamically
computing it)
* import type definitions when TYPE_CHECKING is true
This commit is contained in:
Eugene Yurtsev 2024-04-05 15:17:51 -04:00 committed by GitHub
parent 5a76087965
commit 520ff50adc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 101 additions and 3 deletions

View File

@ -7,7 +7,78 @@
BaseCallbackHandler --> <name>CallbackHandler # Example: AimCallbackHandler
"""
import importlib
from typing import Any
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from langchain_community.callbacks.aim_callback import (
AimCallbackHandler, # noqa: F401
)
from langchain_community.callbacks.argilla_callback import (
ArgillaCallbackHandler, # noqa: F401
)
from langchain_community.callbacks.arize_callback import (
ArizeCallbackHandler, # noqa: F401
)
from langchain_community.callbacks.arthur_callback import (
ArthurCallbackHandler, # noqa: F401
)
from langchain_community.callbacks.clearml_callback import (
ClearMLCallbackHandler, # noqa: F401
)
from langchain_community.callbacks.comet_ml_callback import (
CometCallbackHandler, # noqa: F401
)
from langchain_community.callbacks.context_callback import (
ContextCallbackHandler, # noqa: F401
)
from langchain_community.callbacks.fiddler_callback import (
FiddlerCallbackHandler, # noqa: F401
)
from langchain_community.callbacks.flyte_callback import (
FlyteCallbackHandler, # noqa: F401
)
from langchain_community.callbacks.human import (
HumanApprovalCallbackHandler, # noqa: F401
)
from langchain_community.callbacks.infino_callback import (
InfinoCallbackHandler, # noqa: F401
)
from langchain_community.callbacks.labelstudio_callback import (
LabelStudioCallbackHandler, # noqa: F401
)
from langchain_community.callbacks.llmonitor_callback import (
LLMonitorCallbackHandler, # noqa: F401
)
from langchain_community.callbacks.manager import ( # noqa: F401
get_openai_callback,
wandb_tracing_enabled,
)
from langchain_community.callbacks.mlflow_callback import (
MlflowCallbackHandler, # noqa: F401
)
from langchain_community.callbacks.openai_info import (
OpenAICallbackHandler, # noqa: F401
)
from langchain_community.callbacks.promptlayer_callback import (
PromptLayerCallbackHandler, # noqa: F401
)
from langchain_community.callbacks.sagemaker_callback import (
SageMakerCallbackHandler, # noqa: F401
)
from langchain_community.callbacks.streamlit import ( # noqa: F401
LLMThoughtLabeler,
StreamlitCallbackHandler,
)
from langchain_community.callbacks.trubrics_callback import (
TrubricsCallbackHandler, # noqa: F401
)
from langchain_community.callbacks.wandb_callback import (
WandbCallbackHandler, # noqa: F401
)
from langchain_community.callbacks.whylabs_callback import (
WhyLabsCallbackHandler, # noqa: F401
)
_module_lookup = {
"AimCallbackHandler": "langchain_community.callbacks.aim_callback",
@ -44,4 +115,29 @@ def __getattr__(name: str) -> Any:
raise AttributeError(f"module {__name__} has no attribute {name}")
__all__ = list(_module_lookup.keys())
__all__ = [
"AimCallbackHandler",
"ArgillaCallbackHandler",
"ArizeCallbackHandler",
"ArthurCallbackHandler",
"ClearMLCallbackHandler",
"CometCallbackHandler",
"ContextCallbackHandler",
"FiddlerCallbackHandler",
"FlyteCallbackHandler",
"HumanApprovalCallbackHandler",
"InfinoCallbackHandler",
"LLMThoughtLabeler",
"LLMonitorCallbackHandler",
"LabelStudioCallbackHandler",
"MlflowCallbackHandler",
"OpenAICallbackHandler",
"PromptLayerCallbackHandler",
"SageMakerCallbackHandler",
"StreamlitCallbackHandler",
"TrubricsCallbackHandler",
"WandbCallbackHandler",
"WhyLabsCallbackHandler",
"get_openai_callback",
"wandb_tracing_enabled",
]

View File

@ -1,4 +1,4 @@
from langchain_community.callbacks import __all__
from langchain_community.callbacks import __all__, _module_lookup
EXPECTED_ALL = [
"AimCallbackHandler",
@ -29,4 +29,6 @@ EXPECTED_ALL = [
def test_all_imports() -> None:
"""Test that __all__ is correctly set."""
assert set(__all__) == set(EXPECTED_ALL)
assert set(__all__) == set(_module_lookup.keys())