mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-01 17:13:22 +00:00
community[patch]: Add unit test to verify that init is correctly defined (#22030)
Fix some __init__ files and add a unit test
This commit is contained in:
parent
ef53ccf54b
commit
58360a1e53
@ -120,6 +120,7 @@ if TYPE_CHECKING:
|
||||
from langchain_community.chat_models.mlx import (
|
||||
ChatMLX,
|
||||
)
|
||||
from langchain_community.chat_models.octoai import ChatOctoAI
|
||||
from langchain_community.chat_models.ollama import (
|
||||
ChatOllama,
|
||||
)
|
||||
@ -171,6 +172,7 @@ __all__ = [
|
||||
"ChatBaichuan",
|
||||
"ChatCohere",
|
||||
"ChatCoze",
|
||||
"ChatOctoAI",
|
||||
"ChatDatabricks",
|
||||
"ChatDeepInfra",
|
||||
"ChatEverlyAI",
|
||||
@ -271,6 +273,3 @@ def __getattr__(name: str) -> Any:
|
||||
module = importlib.import_module(_module_lookup[name])
|
||||
return getattr(module, name)
|
||||
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||
|
||||
|
||||
__all__ = list(_module_lookup.keys())
|
||||
|
@ -29,6 +29,8 @@ if TYPE_CHECKING:
|
||||
Neo4jGraph,
|
||||
)
|
||||
from langchain_community.graphs.neptune_graph import (
|
||||
BaseNeptuneGraph,
|
||||
NeptuneAnalyticsGraph,
|
||||
NeptuneGraph,
|
||||
)
|
||||
from langchain_community.graphs.neptune_rdf_graph import (
|
||||
@ -53,11 +55,13 @@ __all__ = [
|
||||
"GremlinGraph",
|
||||
"HugeGraph",
|
||||
"KuzuGraph",
|
||||
"BaseNeptuneGraph",
|
||||
"MemgraphGraph",
|
||||
"NebulaGraph",
|
||||
"Neo4jGraph",
|
||||
"NeptuneGraph",
|
||||
"NeptuneRdfGraph",
|
||||
"NeptuneAnalyticsGraph",
|
||||
"NetworkxEntityGraph",
|
||||
"OntotextGraphDBGraph",
|
||||
"RdfGraph",
|
||||
@ -89,6 +93,3 @@ def __getattr__(name: str) -> Any:
|
||||
module = importlib.import_module(_module_lookup[name])
|
||||
return getattr(module, name)
|
||||
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||
|
||||
|
||||
__all__ = list(_module_lookup.keys())
|
||||
|
@ -90,6 +90,7 @@ if TYPE_CHECKING:
|
||||
from langchain_community.tools.convert_to_openai import (
|
||||
format_tool_to_openai_function,
|
||||
)
|
||||
from langchain_community.tools.dataherald import DataheraldTextToSQL
|
||||
from langchain_community.tools.ddg_search.tool import (
|
||||
DuckDuckGoSearchResults,
|
||||
DuckDuckGoSearchRun,
|
||||
@ -356,6 +357,7 @@ __all__ = [
|
||||
"CopyFileTool",
|
||||
"CurrentWebPageTool",
|
||||
"DeleteFileTool",
|
||||
"DataheraldTextToSQL",
|
||||
"DuckDuckGoSearchResults",
|
||||
"DuckDuckGoSearchRun",
|
||||
"E2BDataAnalysisTool",
|
||||
@ -610,6 +612,3 @@ def __getattr__(name: str) -> Any:
|
||||
module = importlib.import_module(_module_lookup[name])
|
||||
return getattr(module, name)
|
||||
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||
|
||||
|
||||
__all__ = list(_module_lookup.keys())
|
||||
|
@ -35,6 +35,7 @@ if TYPE_CHECKING:
|
||||
from langchain_community.utilities.brave_search import (
|
||||
BraveSearchWrapper,
|
||||
)
|
||||
from langchain_community.utilities.dataherald import DataheraldAPIWrapper
|
||||
from langchain_community.utilities.dria_index import (
|
||||
DriaAPIWrapper,
|
||||
)
|
||||
@ -124,6 +125,7 @@ if TYPE_CHECKING:
|
||||
from langchain_community.utilities.python import (
|
||||
PythonREPL,
|
||||
)
|
||||
from langchain_community.utilities.rememberizer import RememberizerAPIWrapper
|
||||
from langchain_community.utilities.requests import (
|
||||
Requests,
|
||||
RequestsWrapper,
|
||||
@ -182,6 +184,7 @@ __all__ = [
|
||||
"BibtexparserWrapper",
|
||||
"BingSearchAPIWrapper",
|
||||
"BraveSearchWrapper",
|
||||
"DataheraldAPIWrapper",
|
||||
"DriaAPIWrapper",
|
||||
"DuckDuckGoSearchAPIWrapper",
|
||||
"GoldenQueryAPIWrapper",
|
||||
@ -213,13 +216,14 @@ __all__ = [
|
||||
"PowerBIDataset",
|
||||
"PubMedAPIWrapper",
|
||||
"PythonREPL",
|
||||
"RememberizerAPIWrapper",
|
||||
"Requests",
|
||||
"RequestsWrapper",
|
||||
"RivaASR",
|
||||
"RivaTTS",
|
||||
"SQLDatabase",
|
||||
"SceneXplainAPIWrapper",
|
||||
"SearchApiAPIWrapper",
|
||||
"SQLDatabase",
|
||||
"SearxSearchWrapper",
|
||||
"SerpAPIWrapper",
|
||||
"SparkSQL",
|
||||
@ -304,6 +308,3 @@ def __getattr__(name: str) -> Any:
|
||||
module = importlib.import_module(_module_lookup[name])
|
||||
return getattr(module, name)
|
||||
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||
|
||||
|
||||
__all__ = list(_module_lookup.keys())
|
||||
|
@ -2,6 +2,7 @@ import ast
|
||||
import glob
|
||||
import importlib
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
@ -91,3 +92,82 @@ def test_no_dynamic__all__() -> None:
|
||||
f"__all__ is not correctly defined in the "
|
||||
f"following files: {sorted(bad_definitions)}"
|
||||
)
|
||||
|
||||
|
||||
def _extract_type_checking_imports(code: str) -> List[Tuple[str, str]]:
|
||||
"""Extract all TYPE CHECKING imports that import from langchain_community."""
|
||||
imports: List[Tuple[str, str]] = []
|
||||
|
||||
tree = ast.parse(code)
|
||||
|
||||
class TypeCheckingVisitor(ast.NodeVisitor):
|
||||
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
|
||||
if node.module:
|
||||
for alias in node.names:
|
||||
imports.append((node.module, alias.name))
|
||||
|
||||
class GlobalScopeVisitor(ast.NodeVisitor):
|
||||
def visit_If(self, node: ast.If) -> None:
|
||||
if (
|
||||
isinstance(node.test, ast.Name)
|
||||
and node.test.id == "TYPE_CHECKING"
|
||||
and isinstance(node.test.ctx, ast.Load)
|
||||
):
|
||||
TypeCheckingVisitor().visit(node)
|
||||
self.generic_visit(node)
|
||||
|
||||
GlobalScopeVisitor().visit(tree)
|
||||
return imports
|
||||
|
||||
|
||||
def test_init_files_properly_defined() -> None:
|
||||
"""This is part of a set of tests that verify that init files are properly
|
||||
|
||||
defined if they're using dynamic imports.
|
||||
"""
|
||||
# Please never ever add more modules to this list.
|
||||
# Do feel free to fix the underlying issues and remove exceptions
|
||||
# from the list.
|
||||
excepted_modules = {"llms"} # NEVER ADD MORE MODULES TO THIS LIST
|
||||
for path in glob.glob(ALL_COMMUNITY_GLOB):
|
||||
# Relative to community root
|
||||
relative_path = Path(path).relative_to(COMMUNITY_ROOT)
|
||||
str_path = str(relative_path)
|
||||
|
||||
if not str_path.endswith("__init__.py"):
|
||||
continue
|
||||
|
||||
module_name = str(relative_path.parent).replace("/", ".")
|
||||
|
||||
if module_name in excepted_modules:
|
||||
continue
|
||||
|
||||
code = Path(path).read_text()
|
||||
|
||||
# Check for dynamic __getattr__ definition in the __init__ file
|
||||
if "__getattr__" not in code:
|
||||
continue
|
||||
|
||||
try:
|
||||
module = importlib.import_module("langchain_community." + module_name)
|
||||
except ModuleNotFoundError as e:
|
||||
raise ModuleNotFoundError(
|
||||
f"Could not import `{module_name}`. Defined in path: {path}"
|
||||
) from e
|
||||
|
||||
if not hasattr(module, "__all__"):
|
||||
raise AssertionError(
|
||||
f"__all__ not defined in {module_name}. This is required "
|
||||
f"if __getattr__ is defined."
|
||||
)
|
||||
|
||||
imports = _extract_type_checking_imports(code)
|
||||
|
||||
# Get the names of all the TYPE CHECKING imports
|
||||
names = [name for _, name in imports]
|
||||
|
||||
missing_imports = set(module.__all__) - set(names)
|
||||
|
||||
assert (
|
||||
not missing_imports
|
||||
), f"Missing imports: {missing_imports} in file path: {path}"
|
||||
|
@ -23,7 +23,9 @@ class TestRememberizerAPIWrapper(unittest.TestCase):
|
||||
]
|
||||
},
|
||||
)
|
||||
wrapper = RememberizerAPIWrapper(rememberizer_api_key="dummy_key", n=10)
|
||||
wrapper = RememberizerAPIWrapper(
|
||||
rememberizer_api_key="dummy_key", top_k_results=10
|
||||
)
|
||||
result = wrapper.search("test")
|
||||
self.assertEqual(
|
||||
result,
|
||||
@ -44,7 +46,9 @@ class TestRememberizerAPIWrapper(unittest.TestCase):
|
||||
status=400,
|
||||
json={"detail": "Incorrect authentication credentials."},
|
||||
)
|
||||
wrapper = RememberizerAPIWrapper(rememberizer_api_key="dummy_key", n=10)
|
||||
wrapper = RememberizerAPIWrapper(
|
||||
rememberizer_api_key="dummy_key", top_k_results=10
|
||||
)
|
||||
with self.assertRaises(ValueError) as e:
|
||||
wrapper.search("test")
|
||||
self.assertEqual(
|
||||
@ -66,7 +70,9 @@ class TestRememberizerAPIWrapper(unittest.TestCase):
|
||||
"document": {"id": "id2", "name": "name2"},
|
||||
},
|
||||
]
|
||||
wrapper = RememberizerAPIWrapper(rememberizer_api_key="dummy_key", n=10)
|
||||
wrapper = RememberizerAPIWrapper(
|
||||
rememberizer_api_key="dummy_key", top_k_results=10
|
||||
)
|
||||
result = wrapper.load("test")
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertEqual(result[0].page_content, "content1")
|
||||
|
Loading…
Reference in New Issue
Block a user