From 3ef8b242771268647e889c6a60507fd139a44ad0 Mon Sep 17 00:00:00 2001 From: Leonid Ganeline Date: Fri, 3 May 2024 14:21:36 -0700 Subject: [PATCH] core[patch]: `utils.guard_import` fix (#21133) Issues (nit): 1. `utils.guard_import` prints wrong error message when there is an import `error.` It prints the whole `module_name` but should be only the first part as the pip package name. E.i. `langchain_core.utils` -> print not `langchain-core` but `langchain_core.utils`. Also replace '_' with '-' in the pip package name. 2. it does not handle the `ModuleNotFoundError` which raised if `guard_import("wrong_module")` Fixed issues; added ut-s. Controversial: I've reraised `ModuleNotFoundError` as `ImportError`, since in case of the error, the proposed action is the same - we need to install a missed package. --- libs/core/langchain_core/utils/utils.py | 6 +- .../core/tests/unit_tests/utils/test_utils.py | 62 ++++++++++++++++++- 2 files changed, 64 insertions(+), 4 deletions(-) diff --git a/libs/core/langchain_core/utils/utils.py b/libs/core/langchain_core/utils/utils.py index 0738b1ddb5f..d989d94cfa0 100644 --- a/libs/core/langchain_core/utils/utils.py +++ b/libs/core/langchain_core/utils/utils.py @@ -1,4 +1,5 @@ """Generic utility functions.""" + import contextlib import datetime import functools @@ -88,10 +89,11 @@ def guard_import( installed.""" try: module = importlib.import_module(module_name, package) - except ImportError: + except (ImportError, ModuleNotFoundError): + pip_name = pip_name or module_name.split(".")[0].replace("_", "-") raise ImportError( f"Could not import {module_name} python package. " - f"Please install it with `pip install {pip_name or module_name}`." + f"Please install it with `pip install {pip_name}`." ) return module diff --git a/libs/core/tests/unit_tests/utils/test_utils.py b/libs/core/tests/unit_tests/utils/test_utils.py index 9cffdf80196..2524aec5f9c 100644 --- a/libs/core/tests/unit_tests/utils/test_utils.py +++ b/libs/core/tests/unit_tests/utils/test_utils.py @@ -1,11 +1,12 @@ import re from contextlib import AbstractContextManager, nullcontext -from typing import Dict, Optional, Tuple, Type, Union +from typing import Any, Dict, Optional, Tuple, Type, Union from unittest.mock import patch import pytest -from langchain_core.utils import check_package_version +from langchain_core import utils +from langchain_core.utils import check_package_version, guard_import from langchain_core.utils._merge import merge_dicts @@ -113,3 +114,60 @@ def test_merge_dicts( with err: actual = merge_dicts(left, right) assert actual == expected + + +@pytest.mark.parametrize( + ("module_name", "pip_name", "package", "expected"), + [ + ("langchain_core.utils", None, None, utils), + ("langchain_core.utils", "langchain-core", None, utils), + ("langchain_core.utils", None, "langchain-core", utils), + ("langchain_core.utils", "langchain-core", "langchain-core", utils), + ], +) +def test_guard_import( + module_name: str, pip_name: Optional[str], package: Optional[str], expected: Any +) -> None: + if package is None and pip_name is None: + ret = guard_import(module_name) + elif package is None and pip_name is not None: + ret = guard_import(module_name, pip_name=pip_name) + elif package is not None and pip_name is None: + ret = guard_import(module_name, package=package) + elif package is not None and pip_name is not None: + ret = guard_import(module_name, pip_name=pip_name, package=package) + else: + raise ValueError("Invalid test case") + assert ret == expected + + +@pytest.mark.parametrize( + ("module_name", "pip_name", "package"), + [ + ("langchain_core.utilsW", None, None), + ("langchain_core.utilsW", "langchain-core-2", None), + ("langchain_core.utilsW", None, "langchain-coreWX"), + ("langchain_core.utilsW", "langchain-core-2", "langchain-coreWX"), + ("langchain_coreW", None, None), # ModuleNotFoundError + ], +) +def test_guard_import_failure( + module_name: str, pip_name: Optional[str], package: Optional[str] +) -> None: + with pytest.raises(ImportError) as exc_info: + if package is None and pip_name is None: + guard_import(module_name) + elif package is None and pip_name is not None: + guard_import(module_name, pip_name=pip_name) + elif package is not None and pip_name is None: + guard_import(module_name, package=package) + elif package is not None and pip_name is not None: + guard_import(module_name, pip_name=pip_name, package=package) + else: + raise ValueError("Invalid test case") + pip_name = pip_name or module_name.split(".")[0].replace("_", "-") + err_msg = ( + f"Could not import {module_name} python package. " + f"Please install it with `pip install {pip_name}`." + ) + assert exc_info.value.msg == err_msg