From 2914abd747caa4817cbeaab80ab7c581649a967f Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Wed, 1 May 2024 15:56:11 -0400 Subject: [PATCH] langchain[patch]: Fix how the serializable test identifies serializable objects (#21165) dir() will not work if we're using optional imports. The only way to do this is by using contents of __all__ --- .../unit_tests/load/test_serializable.py | 83 ++++++++++++++----- 1 file changed, 60 insertions(+), 23 deletions(-) diff --git a/libs/langchain/tests/unit_tests/load/test_serializable.py b/libs/langchain/tests/unit_tests/load/test_serializable.py index a8a35a4fe7c..ea21e709acc 100644 --- a/libs/langchain/tests/unit_tests/load/test_serializable.py +++ b/libs/langchain/tests/unit_tests/load/test_serializable.py @@ -1,5 +1,7 @@ import importlib +import inspect import pkgutil +from types import ModuleType from langchain_core.load.mapping import SERIALIZABLE_MAPPING @@ -8,37 +10,71 @@ def import_all_modules(package_name: str) -> dict: package = importlib.import_module(package_name) classes: dict = {} - for attribute_name in dir(package): - attribute = getattr(package, attribute_name) - if hasattr(attribute, "is_lc_serializable") and isinstance(attribute, type): + def _handle_module(module: ModuleType) -> None: + # Iterate over all members of the module + + names = dir(module) + + if hasattr(module, "__all__"): + names += list(module.__all__) + + names = sorted(set(names)) + + for name in names: + # Check if it's a class or function + attr = getattr(module, name) + + if not inspect.isclass(attr): + continue + + if not hasattr(attr, "is_lc_serializable") or not isinstance(attr, type): + continue + if ( - isinstance(attribute.is_lc_serializable(), bool) # type: ignore - and attribute.is_lc_serializable() # type: ignore + isinstance(attr.is_lc_serializable(), bool) # type: ignore + and attr.is_lc_serializable() # type: ignore ): - key = tuple(attribute.lc_id()) # type: ignore - value = tuple(attribute.__module__.split(".") + [attribute.__name__]) + key = tuple(attr.lc_id()) # type: ignore + value = tuple(attr.__module__.split(".") + [attr.__name__]) if key in classes and classes[key] != value: raise ValueError classes[key] = value - if hasattr(package, "__path__"): - for loader, module_name, is_pkg in pkgutil.walk_packages( - package.__path__, package_name + "." - ): - if module_name not in ( - "langchain.chains.llm_bash", - "langchain.chains.llm_symbolic_math", - "langchain_community.tools.python", - "langchain_community.vectorstores._pgvector_data_models", - ): - importlib.import_module(module_name) - new_classes = import_all_modules(module_name) - for k, v in new_classes.items(): - if k in classes and classes[k] != v: - raise ValueError - classes[k] = v + + _handle_module(package) + + for importer, modname, ispkg in pkgutil.walk_packages( + package.__path__, package.__name__ + "." + ): + try: + module = importlib.import_module(modname) + except ModuleNotFoundError: + continue + _handle_module(module) + return classes +def test_import_all_modules() -> None: + """Test import all modules works as expected""" + all_modules = import_all_modules("langchain") + filtered_modules = [ + k + for k in all_modules + if len(k) == 4 and tuple(k[:2]) == ("langchain", "chat_models") + ] + # This test will need to be updated if new serializable classes are added + # to community + assert filtered_modules == [ + ("langchain", "chat_models", "azure_openai", "AzureChatOpenAI"), + ("langchain", "chat_models", "bedrock", "BedrockChat"), + ("langchain", "chat_models", "anthropic", "ChatAnthropic"), + ("langchain", "chat_models", "fireworks", "ChatFireworks"), + ("langchain", "chat_models", "google_palm", "ChatGooglePalm"), + ("langchain", "chat_models", "openai", "ChatOpenAI"), + ("langchain", "chat_models", "vertexai", "ChatVertexAI"), + ] + + def test_serializable_mapping() -> None: to_skip = { # This should have had a different namespace, as it was never @@ -59,6 +95,7 @@ def test_serializable_mapping() -> None: ), } serializable_modules = import_all_modules("langchain") + missing = set(SERIALIZABLE_MAPPING).difference( set(serializable_modules).union(to_skip) )