community[patch]: Add unit test to verify that init is correctly defined (#22030)

Fix some __init__ files and add a unit test
This commit is contained in:
Eugene Yurtsev 2024-05-22 13:19:00 -04:00 committed by GitHub
parent ef53ccf54b
commit 58360a1e53
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 102 additions and 16 deletions

View File

@ -120,6 +120,7 @@ if TYPE_CHECKING:
from langchain_community.chat_models.mlx import (
ChatMLX,
)
from langchain_community.chat_models.octoai import ChatOctoAI
from langchain_community.chat_models.ollama import (
ChatOllama,
)
@ -171,6 +172,7 @@ __all__ = [
"ChatBaichuan",
"ChatCohere",
"ChatCoze",
"ChatOctoAI",
"ChatDatabricks",
"ChatDeepInfra",
"ChatEverlyAI",
@ -271,6 +273,3 @@ def __getattr__(name: str) -> Any:
module = importlib.import_module(_module_lookup[name])
return getattr(module, name)
raise AttributeError(f"module {__name__} has no attribute {name}")
__all__ = list(_module_lookup.keys())

View File

@ -29,6 +29,8 @@ if TYPE_CHECKING:
Neo4jGraph,
)
from langchain_community.graphs.neptune_graph import (
BaseNeptuneGraph,
NeptuneAnalyticsGraph,
NeptuneGraph,
)
from langchain_community.graphs.neptune_rdf_graph import (
@ -53,11 +55,13 @@ __all__ = [
"GremlinGraph",
"HugeGraph",
"KuzuGraph",
"BaseNeptuneGraph",
"MemgraphGraph",
"NebulaGraph",
"Neo4jGraph",
"NeptuneGraph",
"NeptuneRdfGraph",
"NeptuneAnalyticsGraph",
"NetworkxEntityGraph",
"OntotextGraphDBGraph",
"RdfGraph",
@ -89,6 +93,3 @@ def __getattr__(name: str) -> Any:
module = importlib.import_module(_module_lookup[name])
return getattr(module, name)
raise AttributeError(f"module {__name__} has no attribute {name}")
__all__ = list(_module_lookup.keys())

View File

@ -90,6 +90,7 @@ if TYPE_CHECKING:
from langchain_community.tools.convert_to_openai import (
format_tool_to_openai_function,
)
from langchain_community.tools.dataherald import DataheraldTextToSQL
from langchain_community.tools.ddg_search.tool import (
DuckDuckGoSearchResults,
DuckDuckGoSearchRun,
@ -356,6 +357,7 @@ __all__ = [
"CopyFileTool",
"CurrentWebPageTool",
"DeleteFileTool",
"DataheraldTextToSQL",
"DuckDuckGoSearchResults",
"DuckDuckGoSearchRun",
"E2BDataAnalysisTool",
@ -610,6 +612,3 @@ def __getattr__(name: str) -> Any:
module = importlib.import_module(_module_lookup[name])
return getattr(module, name)
raise AttributeError(f"module {__name__} has no attribute {name}")
__all__ = list(_module_lookup.keys())

View File

@ -35,6 +35,7 @@ if TYPE_CHECKING:
from langchain_community.utilities.brave_search import (
BraveSearchWrapper,
)
from langchain_community.utilities.dataherald import DataheraldAPIWrapper
from langchain_community.utilities.dria_index import (
DriaAPIWrapper,
)
@ -124,6 +125,7 @@ if TYPE_CHECKING:
from langchain_community.utilities.python import (
PythonREPL,
)
from langchain_community.utilities.rememberizer import RememberizerAPIWrapper
from langchain_community.utilities.requests import (
Requests,
RequestsWrapper,
@ -182,6 +184,7 @@ __all__ = [
"BibtexparserWrapper",
"BingSearchAPIWrapper",
"BraveSearchWrapper",
"DataheraldAPIWrapper",
"DriaAPIWrapper",
"DuckDuckGoSearchAPIWrapper",
"GoldenQueryAPIWrapper",
@ -213,13 +216,14 @@ __all__ = [
"PowerBIDataset",
"PubMedAPIWrapper",
"PythonREPL",
"RememberizerAPIWrapper",
"Requests",
"RequestsWrapper",
"RivaASR",
"RivaTTS",
"SQLDatabase",
"SceneXplainAPIWrapper",
"SearchApiAPIWrapper",
"SQLDatabase",
"SearxSearchWrapper",
"SerpAPIWrapper",
"SparkSQL",
@ -304,6 +308,3 @@ def __getattr__(name: str) -> Any:
module = importlib.import_module(_module_lookup[name])
return getattr(module, name)
raise AttributeError(f"module {__name__} has no attribute {name}")
__all__ = list(_module_lookup.keys())

View File

@ -2,6 +2,7 @@ import ast
import glob
import importlib
from pathlib import Path
from typing import List, Tuple
import pytest
@ -91,3 +92,82 @@ def test_no_dynamic__all__() -> None:
f"__all__ is not correctly defined in the "
f"following files: {sorted(bad_definitions)}"
)
def _extract_type_checking_imports(code: str) -> List[Tuple[str, str]]:
"""Extract all TYPE CHECKING imports that import from langchain_community."""
imports: List[Tuple[str, str]] = []
tree = ast.parse(code)
class TypeCheckingVisitor(ast.NodeVisitor):
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
if node.module:
for alias in node.names:
imports.append((node.module, alias.name))
class GlobalScopeVisitor(ast.NodeVisitor):
def visit_If(self, node: ast.If) -> None:
if (
isinstance(node.test, ast.Name)
and node.test.id == "TYPE_CHECKING"
and isinstance(node.test.ctx, ast.Load)
):
TypeCheckingVisitor().visit(node)
self.generic_visit(node)
GlobalScopeVisitor().visit(tree)
return imports
def test_init_files_properly_defined() -> None:
"""This is part of a set of tests that verify that init files are properly
defined if they're using dynamic imports.
"""
# Please never ever add more modules to this list.
# Do feel free to fix the underlying issues and remove exceptions
# from the list.
excepted_modules = {"llms"} # NEVER ADD MORE MODULES TO THIS LIST
for path in glob.glob(ALL_COMMUNITY_GLOB):
# Relative to community root
relative_path = Path(path).relative_to(COMMUNITY_ROOT)
str_path = str(relative_path)
if not str_path.endswith("__init__.py"):
continue
module_name = str(relative_path.parent).replace("/", ".")
if module_name in excepted_modules:
continue
code = Path(path).read_text()
# Check for dynamic __getattr__ definition in the __init__ file
if "__getattr__" not in code:
continue
try:
module = importlib.import_module("langchain_community." + module_name)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
f"Could not import `{module_name}`. Defined in path: {path}"
) from e
if not hasattr(module, "__all__"):
raise AssertionError(
f"__all__ not defined in {module_name}. This is required "
f"if __getattr__ is defined."
)
imports = _extract_type_checking_imports(code)
# Get the names of all the TYPE CHECKING imports
names = [name for _, name in imports]
missing_imports = set(module.__all__) - set(names)
assert (
not missing_imports
), f"Missing imports: {missing_imports} in file path: {path}"

View File

@ -23,7 +23,9 @@ class TestRememberizerAPIWrapper(unittest.TestCase):
]
},
)
wrapper = RememberizerAPIWrapper(rememberizer_api_key="dummy_key", n=10)
wrapper = RememberizerAPIWrapper(
rememberizer_api_key="dummy_key", top_k_results=10
)
result = wrapper.search("test")
self.assertEqual(
result,
@ -44,7 +46,9 @@ class TestRememberizerAPIWrapper(unittest.TestCase):
status=400,
json={"detail": "Incorrect authentication credentials."},
)
wrapper = RememberizerAPIWrapper(rememberizer_api_key="dummy_key", n=10)
wrapper = RememberizerAPIWrapper(
rememberizer_api_key="dummy_key", top_k_results=10
)
with self.assertRaises(ValueError) as e:
wrapper.search("test")
self.assertEqual(
@ -66,7 +70,9 @@ class TestRememberizerAPIWrapper(unittest.TestCase):
"document": {"id": "id2", "name": "name2"},
},
]
wrapper = RememberizerAPIWrapper(rememberizer_api_key="dummy_key", n=10)
wrapper = RememberizerAPIWrapper(
rememberizer_api_key="dummy_key", top_k_results=10
)
result = wrapper.load("test")
self.assertEqual(len(result), 2)
self.assertEqual(result[0].page_content, "content1")