From 75e50a3efd75d5e9fad6803df0c7d3a73a64f3f4 Mon Sep 17 00:00:00 2001 From: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com> Date: Thu, 17 Apr 2025 14:15:28 -0400 Subject: [PATCH] core[patch]: Raise `AttributeError` (instead of `ModuleNotFoundError`) in custom `__getattr__` (#30905) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Follow up to https://github.com/langchain-ai/langchain/pull/30769, fixing the regression reported [here](https://github.com/langchain-ai/langchain/pull/30769#issuecomment-2807483610), thanks @krassowski for the report! Fix inspired by https://github.com/PrefectHQ/prefect/pull/16172/files Other changes: * Using tuples for `__all__`, except in `output_parsers` bc of a list namespace conflict * Using a helper function for imports due to repeated logic across `__init__.py` files becoming hard to maintain. Co-authored-by: MichaƂ Krassowski < krassowski 5832902+krassowski@users.noreply.github.com>" --- libs/core/langchain_core/_api/__init__.py | 14 +++----- libs/core/langchain_core/_import_utils.py | 34 +++++++++++++++++++ .../core/langchain_core/callbacks/__init__.py | 14 +++----- .../document_loaders/__init__.py | 14 +++----- .../core/langchain_core/documents/__init__.py | 12 +++---- .../langchain_core/embeddings/__init__.py | 12 +++---- .../example_selectors/__init__.py | 14 +++----- libs/core/langchain_core/indexing/__init__.py | 14 +++----- .../language_models/__init__.py | 14 +++----- libs/core/langchain_core/load/__init__.py | 12 +++---- libs/core/langchain_core/messages/__init__.py | 14 +++----- .../langchain_core/output_parsers/__init__.py | 10 ++---- libs/core/langchain_core/outputs/__init__.py | 14 +++----- libs/core/langchain_core/prompts/__init__.py | 14 +++----- .../core/langchain_core/runnables/__init__.py | 14 +++----- libs/core/langchain_core/tools/__init__.py | 14 +++----- libs/core/langchain_core/tracers/__init__.py | 14 +++----- libs/core/langchain_core/utils/__init__.py | 14 +++----- .../langchain_core/vectorstores/__init__.py | 31 ++++++++++++++--- .../unit_tests/indexing/test_public_api.py | 2 +- 20 files changed, 142 insertions(+), 153 deletions(-) create mode 100644 libs/core/langchain_core/_import_utils.py diff --git a/libs/core/langchain_core/_api/__init__.py b/libs/core/langchain_core/_api/__init__.py index 7e22d2d3064..01b52e8615c 100644 --- a/libs/core/langchain_core/_api/__init__.py +++ b/libs/core/langchain_core/_api/__init__.py @@ -9,9 +9,10 @@ This module is only relevant for LangChain developers, not for users. """ -from importlib import import_module from typing import TYPE_CHECKING +from langchain_core._import_utils import import_attr + if TYPE_CHECKING: from .beta_decorator import ( LangChainBetaWarning, @@ -28,7 +29,7 @@ if TYPE_CHECKING: ) from .path import as_import_path, get_relative_path -__all__ = [ +__all__ = ( "as_import_path", "beta", "deprecated", @@ -40,7 +41,7 @@ __all__ = [ "suppress_langchain_deprecation_warning", "surface_langchain_deprecation_warnings", "warn_deprecated", -] +) _dynamic_imports = { "LangChainBetaWarning": "beta_decorator", @@ -59,12 +60,7 @@ _dynamic_imports = { def __getattr__(attr_name: str) -> object: module_name = _dynamic_imports.get(attr_name) - package = __spec__.parent - if module_name == "__module__" or module_name is None: - result = import_module(f".{attr_name}", package=package) - else: - module = import_module(f".{module_name}", package=package) - result = getattr(module, attr_name) + result = import_attr(attr_name, module_name, __spec__.parent) globals()[attr_name] = result return result diff --git a/libs/core/langchain_core/_import_utils.py b/libs/core/langchain_core/_import_utils.py new file mode 100644 index 00000000000..6c1d99a3f4d --- /dev/null +++ b/libs/core/langchain_core/_import_utils.py @@ -0,0 +1,34 @@ +from importlib import import_module +from typing import Union + + +def import_attr( + attr_name: str, + module_name: Union[str, None], + package: Union[str, None], +) -> object: + """Import an attribute from a module located in a package. + + This utility function is used in custom __getattr__ methods within __init__.py + files to dynamically import attributes. + + Args: + attr_name: The name of the attribute to import. + module_name: The name of the module to import from. If None, the attribute + is imported from the package itself. + package: The name of the package where the module is located. + """ + if module_name == "__module__" or module_name is None: + try: + result = import_module(f".{attr_name}", package=package) + except ModuleNotFoundError: + msg = f"module '{package!r}' has no attribute {attr_name!r}" + raise AttributeError(msg) from None + else: + try: + module = import_module(f".{module_name}", package=package) + except ModuleNotFoundError: + msg = f"module '{package!r}.{module_name!r}' not found" + raise ImportError(msg) from None + result = getattr(module, attr_name) + return result diff --git a/libs/core/langchain_core/callbacks/__init__.py b/libs/core/langchain_core/callbacks/__init__.py index 487f5c259c3..6e842b96598 100644 --- a/libs/core/langchain_core/callbacks/__init__.py +++ b/libs/core/langchain_core/callbacks/__init__.py @@ -7,9 +7,10 @@ BaseCallbackHandler --> CallbackHandler # Example: AimCallbackHandler """ -from importlib import import_module from typing import TYPE_CHECKING +from langchain_core._import_utils import import_attr + if TYPE_CHECKING: from langchain_core.callbacks.base import ( AsyncCallbackHandler, @@ -52,7 +53,7 @@ if TYPE_CHECKING: get_usage_metadata_callback, ) -__all__ = [ +__all__ = ( "dispatch_custom_event", "adispatch_custom_event", "RetrieverManagerMixin", @@ -87,7 +88,7 @@ __all__ = [ "FileCallbackHandler", "UsageMetadataCallbackHandler", "get_usage_metadata_callback", -] +) _dynamic_imports = { "AsyncCallbackHandler": "base", @@ -129,12 +130,7 @@ _dynamic_imports = { def __getattr__(attr_name: str) -> object: module_name = _dynamic_imports.get(attr_name) - package = __spec__.parent - if module_name == "__module__" or module_name is None: - result = import_module(f".{attr_name}", package=package) - else: - module = import_module(f".{module_name}", package=package) - result = getattr(module, attr_name) + result = import_attr(attr_name, module_name, __spec__.parent) globals()[attr_name] = result return result diff --git a/libs/core/langchain_core/document_loaders/__init__.py b/libs/core/langchain_core/document_loaders/__init__.py index 903667c4db0..f4225154a6f 100644 --- a/libs/core/langchain_core/document_loaders/__init__.py +++ b/libs/core/langchain_core/document_loaders/__init__.py @@ -1,21 +1,22 @@ """Document loaders.""" -from importlib import import_module from typing import TYPE_CHECKING +from langchain_core._import_utils import import_attr + if TYPE_CHECKING: from langchain_core.document_loaders.base import BaseBlobParser, BaseLoader from langchain_core.document_loaders.blob_loaders import Blob, BlobLoader, PathLike from langchain_core.document_loaders.langsmith import LangSmithLoader -__all__ = [ +__all__ = ( "BaseBlobParser", "BaseLoader", "Blob", "BlobLoader", "PathLike", "LangSmithLoader", -] +) _dynamic_imports = { "BaseBlobParser": "base", @@ -29,12 +30,7 @@ _dynamic_imports = { def __getattr__(attr_name: str) -> object: module_name = _dynamic_imports.get(attr_name) - package = __spec__.parent - if module_name == "__module__" or module_name is None: - result = import_module(f".{attr_name}", package=package) - else: - module = import_module(f".{module_name}", package=package) - result = getattr(module, attr_name) + result = import_attr(attr_name, module_name, __spec__.parent) globals()[attr_name] = result return result diff --git a/libs/core/langchain_core/documents/__init__.py b/libs/core/langchain_core/documents/__init__.py index 68fda191a0d..bc79a7a0dc2 100644 --- a/libs/core/langchain_core/documents/__init__.py +++ b/libs/core/langchain_core/documents/__init__.py @@ -5,15 +5,16 @@ and their transformations. """ -from importlib import import_module from typing import TYPE_CHECKING +from langchain_core._import_utils import import_attr + if TYPE_CHECKING: from .base import Document from .compressor import BaseDocumentCompressor from .transformers import BaseDocumentTransformer -__all__ = ["Document", "BaseDocumentTransformer", "BaseDocumentCompressor"] +__all__ = ("Document", "BaseDocumentTransformer", "BaseDocumentCompressor") _dynamic_imports = { "Document": "base", @@ -24,12 +25,7 @@ _dynamic_imports = { def __getattr__(attr_name: str) -> object: module_name = _dynamic_imports.get(attr_name) - package = __spec__.parent - if module_name == "__module__" or module_name is None: - result = import_module(f".{attr_name}", package=package) - else: - module = import_module(f".{module_name}", package=package) - result = getattr(module, attr_name) + result = import_attr(attr_name, module_name, __spec__.parent) globals()[attr_name] = result return result diff --git a/libs/core/langchain_core/embeddings/__init__.py b/libs/core/langchain_core/embeddings/__init__.py index c492330b92d..66acae126fc 100644 --- a/libs/core/langchain_core/embeddings/__init__.py +++ b/libs/core/langchain_core/embeddings/__init__.py @@ -1,8 +1,9 @@ """Embeddings.""" -from importlib import import_module from typing import TYPE_CHECKING +from langchain_core._import_utils import import_attr + if TYPE_CHECKING: from langchain_core.embeddings.embeddings import Embeddings from langchain_core.embeddings.fake import ( @@ -10,7 +11,7 @@ if TYPE_CHECKING: FakeEmbeddings, ) -__all__ = ["DeterministicFakeEmbedding", "Embeddings", "FakeEmbeddings"] +__all__ = ("DeterministicFakeEmbedding", "Embeddings", "FakeEmbeddings") _dynamic_imports = { "Embeddings": "embeddings", @@ -21,12 +22,7 @@ _dynamic_imports = { def __getattr__(attr_name: str) -> object: module_name = _dynamic_imports.get(attr_name) - package = __spec__.parent - if module_name == "__module__" or module_name is None: - result = import_module(f".{attr_name}", package=package) - else: - module = import_module(f".{module_name}", package=package) - result = getattr(module, attr_name) + result = import_attr(attr_name, module_name, __spec__.parent) globals()[attr_name] = result return result diff --git a/libs/core/langchain_core/example_selectors/__init__.py b/libs/core/langchain_core/example_selectors/__init__.py index ba3d55f509f..db079c9f91f 100644 --- a/libs/core/langchain_core/example_selectors/__init__.py +++ b/libs/core/langchain_core/example_selectors/__init__.py @@ -4,9 +4,10 @@ This allows us to select examples that are most relevant to the input. """ -from importlib import import_module from typing import TYPE_CHECKING +from langchain_core._import_utils import import_attr + if TYPE_CHECKING: from langchain_core.example_selectors.base import BaseExampleSelector from langchain_core.example_selectors.length_based import ( @@ -18,13 +19,13 @@ if TYPE_CHECKING: sorted_values, ) -__all__ = [ +__all__ = ( "BaseExampleSelector", "LengthBasedExampleSelector", "MaxMarginalRelevanceExampleSelector", "SemanticSimilarityExampleSelector", "sorted_values", -] +) _dynamic_imports = { "BaseExampleSelector": "base", @@ -37,12 +38,7 @@ _dynamic_imports = { def __getattr__(attr_name: str) -> object: module_name = _dynamic_imports.get(attr_name) - package = __spec__.parent - if module_name == "__module__" or module_name is None: - result = import_module(f".{attr_name}", package=package) - else: - module = import_module(f".{module_name}", package=package) - result = getattr(module, attr_name) + result = import_attr(attr_name, module_name, __spec__.parent) globals()[attr_name] = result return result diff --git a/libs/core/langchain_core/indexing/__init__.py b/libs/core/langchain_core/indexing/__init__.py index 50545cd2e54..2a64cc2510e 100644 --- a/libs/core/langchain_core/indexing/__init__.py +++ b/libs/core/langchain_core/indexing/__init__.py @@ -5,9 +5,10 @@ a vectorstore while avoiding duplicated content and over-writing content if it's unchanged. """ -from importlib import import_module from typing import TYPE_CHECKING +from langchain_core._import_utils import import_attr + if TYPE_CHECKING: from langchain_core.indexing.api import IndexingResult, aindex, index from langchain_core.indexing.base import ( @@ -18,7 +19,7 @@ if TYPE_CHECKING: UpsertResponse, ) -__all__ = [ +__all__ = ( "aindex", "DeleteResponse", "DocumentIndex", @@ -27,7 +28,7 @@ __all__ = [ "InMemoryRecordManager", "RecordManager", "UpsertResponse", -] +) _dynamic_imports = { "aindex": "api", @@ -43,12 +44,7 @@ _dynamic_imports = { def __getattr__(attr_name: str) -> object: module_name = _dynamic_imports.get(attr_name) - package = __spec__.parent - if module_name == "__module__" or module_name is None: - result = import_module(f".{attr_name}", package=package) - else: - module = import_module(f".{module_name}", package=package) - result = getattr(module, attr_name) + result = import_attr(attr_name, module_name, __spec__.parent) globals()[attr_name] = result return result diff --git a/libs/core/langchain_core/language_models/__init__.py b/libs/core/langchain_core/language_models/__init__.py index e99da7f1d4b..8ed64aeaa08 100644 --- a/libs/core/langchain_core/language_models/__init__.py +++ b/libs/core/langchain_core/language_models/__init__.py @@ -41,9 +41,10 @@ https://python.langchain.com/docs/how_to/custom_llm/ """ # noqa: E501 -from importlib import import_module from typing import TYPE_CHECKING +from langchain_core._import_utils import import_attr + if TYPE_CHECKING: from langchain_core.language_models.base import ( BaseLanguageModel, @@ -66,7 +67,7 @@ if TYPE_CHECKING: ) from langchain_core.language_models.llms import LLM, BaseLLM -__all__ = [ +__all__ = ( "BaseLanguageModel", "BaseChatModel", "SimpleChatModel", @@ -83,7 +84,7 @@ __all__ = [ "FakeMessagesListChatModel", "GenericFakeChatModel", "ParrotFakeChatModel", -] +) _dynamic_imports = { "BaseLanguageModel": "base", @@ -107,12 +108,7 @@ _dynamic_imports = { def __getattr__(attr_name: str) -> object: module_name = _dynamic_imports.get(attr_name) - package = __spec__.parent - if module_name == "__module__" or module_name is None: - result = import_module(f".{attr_name}", package=package) - else: - module = import_module(f".{module_name}", package=package) - result = getattr(module, attr_name) + result = import_attr(attr_name, module_name, __spec__.parent) globals()[attr_name] = result return result diff --git a/libs/core/langchain_core/load/__init__.py b/libs/core/langchain_core/load/__init__.py index f5a0e6086eb..a87a380f5cf 100644 --- a/libs/core/langchain_core/load/__init__.py +++ b/libs/core/langchain_core/load/__init__.py @@ -1,8 +1,9 @@ """**Load** module helps with serialization and deserialization.""" -from importlib import import_module from typing import TYPE_CHECKING +from langchain_core._import_utils import import_attr + if TYPE_CHECKING: from langchain_core.load.dump import dumpd, dumps from langchain_core.load.load import loads @@ -14,7 +15,7 @@ if TYPE_CHECKING: # the `from langchain_core.load.load import load` absolute import should also work. from langchain_core.load.load import load -__all__ = ["dumpd", "dumps", "load", "loads", "Serializable"] +__all__ = ("dumpd", "dumps", "load", "loads", "Serializable") _dynamic_imports = { "dumpd": "dump", @@ -26,12 +27,7 @@ _dynamic_imports = { def __getattr__(attr_name: str) -> object: module_name = _dynamic_imports.get(attr_name) - package = __spec__.parent - if module_name == "__module__" or module_name is None: - result = import_module(f".{attr_name}", package=package) - else: - module = import_module(f".{module_name}", package=package) - result = getattr(module, attr_name) + result = import_attr(attr_name, module_name, __spec__.parent) globals()[attr_name] = result return result diff --git a/libs/core/langchain_core/messages/__init__.py b/libs/core/langchain_core/messages/__init__.py index a7f8db60f62..d4e22138cef 100644 --- a/libs/core/langchain_core/messages/__init__.py +++ b/libs/core/langchain_core/messages/__init__.py @@ -15,9 +15,10 @@ """ # noqa: E501 -from importlib import import_module from typing import TYPE_CHECKING +from langchain_core._import_utils import import_attr + if TYPE_CHECKING: from langchain_core.messages.ai import ( AIMessage, @@ -60,7 +61,7 @@ if TYPE_CHECKING: trim_messages, ) -__all__ = [ +__all__ = ( "AIMessage", "AIMessageChunk", "AnyMessage", @@ -95,7 +96,7 @@ __all__ = [ "merge_message_runs", "trim_messages", "convert_to_openai_messages", -] +) _dynamic_imports = { "AIMessage": "ai", @@ -137,12 +138,7 @@ _dynamic_imports = { def __getattr__(attr_name: str) -> object: module_name = _dynamic_imports.get(attr_name) - package = __spec__.parent - if module_name == "__module__" or module_name is None: - result = import_module(f".{attr_name}", package=package) - else: - module = import_module(f".{module_name}", package=package) - result = getattr(module, attr_name) + result = import_attr(attr_name, module_name, __spec__.parent) globals()[attr_name] = result return result diff --git a/libs/core/langchain_core/output_parsers/__init__.py b/libs/core/langchain_core/output_parsers/__init__.py index c72c7f8506c..1c81be2ade5 100644 --- a/libs/core/langchain_core/output_parsers/__init__.py +++ b/libs/core/langchain_core/output_parsers/__init__.py @@ -13,9 +13,10 @@ Serializable, Generation, PromptValue """ # noqa: E501 -from importlib import import_module from typing import TYPE_CHECKING +from langchain_core._import_utils import import_attr + if TYPE_CHECKING: from langchain_core.output_parsers.base import ( BaseGenerationOutputParser, @@ -88,12 +89,7 @@ _dynamic_imports = { def __getattr__(attr_name: str) -> object: module_name = _dynamic_imports.get(attr_name) - package = __spec__.parent - if module_name == "__module__" or module_name is None: - result = import_module(f".{attr_name}", package=package) - else: - module = import_module(f".{module_name}", package=package) - result = getattr(module, attr_name) + result = import_attr(attr_name, module_name, __spec__.parent) globals()[attr_name] = result return result diff --git a/libs/core/langchain_core/outputs/__init__.py b/libs/core/langchain_core/outputs/__init__.py index 0f20be74280..b9072b9c929 100644 --- a/libs/core/langchain_core/outputs/__init__.py +++ b/libs/core/langchain_core/outputs/__init__.py @@ -21,9 +21,10 @@ in the AIMessage object, it is recommended to access it from there rather than from the `LLMResult` object. """ -from importlib import import_module from typing import TYPE_CHECKING +from langchain_core._import_utils import import_attr + if TYPE_CHECKING: from langchain_core.outputs.chat_generation import ( ChatGeneration, @@ -34,7 +35,7 @@ if TYPE_CHECKING: from langchain_core.outputs.llm_result import LLMResult from langchain_core.outputs.run_info import RunInfo -__all__ = [ +__all__ = ( "ChatGeneration", "ChatGenerationChunk", "ChatResult", @@ -42,7 +43,7 @@ __all__ = [ "GenerationChunk", "LLMResult", "RunInfo", -] +) _dynamic_imports = { "ChatGeneration": "chat_generation", @@ -57,12 +58,7 @@ _dynamic_imports = { def __getattr__(attr_name: str) -> object: module_name = _dynamic_imports.get(attr_name) - package = __spec__.parent - if module_name == "__module__" or module_name is None: - result = import_module(f".{attr_name}", package=package) - else: - module = import_module(f".{module_name}", package=package) - result = getattr(module, attr_name) + result = import_attr(attr_name, module_name, __spec__.parent) globals()[attr_name] = result return result diff --git a/libs/core/langchain_core/prompts/__init__.py b/libs/core/langchain_core/prompts/__init__.py index bc661e53ad1..706b3b6f8fd 100644 --- a/libs/core/langchain_core/prompts/__init__.py +++ b/libs/core/langchain_core/prompts/__init__.py @@ -25,9 +25,10 @@ from multiple components and prompt values. Prompt classes and functions make co """ # noqa: E501 -from importlib import import_module from typing import TYPE_CHECKING +from langchain_core._import_utils import import_attr + if TYPE_CHECKING: from langchain_core.prompts.base import ( BasePromptTemplate, @@ -61,7 +62,7 @@ if TYPE_CHECKING: validate_jinja2, ) -__all__ = [ +__all__ = ( "AIMessagePromptTemplate", "BaseChatPromptTemplate", "BasePromptTemplate", @@ -83,7 +84,7 @@ __all__ = [ "get_template_variables", "jinja2_formatter", "validate_jinja2", -] +) _dynamic_imports = { "BasePromptTemplate": "base", @@ -112,12 +113,7 @@ _dynamic_imports = { def __getattr__(attr_name: str) -> object: module_name = _dynamic_imports.get(attr_name) - package = __spec__.parent - if module_name == "__module__" or module_name is None: - result = import_module(f".{attr_name}", package=package) - else: - module = import_module(f".{module_name}", package=package) - result = getattr(module, attr_name) + result = import_attr(attr_name, module_name, __spec__.parent) globals()[attr_name] = result return result diff --git a/libs/core/langchain_core/runnables/__init__.py b/libs/core/langchain_core/runnables/__init__.py index d8a04ef5d19..12825773d80 100644 --- a/libs/core/langchain_core/runnables/__init__.py +++ b/libs/core/langchain_core/runnables/__init__.py @@ -17,9 +17,10 @@ creating more responsive UX. This module contains schema and implementation of LangChain Runnables primitives. """ -from importlib import import_module from typing import TYPE_CHECKING +from langchain_core._import_utils import import_attr + if TYPE_CHECKING: from langchain_core.runnables.base import ( Runnable, @@ -58,7 +59,7 @@ if TYPE_CHECKING: add, ) -__all__ = [ +__all__ = ( "chain", "AddableDict", "ConfigurableField", @@ -88,7 +89,7 @@ __all__ = [ "get_config_list", "aadd", "add", -] +) _dynamic_imports = { "chain": "base", @@ -125,12 +126,7 @@ _dynamic_imports = { def __getattr__(attr_name: str) -> object: module_name = _dynamic_imports.get(attr_name) - package = __spec__.parent - if module_name == "__module__" or module_name is None: - result = import_module(f".{attr_name}", package=package) - else: - module = import_module(f".{module_name}", package=package) - result = getattr(module, attr_name) + result = import_attr(attr_name, module_name, __spec__.parent) globals()[attr_name] = result return result diff --git a/libs/core/langchain_core/tools/__init__.py b/libs/core/langchain_core/tools/__init__.py index 7409647a6f3..f13b3167f0c 100644 --- a/libs/core/langchain_core/tools/__init__.py +++ b/libs/core/langchain_core/tools/__init__.py @@ -19,9 +19,10 @@ tool for the job. from __future__ import annotations -from importlib import import_module from typing import TYPE_CHECKING +from langchain_core._import_utils import import_attr + if TYPE_CHECKING: from langchain_core.tools.base import ( FILTERED_ARGS, @@ -51,7 +52,7 @@ if TYPE_CHECKING: from langchain_core.tools.simple import Tool from langchain_core.tools.structured import StructuredTool -__all__ = [ +__all__ = ( "ArgsSchema", "BaseTool", "BaseToolkit", @@ -71,7 +72,7 @@ __all__ = [ "create_retriever_tool", "Tool", "StructuredTool", -] +) _dynamic_imports = { "FILTERED_ARGS": "base", @@ -98,12 +99,7 @@ _dynamic_imports = { def __getattr__(attr_name: str) -> object: module_name = _dynamic_imports.get(attr_name) - package = __spec__.parent - if module_name == "__module__" or module_name is None: - result = import_module(f".{attr_name}", package=package) - else: - module = import_module(f".{module_name}", package=package) - result = getattr(module, attr_name) + result = import_attr(attr_name, module_name, __spec__.parent) globals()[attr_name] = result return result diff --git a/libs/core/langchain_core/tracers/__init__.py b/libs/core/langchain_core/tracers/__init__.py index cca451516f1..db8e828caac 100644 --- a/libs/core/langchain_core/tracers/__init__.py +++ b/libs/core/langchain_core/tracers/__init__.py @@ -8,9 +8,10 @@ --> # Examples: LogStreamCallbackHandler """ # noqa: E501 -from importlib import import_module from typing import TYPE_CHECKING +from langchain_core._import_utils import import_attr + if TYPE_CHECKING: from langchain_core.tracers.base import BaseTracer from langchain_core.tracers.evaluation import EvaluatorCallbackHandler @@ -23,7 +24,7 @@ if TYPE_CHECKING: from langchain_core.tracers.schemas import Run from langchain_core.tracers.stdout import ConsoleCallbackHandler -__all__ = [ +__all__ = ( "BaseTracer", "EvaluatorCallbackHandler", "LangChainTracer", @@ -32,7 +33,7 @@ __all__ = [ "RunLog", "RunLogPatch", "LogStreamCallbackHandler", -] +) _dynamic_imports = { "BaseTracer": "base", @@ -48,12 +49,7 @@ _dynamic_imports = { def __getattr__(attr_name: str) -> object: module_name = _dynamic_imports.get(attr_name) - package = __spec__.parent - if module_name == "__module__" or module_name is None: - result = import_module(f".{attr_name}", package=package) - else: - module = import_module(f".{module_name}", package=package) - result = getattr(module, attr_name) + result = import_attr(attr_name, module_name, __spec__.parent) globals()[attr_name] = result return result diff --git a/libs/core/langchain_core/utils/__init__.py b/libs/core/langchain_core/utils/__init__.py index 0469e5b68a0..52bd5c57609 100644 --- a/libs/core/langchain_core/utils/__init__.py +++ b/libs/core/langchain_core/utils/__init__.py @@ -3,9 +3,10 @@ These functions do not depend on any other LangChain module. """ -from importlib import import_module from typing import TYPE_CHECKING +from langchain_core._import_utils import import_attr + if TYPE_CHECKING: # for type checking and IDE support, we include the imports here # but we don't want to eagerly import them at runtime @@ -36,7 +37,7 @@ if TYPE_CHECKING: xor_args, ) -__all__ = [ +__all__ = ( "build_extra_kwargs", "StrictFormatter", "check_package_version", @@ -63,7 +64,7 @@ __all__ = [ "abatch_iterate", "from_env", "secret_from_env", -] +) _dynamic_imports = { "image": "__module__", @@ -97,12 +98,7 @@ _dynamic_imports = { def __getattr__(attr_name: str) -> object: module_name = _dynamic_imports.get(attr_name) - package = __spec__.parent - if module_name == "__module__" or module_name is None: - result = import_module(f".{attr_name}", package=package) - else: - module = import_module(f".{module_name}", package=package) - result = getattr(module, attr_name) + result = import_attr(attr_name, module_name, __spec__.parent) globals()[attr_name] = result return result diff --git a/libs/core/langchain_core/vectorstores/__init__.py b/libs/core/langchain_core/vectorstores/__init__.py index 632d5309624..3881feb98e5 100644 --- a/libs/core/langchain_core/vectorstores/__init__.py +++ b/libs/core/langchain_core/vectorstores/__init__.py @@ -1,11 +1,34 @@ """Vector stores.""" -from langchain_core.vectorstores.base import VST, VectorStore, VectorStoreRetriever -from langchain_core.vectorstores.in_memory import InMemoryVectorStore +from typing import TYPE_CHECKING -__all__ = [ +from langchain_core._import_utils import import_attr + +if TYPE_CHECKING: + from langchain_core.vectorstores.base import VST, VectorStore, VectorStoreRetriever + from langchain_core.vectorstores.in_memory import InMemoryVectorStore + +__all__ = ( "VectorStore", "VST", "VectorStoreRetriever", "InMemoryVectorStore", -] +) + +_dynamic_imports = { + "VectorStore": "base", + "VST": "base", + "VectorStoreRetriever": "base", + "InMemoryVectorStore": "in_memory", +} + + +def __getattr__(attr_name: str) -> object: + module_name = _dynamic_imports.get(attr_name) + result = import_attr(attr_name, module_name, __spec__.parent) + globals()[attr_name] = result + return result + + +def __dir__() -> list[str]: + return list(__all__) diff --git a/libs/core/tests/unit_tests/indexing/test_public_api.py b/libs/core/tests/unit_tests/indexing/test_public_api.py index 24ac092eb2d..e737c79e381 100644 --- a/libs/core/tests/unit_tests/indexing/test_public_api.py +++ b/libs/core/tests/unit_tests/indexing/test_public_api.py @@ -3,7 +3,7 @@ from langchain_core.indexing import __all__ def test_all() -> None: """Use to catch obvious breaking changes.""" - assert __all__ == sorted(__all__, key=str.lower) + assert list(__all__) == sorted(__all__, key=str.lower) assert set(__all__) == { "aindex", "DeleteResponse",