mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-25 16:13:25 +00:00
core[patch]: utils.guard_import
fix (#21133)
Issues (nit): 1. `utils.guard_import` prints wrong error message when there is an import `error.` It prints the whole `module_name` but should be only the first part as the pip package name. E.i. `langchain_core.utils` -> print not `langchain-core` but `langchain_core.utils`. Also replace '_' with '-' in the pip package name. 2. it does not handle the `ModuleNotFoundError` which raised if `guard_import("wrong_module")` Fixed issues; added ut-s. Controversial: I've reraised `ModuleNotFoundError` as `ImportError`, since in case of the error, the proposed action is the same - we need to install a missed package.
This commit is contained in:
parent
36c2ca3c8b
commit
3ef8b24277
@ -1,4 +1,5 @@
|
||||
"""Generic utility functions."""
|
||||
|
||||
import contextlib
|
||||
import datetime
|
||||
import functools
|
||||
@ -88,10 +89,11 @@ def guard_import(
|
||||
installed."""
|
||||
try:
|
||||
module = importlib.import_module(module_name, package)
|
||||
except ImportError:
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
pip_name = pip_name or module_name.split(".")[0].replace("_", "-")
|
||||
raise ImportError(
|
||||
f"Could not import {module_name} python package. "
|
||||
f"Please install it with `pip install {pip_name or module_name}`."
|
||||
f"Please install it with `pip install {pip_name}`."
|
||||
)
|
||||
return module
|
||||
|
||||
|
@ -1,11 +1,12 @@
|
||||
import re
|
||||
from contextlib import AbstractContextManager, nullcontext
|
||||
from typing import Dict, Optional, Tuple, Type, Union
|
||||
from typing import Any, Dict, Optional, Tuple, Type, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_core.utils import check_package_version
|
||||
from langchain_core import utils
|
||||
from langchain_core.utils import check_package_version, guard_import
|
||||
from langchain_core.utils._merge import merge_dicts
|
||||
|
||||
|
||||
@ -113,3 +114,60 @@ def test_merge_dicts(
|
||||
with err:
|
||||
actual = merge_dicts(left, right)
|
||||
assert actual == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("module_name", "pip_name", "package", "expected"),
|
||||
[
|
||||
("langchain_core.utils", None, None, utils),
|
||||
("langchain_core.utils", "langchain-core", None, utils),
|
||||
("langchain_core.utils", None, "langchain-core", utils),
|
||||
("langchain_core.utils", "langchain-core", "langchain-core", utils),
|
||||
],
|
||||
)
|
||||
def test_guard_import(
|
||||
module_name: str, pip_name: Optional[str], package: Optional[str], expected: Any
|
||||
) -> None:
|
||||
if package is None and pip_name is None:
|
||||
ret = guard_import(module_name)
|
||||
elif package is None and pip_name is not None:
|
||||
ret = guard_import(module_name, pip_name=pip_name)
|
||||
elif package is not None and pip_name is None:
|
||||
ret = guard_import(module_name, package=package)
|
||||
elif package is not None and pip_name is not None:
|
||||
ret = guard_import(module_name, pip_name=pip_name, package=package)
|
||||
else:
|
||||
raise ValueError("Invalid test case")
|
||||
assert ret == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("module_name", "pip_name", "package"),
|
||||
[
|
||||
("langchain_core.utilsW", None, None),
|
||||
("langchain_core.utilsW", "langchain-core-2", None),
|
||||
("langchain_core.utilsW", None, "langchain-coreWX"),
|
||||
("langchain_core.utilsW", "langchain-core-2", "langchain-coreWX"),
|
||||
("langchain_coreW", None, None), # ModuleNotFoundError
|
||||
],
|
||||
)
|
||||
def test_guard_import_failure(
|
||||
module_name: str, pip_name: Optional[str], package: Optional[str]
|
||||
) -> None:
|
||||
with pytest.raises(ImportError) as exc_info:
|
||||
if package is None and pip_name is None:
|
||||
guard_import(module_name)
|
||||
elif package is None and pip_name is not None:
|
||||
guard_import(module_name, pip_name=pip_name)
|
||||
elif package is not None and pip_name is None:
|
||||
guard_import(module_name, package=package)
|
||||
elif package is not None and pip_name is not None:
|
||||
guard_import(module_name, pip_name=pip_name, package=package)
|
||||
else:
|
||||
raise ValueError("Invalid test case")
|
||||
pip_name = pip_name or module_name.split(".")[0].replace("_", "-")
|
||||
err_msg = (
|
||||
f"Could not import {module_name} python package. "
|
||||
f"Please install it with `pip install {pip_name}`."
|
||||
)
|
||||
assert exc_info.value.msg == err_msg
|
||||
|
Loading…
Reference in New Issue
Block a user