diff --git a/libs/core/langchain_core/callbacks/manager.py b/libs/core/langchain_core/callbacks/manager.py index 4851a1c1c1d..676bb15e525 100644 --- a/libs/core/langchain_core/callbacks/manager.py +++ b/libs/core/langchain_core/callbacks/manager.py @@ -28,8 +28,18 @@ from langchain_core.callbacks.base import ( ToolManagerMixin, ) from langchain_core.callbacks.stdout import StdOutCallbackHandler +from langchain_core.globals import get_debug from langchain_core.messages import BaseMessage, get_buffer_string +from langchain_core.tracers.context import ( + _configure_hooks, + _get_trace_callbacks, + _get_tracer_project, + _tracing_v2_is_enabled, + tracing_v2_callback_var, +) +from langchain_core.tracers.langchain import LangChainTracer from langchain_core.tracers.schemas import Run +from langchain_core.tracers.stdout import ConsoleCallbackHandler from langchain_core.utils.env import env_var_is_set if TYPE_CHECKING: @@ -46,8 +56,6 @@ logger = logging.getLogger(__name__) def _get_debug() -> bool: - from langchain_core.globals import get_debug - return get_debug() @@ -103,8 +111,6 @@ def trace_as_chain_group( manager.on_chain_end({"output": res}) """ - from langchain_core.tracers.context import _get_trace_callbacks - cb = _get_trace_callbacks( project_name, example_id, callback_manager=callback_manager ) @@ -189,8 +195,6 @@ async def atrace_as_chain_group( await manager.on_chain_end({"output": res}) """ - from langchain_core.tracers.context import _get_trace_callbacks - cb = _get_trace_callbacks( project_name, example_id, callback_manager=callback_manager ) @@ -2376,13 +2380,6 @@ def _configure( Returns: T: The configured callback manager. """ - from langchain_core.tracers.context import ( - _configure_hooks, - _get_tracer_project, - _tracing_v2_is_enabled, - tracing_v2_callback_var, - ) - tracing_context = get_tracing_context() tracing_metadata = tracing_context["metadata"] tracing_tags = tracing_context["tags"] @@ -2459,9 +2456,6 @@ def _configure( tracer_project = _get_tracer_project() debug = _get_debug() if verbose or debug or tracing_v2_enabled_: - from langchain_core.tracers.langchain import LangChainTracer - from langchain_core.tracers.stdout import ConsoleCallbackHandler - if verbose and not any( isinstance(handler, StdOutCallbackHandler) for handler in callback_manager.handlers @@ -2630,7 +2624,8 @@ async def adispatch_custom_event( .. versionadded:: 0.2.15 """ - from langchain_core.runnables.config import ( + # Import locally to prevent circular imports. + from langchain_core.runnables.config import ( # noqa: PLC0415 ensure_config, get_async_callback_manager_for_config, ) @@ -2705,7 +2700,8 @@ def dispatch_custom_event( .. versionadded:: 0.2.15 """ - from langchain_core.runnables.config import ( + # Import locally to prevent circular imports. + from langchain_core.runnables.config import ( # noqa: PLC0415 ensure_config, get_callback_manager_for_config, ) diff --git a/libs/core/langchain_core/callbacks/usage.py b/libs/core/langchain_core/callbacks/usage.py index 497b6ff4912..90fa7573997 100644 --- a/libs/core/langchain_core/callbacks/usage.py +++ b/libs/core/langchain_core/callbacks/usage.py @@ -12,6 +12,7 @@ from langchain_core.callbacks import BaseCallbackHandler from langchain_core.messages import AIMessage from langchain_core.messages.ai import UsageMetadata, add_usage from langchain_core.outputs import ChatGeneration, LLMResult +from langchain_core.tracers.context import register_configure_hook class UsageMetadataCallbackHandler(BaseCallbackHandler): @@ -133,8 +134,6 @@ def get_usage_metadata_callback( .. versionadded:: 0.3.49 """ - from langchain_core.tracers.context import register_configure_hook - usage_metadata_callback_var: ContextVar[Optional[UsageMetadataCallbackHandler]] = ( ContextVar(name, default=None) ) diff --git a/libs/core/langchain_core/chat_history.py b/libs/core/langchain_core/chat_history.py index 0091b240b69..6f66b7aba75 100644 --- a/libs/core/langchain_core/chat_history.py +++ b/libs/core/langchain_core/chat_history.py @@ -27,6 +27,7 @@ from langchain_core.messages import ( HumanMessage, get_buffer_string, ) +from langchain_core.runnables.config import run_in_executor if TYPE_CHECKING: from collections.abc import Sequence @@ -113,8 +114,6 @@ class BaseChatMessageHistory(ABC): Returns: The messages. """ - from langchain_core.runnables.config import run_in_executor - return await run_in_executor(None, lambda: self.messages) def add_user_message(self, message: Union[HumanMessage, str]) -> None: @@ -190,8 +189,6 @@ class BaseChatMessageHistory(ABC): Args: messages: A sequence of BaseMessage objects to store. """ - from langchain_core.runnables.config import run_in_executor - await run_in_executor(None, self.add_messages, messages) @abstractmethod @@ -200,8 +197,6 @@ class BaseChatMessageHistory(ABC): async def aclear(self) -> None: """Async remove all messages from the store.""" - from langchain_core.runnables.config import run_in_executor - await run_in_executor(None, self.clear) def __str__(self) -> str: diff --git a/libs/core/langchain_core/document_loaders/base.py b/libs/core/langchain_core/document_loaders/base.py index 7208529da17..deeb9569ad2 100644 --- a/libs/core/langchain_core/document_loaders/base.py +++ b/libs/core/langchain_core/document_loaders/base.py @@ -15,6 +15,13 @@ if TYPE_CHECKING: from langchain_core.documents import Document from langchain_core.documents.base import Blob +try: + from langchain_text_splitters import RecursiveCharacterTextSplitter + + _HAS_TEXT_SPLITTERS = True +except ImportError: + _HAS_TEXT_SPLITTERS = False + class BaseLoader(ABC): # noqa: B024 """Interface for Document Loader. @@ -62,15 +69,13 @@ class BaseLoader(ABC): # noqa: B024 List of Documents. """ if text_splitter is None: - try: - from langchain_text_splitters import RecursiveCharacterTextSplitter - except ImportError as e: + if not _HAS_TEXT_SPLITTERS: msg = ( "Unable to import from langchain_text_splitters. Please specify " "text_splitter or install langchain_text_splitters with " "`pip install -U langchain-text-splitters`." ) - raise ImportError(msg) from e + raise ImportError(msg) text_splitter_: TextSplitter = RecursiveCharacterTextSplitter() else: diff --git a/libs/core/langchain_core/embeddings/fake.py b/libs/core/langchain_core/embeddings/fake.py index cd644bf8c69..c33281c484a 100644 --- a/libs/core/langchain_core/embeddings/fake.py +++ b/libs/core/langchain_core/embeddings/fake.py @@ -1,6 +1,7 @@ """Module contains a few fake embedding models for testing purposes.""" # Please do not add additional fake embedding model implementations here. +import contextlib import hashlib from pydantic import BaseModel @@ -8,6 +9,9 @@ from typing_extensions import override from langchain_core.embeddings import Embeddings +with contextlib.suppress(ImportError): + import numpy as np + class FakeEmbeddings(Embeddings, BaseModel): """Fake embedding model for unit testing purposes. @@ -54,8 +58,6 @@ class FakeEmbeddings(Embeddings, BaseModel): """The size of the embedding vector.""" def _get_embedding(self) -> list[float]: - import numpy as np - return list(np.random.default_rng().normal(size=self.size)) @override @@ -113,8 +115,6 @@ class DeterministicFakeEmbedding(Embeddings, BaseModel): """The size of the embedding vector.""" def _get_embedding(self, seed: int) -> list[float]: - import numpy as np - # set the seed for the random generator rng = np.random.default_rng(seed) return list(rng.normal(size=self.size)) diff --git a/libs/core/langchain_core/env.py b/libs/core/langchain_core/env.py index fd619bc1d94..240e62a8e6e 100644 --- a/libs/core/langchain_core/env.py +++ b/libs/core/langchain_core/env.py @@ -3,6 +3,8 @@ import platform from functools import lru_cache +from langchain_core import __version__ + @lru_cache(maxsize=1) def get_runtime_environment() -> dict: @@ -11,9 +13,6 @@ def get_runtime_environment() -> dict: Returns: A dictionary with information about the runtime environment. """ - # Lazy import to avoid circular imports - from langchain_core import __version__ - return { "library_version": __version__, "library": "langchain-core", diff --git a/libs/core/langchain_core/globals.py b/libs/core/langchain_core/globals.py index faf26578f35..01e851f942e 100644 --- a/libs/core/langchain_core/globals.py +++ b/libs/core/langchain_core/globals.py @@ -6,6 +6,13 @@ from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: from langchain_core.caches import BaseCache +try: + import langchain # type: ignore[import-not-found] + + _HAS_LANGCHAIN = True +except ImportError: + _HAS_LANGCHAIN = False + # DO NOT USE THESE VALUES DIRECTLY! # Use them only via `get_()` and `set_()` below, @@ -22,9 +29,7 @@ def set_verbose(value: bool) -> None: # noqa: FBT001 Args: value: The new value for the `verbose` global setting. """ - try: - import langchain # type: ignore[import-not-found] - + if _HAS_LANGCHAIN: # We're about to run some deprecated code, don't report warnings from it. # The user called the correct (non-deprecated) code path and shouldn't get # warnings. @@ -43,8 +48,6 @@ def set_verbose(value: bool) -> None: # noqa: FBT001 # Remove it once `langchain.verbose` is no longer supported, and once all # users have migrated to using `set_verbose()` here. langchain.verbose = value - except ImportError: - pass global _verbose # noqa: PLW0603 _verbose = value @@ -56,9 +59,7 @@ def get_verbose() -> bool: Returns: The value of the `verbose` global setting. """ - try: - import langchain - + if _HAS_LANGCHAIN: # We're about to run some deprecated code, don't report warnings from it. # The user called the correct (non-deprecated) code path and shouldn't get # warnings. @@ -83,7 +84,7 @@ def get_verbose() -> bool: # deprecation warnings directing them to use `set_verbose()` when they # import `langchain.verbose`. old_verbose = langchain.verbose - except ImportError: + else: old_verbose = False return _verbose or old_verbose @@ -95,9 +96,7 @@ def set_debug(value: bool) -> None: # noqa: FBT001 Args: value: The new value for the `debug` global setting. """ - try: - import langchain - + if _HAS_LANGCHAIN: # We're about to run some deprecated code, don't report warnings from it. # The user called the correct (non-deprecated) code path and shouldn't get # warnings. @@ -114,8 +113,6 @@ def set_debug(value: bool) -> None: # noqa: FBT001 # Remove it once `langchain.debug` is no longer supported, and once all # users have migrated to using `set_debug()` here. langchain.debug = value - except ImportError: - pass global _debug # noqa: PLW0603 _debug = value @@ -127,9 +124,7 @@ def get_debug() -> bool: Returns: The value of the `debug` global setting. """ - try: - import langchain - + if _HAS_LANGCHAIN: # We're about to run some deprecated code, don't report warnings from it. # The user called the correct (non-deprecated) code path and shouldn't get # warnings. @@ -151,7 +146,7 @@ def get_debug() -> bool: # to using `set_debug()` yet. Those users are getting deprecation warnings # directing them to use `set_debug()` when they import `langchain.debug`. old_debug = langchain.debug - except ImportError: + else: old_debug = False return _debug or old_debug @@ -163,9 +158,7 @@ def set_llm_cache(value: Optional["BaseCache"]) -> None: Args: value: The new LLM cache to use. If `None`, the LLM cache is disabled. """ - try: - import langchain - + if _HAS_LANGCHAIN: # We're about to run some deprecated code, don't report warnings from it. # The user called the correct (non-deprecated) code path and shouldn't get # warnings. @@ -184,8 +177,6 @@ def set_llm_cache(value: Optional["BaseCache"]) -> None: # Remove it once `langchain.llm_cache` is no longer supported, and # once all users have migrated to using `set_llm_cache()` here. langchain.llm_cache = value - except ImportError: - pass global _llm_cache # noqa: PLW0603 _llm_cache = value @@ -197,9 +188,7 @@ def get_llm_cache() -> Optional["BaseCache"]: Returns: The value of the `llm_cache` global setting. """ - try: - import langchain - + if _HAS_LANGCHAIN: # We're about to run some deprecated code, don't report warnings from it. # The user called the correct (non-deprecated) code path and shouldn't get # warnings. @@ -225,7 +214,7 @@ def get_llm_cache() -> Optional["BaseCache"]: # Those users are getting deprecation warnings directing them # to use `set_llm_cache()` when they import `langchain.llm_cache`. old_llm_cache = langchain.llm_cache - except ImportError: + else: old_llm_cache = None return _llm_cache or old_llm_cache diff --git a/libs/core/langchain_core/language_models/base.py b/libs/core/langchain_core/language_models/base.py index 46bb954e95b..f0e444b5859 100644 --- a/libs/core/langchain_core/language_models/base.py +++ b/libs/core/langchain_core/language_models/base.py @@ -22,19 +22,31 @@ from typing_extensions import TypeAlias, TypedDict, override from langchain_core._api import deprecated from langchain_core.caches import BaseCache from langchain_core.callbacks import Callbacks +from langchain_core.globals import get_verbose from langchain_core.messages import ( AnyMessage, BaseMessage, MessageLikeRepresentation, get_buffer_string, ) -from langchain_core.prompt_values import PromptValue +from langchain_core.prompt_values import ( + ChatPromptValueConcrete, + PromptValue, + StringPromptValue, +) from langchain_core.runnables import Runnable, RunnableSerializable from langchain_core.utils import get_pydantic_field_names if TYPE_CHECKING: from langchain_core.outputs import LLMResult +try: + from transformers import GPT2TokenizerFast # type: ignore[import-not-found] + + _HAS_TRANSFORMERS = True +except ImportError: + _HAS_TRANSFORMERS = False + class LangSmithParams(TypedDict, total=False): """LangSmith parameters for tracing.""" @@ -66,15 +78,13 @@ def get_tokenizer() -> Any: The GPT-2 tokenizer instance. """ - try: - from transformers import GPT2TokenizerFast # type: ignore[import-not-found] - except ImportError as e: + if not _HAS_TRANSFORMERS: msg = ( "Could not import transformers python package. " "This is needed in order to calculate get_token_ids. " "Please install it with `pip install transformers`." ) - raise ImportError(msg) from e + raise ImportError(msg) # create a GPT-2 tokenizer instance return GPT2TokenizerFast.from_pretrained("gpt2") @@ -95,8 +105,6 @@ LanguageModelOutputVar = TypeVar("LanguageModelOutputVar", BaseMessage, str) def _get_verbosity() -> bool: - from langchain_core.globals import get_verbose - return get_verbose() @@ -158,11 +166,6 @@ class BaseLanguageModel( @override def InputType(self) -> TypeAlias: """Get the input type for this runnable.""" - from langchain_core.prompt_values import ( - ChatPromptValueConcrete, - StringPromptValue, - ) - # This is a version of LanguageModelInput which replaces the abstract # base class BaseMessage with a union of its subclasses, which makes # for a much better schema. diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index 1d931689a3d..81ce5cbc505 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -46,6 +46,10 @@ from langchain_core.messages import ( message_chunk_to_message, ) from langchain_core.messages.ai import _LC_ID_PREFIX +from langchain_core.output_parsers.openai_tools import ( + JsonOutputKeyToolsParser, + PydanticToolsParser, +) from langchain_core.outputs import ( ChatGeneration, ChatGenerationChunk, @@ -1590,11 +1594,6 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): msg = f"Received unsupported arguments {kwargs}" raise ValueError(msg) - from langchain_core.output_parsers.openai_tools import ( - JsonOutputKeyToolsParser, - PydanticToolsParser, - ) - if type(self).bind_tools is BaseChatModel.bind_tools: msg = "with_structured_output is not implemented for this model." raise NotImplementedError(msg) diff --git a/libs/core/langchain_core/load/dump.py b/libs/core/langchain_core/load/dump.py index f8ab13a7f97..c7cd819a4ae 100644 --- a/libs/core/langchain_core/load/dump.py +++ b/libs/core/langchain_core/load/dump.py @@ -6,6 +6,8 @@ from typing import Any from pydantic import BaseModel from langchain_core.load.serializable import Serializable, to_json_not_implemented +from langchain_core.messages import AIMessage +from langchain_core.outputs import ChatGeneration def default(obj: Any) -> Any: @@ -23,9 +25,6 @@ def default(obj: Any) -> Any: def _dump_pydantic_models(obj: Any) -> Any: - from langchain_core.messages import AIMessage - from langchain_core.outputs import ChatGeneration - if ( isinstance(obj, ChatGeneration) and isinstance(obj.message, AIMessage) diff --git a/libs/core/langchain_core/messages/base.py b/libs/core/langchain_core/messages/base.py index 8810d1ecf13..4b5b296614b 100644 --- a/libs/core/langchain_core/messages/base.py +++ b/libs/core/langchain_core/messages/base.py @@ -118,7 +118,8 @@ class BaseMessage(Serializable): Returns: A ChatPromptTemplate containing both messages. """ - from langchain_core.prompts.chat import ChatPromptTemplate + # Import locally to prevent circular imports. + from langchain_core.prompts.chat import ChatPromptTemplate # noqa: PLC0415 prompt = ChatPromptTemplate(messages=[self]) return prompt + other diff --git a/libs/core/langchain_core/messages/utils.py b/libs/core/langchain_core/messages/utils.py index 65723ddc3a9..d6971e2ebd0 100644 --- a/libs/core/langchain_core/messages/utils.py +++ b/libs/core/langchain_core/messages/utils.py @@ -42,12 +42,17 @@ from langchain_core.messages.system import SystemMessage, SystemMessageChunk from langchain_core.messages.tool import ToolCall, ToolMessage, ToolMessageChunk if TYPE_CHECKING: - from langchain_text_splitters import TextSplitter - from langchain_core.language_models import BaseLanguageModel from langchain_core.prompt_values import PromptValue from langchain_core.runnables.base import Runnable +try: + from langchain_text_splitters import TextSplitter + + _HAS_LANGCHAIN_TEXT_SPLITTERS = True +except ImportError: + _HAS_LANGCHAIN_TEXT_SPLITTERS = False + logger = logging.getLogger(__name__) @@ -361,7 +366,7 @@ def convert_to_messages( list of messages (BaseMessages). """ # Import here to avoid circular imports - from langchain_core.prompt_values import PromptValue + from langchain_core.prompt_values import PromptValue # noqa: PLC0415 if isinstance(messages, PromptValue): return messages.to_messages() @@ -386,7 +391,8 @@ def _runnable_support(func: Callable) -> Callable: list[BaseMessage], Runnable[Sequence[MessageLikeRepresentation], list[BaseMessage]], ]: - from langchain_core.runnables.base import RunnableLambda + # Import locally to prevent circular import. + from langchain_core.runnables.base import RunnableLambda # noqa: PLC0415 if messages is not None: return func(messages, **kwargs) @@ -989,17 +995,12 @@ def trim_messages( ) raise ValueError(msg) - try: - from langchain_text_splitters import TextSplitter - except ImportError: - text_splitter_fn: Optional[Callable] = cast("Optional[Callable]", text_splitter) + if _HAS_LANGCHAIN_TEXT_SPLITTERS and isinstance(text_splitter, TextSplitter): + text_splitter_fn = text_splitter.split_text + elif text_splitter: + text_splitter_fn = cast("Callable", text_splitter) else: - if isinstance(text_splitter, TextSplitter): - text_splitter_fn = text_splitter.split_text - else: - text_splitter_fn = text_splitter - - text_splitter_fn = text_splitter_fn or _default_text_splitter + text_splitter_fn = _default_text_splitter if strategy == "first": return _first_max_tokens( diff --git a/libs/core/langchain_core/output_parsers/xml.py b/libs/core/langchain_core/output_parsers/xml.py index 0d7ef8b1e79..be12760fade 100644 --- a/libs/core/langchain_core/output_parsers/xml.py +++ b/libs/core/langchain_core/output_parsers/xml.py @@ -15,6 +15,14 @@ from langchain_core.messages import BaseMessage from langchain_core.output_parsers.transform import BaseTransformOutputParser from langchain_core.runnables.utils import AddableDict +try: + from defusedxml import ElementTree # type: ignore[import-untyped] + from defusedxml.ElementTree import XMLParser # type: ignore[import-untyped] + + _HAS_DEFUSEDXML = True +except ImportError: + _HAS_DEFUSEDXML = False + XML_FORMAT_INSTRUCTIONS = """The output should be formatted as a XML file. 1. Output should conform to the tags below. 2. If tags are not given, make them on your own. @@ -50,17 +58,13 @@ class _StreamingParser: parser is requested. """ if parser == "defusedxml": - try: - from defusedxml.ElementTree import ( # type: ignore[import-untyped] - XMLParser, - ) - except ImportError as e: + if not _HAS_DEFUSEDXML: msg = ( "defusedxml is not installed. " "Please install it to use the defusedxml parser." "You can install it with `pip install defusedxml` " ) - raise ImportError(msg) from e + raise ImportError(msg) parser_ = XMLParser(target=TreeBuilder()) else: parser_ = None @@ -207,16 +211,14 @@ class XMLOutputParser(BaseTransformOutputParser): # Imports are temporarily placed here to avoid issue with caching on CI # likely if you're reading this you can move them to the top of the file if self.parser == "defusedxml": - try: - from defusedxml import ElementTree # type: ignore[import-untyped] - except ImportError as e: + if not _HAS_DEFUSEDXML: msg = ( "defusedxml is not installed. " "Please install it to use the defusedxml parser." "You can install it with `pip install defusedxml`" "See https://github.com/tiran/defusedxml for more details" ) - raise ImportError(msg) from e + raise ImportError(msg) et = ElementTree # Use the defusedxml parser else: et = ET # Use the standard library parser diff --git a/libs/core/langchain_core/prompts/message.py b/libs/core/langchain_core/prompts/message.py index 409f54320b1..b8642a53bca 100644 --- a/libs/core/langchain_core/prompts/message.py +++ b/libs/core/langchain_core/prompts/message.py @@ -88,7 +88,8 @@ class BaseMessagePromptTemplate(Serializable, ABC): Returns: Combined prompt template. """ - from langchain_core.prompts.chat import ChatPromptTemplate + # Import locally to avoid circular import. + from langchain_core.prompts.chat import ChatPromptTemplate # noqa: PLC0415 prompt = ChatPromptTemplate(messages=[self]) return prompt + other diff --git a/libs/core/langchain_core/prompts/string.py b/libs/core/langchain_core/prompts/string.py index b6f1cdae28f..8b147c3c0bc 100644 --- a/libs/core/langchain_core/prompts/string.py +++ b/libs/core/langchain_core/prompts/string.py @@ -15,6 +15,14 @@ from langchain_core.utils import get_colored_text, mustache from langchain_core.utils.formatting import formatter from langchain_core.utils.interactive_env import is_interactive_env +try: + from jinja2 import Environment, meta + from jinja2.sandbox import SandboxedEnvironment + + _HAS_JINJA2 = True +except ImportError: + _HAS_JINJA2 = False + PromptTemplateFormat = Literal["f-string", "mustache", "jinja2"] @@ -40,9 +48,7 @@ def jinja2_formatter(template: str, /, **kwargs: Any) -> str: Raises: ImportError: If jinja2 is not installed. """ - try: - from jinja2.sandbox import SandboxedEnvironment - except ImportError as e: + if not _HAS_JINJA2: msg = ( "jinja2 not installed, which is needed to use the jinja2_formatter. " "Please install it with `pip install jinja2`." @@ -50,7 +56,7 @@ def jinja2_formatter(template: str, /, **kwargs: Any) -> str: "Do not expand jinja2 templates using unverified or user-controlled " "inputs as that can result in arbitrary Python code execution." ) - raise ImportError(msg) from e + raise ImportError(msg) # This uses a sandboxed environment to prevent arbitrary code execution. # Jinja2 uses an opt-out rather than opt-in approach for sand-boxing. @@ -88,14 +94,12 @@ def validate_jinja2(template: str, input_variables: list[str]) -> None: def _get_jinja2_variables_from_template(template: str) -> set[str]: - try: - from jinja2 import Environment, meta - except ImportError as e: + if not _HAS_JINJA2: msg = ( "jinja2 not installed, which is needed to use the jinja2_formatter. " "Please install it with `pip install jinja2`." ) - raise ImportError(msg) from e + raise ImportError(msg) env = Environment() # noqa: S701 ast = env.parse(template) return meta.find_undeclared_variables(ast) diff --git a/libs/core/langchain_core/retrievers.py b/libs/core/langchain_core/retrievers.py index 8e1acf56c16..c015747f138 100644 --- a/libs/core/langchain_core/retrievers.py +++ b/libs/core/langchain_core/retrievers.py @@ -31,6 +31,7 @@ from typing_extensions import Self, TypedDict, override from langchain_core._api import deprecated from langchain_core.callbacks import Callbacks +from langchain_core.callbacks.manager import AsyncCallbackManager, CallbackManager from langchain_core.documents import Document from langchain_core.runnables import ( Runnable, @@ -236,8 +237,6 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC): retriever.invoke("query") """ - from langchain_core.callbacks.manager import CallbackManager - config = ensure_config(config) inheritable_metadata = { **(config.get("metadata") or {}), @@ -301,8 +300,6 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC): await retriever.ainvoke("query") """ - from langchain_core.callbacks.manager import AsyncCallbackManager - config = ensure_config(config) inheritable_metadata = { **(config.get("metadata") or {}), diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index 1f90b2f608c..74f2a9049a2 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -41,6 +41,7 @@ from pydantic import BaseModel, ConfigDict, Field, RootModel from typing_extensions import Literal, get_args, override from langchain_core._api import beta_decorator +from langchain_core.callbacks.manager import AsyncCallbackManager, CallbackManager from langchain_core.load.serializable import ( Serializable, SerializedConstructor, @@ -60,7 +61,6 @@ from langchain_core.runnables.config import ( run_in_executor, set_config_context, ) -from langchain_core.runnables.graph import Graph from langchain_core.runnables.utils import ( AddableDict, AnyConfigurableField, @@ -81,6 +81,19 @@ from langchain_core.runnables.utils import ( is_async_callable, is_async_generator, ) +from langchain_core.tracers._streaming import _StreamingCallbackHandler +from langchain_core.tracers.event_stream import ( + _astream_events_implementation_v1, + _astream_events_implementation_v2, +) +from langchain_core.tracers.log_stream import ( + LogStreamCallbackHandler, + _astream_log_implementation, +) +from langchain_core.tracers.root_listeners import ( + AsyncRootListenersTracer, + RootListenersTracer, +) from langchain_core.utils.aiter import aclosing, atee, py_anext from langchain_core.utils.iter import safetee from langchain_core.utils.pydantic import create_model_v2 @@ -94,6 +107,7 @@ if TYPE_CHECKING: from langchain_core.runnables.fallbacks import ( RunnableWithFallbacks as RunnableWithFallbacksT, ) + from langchain_core.runnables.graph import Graph from langchain_core.runnables.retry import ExponentialJitterParams from langchain_core.runnables.schema import StreamEvent from langchain_core.tools import BaseTool @@ -565,6 +579,9 @@ class Runnable(ABC, Generic[Input, Output]): def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph: """Return a graph representation of this ``Runnable``.""" + # Import locally to prevent circular import + from langchain_core.runnables.graph import Graph # noqa: PLC0415 + graph = Graph() try: input_node = graph.add_node(self.get_input_schema(config)) @@ -585,7 +602,8 @@ class Runnable(ABC, Generic[Input, Output]): self, config: Optional[RunnableConfig] = None ) -> list[BasePromptTemplate]: """Return a list of prompts used by this ``Runnable``.""" - from langchain_core.prompts.base import BasePromptTemplate + # Import locally to prevent circular import + from langchain_core.prompts.base import BasePromptTemplate # noqa: PLC0415 return [ node.data @@ -747,7 +765,8 @@ class Runnable(ABC, Generic[Input, Output]): a new ``Runnable``. """ - from langchain_core.runnables.passthrough import RunnablePick + # Import locally to prevent circular import + from langchain_core.runnables.passthrough import RunnablePick # noqa: PLC0415 return self | RunnablePick(keys) @@ -798,7 +817,8 @@ class Runnable(ABC, Generic[Input, Output]): A new ``Runnable``. """ - from langchain_core.runnables.passthrough import RunnableAssign + # Import locally to prevent circular import + from langchain_core.runnables.passthrough import RunnableAssign # noqa: PLC0415 return self | RunnableAssign(RunnableParallel[dict[str, Any]](kwargs)) @@ -1231,11 +1251,6 @@ class Runnable(ABC, Generic[Input, Output]): A ``RunLogPatch`` or ``RunLog`` object. """ - from langchain_core.tracers.log_stream import ( - LogStreamCallbackHandler, - _astream_log_implementation, - ) - stream = LogStreamCallbackHandler( auto_close=False, include_names=include_names, @@ -1489,11 +1504,6 @@ class Runnable(ABC, Generic[Input, Output]): NotImplementedError: If the version is not ``'v1'`` or ``'v2'``. """ # noqa: E501 - from langchain_core.tracers.event_stream import ( - _astream_events_implementation_v1, - _astream_events_implementation_v2, - ) - if version == "v2": event_stream = _astream_events_implementation_v2( self, @@ -1740,8 +1750,6 @@ class Runnable(ABC, Generic[Input, Output]): chain.invoke(2) """ - from langchain_core.tracers.root_listeners import RootListenersTracer - return RunnableBinding( bound=self, config_factories=[ @@ -1834,8 +1842,6 @@ class Runnable(ABC, Generic[Input, Output]): on end callback ends at 2025-03-01T07:05:30.884831+00:00 """ - from langchain_core.tracers.root_listeners import AsyncRootListenersTracer - return RunnableBinding( bound=self, config_factories=[ @@ -1928,7 +1934,8 @@ class Runnable(ABC, Generic[Input, Output]): assert count == 2 """ - from langchain_core.runnables.retry import RunnableRetry + # Import locally to prevent circular import + from langchain_core.runnables.retry import RunnableRetry # noqa: PLC0415 return RunnableRetry( bound=self, @@ -2030,7 +2037,10 @@ class Runnable(ABC, Generic[Input, Output]): fallback in order, upon failures. """ - from langchain_core.runnables.fallbacks import RunnableWithFallbacks + # Import locally to prevent circular import + from langchain_core.runnables.fallbacks import ( # noqa: PLC0415 + RunnableWithFallbacks, + ) return RunnableWithFallbacks( runnable=self, @@ -2316,9 +2326,6 @@ class Runnable(ABC, Generic[Input, Output]): Use this to implement ``stream`` or ``transform`` in ``Runnable`` subclasses. """ - # Mixin that is used by both astream log and astream events implementation - from langchain_core.tracers._streaming import _StreamingCallbackHandler - # tee the input so we can iterate over it twice input_for_tracing, input_for_transform = tee(inputs, 2) # Start the input iterator to ensure the input Runnable starts before this one @@ -2422,9 +2429,6 @@ class Runnable(ABC, Generic[Input, Output]): Use this to implement ``astream`` or ``atransform`` in ``Runnable`` subclasses. """ - # Mixin that is used by both astream log and astream events implementation - from langchain_core.tracers._streaming import _StreamingCallbackHandler - # tee the input so we can iterate over it twice input_for_tracing, input_for_transform = atee(inputs, 2) # Start the input iterator to ensure the input Runnable starts before this one @@ -2614,7 +2618,7 @@ class Runnable(ABC, Generic[Input, Output]): """ # Avoid circular import - from langchain_core.tools import convert_runnable_to_tool + from langchain_core.tools import convert_runnable_to_tool # noqa: PLC0415 return convert_runnable_to_tool( self, @@ -2690,7 +2694,10 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]): ) """ - from langchain_core.runnables.configurable import RunnableConfigurableFields + # Import locally to prevent circular import + from langchain_core.runnables.configurable import ( # noqa: PLC0415 + RunnableConfigurableFields, + ) model_fields = type(self).model_fields for key in kwargs: @@ -2751,7 +2758,8 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]): ) """ - from langchain_core.runnables.configurable import ( + # Import locally to prevent circular import + from langchain_core.runnables.configurable import ( # noqa: PLC0415 RunnableConfigurableAlternatives, ) @@ -2767,7 +2775,11 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]): def _seq_input_schema( steps: list[Runnable[Any, Any]], config: Optional[RunnableConfig] ) -> type[BaseModel]: - from langchain_core.runnables.passthrough import RunnableAssign, RunnablePick + # Import locally to prevent circular import + from langchain_core.runnables.passthrough import ( # noqa: PLC0415 + RunnableAssign, + RunnablePick, + ) first = steps[0] if len(steps) == 1: @@ -2793,7 +2805,11 @@ def _seq_input_schema( def _seq_output_schema( steps: list[Runnable[Any, Any]], config: Optional[RunnableConfig] ) -> type[BaseModel]: - from langchain_core.runnables.passthrough import RunnableAssign, RunnablePick + # Import locally to prevent circular import + from langchain_core.runnables.passthrough import ( # noqa: PLC0415 + RunnableAssign, + RunnablePick, + ) last = steps[-1] if len(steps) == 1: @@ -3050,7 +3066,8 @@ class RunnableSequence(RunnableSerializable[Input, Output]): The config specs of the ``Runnable``. """ - from langchain_core.beta.runnables.context import ( + # Import locally to prevent circular import + from langchain_core.beta.runnables.context import ( # noqa: PLC0415 CONTEXT_CONFIG_PREFIX, _key_from_id, ) @@ -3108,7 +3125,8 @@ class RunnableSequence(RunnableSerializable[Input, Output]): ValueError: If a ``Runnable`` has no first or last node. """ - from langchain_core.runnables.graph import Graph + # Import locally to prevent circular import + from langchain_core.runnables.graph import Graph # noqa: PLC0415 graph = Graph() for step in self.steps: @@ -3196,7 +3214,10 @@ class RunnableSequence(RunnableSerializable[Input, Output]): def invoke( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Output: - from langchain_core.beta.runnables.context import config_with_context + # Import locally to prevent circular import + from langchain_core.beta.runnables.context import ( # noqa: PLC0415 + config_with_context, + ) # setup callbacks and context config = config_with_context(ensure_config(config), self.steps) @@ -3237,7 +3258,10 @@ class RunnableSequence(RunnableSerializable[Input, Output]): config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> Output: - from langchain_core.beta.runnables.context import aconfig_with_context + # Import locally to prevent circular import + from langchain_core.beta.runnables.context import ( # noqa: PLC0415 + aconfig_with_context, + ) # setup callbacks and context config = aconfig_with_context(ensure_config(config), self.steps) @@ -3281,8 +3305,10 @@ class RunnableSequence(RunnableSerializable[Input, Output]): return_exceptions: bool = False, **kwargs: Optional[Any], ) -> list[Output]: - from langchain_core.beta.runnables.context import config_with_context - from langchain_core.callbacks.manager import CallbackManager + # Import locally to prevent circular import + from langchain_core.beta.runnables.context import ( # noqa: PLC0415 + config_with_context, + ) if not inputs: return [] @@ -3411,8 +3437,10 @@ class RunnableSequence(RunnableSerializable[Input, Output]): return_exceptions: bool = False, **kwargs: Optional[Any], ) -> list[Output]: - from langchain_core.beta.runnables.context import aconfig_with_context - from langchain_core.callbacks.manager import AsyncCallbackManager + # Import locally to prevent circular import + from langchain_core.beta.runnables.context import ( # noqa: PLC0415 + aconfig_with_context, + ) if not inputs: return [] @@ -3542,7 +3570,10 @@ class RunnableSequence(RunnableSerializable[Input, Output]): config: RunnableConfig, **kwargs: Any, ) -> Iterator[Output]: - from langchain_core.beta.runnables.context import config_with_context + # Import locally to prevent circular import + from langchain_core.beta.runnables.context import ( # noqa: PLC0415 + config_with_context, + ) steps = [self.first, *self.middle, self.last] config = config_with_context(config, self.steps) @@ -3569,7 +3600,10 @@ class RunnableSequence(RunnableSerializable[Input, Output]): config: RunnableConfig, **kwargs: Any, ) -> AsyncIterator[Output]: - from langchain_core.beta.runnables.context import aconfig_with_context + # Import locally to prevent circular import + from langchain_core.beta.runnables.context import ( # noqa: PLC0415 + aconfig_with_context, + ) steps = [self.first, *self.middle, self.last] config = aconfig_with_context(config, self.steps) @@ -3882,7 +3916,8 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]): ValueError: If a ``Runnable`` has no first or last node. """ - from langchain_core.runnables.graph import Graph + # Import locally to prevent circular import + from langchain_core.runnables.graph import Graph # noqa: PLC0415 graph = Graph() input_node = graph.add_node(self.get_input_schema(config)) @@ -3918,8 +3953,6 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]): def invoke( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> dict[str, Any]: - from langchain_core.callbacks.manager import CallbackManager - # setup callbacks config = ensure_config(config) callback_manager = CallbackManager.configure( @@ -4767,6 +4800,9 @@ class RunnableLambda(Runnable[Input, Output]): @override def get_graph(self, config: RunnableConfig | None = None) -> Graph: if deps := self.deps: + # Import locally to prevent circular import + from langchain_core.runnables.graph import Graph # noqa: PLC0415 + graph = Graph() input_node = graph.add_node(self.get_input_schema(config)) output_node = graph.add_node(self.get_output_schema(config)) @@ -6030,7 +6066,6 @@ class RunnableBinding(RunnableBindingBase[Input, Output]): # type: ignore[no-re Returns: A new ``Runnable`` with the listeners bound. """ - from langchain_core.tracers.root_listeners import RootListenersTracer def listener_config_factory(config: RunnableConfig) -> RunnableConfig: return { diff --git a/libs/core/langchain_core/runnables/branch.py b/libs/core/langchain_core/runnables/branch.py index ed8d0ed0cdd..66ecb9592c8 100644 --- a/libs/core/langchain_core/runnables/branch.py +++ b/libs/core/langchain_core/runnables/branch.py @@ -12,6 +12,10 @@ from typing import ( from pydantic import BaseModel, ConfigDict from typing_extensions import override +from langchain_core.beta.runnables.context import ( + CONTEXT_CONFIG_PREFIX, + CONTEXT_CONFIG_SUFFIX_SET, +) from langchain_core.runnables.base import ( Runnable, RunnableLike, @@ -177,11 +181,6 @@ class RunnableBranch(RunnableSerializable[Input, Output]): @property @override def config_specs(self) -> list[ConfigurableFieldSpec]: - from langchain_core.beta.runnables.context import ( - CONTEXT_CONFIG_PREFIX, - CONTEXT_CONFIG_SUFFIX_SET, - ) - specs = get_unique_config_specs( spec for step in ( diff --git a/libs/core/langchain_core/runnables/config.py b/libs/core/langchain_core/runnables/config.py index 98090872a34..041d1d2629e 100644 --- a/libs/core/langchain_core/runnables/config.py +++ b/libs/core/langchain_core/runnables/config.py @@ -12,21 +12,22 @@ from contextvars import Context, ContextVar, Token, copy_context from functools import partial from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union, cast +from langsmith.run_helpers import _set_tracing_context, get_tracing_context from typing_extensions import ParamSpec, TypedDict +from langchain_core.callbacks.manager import AsyncCallbackManager, CallbackManager from langchain_core.runnables.utils import ( Input, Output, accepts_config, accepts_run_manager, ) +from langchain_core.tracers.langchain import LangChainTracer if TYPE_CHECKING: from langchain_core.callbacks.base import BaseCallbackManager, Callbacks from langchain_core.callbacks.manager import ( - AsyncCallbackManager, AsyncCallbackManagerForChainRun, - CallbackManager, CallbackManagerForChainRun, ) else: @@ -129,8 +130,6 @@ def _set_config_context( Returns: The token to reset the config and the previous tracing context. """ - from langchain_core.tracers.langchain import LangChainTracer - config_token = var_child_runnable_config.set(config) current_context = None if ( @@ -150,8 +149,6 @@ def _set_config_context( ) and (run := tracer.run_map.get(str(parent_run_id))) ): - from langsmith.run_helpers import _set_tracing_context, get_tracing_context - current_context = get_tracing_context() _set_tracing_context({"parent": run}) return config_token, current_context @@ -167,8 +164,6 @@ def set_config_context(config: RunnableConfig) -> Generator[Context, None, None] Yields: The config context. """ - from langsmith.run_helpers import _set_tracing_context - ctx = copy_context() config_token, _ = ctx.run(_set_config_context, config) try: @@ -481,8 +476,6 @@ def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager: Returns: CallbackManager: The callback manager. """ - from langchain_core.callbacks.manager import CallbackManager - return CallbackManager.configure( inheritable_callbacks=config.get("callbacks"), inheritable_tags=config.get("tags"), @@ -501,8 +494,6 @@ def get_async_callback_manager_for_config( Returns: AsyncCallbackManager: The async callback manager. """ - from langchain_core.callbacks.manager import AsyncCallbackManager - return AsyncCallbackManager.configure( inheritable_callbacks=config.get("callbacks"), inheritable_tags=config.get("tags"), diff --git a/libs/core/langchain_core/runnables/fallbacks.py b/libs/core/langchain_core/runnables/fallbacks.py index ede6fcb572d..f72850f3413 100644 --- a/libs/core/langchain_core/runnables/fallbacks.py +++ b/libs/core/langchain_core/runnables/fallbacks.py @@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, Any, Optional, Union, cast from pydantic import BaseModel, ConfigDict from typing_extensions import override +from langchain_core.callbacks.manager import AsyncCallbackManager, CallbackManager from langchain_core.runnables.base import Runnable, RunnableSerializable from langchain_core.runnables.config import ( RunnableConfig, @@ -272,8 +273,6 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): return_exceptions: bool = False, **kwargs: Optional[Any], ) -> list[Output]: - from langchain_core.callbacks.manager import CallbackManager - if self.exception_key is not None and not all( isinstance(input_, dict) for input_ in inputs ): @@ -366,8 +365,6 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): return_exceptions: bool = False, **kwargs: Optional[Any], ) -> list[Output]: - from langchain_core.callbacks.manager import AsyncCallbackManager - if self.exception_key is not None and not all( isinstance(input_, dict) for input_ in inputs ): diff --git a/libs/core/langchain_core/runnables/graph.py b/libs/core/langchain_core/runnables/graph.py index 389572a23b3..cebf2a667c1 100644 --- a/libs/core/langchain_core/runnables/graph.py +++ b/libs/core/langchain_core/runnables/graph.py @@ -19,6 +19,8 @@ from typing import ( ) from uuid import UUID, uuid4 +from langchain_core.load.serializable import to_json_not_implemented +from langchain_core.runnables.base import Runnable, RunnableSerializable from langchain_core.utils.pydantic import _IgnoreUnserializable, is_basemodel_subclass if TYPE_CHECKING: @@ -191,8 +193,6 @@ def node_data_str( Returns: A string representation of the data. """ - from langchain_core.runnables.base import Runnable - if not is_uuid(id) or data is None: return id data_str = data.get_name() if isinstance(data, Runnable) else data.__name__ @@ -212,9 +212,6 @@ def node_data_json( Returns: A dictionary with the type of the data and the data itself. """ - from langchain_core.load.serializable import to_json_not_implemented - from langchain_core.runnables.base import Runnable, RunnableSerializable - if node.data is None: json: dict[str, Any] = {} elif isinstance(node.data, RunnableSerializable): @@ -518,7 +515,8 @@ class Graph: Returns: The ASCII art string. """ - from langchain_core.runnables.graph_ascii import draw_ascii + # Import locally to prevent circular import + from langchain_core.runnables.graph_ascii import draw_ascii # noqa: PLC0415 return draw_ascii( {node.id: node.name for node in self.nodes.values()}, @@ -562,7 +560,8 @@ class Graph: Returns: The PNG image as bytes if output_file_path is None, None otherwise. """ - from langchain_core.runnables.graph_png import PngDrawer + # Import locally to prevent circular import + from langchain_core.runnables.graph_png import PngDrawer # noqa: PLC0415 default_node_labels = {node.id: node.name for node in self.nodes.values()} @@ -617,7 +616,8 @@ class Graph: The Mermaid syntax string. """ - from langchain_core.runnables.graph_mermaid import draw_mermaid + # Import locally to prevent circular import + from langchain_core.runnables.graph_mermaid import draw_mermaid # noqa: PLC0415 graph = self.reid() first_node = graph.first_node() @@ -688,7 +688,10 @@ class Graph: The PNG image as bytes. """ - from langchain_core.runnables.graph_mermaid import draw_mermaid_png + # Import locally to prevent circular import + from langchain_core.runnables.graph_mermaid import ( # noqa: PLC0415 + draw_mermaid_png, + ) mermaid_syntax = self.draw_mermaid( curve_style=curve_style, diff --git a/libs/core/langchain_core/runnables/graph_ascii.py b/libs/core/langchain_core/runnables/graph_ascii.py index 9c11c585cd1..39ad3f165e6 100644 --- a/libs/core/langchain_core/runnables/graph_ascii.py +++ b/libs/core/langchain_core/runnables/graph_ascii.py @@ -3,12 +3,24 @@ Adapted from https://github.com/iterative/dvc/blob/main/dvc/dagascii.py. """ +from __future__ import annotations + import math import os from collections.abc import Mapping, Sequence -from typing import Any +from typing import TYPE_CHECKING, Any -from langchain_core.runnables.graph import Edge as LangEdge +try: + from grandalf.graphs import Edge, Graph, Vertex # type: ignore[import-untyped] + from grandalf.layouts import SugiyamaLayout # type: ignore[import-untyped] + from grandalf.routing import route_with_lines # type: ignore[import-untyped] + + _HAS_GRANDALF = True +except ImportError: + _HAS_GRANDALF = False + +if TYPE_CHECKING: + from langchain_core.runnables.graph import Edge as LangEdge class VertexViewer: @@ -185,13 +197,9 @@ class _EdgeViewer: def _build_sugiyama_layout( vertices: Mapping[str, str], edges: Sequence[LangEdge] ) -> Any: - try: - from grandalf.graphs import Edge, Graph, Vertex # type: ignore[import-untyped] - from grandalf.layouts import SugiyamaLayout # type: ignore[import-untyped] - from grandalf.routing import route_with_lines # type: ignore[import-untyped] - except ImportError as exc: + if not _HAS_GRANDALF: msg = "Install grandalf to draw graphs: `pip install grandalf`." - raise ImportError(msg) from exc + raise ImportError(msg) # # Just a reminder about naming conventions: diff --git a/libs/core/langchain_core/runnables/graph_mermaid.py b/libs/core/langchain_core/runnables/graph_mermaid.py index 17ebc55a518..df1468fc437 100644 --- a/libs/core/langchain_core/runnables/graph_mermaid.py +++ b/libs/core/langchain_core/runnables/graph_mermaid.py @@ -1,5 +1,7 @@ """Mermaid graph drawing utilities.""" +from __future__ import annotations + import asyncio import base64 import random @@ -7,18 +9,34 @@ import re import time from dataclasses import asdict from pathlib import Path -from typing import Any, Literal, Optional +from typing import TYPE_CHECKING, Any, Literal, Optional import yaml from langchain_core.runnables.graph import ( CurveStyle, - Edge, MermaidDrawMethod, - Node, NodeStyles, ) +if TYPE_CHECKING: + from langchain_core.runnables.graph import Edge, Node + + +try: + import requests + + _HAS_REQUESTS = True +except ImportError: + _HAS_REQUESTS = False + +try: + from pyppeteer import launch # type: ignore[import-not-found] + + _HAS_PYPPETEER = True +except ImportError: + _HAS_PYPPETEER = False + MARKDOWN_SPECIAL_CHARS = "*_`" @@ -283,8 +301,6 @@ def draw_mermaid_png( ValueError: If an invalid draw method is provided. """ if draw_method == MermaidDrawMethod.PYPPETEER: - import asyncio - img_bytes = asyncio.run( _render_mermaid_using_pyppeteer( mermaid_syntax, output_file_path, background_color, padding @@ -317,11 +333,9 @@ async def _render_mermaid_using_pyppeteer( device_scale_factor: int = 3, ) -> bytes: """Renders Mermaid graph using Pyppeteer.""" - try: - from pyppeteer import launch # type: ignore[import-not-found] - except ImportError as e: + if not _HAS_PYPPETEER: msg = "Install Pyppeteer to use the Pyppeteer method: `pip install pyppeteer`." - raise ImportError(msg) from e + raise ImportError(msg) browser = await launch() page = await browser.newPage() @@ -392,14 +406,12 @@ def _render_mermaid_using_api( retry_delay: float = 1.0, ) -> bytes: """Renders Mermaid graph using the Mermaid.INK API.""" - try: - import requests - except ImportError as e: + if not _HAS_REQUESTS: msg = ( "Install the `requests` module to use the Mermaid.INK API: " "`pip install requests`." ) - raise ImportError(msg) from e + raise ImportError(msg) # Use Mermaid API to render the image mermaid_syntax_encoded = base64.b64encode(mermaid_syntax.encode("utf8")).decode( diff --git a/libs/core/langchain_core/runnables/graph_png.py b/libs/core/langchain_core/runnables/graph_png.py index 27162907843..154f8e3ab44 100644 --- a/libs/core/langchain_core/runnables/graph_png.py +++ b/libs/core/langchain_core/runnables/graph_png.py @@ -4,6 +4,13 @@ from typing import Any, Optional from langchain_core.runnables.graph import Graph, LabelsDict +try: + import pygraphviz as pgv # type: ignore[import-not-found] + + _HAS_PYGRAPHVIZ = True +except ImportError: + _HAS_PYGRAPHVIZ = False + class PngDrawer: """Helper class to draw a state graph into a PNG file. @@ -125,11 +132,9 @@ class PngDrawer: Returns: The PNG bytes if ``output_path`` is None, else None. """ - try: - import pygraphviz as pgv # type: ignore[import-not-found] - except ImportError as exc: + if not _HAS_PYGRAPHVIZ: msg = "Install pygraphviz to draw graphs: `pip install pygraphviz`." - raise ImportError(msg) from exc + raise ImportError(msg) # Create a directed graph viz = pgv.AGraph(directed=True, nodesep=0.9, ranksep=1.0) diff --git a/libs/core/langchain_core/runnables/history.py b/libs/core/langchain_core/runnables/history.py index 24bfe22df0e..19158607bad 100644 --- a/libs/core/langchain_core/runnables/history.py +++ b/libs/core/langchain_core/runnables/history.py @@ -18,6 +18,7 @@ from typing_extensions import override from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.load.load import load +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage from langchain_core.runnables.base import Runnable, RunnableBindingBase, RunnableLambda from langchain_core.runnables.passthrough import RunnablePassthrough from langchain_core.runnables.utils import ( @@ -29,7 +30,6 @@ from langchain_core.utils.pydantic import create_model_v2 if TYPE_CHECKING: from langchain_core.language_models.base import LanguageModelLike - from langchain_core.messages.base import BaseMessage from langchain_core.runnables.config import RunnableConfig from langchain_core.tracers.schemas import Run @@ -384,8 +384,6 @@ class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef] def get_input_schema( self, config: Optional[RunnableConfig] = None ) -> type[BaseModel]: - from langchain_core.messages import BaseMessage - fields: dict = {} if self.input_messages_key and self.history_messages_key: fields[self.input_messages_key] = ( @@ -447,8 +445,6 @@ class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef] def _get_input_messages( self, input_val: Union[str, BaseMessage, Sequence[BaseMessage], dict] ) -> list[BaseMessage]: - from langchain_core.messages import BaseMessage - # If dictionary, try to pluck the single key representing messages if isinstance(input_val, dict): if self.input_messages_key: @@ -461,8 +457,6 @@ class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef] # If value is a string, convert to a human message if isinstance(input_val, str): - from langchain_core.messages import HumanMessage - return [HumanMessage(content=input_val)] # If value is a single message, convert to a list if isinstance(input_val, BaseMessage): @@ -489,8 +483,6 @@ class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef] def _get_output_messages( self, output_val: Union[str, BaseMessage, Sequence[BaseMessage], dict] ) -> list[BaseMessage]: - from langchain_core.messages import BaseMessage - # If dictionary, try to pluck the single key representing messages if isinstance(output_val, dict): if self.output_messages_key: @@ -507,8 +499,6 @@ class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef] output_val = output_val[key] if isinstance(output_val, str): - from langchain_core.messages import AIMessage - return [AIMessage(content=output_val)] # If value is a single message, convert to a list if isinstance(output_val, BaseMessage): diff --git a/libs/core/langchain_core/sys_info.py b/libs/core/langchain_core/sys_info.py index b7073ffb9cd..dcd0a178bf3 100644 --- a/libs/core/langchain_core/sys_info.py +++ b/libs/core/langchain_core/sys_info.py @@ -4,13 +4,15 @@ sys_info prints information about the system and langchain packages for debugging purposes. """ +import pkgutil +import platform +import sys from collections.abc import Sequence +from importlib import metadata, util def _get_sub_deps(packages: Sequence[str]) -> list[str]: """Get any specified sub-dependencies.""" - from importlib import metadata - sub_deps = set() underscored_packages = {pkg.replace("-", "_") for pkg in packages} @@ -37,11 +39,6 @@ def print_sys_info(*, additional_pkgs: Sequence[str] = ()) -> None: Args: additional_pkgs: Additional packages to include in the output. """ - import pkgutil - import platform - import sys - from importlib import metadata, util - # Packages that do not start with "langchain" prefix. other_langchain_packages = [ "langserve", diff --git a/libs/core/langchain_core/tracers/event_stream.py b/libs/core/langchain_core/tracers/event_stream.py index 5b150604516..61c3549cfb7 100644 --- a/libs/core/langchain_core/tracers/event_stream.py +++ b/libs/core/langchain_core/tracers/event_stream.py @@ -18,13 +18,14 @@ from uuid import UUID, uuid4 from typing_extensions import NotRequired, override -from langchain_core.callbacks.base import AsyncCallbackHandler +from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackManager from langchain_core.messages import AIMessageChunk, BaseMessage, BaseMessageChunk from langchain_core.outputs import ( ChatGenerationChunk, GenerationChunk, LLMResult, ) +from langchain_core.runnables import ensure_config from langchain_core.runnables.schema import ( CustomStreamEvent, EventData, @@ -37,6 +38,11 @@ from langchain_core.runnables.utils import ( _RootEventFilter, ) from langchain_core.tracers._streaming import _StreamingCallbackHandler +from langchain_core.tracers.log_stream import ( + LogStreamCallbackHandler, + RunLog, + _astream_log_implementation, +) from langchain_core.tracers.memory_stream import _MemoryStream from langchain_core.utils.aiter import aclosing, py_anext @@ -769,14 +775,6 @@ async def _astream_events_implementation_v1( exclude_tags: Optional[Sequence[str]] = None, **kwargs: Any, ) -> AsyncIterator[StandardStreamEvent]: - from langchain_core.runnables import ensure_config - from langchain_core.runnables.utils import _RootEventFilter - from langchain_core.tracers.log_stream import ( - LogStreamCallbackHandler, - RunLog, - _astream_log_implementation, - ) - stream = LogStreamCallbackHandler( auto_close=False, include_names=include_names, @@ -954,9 +952,6 @@ async def _astream_events_implementation_v2( **kwargs: Any, ) -> AsyncIterator[StandardStreamEvent]: """Implementation of the astream events API for V2 runnables.""" - from langchain_core.callbacks.base import BaseCallbackManager - from langchain_core.runnables import ensure_config - event_streamer = _AstreamEventsCallbackHandler( include_names=include_names, include_types=include_types, diff --git a/libs/core/langchain_core/tracers/log_stream.py b/libs/core/langchain_core/tracers/log_stream.py index 2821ade0dc5..923f0275683 100644 --- a/libs/core/langchain_core/tracers/log_stream.py +++ b/libs/core/langchain_core/tracers/log_stream.py @@ -7,6 +7,7 @@ import contextlib import copy import threading from collections import defaultdict +from pprint import pformat from typing import ( TYPE_CHECKING, Any, @@ -20,10 +21,11 @@ from typing import ( import jsonpatch # type: ignore[import-untyped] from typing_extensions import NotRequired, TypedDict, override +from langchain_core.callbacks.base import BaseCallbackManager from langchain_core.load import dumps from langchain_core.load.load import load from langchain_core.outputs import ChatGenerationChunk, GenerationChunk -from langchain_core.runnables import Runnable, RunnableConfig, ensure_config +from langchain_core.runnables import RunnableConfig, ensure_config from langchain_core.tracers._streaming import _StreamingCallbackHandler from langchain_core.tracers.base import BaseTracer from langchain_core.tracers.memory_stream import _MemoryStream @@ -32,6 +34,7 @@ if TYPE_CHECKING: from collections.abc import AsyncIterator, Iterator, Sequence from uuid import UUID + from langchain_core.runnables import Runnable from langchain_core.runnables.utils import Input, Output from langchain_core.tracers.schemas import Run @@ -131,8 +134,6 @@ class RunLogPatch: @override def __repr__(self) -> str: - from pprint import pformat - # 1:-1 to get rid of the [] around the list return f"RunLogPatch({pformat(self.ops)[1:-1]})" @@ -181,8 +182,6 @@ class RunLog(RunLogPatch): @override def __repr__(self) -> str: - from pprint import pformat - return f"RunLog({pformat(self.state)})" @override @@ -672,14 +671,6 @@ async def _astream_log_implementation( Yields: The run log patches or states, depending on the value of ``diff``. """ - import jsonpatch - - from langchain_core.callbacks.base import BaseCallbackManager - from langchain_core.tracers.log_stream import ( - RunLog, - RunLogPatch, - ) - # Assign the stream handler to the config config = ensure_config(config) callbacks = config.get("callbacks") diff --git a/libs/core/langchain_core/utils/function_calling.py b/libs/core/langchain_core/utils/function_calling.py index fff700651ea..801aacc727a 100644 --- a/libs/core/langchain_core/utils/function_calling.py +++ b/libs/core/langchain_core/utils/function_calling.py @@ -21,8 +21,10 @@ from typing import ( from pydantic import BaseModel from pydantic.v1 import BaseModel as BaseModelV1 +from pydantic.v1 import Field, create_model from typing_extensions import TypedDict, get_args, get_origin, is_typeddict +import langchain_core from langchain_core._api import beta, deprecated from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage from langchain_core.utils.json_schema import dereference_refs @@ -220,10 +222,8 @@ def _convert_python_function_to_openai_function( Returns: The OpenAI function description. """ - from langchain_core.tools.base import create_schema_from_function - func_name = _get_python_function_name(function) - model = create_schema_from_function( + model = langchain_core.tools.base.create_schema_from_function( func_name, function, filter_args=(), @@ -264,9 +264,6 @@ def _convert_any_typed_dicts_to_pydantic( visited: dict, depth: int = 0, ) -> type: - from pydantic.v1 import Field as Field_v1 - from pydantic.v1 import create_model as create_model_v1 - if type_ in visited: return visited[type_] if depth >= _MAX_TYPED_DICT_RECURSION: @@ -297,7 +294,7 @@ def _convert_any_typed_dicts_to_pydantic( raise ValueError(msg) if arg_desc := arg_descriptions.get(arg): field_kwargs["description"] = arg_desc - fields[arg] = (new_arg_type, Field_v1(**field_kwargs)) + fields[arg] = (new_arg_type, Field(**field_kwargs)) else: new_arg_type = _convert_any_typed_dicts_to_pydantic( arg_type, depth=depth + 1, visited=visited @@ -305,8 +302,8 @@ def _convert_any_typed_dicts_to_pydantic( field_kwargs = {"default": ...} if arg_desc := arg_descriptions.get(arg): field_kwargs["description"] = arg_desc - fields[arg] = (new_arg_type, Field_v1(**field_kwargs)) - model = create_model_v1(typed_dict.__name__, **fields) + fields[arg] = (new_arg_type, Field(**field_kwargs)) + model = create_model(typed_dict.__name__, **fields) model.__doc__ = description visited[typed_dict] = model return model @@ -332,9 +329,9 @@ def _format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription: Returns: The function description. """ - from langchain_core.tools import simple - - is_simple_oai_tool = isinstance(tool, simple.Tool) and not tool.args_schema + is_simple_oai_tool = ( + isinstance(tool, langchain_core.tools.simple.Tool) and not tool.args_schema + ) if tool.tool_call_schema and not is_simple_oai_tool: if isinstance(tool.tool_call_schema, dict): return _convert_json_schema_to_openai_function( @@ -435,8 +432,6 @@ def convert_to_openai_function( 'description' and 'parameters' keys are now optional. Only 'name' is required and guaranteed to be part of the output. """ - from langchain_core.tools import BaseTool - # an Anthropic format tool if isinstance(function, dict) and all( k in function for k in ("name", "input_schema") @@ -476,7 +471,7 @@ def convert_to_openai_function( oai_function = cast( "dict", _convert_typed_dict_to_openai_function(cast("type", function)) ) - elif isinstance(function, BaseTool): + elif isinstance(function, langchain_core.tools.base.BaseTool): oai_function = cast("dict", _format_tool_to_openai_function(function)) elif callable(function): oai_function = cast( @@ -582,7 +577,8 @@ def convert_to_openai_tool( Added support for OpenAI's image generation built-in tool. """ - from langchain_core.tools import Tool + # Import locally to prevent circular import + from langchain_core.tools import Tool # noqa: PLC0415 if isinstance(tool, dict): if tool.get("type") in _WellKnownOpenAITools: diff --git a/libs/core/langchain_core/utils/interactive_env.py b/libs/core/langchain_core/utils/interactive_env.py index b974b3091eb..305b8edc146 100644 --- a/libs/core/langchain_core/utils/interactive_env.py +++ b/libs/core/langchain_core/utils/interactive_env.py @@ -1,5 +1,7 @@ """Utilities for working with interactive environments.""" +import sys + def is_interactive_env() -> bool: """Determine if running within IPython or Jupyter. @@ -7,6 +9,4 @@ def is_interactive_env() -> bool: Returns: True if running in an interactive environment, False otherwise. """ - import sys - return hasattr(sys, "ps2") diff --git a/libs/core/langchain_core/vectorstores/base.py b/libs/core/langchain_core/vectorstores/base.py index 888136b8929..9c06612b2ee 100644 --- a/libs/core/langchain_core/vectorstores/base.py +++ b/libs/core/langchain_core/vectorstores/base.py @@ -38,6 +38,7 @@ from typing import ( from pydantic import ConfigDict, Field, model_validator from typing_extensions import Self, override +from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.retrievers import BaseRetriever, LangSmithRetrieverParams from langchain_core.runnables.config import run_in_executor @@ -49,7 +50,6 @@ if TYPE_CHECKING: AsyncCallbackManagerForRetrieverRun, CallbackManagerForRetrieverRun, ) - from langchain_core.documents import Document logger = logging.getLogger(__name__) @@ -85,9 +85,6 @@ class VectorStore(ABC): ValueError: If the number of ids does not match the number of texts. """ if type(self).add_documents != VectorStore.add_documents: - # Import document in local scope to avoid circular imports - from langchain_core.documents import Document - # This condition is triggered if the subclass has provided # an implementation of the upsert method. # The existing add_texts @@ -234,9 +231,6 @@ class VectorStore(ABC): # For backward compatibility kwargs["ids"] = ids if type(self).aadd_documents != VectorStore.aadd_documents: - # Import document in local scope to avoid circular imports - from langchain_core.documents import Document - # This condition is triggered if the subclass has provided # an implementation of the upsert method. # The existing add_texts diff --git a/libs/core/langchain_core/vectorstores/in_memory.py b/libs/core/langchain_core/vectorstores/in_memory.py index 0395478dfe9..d08a68c8fc2 100644 --- a/libs/core/langchain_core/vectorstores/in_memory.py +++ b/libs/core/langchain_core/vectorstores/in_memory.py @@ -27,6 +27,13 @@ if TYPE_CHECKING: from langchain_core.embeddings import Embeddings from langchain_core.indexing import UpsertResponse +try: + import numpy as np + + _HAS_NUMPY = True +except ImportError: + _HAS_NUMPY = False + class InMemoryVectorStore(VectorStore): """In-memory vector store implementation. @@ -496,14 +503,12 @@ class InMemoryVectorStore(VectorStore): filter=filter, ) - try: - import numpy as np - except ImportError as e: + if not _HAS_NUMPY: msg = ( "numpy must be installed to use max_marginal_relevance_search " "pip install numpy" ) - raise ImportError(msg) from e + raise ImportError(msg) mmr_chosen_indices = maximal_marginal_relevance( np.array(embedding, dtype=np.float32), diff --git a/libs/core/langchain_core/vectorstores/utils.py b/libs/core/langchain_core/vectorstores/utils.py index 645306bccde..1f68269915a 100644 --- a/libs/core/langchain_core/vectorstores/utils.py +++ b/libs/core/langchain_core/vectorstores/utils.py @@ -10,9 +10,21 @@ import logging import warnings from typing import TYPE_CHECKING, Union -if TYPE_CHECKING: +try: import numpy as np + _HAS_NUMPY = True +except ImportError: + _HAS_NUMPY = False + +try: + import simsimd as simd # type: ignore[import-not-found] + + _HAS_SIMSIMD = True +except ImportError: + _HAS_SIMSIMD = False + +if TYPE_CHECKING: Matrix = Union[list[list[float]], list[np.ndarray], np.ndarray] logger = logging.getLogger(__name__) @@ -33,14 +45,12 @@ def _cosine_similarity(x: Matrix, y: Matrix) -> np.ndarray: ValueError: If the number of columns in X and Y are not the same. ImportError: If numpy is not installed. """ - try: - import numpy as np - except ImportError as e: + if not _HAS_NUMPY: msg = ( "cosine_similarity requires numpy to be installed. " "Please install numpy with `pip install numpy`." ) - raise ImportError(msg) from e + raise ImportError(msg) if len(x) == 0 or len(y) == 0: return np.array([[]]) @@ -70,9 +80,7 @@ def _cosine_similarity(x: Matrix, y: Matrix) -> np.ndarray: f"and Y has shape {y.shape}." ) raise ValueError(msg) - try: - import simsimd as simd # type: ignore[import-not-found] - except ImportError: + if not _HAS_SIMSIMD: logger.debug( "Unable to import simsimd, defaulting to NumPy implementation. If you want " "to use simsimd please install with `pip install simsimd`." @@ -113,14 +121,12 @@ def maximal_marginal_relevance( Raises: ImportError: If numpy is not installed. """ - try: - import numpy as np - except ImportError as e: + if not _HAS_NUMPY: msg = ( "maximal_marginal_relevance requires numpy to be installed. " "Please install numpy with `pip install numpy`." ) - raise ImportError(msg) from e + raise ImportError(msg) if min(k, len(embedding_list)) <= 0: return [] diff --git a/libs/core/pyproject.toml b/libs/core/pyproject.toml index 8f48ba5d5e8..50140edb77b 100644 --- a/libs/core/pyproject.toml +++ b/libs/core/pyproject.toml @@ -114,7 +114,6 @@ ignore = [ "BLE", # Blind exceptions "DOC", # Docstrings (preview) "ERA", # No commented-out code - "PLC0415", # Imports outside top level "PLR2004", # Comparison to magic number ] unfixable = ["PLW1510",] diff --git a/libs/core/tests/unit_tests/language_models/chat_models/test_cache.py b/libs/core/tests/unit_tests/language_models/chat_models/test_cache.py index 39e4babc782..b8816fa1898 100644 --- a/libs/core/tests/unit_tests/language_models/chat_models/test_cache.py +++ b/libs/core/tests/unit_tests/language_models/chat_models/test_cache.py @@ -12,6 +12,7 @@ from langchain_core.language_models.fake_chat_models import ( FakeListChatModel, GenericFakeChatModel, ) +from langchain_core.load import dumps from langchain_core.messages import AIMessage from langchain_core.outputs import ChatGeneration, Generation from langchain_core.outputs.chat_result import ChatResult @@ -318,8 +319,6 @@ def test_cache_with_generation_objects() -> None: cache = InMemoryCache() # Create a simple fake chat model that we can control - from langchain_core.messages import AIMessage - class SimpleFakeChat: """Simple fake chat model for testing.""" @@ -332,8 +331,6 @@ def test_cache_with_generation_objects() -> None: def generate_response(self, prompt: str) -> ChatResult: """Simulate the cache lookup and generation logic.""" - from langchain_core.load import dumps - llm_string = self._get_llm_string() prompt_str = dumps([prompt]) diff --git a/libs/core/tests/unit_tests/language_models/chat_models/test_rate_limiting.py b/libs/core/tests/unit_tests/language_models/chat_models/test_rate_limiting.py index c4d6a50f6be..2cbc0001c3a 100644 --- a/libs/core/tests/unit_tests/language_models/chat_models/test_rate_limiting.py +++ b/libs/core/tests/unit_tests/language_models/chat_models/test_rate_limiting.py @@ -5,6 +5,7 @@ from blockbuster import BlockBuster from langchain_core.caches import InMemoryCache from langchain_core.language_models import GenericFakeChatModel +from langchain_core.load import dumps from langchain_core.rate_limiters import InMemoryRateLimiter @@ -229,8 +230,6 @@ class SerializableModel(GenericFakeChatModel): def test_serialization_with_rate_limiter() -> None: """Test model serialization with rate limiter.""" - from langchain_core.load import dumps - model = SerializableModel( messages=iter(["hello", "world", "!"]), rate_limiter=InMemoryRateLimiter( diff --git a/libs/core/tests/unit_tests/load/test_serializable.py b/libs/core/tests/unit_tests/load/test_serializable.py index e782709fd41..8630df1b47e 100644 --- a/libs/core/tests/unit_tests/load/test_serializable.py +++ b/libs/core/tests/unit_tests/load/test_serializable.py @@ -1,7 +1,7 @@ import json import pytest -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, SecretStr from langchain_core.load import Serializable, dumpd, dumps, load from langchain_core.load.serializable import _is_field_useful @@ -62,9 +62,6 @@ def test_simple_serialization_is_serializable() -> None: def test_simple_serialization_secret() -> None: """Test handling of secrets.""" - from pydantic import SecretStr - - from langchain_core.load import Serializable class Foo(Serializable): bar: int diff --git a/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py b/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py index 7258f5bd20f..edcf72e4f0f 100644 --- a/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py +++ b/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py @@ -1,6 +1,7 @@ from collections.abc import AsyncIterator, Iterator from typing import Any +import pydantic import pytest from pydantic import BaseModel, Field, ValidationError @@ -802,7 +803,6 @@ async def test_partial_pydantic_output_parser_async() -> None: def test_parse_with_different_pydantic_2_v1() -> None: """Test with pydantic.v1.BaseModel from pydantic 2.""" - import pydantic class Forecast(pydantic.v1.BaseModel): temperature: int @@ -836,9 +836,8 @@ def test_parse_with_different_pydantic_2_v1() -> None: def test_parse_with_different_pydantic_2_proper() -> None: """Test with pydantic.BaseModel from pydantic 2.""" - import pydantic - class Forecast(pydantic.BaseModel): + class Forecast(BaseModel): temperature: int forecast: str diff --git a/libs/core/tests/unit_tests/output_parsers/test_pydantic_parser.py b/libs/core/tests/unit_tests/output_parsers/test_pydantic_parser.py index a72aa926780..07acf94d7c2 100644 --- a/libs/core/tests/unit_tests/output_parsers/test_pydantic_parser.py +++ b/libs/core/tests/unit_tests/output_parsers/test_pydantic_parser.py @@ -189,8 +189,6 @@ def test_pydantic_output_parser_type_inference() -> None: def test_format_instructions_preserves_language() -> None: """Test format instructions does not attempt to encode into ascii.""" - from pydantic import BaseModel, Field - description = ( "你好, こんにちは, नमस्ते, Bonjour, Hola, " "Olá, 안녕하세요, Jambo, Merhaba, Γειά σου" # noqa: RUF001 diff --git a/libs/core/tests/unit_tests/prompts/test_prompt.py b/libs/core/tests/unit_tests/prompts/test_prompt.py index 09d26438cd4..6c3e052089a 100644 --- a/libs/core/tests/unit_tests/prompts/test_prompt.py +++ b/libs/core/tests/unit_tests/prompts/test_prompt.py @@ -1,6 +1,7 @@ """Test functionality related to prompts.""" import re +from tempfile import NamedTemporaryFile from typing import Any, Union from unittest import mock @@ -32,8 +33,6 @@ def test_from_file_encoding() -> None: input_variables = ["foo"] # First write to a file using CP-1252 encoding. - from tempfile import NamedTemporaryFile - with NamedTemporaryFile(delete=True, mode="w", encoding="cp1252") as f: f.write(template) f.flush() @@ -434,11 +433,9 @@ Will it get confused{ }? assert prompt == expected_prompt -@pytest.mark.requires("jinja2") def test_basic_sandboxing_with_jinja2() -> None: """Test basic sandboxing with jinja2.""" - import jinja2 - + jinja2 = pytest.importorskip("jinja2") template = " {{''.__class__.__bases__[0] }} " # malicious code prompt = PromptTemplate.from_template(template, template_format="jinja2") with pytest.raises(jinja2.exceptions.SecurityError): diff --git a/libs/core/tests/unit_tests/runnables/test_concurrency.py b/libs/core/tests/unit_tests/runnables/test_concurrency.py index 24d4fad5d23..60022b803d7 100644 --- a/libs/core/tests/unit_tests/runnables/test_concurrency.py +++ b/libs/core/tests/unit_tests/runnables/test_concurrency.py @@ -2,6 +2,7 @@ import asyncio import time +from threading import Lock from typing import Any import pytest @@ -80,7 +81,6 @@ def test_batch_concurrency() -> None: """Test that batch respects max_concurrency.""" running_tasks = 0 max_running_tasks = 0 - from threading import Lock lock = Lock() @@ -112,7 +112,6 @@ def test_batch_as_completed_concurrency() -> None: """Test that batch_as_completed respects max_concurrency.""" running_tasks = 0 max_running_tasks = 0 - from threading import Lock lock = Lock() diff --git a/libs/core/tests/unit_tests/runnables/test_history.py b/libs/core/tests/unit_tests/runnables/test_history.py index 0a4b5326ff3..249b4d6e45e 100644 --- a/libs/core/tests/unit_tests/runnables/test_history.py +++ b/libs/core/tests/unit_tests/runnables/test_history.py @@ -4,7 +4,7 @@ from typing import Any, Callable, Optional, Union import pytest from packaging import version -from pydantic import BaseModel +from pydantic import BaseModel, RootModel from typing_extensions import override from langchain_core.callbacks import ( @@ -20,6 +20,11 @@ from langchain_core.runnables.config import RunnableConfig from langchain_core.runnables.history import RunnableWithMessageHistory from langchain_core.runnables.utils import ConfigurableFieldSpec, Input, Output from langchain_core.tracers import Run +from langchain_core.tracers.root_listeners import ( + AsyncListener, + AsyncRootListenersTracer, + RootListenersTracer, +) from langchain_core.utils.pydantic import PYDANTIC_VERSION from tests.unit_tests.pydantic_utils import _schema @@ -499,8 +504,6 @@ def test_get_output_schema() -> None: def test_get_input_schema_input_messages() -> None: - from pydantic import RootModel - runnable_with_message_history_input = RootModel[Sequence[BaseMessage]] runnable = RunnableLambda( @@ -776,8 +779,6 @@ def test_ignore_session_id() -> None: class _RunnableLambdaWithRaiseError(RunnableLambda[Input, Output]): - from langchain_core.tracers.root_listeners import AsyncListener - def with_listeners( self, *, @@ -791,8 +792,6 @@ class _RunnableLambdaWithRaiseError(RunnableLambda[Input, Output]): Union[Callable[[Run], None], Callable[[Run, RunnableConfig], None]] ] = None, ) -> Runnable[Input, Output]: - from langchain_core.tracers.root_listeners import RootListenersTracer - def create_tracer(config: RunnableConfig) -> RunnableConfig: tracer = RootListenersTracer( config=config, @@ -817,8 +816,6 @@ class _RunnableLambdaWithRaiseError(RunnableLambda[Input, Output]): on_end: Optional[AsyncListener] = None, on_error: Optional[AsyncListener] = None, ) -> Runnable[Input, Output]: - from langchain_core.tracers.root_listeners import AsyncRootListenersTracer - def create_tracer(config: RunnableConfig) -> RunnableConfig: tracer = AsyncRootListenersTracer( config=config, diff --git a/libs/core/tests/unit_tests/runnables/test_imports.py b/libs/core/tests/unit_tests/runnables/test_imports.py index f0ebf48c446..e40ffc1fa71 100644 --- a/libs/core/tests/unit_tests/runnables/test_imports.py +++ b/libs/core/tests/unit_tests/runnables/test_imports.py @@ -40,6 +40,6 @@ def test_all_imports() -> None: def test_imports_for_specific_funcs() -> None: """Test that a few specific imports in more internal namespaces.""" # create_model implementation has been moved to langchain_core.utils.pydantic - from langchain_core.runnables.utils import ( # type: ignore[attr-defined] # noqa: F401 + from langchain_core.runnables.utils import ( # type: ignore[attr-defined] # noqa: F401,PLC0415 create_model, ) diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index 14a5334a844..36ef2f80834 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -1,6 +1,7 @@ import asyncio import re import sys +import time import uuid import warnings from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence @@ -17,6 +18,7 @@ from pytest_mock import MockerFixture from syrupy.assertion import SnapshotAssertion from typing_extensions import TypedDict, override +from langchain_core.callbacks import BaseCallbackHandler from langchain_core.callbacks.manager import ( AsyncCallbackManagerForRetrieverRun, CallbackManagerForRetrieverRun, @@ -29,6 +31,7 @@ from langchain_core.language_models import ( FakeListLLM, FakeStreamingListLLM, ) +from langchain_core.language_models.fake_chat_models import GenericFakeChatModel from langchain_core.load import dumpd, dumps from langchain_core.load.load import loads from langchain_core.messages import AIMessageChunk, HumanMessage, SystemMessage @@ -5516,9 +5519,6 @@ async def test_passthrough_atransform_with_dicts() -> None: def test_listeners() -> None: - from langchain_core.runnables import RunnableLambda - from langchain_core.tracers.schemas import Run - def fake_chain(inputs: dict) -> dict: return {**inputs, "key": "extra"} @@ -5546,9 +5546,6 @@ def test_listeners() -> None: async def test_listeners_async() -> None: - from langchain_core.runnables import RunnableLambda - from langchain_core.tracers.schemas import Run - def fake_chain(inputs: dict) -> dict: return {**inputs, "key": "extra"} @@ -5578,12 +5575,6 @@ async def test_listeners_async() -> None: def test_closing_iterator_doesnt_raise_error() -> None: """Test that closing an iterator calls on_chain_end rather than on_chain_error.""" - import time - - from langchain_core.callbacks import BaseCallbackHandler - from langchain_core.language_models.fake_chat_models import GenericFakeChatModel - from langchain_core.output_parsers import StrOutputParser - on_chain_error_triggered = False on_chain_end_triggered = False diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py index a7731053032..ec6451aeac8 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py @@ -20,6 +20,7 @@ from pydantic import BaseModel from typing_extensions import override from langchain_core.callbacks import CallbackManagerForRetrieverRun, Callbacks +from langchain_core.callbacks.manager import adispatch_custom_event from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.documents import Document from langchain_core.language_models import FakeStreamingListLLM, GenericFakeChatModel @@ -2548,7 +2549,6 @@ async def test_cancel_astream_events() -> None: async def test_custom_event() -> None: """Test adhoc event.""" - from langchain_core.callbacks.manager import adispatch_custom_event # Ignoring type due to RunnableLamdba being dynamic when it comes to being # applied as a decorator to async functions. @@ -2625,7 +2625,6 @@ async def test_custom_event() -> None: async def test_custom_event_nested() -> None: """Test adhoc event in a nested chain.""" - from langchain_core.callbacks.manager import adispatch_custom_event # Ignoring type due to RunnableLamdba being dynamic when it comes to being # applied as a decorator to async functions. @@ -2736,7 +2735,6 @@ async def test_custom_event_root_dispatch() -> None: # This just tests that nothing breaks on the path. # It shouldn't do anything at the moment, since the tracer isn't configured # to handle adhoc events. - from langchain_core.callbacks.manager import adispatch_custom_event # Expected behavior is that the event cannot be dispatched with pytest.raises(RuntimeError): @@ -2750,8 +2748,6 @@ IS_GTE_3_11 = sys.version_info >= (3, 11) @pytest.mark.skipif(not IS_GTE_3_11, reason="Requires Python >=3.11") async def test_custom_event_root_dispatch_with_in_tool() -> None: """Test adhoc event in a nested chain.""" - from langchain_core.callbacks.manager import adispatch_custom_event - from langchain_core.tools import tool @tool async def foo(x: int) -> int: diff --git a/libs/core/tests/unit_tests/test_globals.py b/libs/core/tests/unit_tests/test_globals.py index a0b2457bb48..58d6ad38ab8 100644 --- a/libs/core/tests/unit_tests/test_globals.py +++ b/libs/core/tests/unit_tests/test_globals.py @@ -1,18 +1,17 @@ +import langchain_core +from langchain_core.callbacks.manager import _get_debug from langchain_core.globals import get_debug, set_debug def test_debug_is_settable_via_setter() -> None: - from langchain_core import globals as globals_ - from langchain_core.callbacks.manager import _get_debug - - previous_value = globals_._debug + previous_value = langchain_core.globals._debug previous_fn_reading = _get_debug() assert previous_value == previous_fn_reading # Flip the value of the flag. set_debug(not previous_value) - new_value = globals_._debug + new_value = langchain_core.globals._debug new_fn_reading = _get_debug() try: diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index 63107361224..aabc7e35ff4 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -21,7 +21,7 @@ from typing import ( ) import pytest -from pydantic import BaseModel, Field, ValidationError +from pydantic import BaseModel, ConfigDict, Field, ValidationError from pydantic.v1 import BaseModel as BaseModelV1 from pydantic.v1 import ValidationError as ValidationErrorV1 from typing_extensions import TypedDict, override @@ -1852,7 +1852,6 @@ def generate_models() -> list[Any]: def generate_backwards_compatible_v1() -> list[Any]: """Generate a model with pydantic 2 from the v1 namespace.""" - from pydantic.v1 import BaseModel as BaseModelV1 class FooV1Namespace(BaseModelV1): a: int @@ -1920,8 +1919,6 @@ def test_args_schema_explicitly_typed() -> None: is a pydantic 1 model! """ - # Check with whatever pydantic model is passed in and not via v1 namespace - from pydantic import BaseModel class Foo(BaseModel): a: int @@ -1964,7 +1961,6 @@ def test_args_schema_explicitly_typed() -> None: @pytest.mark.parametrize("pydantic_model", TEST_MODELS) def test_structured_tool_with_different_pydantic_versions(pydantic_model: Any) -> None: """This should test that one can type the args schema as a pydantic model.""" - from langchain_core.tools import StructuredTool def foo(a: int, b: str) -> str: """Hahaha.""" @@ -2063,16 +2059,13 @@ def test__get_all_basemodel_annotations_v2(*, use_v1_namespace: bool) -> None: A = TypeVar("A") if use_v1_namespace: - from pydantic.v1 import BaseModel as BaseModel1 - class ModelA(BaseModel1, Generic[A], extra="allow"): + class ModelA(BaseModelV1, Generic[A], extra="allow"): a: A else: - from pydantic import BaseModel as BaseModel2 - from pydantic import ConfigDict - class ModelA(BaseModel2, Generic[A]): # type: ignore[no-redef] + class ModelA(BaseModel, Generic[A]): # type: ignore[no-redef] a: A model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow") @@ -2208,12 +2201,8 @@ def test_create_retriever_tool() -> None: def test_tool_args_schema_pydantic_v2_with_metadata() -> None: - from pydantic import BaseModel as BaseModelV2 - from pydantic import Field as FieldV2 - from pydantic import ValidationError as ValidationErrorV2 - - class Foo(BaseModelV2): - x: list[int] = FieldV2( + class Foo(BaseModel): + x: list[int] = Field( description="List of integers", min_length=10, max_length=15 ) @@ -2240,7 +2229,7 @@ def test_tool_args_schema_pydantic_v2_with_metadata() -> None: } assert foo.invoke({"x": [0] * 10}) - with pytest.raises(ValidationErrorV2): + with pytest.raises(ValidationError): foo.invoke({"x": [0] * 9}) @@ -2576,8 +2565,6 @@ def test_title_property_preserved() -> None: https://github.com/langchain-ai/langchain/issues/30456 """ - from langchain_core.tools import tool - schema_to_be_extracted = { "type": "object", "required": [], diff --git a/libs/core/tests/unit_tests/utils/test_pydantic.py b/libs/core/tests/unit_tests/utils/test_pydantic.py index dadce0e3f7e..3afeca68c73 100644 --- a/libs/core/tests/unit_tests/utils/test_pydantic.py +++ b/libs/core/tests/unit_tests/utils/test_pydantic.py @@ -3,7 +3,8 @@ import warnings from typing import Any, Optional -from pydantic import ConfigDict +from pydantic import BaseModel, ConfigDict, Field +from pydantic.v1 import BaseModel as BaseModelV1 from langchain_core.utils.pydantic import ( _create_subset_model_v2, @@ -16,8 +17,6 @@ from langchain_core.utils.pydantic import ( def test_pre_init_decorator() -> None: - from pydantic import BaseModel - class Foo(BaseModel): x: int = 5 y: int @@ -35,8 +34,6 @@ def test_pre_init_decorator() -> None: def test_pre_init_decorator_with_more_defaults() -> None: - from pydantic import BaseModel, Field - class Foo(BaseModel): a: int = 1 b: Optional[int] = None @@ -56,8 +53,6 @@ def test_pre_init_decorator_with_more_defaults() -> None: def test_with_aliases() -> None: - from pydantic import BaseModel, Field - class Foo(BaseModel): x: int = Field(default=1, alias="y") z: int @@ -92,19 +87,14 @@ def test_with_aliases() -> None: def test_is_basemodel_subclass() -> None: """Test pydantic.""" - from pydantic import BaseModel as BaseModelV2 - from pydantic.v1 import BaseModel as BaseModelV1 - - assert is_basemodel_subclass(BaseModelV2) + assert is_basemodel_subclass(BaseModel) assert is_basemodel_subclass(BaseModelV1) def test_is_basemodel_instance() -> None: """Test pydantic.""" - from pydantic import BaseModel as BaseModelV2 - from pydantic.v1 import BaseModel as BaseModelV1 - class Foo(BaseModelV2): + class Foo(BaseModel): x: int assert is_basemodel_instance(Foo(x=5)) @@ -117,11 +107,9 @@ def test_is_basemodel_instance() -> None: def test_with_field_metadata() -> None: """Test pydantic with field metadata.""" - from pydantic import BaseModel as BaseModelV2 - from pydantic import Field as FieldV2 - class Foo(BaseModelV2): - x: list[int] = FieldV2( + class Foo(BaseModel): + x: list[int] = Field( description="List of integers", min_length=10, max_length=15 ) @@ -144,8 +132,6 @@ def test_with_field_metadata() -> None: def test_fields_pydantic_v2_proper() -> None: - from pydantic import BaseModel - class Foo(BaseModel): x: int @@ -154,9 +140,7 @@ def test_fields_pydantic_v2_proper() -> None: def test_fields_pydantic_v1_from_2() -> None: - from pydantic.v1 import BaseModel - - class Foo(BaseModel): + class Foo(BaseModelV1): x: int fields = get_fields(Foo) diff --git a/libs/core/tests/unit_tests/utils/test_utils.py b/libs/core/tests/unit_tests/utils/test_utils.py index 51967b82a19..1e6a88b6375 100644 --- a/libs/core/tests/unit_tests/utils/test_utils.py +++ b/libs/core/tests/unit_tests/utils/test_utils.py @@ -6,7 +6,9 @@ from typing import Any, Callable, Optional, Union from unittest.mock import patch import pytest -from pydantic import SecretStr +from pydantic import BaseModel, Field, SecretStr +from pydantic.v1 import BaseModel as PydanticV1BaseModel +from pydantic.v1 import Field as PydanticV1Field from langchain_core import utils from langchain_core.outputs import GenerationChunk @@ -212,13 +214,10 @@ def test_guard_import_failure( def test_get_pydantic_field_names_v1_in_2() -> None: - from pydantic.v1 import BaseModel as PydanticV1BaseModel - from pydantic.v1 import Field - class PydanticV1Model(PydanticV1BaseModel): field1: str field2: int - alias_field: int = Field(alias="aliased_field") + alias_field: int = PydanticV1Field(alias="aliased_field") result = get_pydantic_field_names(PydanticV1Model) expected = {"field1", "field2", "aliased_field", "alias_field"} @@ -226,8 +225,6 @@ def test_get_pydantic_field_names_v1_in_2() -> None: def test_get_pydantic_field_names_v2_in_2() -> None: - from pydantic import BaseModel, Field - class PydanticModel(BaseModel): field1: str field2: int @@ -341,8 +338,6 @@ def test_secret_from_env_with_custom_error_message( def test_using_secret_from_env_as_default_factory( monkeypatch: pytest.MonkeyPatch, ) -> None: - from pydantic import BaseModel, Field - class Foo(BaseModel): secret: SecretStr = Field(default_factory=secret_from_env("TEST_KEY"))