mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 00:23:25 +00:00
core[patch]: utils.guard_import
fix (#21133)
Issues (nit): 1. `utils.guard_import` prints wrong error message when there is an import `error.` It prints the whole `module_name` but should be only the first part as the pip package name. E.i. `langchain_core.utils` -> print not `langchain-core` but `langchain_core.utils`. Also replace '_' with '-' in the pip package name. 2. it does not handle the `ModuleNotFoundError` which raised if `guard_import("wrong_module")` Fixed issues; added ut-s. Controversial: I've reraised `ModuleNotFoundError` as `ImportError`, since in case of the error, the proposed action is the same - we need to install a missed package.
This commit is contained in:
parent
36c2ca3c8b
commit
3ef8b24277
@ -1,4 +1,5 @@
|
|||||||
"""Generic utility functions."""
|
"""Generic utility functions."""
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import datetime
|
import datetime
|
||||||
import functools
|
import functools
|
||||||
@ -88,10 +89,11 @@ def guard_import(
|
|||||||
installed."""
|
installed."""
|
||||||
try:
|
try:
|
||||||
module = importlib.import_module(module_name, package)
|
module = importlib.import_module(module_name, package)
|
||||||
except ImportError:
|
except (ImportError, ModuleNotFoundError):
|
||||||
|
pip_name = pip_name or module_name.split(".")[0].replace("_", "-")
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
f"Could not import {module_name} python package. "
|
f"Could not import {module_name} python package. "
|
||||||
f"Please install it with `pip install {pip_name or module_name}`."
|
f"Please install it with `pip install {pip_name}`."
|
||||||
)
|
)
|
||||||
return module
|
return module
|
||||||
|
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
import re
|
import re
|
||||||
from contextlib import AbstractContextManager, nullcontext
|
from contextlib import AbstractContextManager, nullcontext
|
||||||
from typing import Dict, Optional, Tuple, Type, Union
|
from typing import Any, Dict, Optional, Tuple, Type, Union
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from langchain_core.utils import check_package_version
|
from langchain_core import utils
|
||||||
|
from langchain_core.utils import check_package_version, guard_import
|
||||||
from langchain_core.utils._merge import merge_dicts
|
from langchain_core.utils._merge import merge_dicts
|
||||||
|
|
||||||
|
|
||||||
@ -113,3 +114,60 @@ def test_merge_dicts(
|
|||||||
with err:
|
with err:
|
||||||
actual = merge_dicts(left, right)
|
actual = merge_dicts(left, right)
|
||||||
assert actual == expected
|
assert actual == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("module_name", "pip_name", "package", "expected"),
|
||||||
|
[
|
||||||
|
("langchain_core.utils", None, None, utils),
|
||||||
|
("langchain_core.utils", "langchain-core", None, utils),
|
||||||
|
("langchain_core.utils", None, "langchain-core", utils),
|
||||||
|
("langchain_core.utils", "langchain-core", "langchain-core", utils),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_guard_import(
|
||||||
|
module_name: str, pip_name: Optional[str], package: Optional[str], expected: Any
|
||||||
|
) -> None:
|
||||||
|
if package is None and pip_name is None:
|
||||||
|
ret = guard_import(module_name)
|
||||||
|
elif package is None and pip_name is not None:
|
||||||
|
ret = guard_import(module_name, pip_name=pip_name)
|
||||||
|
elif package is not None and pip_name is None:
|
||||||
|
ret = guard_import(module_name, package=package)
|
||||||
|
elif package is not None and pip_name is not None:
|
||||||
|
ret = guard_import(module_name, pip_name=pip_name, package=package)
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid test case")
|
||||||
|
assert ret == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("module_name", "pip_name", "package"),
|
||||||
|
[
|
||||||
|
("langchain_core.utilsW", None, None),
|
||||||
|
("langchain_core.utilsW", "langchain-core-2", None),
|
||||||
|
("langchain_core.utilsW", None, "langchain-coreWX"),
|
||||||
|
("langchain_core.utilsW", "langchain-core-2", "langchain-coreWX"),
|
||||||
|
("langchain_coreW", None, None), # ModuleNotFoundError
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_guard_import_failure(
|
||||||
|
module_name: str, pip_name: Optional[str], package: Optional[str]
|
||||||
|
) -> None:
|
||||||
|
with pytest.raises(ImportError) as exc_info:
|
||||||
|
if package is None and pip_name is None:
|
||||||
|
guard_import(module_name)
|
||||||
|
elif package is None and pip_name is not None:
|
||||||
|
guard_import(module_name, pip_name=pip_name)
|
||||||
|
elif package is not None and pip_name is None:
|
||||||
|
guard_import(module_name, package=package)
|
||||||
|
elif package is not None and pip_name is not None:
|
||||||
|
guard_import(module_name, pip_name=pip_name, package=package)
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid test case")
|
||||||
|
pip_name = pip_name or module_name.split(".")[0].replace("_", "-")
|
||||||
|
err_msg = (
|
||||||
|
f"Could not import {module_name} python package. "
|
||||||
|
f"Please install it with `pip install {pip_name}`."
|
||||||
|
)
|
||||||
|
assert exc_info.value.msg == err_msg
|
||||||
|
Loading…
Reference in New Issue
Block a user