core[patch]: utils.guard_import fix (#21133)

Issues (nit): 
1. `utils.guard_import` prints wrong error message when there is an
import `error.` It prints the whole `module_name` but should be only the
first part as the pip package name. E.i. `langchain_core.utils` -> print
not `langchain-core` but `langchain_core.utils`. Also replace '_' with
'-' in the pip package name.
2. it does not handle the `ModuleNotFoundError` which raised if
`guard_import("wrong_module")`

Fixed issues; added ut-s. Controversial: I've reraised
`ModuleNotFoundError` as `ImportError`, since in case of the error, the
proposed action is the same - we need to install a missed package.
This commit is contained in:
Leonid Ganeline 2024-05-03 14:21:36 -07:00 committed by GitHub
parent 36c2ca3c8b
commit 3ef8b24277
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 64 additions and 4 deletions

View File

@ -1,4 +1,5 @@
"""Generic utility functions."""
import contextlib
import datetime
import functools
@ -88,10 +89,11 @@ def guard_import(
installed."""
try:
module = importlib.import_module(module_name, package)
except ImportError:
except (ImportError, ModuleNotFoundError):
pip_name = pip_name or module_name.split(".")[0].replace("_", "-")
raise ImportError(
f"Could not import {module_name} python package. "
f"Please install it with `pip install {pip_name or module_name}`."
f"Please install it with `pip install {pip_name}`."
)
return module

View File

@ -1,11 +1,12 @@
import re
from contextlib import AbstractContextManager, nullcontext
from typing import Dict, Optional, Tuple, Type, Union
from typing import Any, Dict, Optional, Tuple, Type, Union
from unittest.mock import patch
import pytest
from langchain_core.utils import check_package_version
from langchain_core import utils
from langchain_core.utils import check_package_version, guard_import
from langchain_core.utils._merge import merge_dicts
@ -113,3 +114,60 @@ def test_merge_dicts(
with err:
actual = merge_dicts(left, right)
assert actual == expected
@pytest.mark.parametrize(
("module_name", "pip_name", "package", "expected"),
[
("langchain_core.utils", None, None, utils),
("langchain_core.utils", "langchain-core", None, utils),
("langchain_core.utils", None, "langchain-core", utils),
("langchain_core.utils", "langchain-core", "langchain-core", utils),
],
)
def test_guard_import(
module_name: str, pip_name: Optional[str], package: Optional[str], expected: Any
) -> None:
if package is None and pip_name is None:
ret = guard_import(module_name)
elif package is None and pip_name is not None:
ret = guard_import(module_name, pip_name=pip_name)
elif package is not None and pip_name is None:
ret = guard_import(module_name, package=package)
elif package is not None and pip_name is not None:
ret = guard_import(module_name, pip_name=pip_name, package=package)
else:
raise ValueError("Invalid test case")
assert ret == expected
@pytest.mark.parametrize(
("module_name", "pip_name", "package"),
[
("langchain_core.utilsW", None, None),
("langchain_core.utilsW", "langchain-core-2", None),
("langchain_core.utilsW", None, "langchain-coreWX"),
("langchain_core.utilsW", "langchain-core-2", "langchain-coreWX"),
("langchain_coreW", None, None), # ModuleNotFoundError
],
)
def test_guard_import_failure(
module_name: str, pip_name: Optional[str], package: Optional[str]
) -> None:
with pytest.raises(ImportError) as exc_info:
if package is None and pip_name is None:
guard_import(module_name)
elif package is None and pip_name is not None:
guard_import(module_name, pip_name=pip_name)
elif package is not None and pip_name is None:
guard_import(module_name, package=package)
elif package is not None and pip_name is not None:
guard_import(module_name, pip_name=pip_name, package=package)
else:
raise ValueError("Invalid test case")
pip_name = pip_name or module_name.split(".")[0].replace("_", "-")
err_msg = (
f"Could not import {module_name} python package. "
f"Please install it with `pip install {pip_name}`."
)
assert exc_info.value.msg == err_msg