community[patch]: Add unit test to verify that init is correctly defined (#22030)

Fix some __init__ files and add a unit test
This commit is contained in:
Eugene Yurtsev 2024-05-22 13:19:00 -04:00 committed by GitHub
parent ef53ccf54b
commit 58360a1e53
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 102 additions and 16 deletions

View File

@ -120,6 +120,7 @@ if TYPE_CHECKING:
from langchain_community.chat_models.mlx import ( from langchain_community.chat_models.mlx import (
ChatMLX, ChatMLX,
) )
from langchain_community.chat_models.octoai import ChatOctoAI
from langchain_community.chat_models.ollama import ( from langchain_community.chat_models.ollama import (
ChatOllama, ChatOllama,
) )
@ -171,6 +172,7 @@ __all__ = [
"ChatBaichuan", "ChatBaichuan",
"ChatCohere", "ChatCohere",
"ChatCoze", "ChatCoze",
"ChatOctoAI",
"ChatDatabricks", "ChatDatabricks",
"ChatDeepInfra", "ChatDeepInfra",
"ChatEverlyAI", "ChatEverlyAI",
@ -271,6 +273,3 @@ def __getattr__(name: str) -> Any:
module = importlib.import_module(_module_lookup[name]) module = importlib.import_module(_module_lookup[name])
return getattr(module, name) return getattr(module, name)
raise AttributeError(f"module {__name__} has no attribute {name}") raise AttributeError(f"module {__name__} has no attribute {name}")
__all__ = list(_module_lookup.keys())

View File

@ -29,6 +29,8 @@ if TYPE_CHECKING:
Neo4jGraph, Neo4jGraph,
) )
from langchain_community.graphs.neptune_graph import ( from langchain_community.graphs.neptune_graph import (
BaseNeptuneGraph,
NeptuneAnalyticsGraph,
NeptuneGraph, NeptuneGraph,
) )
from langchain_community.graphs.neptune_rdf_graph import ( from langchain_community.graphs.neptune_rdf_graph import (
@ -53,11 +55,13 @@ __all__ = [
"GremlinGraph", "GremlinGraph",
"HugeGraph", "HugeGraph",
"KuzuGraph", "KuzuGraph",
"BaseNeptuneGraph",
"MemgraphGraph", "MemgraphGraph",
"NebulaGraph", "NebulaGraph",
"Neo4jGraph", "Neo4jGraph",
"NeptuneGraph", "NeptuneGraph",
"NeptuneRdfGraph", "NeptuneRdfGraph",
"NeptuneAnalyticsGraph",
"NetworkxEntityGraph", "NetworkxEntityGraph",
"OntotextGraphDBGraph", "OntotextGraphDBGraph",
"RdfGraph", "RdfGraph",
@ -89,6 +93,3 @@ def __getattr__(name: str) -> Any:
module = importlib.import_module(_module_lookup[name]) module = importlib.import_module(_module_lookup[name])
return getattr(module, name) return getattr(module, name)
raise AttributeError(f"module {__name__} has no attribute {name}") raise AttributeError(f"module {__name__} has no attribute {name}")
__all__ = list(_module_lookup.keys())

View File

@ -90,6 +90,7 @@ if TYPE_CHECKING:
from langchain_community.tools.convert_to_openai import ( from langchain_community.tools.convert_to_openai import (
format_tool_to_openai_function, format_tool_to_openai_function,
) )
from langchain_community.tools.dataherald import DataheraldTextToSQL
from langchain_community.tools.ddg_search.tool import ( from langchain_community.tools.ddg_search.tool import (
DuckDuckGoSearchResults, DuckDuckGoSearchResults,
DuckDuckGoSearchRun, DuckDuckGoSearchRun,
@ -356,6 +357,7 @@ __all__ = [
"CopyFileTool", "CopyFileTool",
"CurrentWebPageTool", "CurrentWebPageTool",
"DeleteFileTool", "DeleteFileTool",
"DataheraldTextToSQL",
"DuckDuckGoSearchResults", "DuckDuckGoSearchResults",
"DuckDuckGoSearchRun", "DuckDuckGoSearchRun",
"E2BDataAnalysisTool", "E2BDataAnalysisTool",
@ -610,6 +612,3 @@ def __getattr__(name: str) -> Any:
module = importlib.import_module(_module_lookup[name]) module = importlib.import_module(_module_lookup[name])
return getattr(module, name) return getattr(module, name)
raise AttributeError(f"module {__name__} has no attribute {name}") raise AttributeError(f"module {__name__} has no attribute {name}")
__all__ = list(_module_lookup.keys())

View File

@ -35,6 +35,7 @@ if TYPE_CHECKING:
from langchain_community.utilities.brave_search import ( from langchain_community.utilities.brave_search import (
BraveSearchWrapper, BraveSearchWrapper,
) )
from langchain_community.utilities.dataherald import DataheraldAPIWrapper
from langchain_community.utilities.dria_index import ( from langchain_community.utilities.dria_index import (
DriaAPIWrapper, DriaAPIWrapper,
) )
@ -124,6 +125,7 @@ if TYPE_CHECKING:
from langchain_community.utilities.python import ( from langchain_community.utilities.python import (
PythonREPL, PythonREPL,
) )
from langchain_community.utilities.rememberizer import RememberizerAPIWrapper
from langchain_community.utilities.requests import ( from langchain_community.utilities.requests import (
Requests, Requests,
RequestsWrapper, RequestsWrapper,
@ -182,6 +184,7 @@ __all__ = [
"BibtexparserWrapper", "BibtexparserWrapper",
"BingSearchAPIWrapper", "BingSearchAPIWrapper",
"BraveSearchWrapper", "BraveSearchWrapper",
"DataheraldAPIWrapper",
"DriaAPIWrapper", "DriaAPIWrapper",
"DuckDuckGoSearchAPIWrapper", "DuckDuckGoSearchAPIWrapper",
"GoldenQueryAPIWrapper", "GoldenQueryAPIWrapper",
@ -213,13 +216,14 @@ __all__ = [
"PowerBIDataset", "PowerBIDataset",
"PubMedAPIWrapper", "PubMedAPIWrapper",
"PythonREPL", "PythonREPL",
"RememberizerAPIWrapper",
"Requests", "Requests",
"RequestsWrapper", "RequestsWrapper",
"RivaASR", "RivaASR",
"RivaTTS", "RivaTTS",
"SQLDatabase",
"SceneXplainAPIWrapper", "SceneXplainAPIWrapper",
"SearchApiAPIWrapper", "SearchApiAPIWrapper",
"SQLDatabase",
"SearxSearchWrapper", "SearxSearchWrapper",
"SerpAPIWrapper", "SerpAPIWrapper",
"SparkSQL", "SparkSQL",
@ -304,6 +308,3 @@ def __getattr__(name: str) -> Any:
module = importlib.import_module(_module_lookup[name]) module = importlib.import_module(_module_lookup[name])
return getattr(module, name) return getattr(module, name)
raise AttributeError(f"module {__name__} has no attribute {name}") raise AttributeError(f"module {__name__} has no attribute {name}")
__all__ = list(_module_lookup.keys())

View File

@ -2,6 +2,7 @@ import ast
import glob import glob
import importlib import importlib
from pathlib import Path from pathlib import Path
from typing import List, Tuple
import pytest import pytest
@ -91,3 +92,82 @@ def test_no_dynamic__all__() -> None:
f"__all__ is not correctly defined in the " f"__all__ is not correctly defined in the "
f"following files: {sorted(bad_definitions)}" f"following files: {sorted(bad_definitions)}"
) )
def _extract_type_checking_imports(code: str) -> List[Tuple[str, str]]:
"""Extract all TYPE CHECKING imports that import from langchain_community."""
imports: List[Tuple[str, str]] = []
tree = ast.parse(code)
class TypeCheckingVisitor(ast.NodeVisitor):
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
if node.module:
for alias in node.names:
imports.append((node.module, alias.name))
class GlobalScopeVisitor(ast.NodeVisitor):
def visit_If(self, node: ast.If) -> None:
if (
isinstance(node.test, ast.Name)
and node.test.id == "TYPE_CHECKING"
and isinstance(node.test.ctx, ast.Load)
):
TypeCheckingVisitor().visit(node)
self.generic_visit(node)
GlobalScopeVisitor().visit(tree)
return imports
def test_init_files_properly_defined() -> None:
"""This is part of a set of tests that verify that init files are properly
defined if they're using dynamic imports.
"""
# Please never ever add more modules to this list.
# Do feel free to fix the underlying issues and remove exceptions
# from the list.
excepted_modules = {"llms"} # NEVER ADD MORE MODULES TO THIS LIST
for path in glob.glob(ALL_COMMUNITY_GLOB):
# Relative to community root
relative_path = Path(path).relative_to(COMMUNITY_ROOT)
str_path = str(relative_path)
if not str_path.endswith("__init__.py"):
continue
module_name = str(relative_path.parent).replace("/", ".")
if module_name in excepted_modules:
continue
code = Path(path).read_text()
# Check for dynamic __getattr__ definition in the __init__ file
if "__getattr__" not in code:
continue
try:
module = importlib.import_module("langchain_community." + module_name)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
f"Could not import `{module_name}`. Defined in path: {path}"
) from e
if not hasattr(module, "__all__"):
raise AssertionError(
f"__all__ not defined in {module_name}. This is required "
f"if __getattr__ is defined."
)
imports = _extract_type_checking_imports(code)
# Get the names of all the TYPE CHECKING imports
names = [name for _, name in imports]
missing_imports = set(module.__all__) - set(names)
assert (
not missing_imports
), f"Missing imports: {missing_imports} in file path: {path}"

View File

@ -23,7 +23,9 @@ class TestRememberizerAPIWrapper(unittest.TestCase):
] ]
}, },
) )
wrapper = RememberizerAPIWrapper(rememberizer_api_key="dummy_key", n=10) wrapper = RememberizerAPIWrapper(
rememberizer_api_key="dummy_key", top_k_results=10
)
result = wrapper.search("test") result = wrapper.search("test")
self.assertEqual( self.assertEqual(
result, result,
@ -44,7 +46,9 @@ class TestRememberizerAPIWrapper(unittest.TestCase):
status=400, status=400,
json={"detail": "Incorrect authentication credentials."}, json={"detail": "Incorrect authentication credentials."},
) )
wrapper = RememberizerAPIWrapper(rememberizer_api_key="dummy_key", n=10) wrapper = RememberizerAPIWrapper(
rememberizer_api_key="dummy_key", top_k_results=10
)
with self.assertRaises(ValueError) as e: with self.assertRaises(ValueError) as e:
wrapper.search("test") wrapper.search("test")
self.assertEqual( self.assertEqual(
@ -66,7 +70,9 @@ class TestRememberizerAPIWrapper(unittest.TestCase):
"document": {"id": "id2", "name": "name2"}, "document": {"id": "id2", "name": "name2"},
}, },
] ]
wrapper = RememberizerAPIWrapper(rememberizer_api_key="dummy_key", n=10) wrapper = RememberizerAPIWrapper(
rememberizer_api_key="dummy_key", top_k_results=10
)
result = wrapper.load("test") result = wrapper.load("test")
self.assertEqual(len(result), 2) self.assertEqual(len(result), 2)
self.assertEqual(result[0].page_content, "content1") self.assertEqual(result[0].page_content, "content1")