diff --git a/libs/community/langchain_community/chat_models/__init__.py b/libs/community/langchain_community/chat_models/__init__.py index bfd52359962..ba9ee7c108d 100644 --- a/libs/community/langchain_community/chat_models/__init__.py +++ b/libs/community/langchain_community/chat_models/__init__.py @@ -120,6 +120,7 @@ if TYPE_CHECKING: from langchain_community.chat_models.mlx import ( ChatMLX, ) + from langchain_community.chat_models.octoai import ChatOctoAI from langchain_community.chat_models.ollama import ( ChatOllama, ) @@ -171,6 +172,7 @@ __all__ = [ "ChatBaichuan", "ChatCohere", "ChatCoze", + "ChatOctoAI", "ChatDatabricks", "ChatDeepInfra", "ChatEverlyAI", @@ -271,6 +273,3 @@ def __getattr__(name: str) -> Any: module = importlib.import_module(_module_lookup[name]) return getattr(module, name) raise AttributeError(f"module {__name__} has no attribute {name}") - - -__all__ = list(_module_lookup.keys()) diff --git a/libs/community/langchain_community/graphs/__init__.py b/libs/community/langchain_community/graphs/__init__.py index fd9fec8ef41..37bbf71b040 100644 --- a/libs/community/langchain_community/graphs/__init__.py +++ b/libs/community/langchain_community/graphs/__init__.py @@ -29,6 +29,8 @@ if TYPE_CHECKING: Neo4jGraph, ) from langchain_community.graphs.neptune_graph import ( + BaseNeptuneGraph, + NeptuneAnalyticsGraph, NeptuneGraph, ) from langchain_community.graphs.neptune_rdf_graph import ( @@ -53,11 +55,13 @@ __all__ = [ "GremlinGraph", "HugeGraph", "KuzuGraph", + "BaseNeptuneGraph", "MemgraphGraph", "NebulaGraph", "Neo4jGraph", "NeptuneGraph", "NeptuneRdfGraph", + "NeptuneAnalyticsGraph", "NetworkxEntityGraph", "OntotextGraphDBGraph", "RdfGraph", @@ -89,6 +93,3 @@ def __getattr__(name: str) -> Any: module = importlib.import_module(_module_lookup[name]) return getattr(module, name) raise AttributeError(f"module {__name__} has no attribute {name}") - - -__all__ = list(_module_lookup.keys()) diff --git a/libs/community/langchain_community/tools/__init__.py b/libs/community/langchain_community/tools/__init__.py index f426b8f0ec9..71445fa6cc9 100644 --- a/libs/community/langchain_community/tools/__init__.py +++ b/libs/community/langchain_community/tools/__init__.py @@ -90,6 +90,7 @@ if TYPE_CHECKING: from langchain_community.tools.convert_to_openai import ( format_tool_to_openai_function, ) + from langchain_community.tools.dataherald import DataheraldTextToSQL from langchain_community.tools.ddg_search.tool import ( DuckDuckGoSearchResults, DuckDuckGoSearchRun, @@ -356,6 +357,7 @@ __all__ = [ "CopyFileTool", "CurrentWebPageTool", "DeleteFileTool", + "DataheraldTextToSQL", "DuckDuckGoSearchResults", "DuckDuckGoSearchRun", "E2BDataAnalysisTool", @@ -610,6 +612,3 @@ def __getattr__(name: str) -> Any: module = importlib.import_module(_module_lookup[name]) return getattr(module, name) raise AttributeError(f"module {__name__} has no attribute {name}") - - -__all__ = list(_module_lookup.keys()) diff --git a/libs/community/langchain_community/utilities/__init__.py b/libs/community/langchain_community/utilities/__init__.py index 6a069557b71..5aea1de0623 100644 --- a/libs/community/langchain_community/utilities/__init__.py +++ b/libs/community/langchain_community/utilities/__init__.py @@ -35,6 +35,7 @@ if TYPE_CHECKING: from langchain_community.utilities.brave_search import ( BraveSearchWrapper, ) + from langchain_community.utilities.dataherald import DataheraldAPIWrapper from langchain_community.utilities.dria_index import ( DriaAPIWrapper, ) @@ -124,6 +125,7 @@ if TYPE_CHECKING: from langchain_community.utilities.python import ( PythonREPL, ) + from langchain_community.utilities.rememberizer import RememberizerAPIWrapper from langchain_community.utilities.requests import ( Requests, RequestsWrapper, @@ -182,6 +184,7 @@ __all__ = [ "BibtexparserWrapper", "BingSearchAPIWrapper", "BraveSearchWrapper", + "DataheraldAPIWrapper", "DriaAPIWrapper", "DuckDuckGoSearchAPIWrapper", "GoldenQueryAPIWrapper", @@ -213,13 +216,14 @@ __all__ = [ "PowerBIDataset", "PubMedAPIWrapper", "PythonREPL", + "RememberizerAPIWrapper", "Requests", "RequestsWrapper", "RivaASR", "RivaTTS", - "SQLDatabase", "SceneXplainAPIWrapper", "SearchApiAPIWrapper", + "SQLDatabase", "SearxSearchWrapper", "SerpAPIWrapper", "SparkSQL", @@ -304,6 +308,3 @@ def __getattr__(name: str) -> Any: module = importlib.import_module(_module_lookup[name]) return getattr(module, name) raise AttributeError(f"module {__name__} has no attribute {name}") - - -__all__ = list(_module_lookup.keys()) diff --git a/libs/community/tests/unit_tests/test_imports.py b/libs/community/tests/unit_tests/test_imports.py index 0819c2de498..a32e44d9d21 100644 --- a/libs/community/tests/unit_tests/test_imports.py +++ b/libs/community/tests/unit_tests/test_imports.py @@ -2,6 +2,7 @@ import ast import glob import importlib from pathlib import Path +from typing import List, Tuple import pytest @@ -91,3 +92,82 @@ def test_no_dynamic__all__() -> None: f"__all__ is not correctly defined in the " f"following files: {sorted(bad_definitions)}" ) + + +def _extract_type_checking_imports(code: str) -> List[Tuple[str, str]]: + """Extract all TYPE CHECKING imports that import from langchain_community.""" + imports: List[Tuple[str, str]] = [] + + tree = ast.parse(code) + + class TypeCheckingVisitor(ast.NodeVisitor): + def visit_ImportFrom(self, node: ast.ImportFrom) -> None: + if node.module: + for alias in node.names: + imports.append((node.module, alias.name)) + + class GlobalScopeVisitor(ast.NodeVisitor): + def visit_If(self, node: ast.If) -> None: + if ( + isinstance(node.test, ast.Name) + and node.test.id == "TYPE_CHECKING" + and isinstance(node.test.ctx, ast.Load) + ): + TypeCheckingVisitor().visit(node) + self.generic_visit(node) + + GlobalScopeVisitor().visit(tree) + return imports + + +def test_init_files_properly_defined() -> None: + """This is part of a set of tests that verify that init files are properly + + defined if they're using dynamic imports. + """ + # Please never ever add more modules to this list. + # Do feel free to fix the underlying issues and remove exceptions + # from the list. + excepted_modules = {"llms"} # NEVER ADD MORE MODULES TO THIS LIST + for path in glob.glob(ALL_COMMUNITY_GLOB): + # Relative to community root + relative_path = Path(path).relative_to(COMMUNITY_ROOT) + str_path = str(relative_path) + + if not str_path.endswith("__init__.py"): + continue + + module_name = str(relative_path.parent).replace("/", ".") + + if module_name in excepted_modules: + continue + + code = Path(path).read_text() + + # Check for dynamic __getattr__ definition in the __init__ file + if "__getattr__" not in code: + continue + + try: + module = importlib.import_module("langchain_community." + module_name) + except ModuleNotFoundError as e: + raise ModuleNotFoundError( + f"Could not import `{module_name}`. Defined in path: {path}" + ) from e + + if not hasattr(module, "__all__"): + raise AssertionError( + f"__all__ not defined in {module_name}. This is required " + f"if __getattr__ is defined." + ) + + imports = _extract_type_checking_imports(code) + + # Get the names of all the TYPE CHECKING imports + names = [name for _, name in imports] + + missing_imports = set(module.__all__) - set(names) + + assert ( + not missing_imports + ), f"Missing imports: {missing_imports} in file path: {path}" diff --git a/libs/community/tests/unit_tests/utilities/test_rememberizer.py b/libs/community/tests/unit_tests/utilities/test_rememberizer.py index 3b288a107f9..f6fe63b03f4 100644 --- a/libs/community/tests/unit_tests/utilities/test_rememberizer.py +++ b/libs/community/tests/unit_tests/utilities/test_rememberizer.py @@ -23,7 +23,9 @@ class TestRememberizerAPIWrapper(unittest.TestCase): ] }, ) - wrapper = RememberizerAPIWrapper(rememberizer_api_key="dummy_key", n=10) + wrapper = RememberizerAPIWrapper( + rememberizer_api_key="dummy_key", top_k_results=10 + ) result = wrapper.search("test") self.assertEqual( result, @@ -44,7 +46,9 @@ class TestRememberizerAPIWrapper(unittest.TestCase): status=400, json={"detail": "Incorrect authentication credentials."}, ) - wrapper = RememberizerAPIWrapper(rememberizer_api_key="dummy_key", n=10) + wrapper = RememberizerAPIWrapper( + rememberizer_api_key="dummy_key", top_k_results=10 + ) with self.assertRaises(ValueError) as e: wrapper.search("test") self.assertEqual( @@ -66,7 +70,9 @@ class TestRememberizerAPIWrapper(unittest.TestCase): "document": {"id": "id2", "name": "name2"}, }, ] - wrapper = RememberizerAPIWrapper(rememberizer_api_key="dummy_key", n=10) + wrapper = RememberizerAPIWrapper( + rememberizer_api_key="dummy_key", top_k_results=10 + ) result = wrapper.load("test") self.assertEqual(len(result), 2) self.assertEqual(result[0].page_content, "content1")