mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-08 04:25:46 +00:00
community[patch]: Add unit test to verify that init is correctly defined (#22030)
Fix some __init__ files and add a unit test
This commit is contained in:
parent
ef53ccf54b
commit
58360a1e53
@ -120,6 +120,7 @@ if TYPE_CHECKING:
|
|||||||
from langchain_community.chat_models.mlx import (
|
from langchain_community.chat_models.mlx import (
|
||||||
ChatMLX,
|
ChatMLX,
|
||||||
)
|
)
|
||||||
|
from langchain_community.chat_models.octoai import ChatOctoAI
|
||||||
from langchain_community.chat_models.ollama import (
|
from langchain_community.chat_models.ollama import (
|
||||||
ChatOllama,
|
ChatOllama,
|
||||||
)
|
)
|
||||||
@ -171,6 +172,7 @@ __all__ = [
|
|||||||
"ChatBaichuan",
|
"ChatBaichuan",
|
||||||
"ChatCohere",
|
"ChatCohere",
|
||||||
"ChatCoze",
|
"ChatCoze",
|
||||||
|
"ChatOctoAI",
|
||||||
"ChatDatabricks",
|
"ChatDatabricks",
|
||||||
"ChatDeepInfra",
|
"ChatDeepInfra",
|
||||||
"ChatEverlyAI",
|
"ChatEverlyAI",
|
||||||
@ -271,6 +273,3 @@ def __getattr__(name: str) -> Any:
|
|||||||
module = importlib.import_module(_module_lookup[name])
|
module = importlib.import_module(_module_lookup[name])
|
||||||
return getattr(module, name)
|
return getattr(module, name)
|
||||||
raise AttributeError(f"module {__name__} has no attribute {name}")
|
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||||
|
|
||||||
|
|
||||||
__all__ = list(_module_lookup.keys())
|
|
||||||
|
@ -29,6 +29,8 @@ if TYPE_CHECKING:
|
|||||||
Neo4jGraph,
|
Neo4jGraph,
|
||||||
)
|
)
|
||||||
from langchain_community.graphs.neptune_graph import (
|
from langchain_community.graphs.neptune_graph import (
|
||||||
|
BaseNeptuneGraph,
|
||||||
|
NeptuneAnalyticsGraph,
|
||||||
NeptuneGraph,
|
NeptuneGraph,
|
||||||
)
|
)
|
||||||
from langchain_community.graphs.neptune_rdf_graph import (
|
from langchain_community.graphs.neptune_rdf_graph import (
|
||||||
@ -53,11 +55,13 @@ __all__ = [
|
|||||||
"GremlinGraph",
|
"GremlinGraph",
|
||||||
"HugeGraph",
|
"HugeGraph",
|
||||||
"KuzuGraph",
|
"KuzuGraph",
|
||||||
|
"BaseNeptuneGraph",
|
||||||
"MemgraphGraph",
|
"MemgraphGraph",
|
||||||
"NebulaGraph",
|
"NebulaGraph",
|
||||||
"Neo4jGraph",
|
"Neo4jGraph",
|
||||||
"NeptuneGraph",
|
"NeptuneGraph",
|
||||||
"NeptuneRdfGraph",
|
"NeptuneRdfGraph",
|
||||||
|
"NeptuneAnalyticsGraph",
|
||||||
"NetworkxEntityGraph",
|
"NetworkxEntityGraph",
|
||||||
"OntotextGraphDBGraph",
|
"OntotextGraphDBGraph",
|
||||||
"RdfGraph",
|
"RdfGraph",
|
||||||
@ -89,6 +93,3 @@ def __getattr__(name: str) -> Any:
|
|||||||
module = importlib.import_module(_module_lookup[name])
|
module = importlib.import_module(_module_lookup[name])
|
||||||
return getattr(module, name)
|
return getattr(module, name)
|
||||||
raise AttributeError(f"module {__name__} has no attribute {name}")
|
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||||
|
|
||||||
|
|
||||||
__all__ = list(_module_lookup.keys())
|
|
||||||
|
@ -90,6 +90,7 @@ if TYPE_CHECKING:
|
|||||||
from langchain_community.tools.convert_to_openai import (
|
from langchain_community.tools.convert_to_openai import (
|
||||||
format_tool_to_openai_function,
|
format_tool_to_openai_function,
|
||||||
)
|
)
|
||||||
|
from langchain_community.tools.dataherald import DataheraldTextToSQL
|
||||||
from langchain_community.tools.ddg_search.tool import (
|
from langchain_community.tools.ddg_search.tool import (
|
||||||
DuckDuckGoSearchResults,
|
DuckDuckGoSearchResults,
|
||||||
DuckDuckGoSearchRun,
|
DuckDuckGoSearchRun,
|
||||||
@ -356,6 +357,7 @@ __all__ = [
|
|||||||
"CopyFileTool",
|
"CopyFileTool",
|
||||||
"CurrentWebPageTool",
|
"CurrentWebPageTool",
|
||||||
"DeleteFileTool",
|
"DeleteFileTool",
|
||||||
|
"DataheraldTextToSQL",
|
||||||
"DuckDuckGoSearchResults",
|
"DuckDuckGoSearchResults",
|
||||||
"DuckDuckGoSearchRun",
|
"DuckDuckGoSearchRun",
|
||||||
"E2BDataAnalysisTool",
|
"E2BDataAnalysisTool",
|
||||||
@ -610,6 +612,3 @@ def __getattr__(name: str) -> Any:
|
|||||||
module = importlib.import_module(_module_lookup[name])
|
module = importlib.import_module(_module_lookup[name])
|
||||||
return getattr(module, name)
|
return getattr(module, name)
|
||||||
raise AttributeError(f"module {__name__} has no attribute {name}")
|
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||||
|
|
||||||
|
|
||||||
__all__ = list(_module_lookup.keys())
|
|
||||||
|
@ -35,6 +35,7 @@ if TYPE_CHECKING:
|
|||||||
from langchain_community.utilities.brave_search import (
|
from langchain_community.utilities.brave_search import (
|
||||||
BraveSearchWrapper,
|
BraveSearchWrapper,
|
||||||
)
|
)
|
||||||
|
from langchain_community.utilities.dataherald import DataheraldAPIWrapper
|
||||||
from langchain_community.utilities.dria_index import (
|
from langchain_community.utilities.dria_index import (
|
||||||
DriaAPIWrapper,
|
DriaAPIWrapper,
|
||||||
)
|
)
|
||||||
@ -124,6 +125,7 @@ if TYPE_CHECKING:
|
|||||||
from langchain_community.utilities.python import (
|
from langchain_community.utilities.python import (
|
||||||
PythonREPL,
|
PythonREPL,
|
||||||
)
|
)
|
||||||
|
from langchain_community.utilities.rememberizer import RememberizerAPIWrapper
|
||||||
from langchain_community.utilities.requests import (
|
from langchain_community.utilities.requests import (
|
||||||
Requests,
|
Requests,
|
||||||
RequestsWrapper,
|
RequestsWrapper,
|
||||||
@ -182,6 +184,7 @@ __all__ = [
|
|||||||
"BibtexparserWrapper",
|
"BibtexparserWrapper",
|
||||||
"BingSearchAPIWrapper",
|
"BingSearchAPIWrapper",
|
||||||
"BraveSearchWrapper",
|
"BraveSearchWrapper",
|
||||||
|
"DataheraldAPIWrapper",
|
||||||
"DriaAPIWrapper",
|
"DriaAPIWrapper",
|
||||||
"DuckDuckGoSearchAPIWrapper",
|
"DuckDuckGoSearchAPIWrapper",
|
||||||
"GoldenQueryAPIWrapper",
|
"GoldenQueryAPIWrapper",
|
||||||
@ -213,13 +216,14 @@ __all__ = [
|
|||||||
"PowerBIDataset",
|
"PowerBIDataset",
|
||||||
"PubMedAPIWrapper",
|
"PubMedAPIWrapper",
|
||||||
"PythonREPL",
|
"PythonREPL",
|
||||||
|
"RememberizerAPIWrapper",
|
||||||
"Requests",
|
"Requests",
|
||||||
"RequestsWrapper",
|
"RequestsWrapper",
|
||||||
"RivaASR",
|
"RivaASR",
|
||||||
"RivaTTS",
|
"RivaTTS",
|
||||||
"SQLDatabase",
|
|
||||||
"SceneXplainAPIWrapper",
|
"SceneXplainAPIWrapper",
|
||||||
"SearchApiAPIWrapper",
|
"SearchApiAPIWrapper",
|
||||||
|
"SQLDatabase",
|
||||||
"SearxSearchWrapper",
|
"SearxSearchWrapper",
|
||||||
"SerpAPIWrapper",
|
"SerpAPIWrapper",
|
||||||
"SparkSQL",
|
"SparkSQL",
|
||||||
@ -304,6 +308,3 @@ def __getattr__(name: str) -> Any:
|
|||||||
module = importlib.import_module(_module_lookup[name])
|
module = importlib.import_module(_module_lookup[name])
|
||||||
return getattr(module, name)
|
return getattr(module, name)
|
||||||
raise AttributeError(f"module {__name__} has no attribute {name}")
|
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||||
|
|
||||||
|
|
||||||
__all__ = list(_module_lookup.keys())
|
|
||||||
|
@ -2,6 +2,7 @@ import ast
|
|||||||
import glob
|
import glob
|
||||||
import importlib
|
import importlib
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -91,3 +92,82 @@ def test_no_dynamic__all__() -> None:
|
|||||||
f"__all__ is not correctly defined in the "
|
f"__all__ is not correctly defined in the "
|
||||||
f"following files: {sorted(bad_definitions)}"
|
f"following files: {sorted(bad_definitions)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_type_checking_imports(code: str) -> List[Tuple[str, str]]:
|
||||||
|
"""Extract all TYPE CHECKING imports that import from langchain_community."""
|
||||||
|
imports: List[Tuple[str, str]] = []
|
||||||
|
|
||||||
|
tree = ast.parse(code)
|
||||||
|
|
||||||
|
class TypeCheckingVisitor(ast.NodeVisitor):
|
||||||
|
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
|
||||||
|
if node.module:
|
||||||
|
for alias in node.names:
|
||||||
|
imports.append((node.module, alias.name))
|
||||||
|
|
||||||
|
class GlobalScopeVisitor(ast.NodeVisitor):
|
||||||
|
def visit_If(self, node: ast.If) -> None:
|
||||||
|
if (
|
||||||
|
isinstance(node.test, ast.Name)
|
||||||
|
and node.test.id == "TYPE_CHECKING"
|
||||||
|
and isinstance(node.test.ctx, ast.Load)
|
||||||
|
):
|
||||||
|
TypeCheckingVisitor().visit(node)
|
||||||
|
self.generic_visit(node)
|
||||||
|
|
||||||
|
GlobalScopeVisitor().visit(tree)
|
||||||
|
return imports
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_files_properly_defined() -> None:
|
||||||
|
"""This is part of a set of tests that verify that init files are properly
|
||||||
|
|
||||||
|
defined if they're using dynamic imports.
|
||||||
|
"""
|
||||||
|
# Please never ever add more modules to this list.
|
||||||
|
# Do feel free to fix the underlying issues and remove exceptions
|
||||||
|
# from the list.
|
||||||
|
excepted_modules = {"llms"} # NEVER ADD MORE MODULES TO THIS LIST
|
||||||
|
for path in glob.glob(ALL_COMMUNITY_GLOB):
|
||||||
|
# Relative to community root
|
||||||
|
relative_path = Path(path).relative_to(COMMUNITY_ROOT)
|
||||||
|
str_path = str(relative_path)
|
||||||
|
|
||||||
|
if not str_path.endswith("__init__.py"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
module_name = str(relative_path.parent).replace("/", ".")
|
||||||
|
|
||||||
|
if module_name in excepted_modules:
|
||||||
|
continue
|
||||||
|
|
||||||
|
code = Path(path).read_text()
|
||||||
|
|
||||||
|
# Check for dynamic __getattr__ definition in the __init__ file
|
||||||
|
if "__getattr__" not in code:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
module = importlib.import_module("langchain_community." + module_name)
|
||||||
|
except ModuleNotFoundError as e:
|
||||||
|
raise ModuleNotFoundError(
|
||||||
|
f"Could not import `{module_name}`. Defined in path: {path}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
if not hasattr(module, "__all__"):
|
||||||
|
raise AssertionError(
|
||||||
|
f"__all__ not defined in {module_name}. This is required "
|
||||||
|
f"if __getattr__ is defined."
|
||||||
|
)
|
||||||
|
|
||||||
|
imports = _extract_type_checking_imports(code)
|
||||||
|
|
||||||
|
# Get the names of all the TYPE CHECKING imports
|
||||||
|
names = [name for _, name in imports]
|
||||||
|
|
||||||
|
missing_imports = set(module.__all__) - set(names)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
not missing_imports
|
||||||
|
), f"Missing imports: {missing_imports} in file path: {path}"
|
||||||
|
@ -23,7 +23,9 @@ class TestRememberizerAPIWrapper(unittest.TestCase):
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
wrapper = RememberizerAPIWrapper(rememberizer_api_key="dummy_key", n=10)
|
wrapper = RememberizerAPIWrapper(
|
||||||
|
rememberizer_api_key="dummy_key", top_k_results=10
|
||||||
|
)
|
||||||
result = wrapper.search("test")
|
result = wrapper.search("test")
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
result,
|
result,
|
||||||
@ -44,7 +46,9 @@ class TestRememberizerAPIWrapper(unittest.TestCase):
|
|||||||
status=400,
|
status=400,
|
||||||
json={"detail": "Incorrect authentication credentials."},
|
json={"detail": "Incorrect authentication credentials."},
|
||||||
)
|
)
|
||||||
wrapper = RememberizerAPIWrapper(rememberizer_api_key="dummy_key", n=10)
|
wrapper = RememberizerAPIWrapper(
|
||||||
|
rememberizer_api_key="dummy_key", top_k_results=10
|
||||||
|
)
|
||||||
with self.assertRaises(ValueError) as e:
|
with self.assertRaises(ValueError) as e:
|
||||||
wrapper.search("test")
|
wrapper.search("test")
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
@ -66,7 +70,9 @@ class TestRememberizerAPIWrapper(unittest.TestCase):
|
|||||||
"document": {"id": "id2", "name": "name2"},
|
"document": {"id": "id2", "name": "name2"},
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
wrapper = RememberizerAPIWrapper(rememberizer_api_key="dummy_key", n=10)
|
wrapper = RememberizerAPIWrapper(
|
||||||
|
rememberizer_api_key="dummy_key", top_k_results=10
|
||||||
|
)
|
||||||
result = wrapper.load("test")
|
result = wrapper.load("test")
|
||||||
self.assertEqual(len(result), 2)
|
self.assertEqual(len(result), 2)
|
||||||
self.assertEqual(result[0].page_content, "content1")
|
self.assertEqual(result[0].page_content, "content1")
|
||||||
|
Loading…
Reference in New Issue
Block a user