chore(core): add ruff rule PLC0415 (#32351)

See https://docs.astral.sh/ruff/rules/import-outside-top-level/

Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
Christophe Bornet
2025-09-08 20:15:04 +02:00
committed by GitHub
parent 16420cad71
commit cc98fb9bee
49 changed files with 360 additions and 412 deletions

View File

@@ -28,8 +28,18 @@ from langchain_core.callbacks.base import (
ToolManagerMixin, ToolManagerMixin,
) )
from langchain_core.callbacks.stdout import StdOutCallbackHandler from langchain_core.callbacks.stdout import StdOutCallbackHandler
from langchain_core.globals import get_debug
from langchain_core.messages import BaseMessage, get_buffer_string from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_core.tracers.context import (
_configure_hooks,
_get_trace_callbacks,
_get_tracer_project,
_tracing_v2_is_enabled,
tracing_v2_callback_var,
)
from langchain_core.tracers.langchain import LangChainTracer
from langchain_core.tracers.schemas import Run from langchain_core.tracers.schemas import Run
from langchain_core.tracers.stdout import ConsoleCallbackHandler
from langchain_core.utils.env import env_var_is_set from langchain_core.utils.env import env_var_is_set
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -46,8 +56,6 @@ logger = logging.getLogger(__name__)
def _get_debug() -> bool: def _get_debug() -> bool:
from langchain_core.globals import get_debug
return get_debug() return get_debug()
@@ -103,8 +111,6 @@ def trace_as_chain_group(
manager.on_chain_end({"output": res}) manager.on_chain_end({"output": res})
""" """
from langchain_core.tracers.context import _get_trace_callbacks
cb = _get_trace_callbacks( cb = _get_trace_callbacks(
project_name, example_id, callback_manager=callback_manager project_name, example_id, callback_manager=callback_manager
) )
@@ -189,8 +195,6 @@ async def atrace_as_chain_group(
await manager.on_chain_end({"output": res}) await manager.on_chain_end({"output": res})
""" """
from langchain_core.tracers.context import _get_trace_callbacks
cb = _get_trace_callbacks( cb = _get_trace_callbacks(
project_name, example_id, callback_manager=callback_manager project_name, example_id, callback_manager=callback_manager
) )
@@ -2376,13 +2380,6 @@ def _configure(
Returns: Returns:
T: The configured callback manager. T: The configured callback manager.
""" """
from langchain_core.tracers.context import (
_configure_hooks,
_get_tracer_project,
_tracing_v2_is_enabled,
tracing_v2_callback_var,
)
tracing_context = get_tracing_context() tracing_context = get_tracing_context()
tracing_metadata = tracing_context["metadata"] tracing_metadata = tracing_context["metadata"]
tracing_tags = tracing_context["tags"] tracing_tags = tracing_context["tags"]
@@ -2459,9 +2456,6 @@ def _configure(
tracer_project = _get_tracer_project() tracer_project = _get_tracer_project()
debug = _get_debug() debug = _get_debug()
if verbose or debug or tracing_v2_enabled_: if verbose or debug or tracing_v2_enabled_:
from langchain_core.tracers.langchain import LangChainTracer
from langchain_core.tracers.stdout import ConsoleCallbackHandler
if verbose and not any( if verbose and not any(
isinstance(handler, StdOutCallbackHandler) isinstance(handler, StdOutCallbackHandler)
for handler in callback_manager.handlers for handler in callback_manager.handlers
@@ -2630,7 +2624,8 @@ async def adispatch_custom_event(
.. versionadded:: 0.2.15 .. versionadded:: 0.2.15
""" """
from langchain_core.runnables.config import ( # Import locally to prevent circular imports.
from langchain_core.runnables.config import ( # noqa: PLC0415
ensure_config, ensure_config,
get_async_callback_manager_for_config, get_async_callback_manager_for_config,
) )
@@ -2705,7 +2700,8 @@ def dispatch_custom_event(
.. versionadded:: 0.2.15 .. versionadded:: 0.2.15
""" """
from langchain_core.runnables.config import ( # Import locally to prevent circular imports.
from langchain_core.runnables.config import ( # noqa: PLC0415
ensure_config, ensure_config,
get_callback_manager_for_config, get_callback_manager_for_config,
) )

View File

@@ -12,6 +12,7 @@ from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import AIMessage from langchain_core.messages import AIMessage
from langchain_core.messages.ai import UsageMetadata, add_usage from langchain_core.messages.ai import UsageMetadata, add_usage
from langchain_core.outputs import ChatGeneration, LLMResult from langchain_core.outputs import ChatGeneration, LLMResult
from langchain_core.tracers.context import register_configure_hook
class UsageMetadataCallbackHandler(BaseCallbackHandler): class UsageMetadataCallbackHandler(BaseCallbackHandler):
@@ -133,8 +134,6 @@ def get_usage_metadata_callback(
.. versionadded:: 0.3.49 .. versionadded:: 0.3.49
""" """
from langchain_core.tracers.context import register_configure_hook
usage_metadata_callback_var: ContextVar[Optional[UsageMetadataCallbackHandler]] = ( usage_metadata_callback_var: ContextVar[Optional[UsageMetadataCallbackHandler]] = (
ContextVar(name, default=None) ContextVar(name, default=None)
) )

View File

@@ -27,6 +27,7 @@ from langchain_core.messages import (
HumanMessage, HumanMessage,
get_buffer_string, get_buffer_string,
) )
from langchain_core.runnables.config import run_in_executor
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Sequence from collections.abc import Sequence
@@ -113,8 +114,6 @@ class BaseChatMessageHistory(ABC):
Returns: Returns:
The messages. The messages.
""" """
from langchain_core.runnables.config import run_in_executor
return await run_in_executor(None, lambda: self.messages) return await run_in_executor(None, lambda: self.messages)
def add_user_message(self, message: Union[HumanMessage, str]) -> None: def add_user_message(self, message: Union[HumanMessage, str]) -> None:
@@ -190,8 +189,6 @@ class BaseChatMessageHistory(ABC):
Args: Args:
messages: A sequence of BaseMessage objects to store. messages: A sequence of BaseMessage objects to store.
""" """
from langchain_core.runnables.config import run_in_executor
await run_in_executor(None, self.add_messages, messages) await run_in_executor(None, self.add_messages, messages)
@abstractmethod @abstractmethod
@@ -200,8 +197,6 @@ class BaseChatMessageHistory(ABC):
async def aclear(self) -> None: async def aclear(self) -> None:
"""Async remove all messages from the store.""" """Async remove all messages from the store."""
from langchain_core.runnables.config import run_in_executor
await run_in_executor(None, self.clear) await run_in_executor(None, self.clear)
def __str__(self) -> str: def __str__(self) -> str:

View File

@@ -15,6 +15,13 @@ if TYPE_CHECKING:
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.documents.base import Blob from langchain_core.documents.base import Blob
try:
from langchain_text_splitters import RecursiveCharacterTextSplitter
_HAS_TEXT_SPLITTERS = True
except ImportError:
_HAS_TEXT_SPLITTERS = False
class BaseLoader(ABC): # noqa: B024 class BaseLoader(ABC): # noqa: B024
"""Interface for Document Loader. """Interface for Document Loader.
@@ -62,15 +69,13 @@ class BaseLoader(ABC): # noqa: B024
List of Documents. List of Documents.
""" """
if text_splitter is None: if text_splitter is None:
try: if not _HAS_TEXT_SPLITTERS:
from langchain_text_splitters import RecursiveCharacterTextSplitter
except ImportError as e:
msg = ( msg = (
"Unable to import from langchain_text_splitters. Please specify " "Unable to import from langchain_text_splitters. Please specify "
"text_splitter or install langchain_text_splitters with " "text_splitter or install langchain_text_splitters with "
"`pip install -U langchain-text-splitters`." "`pip install -U langchain-text-splitters`."
) )
raise ImportError(msg) from e raise ImportError(msg)
text_splitter_: TextSplitter = RecursiveCharacterTextSplitter() text_splitter_: TextSplitter = RecursiveCharacterTextSplitter()
else: else:

View File

@@ -1,6 +1,7 @@
"""Module contains a few fake embedding models for testing purposes.""" """Module contains a few fake embedding models for testing purposes."""
# Please do not add additional fake embedding model implementations here. # Please do not add additional fake embedding model implementations here.
import contextlib
import hashlib import hashlib
from pydantic import BaseModel from pydantic import BaseModel
@@ -8,6 +9,9 @@ from typing_extensions import override
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
with contextlib.suppress(ImportError):
import numpy as np
class FakeEmbeddings(Embeddings, BaseModel): class FakeEmbeddings(Embeddings, BaseModel):
"""Fake embedding model for unit testing purposes. """Fake embedding model for unit testing purposes.
@@ -54,8 +58,6 @@ class FakeEmbeddings(Embeddings, BaseModel):
"""The size of the embedding vector.""" """The size of the embedding vector."""
def _get_embedding(self) -> list[float]: def _get_embedding(self) -> list[float]:
import numpy as np
return list(np.random.default_rng().normal(size=self.size)) return list(np.random.default_rng().normal(size=self.size))
@override @override
@@ -113,8 +115,6 @@ class DeterministicFakeEmbedding(Embeddings, BaseModel):
"""The size of the embedding vector.""" """The size of the embedding vector."""
def _get_embedding(self, seed: int) -> list[float]: def _get_embedding(self, seed: int) -> list[float]:
import numpy as np
# set the seed for the random generator # set the seed for the random generator
rng = np.random.default_rng(seed) rng = np.random.default_rng(seed)
return list(rng.normal(size=self.size)) return list(rng.normal(size=self.size))

View File

@@ -3,6 +3,8 @@
import platform import platform
from functools import lru_cache from functools import lru_cache
from langchain_core import __version__
@lru_cache(maxsize=1) @lru_cache(maxsize=1)
def get_runtime_environment() -> dict: def get_runtime_environment() -> dict:
@@ -11,9 +13,6 @@ def get_runtime_environment() -> dict:
Returns: Returns:
A dictionary with information about the runtime environment. A dictionary with information about the runtime environment.
""" """
# Lazy import to avoid circular imports
from langchain_core import __version__
return { return {
"library_version": __version__, "library_version": __version__,
"library": "langchain-core", "library": "langchain-core",

View File

@@ -6,6 +6,13 @@ from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING: if TYPE_CHECKING:
from langchain_core.caches import BaseCache from langchain_core.caches import BaseCache
try:
import langchain # type: ignore[import-not-found]
_HAS_LANGCHAIN = True
except ImportError:
_HAS_LANGCHAIN = False
# DO NOT USE THESE VALUES DIRECTLY! # DO NOT USE THESE VALUES DIRECTLY!
# Use them only via `get_<X>()` and `set_<X>()` below, # Use them only via `get_<X>()` and `set_<X>()` below,
@@ -22,9 +29,7 @@ def set_verbose(value: bool) -> None: # noqa: FBT001
Args: Args:
value: The new value for the `verbose` global setting. value: The new value for the `verbose` global setting.
""" """
try: if _HAS_LANGCHAIN:
import langchain # type: ignore[import-not-found]
# We're about to run some deprecated code, don't report warnings from it. # We're about to run some deprecated code, don't report warnings from it.
# The user called the correct (non-deprecated) code path and shouldn't get # The user called the correct (non-deprecated) code path and shouldn't get
# warnings. # warnings.
@@ -43,8 +48,6 @@ def set_verbose(value: bool) -> None: # noqa: FBT001
# Remove it once `langchain.verbose` is no longer supported, and once all # Remove it once `langchain.verbose` is no longer supported, and once all
# users have migrated to using `set_verbose()` here. # users have migrated to using `set_verbose()` here.
langchain.verbose = value langchain.verbose = value
except ImportError:
pass
global _verbose # noqa: PLW0603 global _verbose # noqa: PLW0603
_verbose = value _verbose = value
@@ -56,9 +59,7 @@ def get_verbose() -> bool:
Returns: Returns:
The value of the `verbose` global setting. The value of the `verbose` global setting.
""" """
try: if _HAS_LANGCHAIN:
import langchain
# We're about to run some deprecated code, don't report warnings from it. # We're about to run some deprecated code, don't report warnings from it.
# The user called the correct (non-deprecated) code path and shouldn't get # The user called the correct (non-deprecated) code path and shouldn't get
# warnings. # warnings.
@@ -83,7 +84,7 @@ def get_verbose() -> bool:
# deprecation warnings directing them to use `set_verbose()` when they # deprecation warnings directing them to use `set_verbose()` when they
# import `langchain.verbose`. # import `langchain.verbose`.
old_verbose = langchain.verbose old_verbose = langchain.verbose
except ImportError: else:
old_verbose = False old_verbose = False
return _verbose or old_verbose return _verbose or old_verbose
@@ -95,9 +96,7 @@ def set_debug(value: bool) -> None: # noqa: FBT001
Args: Args:
value: The new value for the `debug` global setting. value: The new value for the `debug` global setting.
""" """
try: if _HAS_LANGCHAIN:
import langchain
# We're about to run some deprecated code, don't report warnings from it. # We're about to run some deprecated code, don't report warnings from it.
# The user called the correct (non-deprecated) code path and shouldn't get # The user called the correct (non-deprecated) code path and shouldn't get
# warnings. # warnings.
@@ -114,8 +113,6 @@ def set_debug(value: bool) -> None: # noqa: FBT001
# Remove it once `langchain.debug` is no longer supported, and once all # Remove it once `langchain.debug` is no longer supported, and once all
# users have migrated to using `set_debug()` here. # users have migrated to using `set_debug()` here.
langchain.debug = value langchain.debug = value
except ImportError:
pass
global _debug # noqa: PLW0603 global _debug # noqa: PLW0603
_debug = value _debug = value
@@ -127,9 +124,7 @@ def get_debug() -> bool:
Returns: Returns:
The value of the `debug` global setting. The value of the `debug` global setting.
""" """
try: if _HAS_LANGCHAIN:
import langchain
# We're about to run some deprecated code, don't report warnings from it. # We're about to run some deprecated code, don't report warnings from it.
# The user called the correct (non-deprecated) code path and shouldn't get # The user called the correct (non-deprecated) code path and shouldn't get
# warnings. # warnings.
@@ -151,7 +146,7 @@ def get_debug() -> bool:
# to using `set_debug()` yet. Those users are getting deprecation warnings # to using `set_debug()` yet. Those users are getting deprecation warnings
# directing them to use `set_debug()` when they import `langchain.debug`. # directing them to use `set_debug()` when they import `langchain.debug`.
old_debug = langchain.debug old_debug = langchain.debug
except ImportError: else:
old_debug = False old_debug = False
return _debug or old_debug return _debug or old_debug
@@ -163,9 +158,7 @@ def set_llm_cache(value: Optional["BaseCache"]) -> None:
Args: Args:
value: The new LLM cache to use. If `None`, the LLM cache is disabled. value: The new LLM cache to use. If `None`, the LLM cache is disabled.
""" """
try: if _HAS_LANGCHAIN:
import langchain
# We're about to run some deprecated code, don't report warnings from it. # We're about to run some deprecated code, don't report warnings from it.
# The user called the correct (non-deprecated) code path and shouldn't get # The user called the correct (non-deprecated) code path and shouldn't get
# warnings. # warnings.
@@ -184,8 +177,6 @@ def set_llm_cache(value: Optional["BaseCache"]) -> None:
# Remove it once `langchain.llm_cache` is no longer supported, and # Remove it once `langchain.llm_cache` is no longer supported, and
# once all users have migrated to using `set_llm_cache()` here. # once all users have migrated to using `set_llm_cache()` here.
langchain.llm_cache = value langchain.llm_cache = value
except ImportError:
pass
global _llm_cache # noqa: PLW0603 global _llm_cache # noqa: PLW0603
_llm_cache = value _llm_cache = value
@@ -197,9 +188,7 @@ def get_llm_cache() -> Optional["BaseCache"]:
Returns: Returns:
The value of the `llm_cache` global setting. The value of the `llm_cache` global setting.
""" """
try: if _HAS_LANGCHAIN:
import langchain
# We're about to run some deprecated code, don't report warnings from it. # We're about to run some deprecated code, don't report warnings from it.
# The user called the correct (non-deprecated) code path and shouldn't get # The user called the correct (non-deprecated) code path and shouldn't get
# warnings. # warnings.
@@ -225,7 +214,7 @@ def get_llm_cache() -> Optional["BaseCache"]:
# Those users are getting deprecation warnings directing them # Those users are getting deprecation warnings directing them
# to use `set_llm_cache()` when they import `langchain.llm_cache`. # to use `set_llm_cache()` when they import `langchain.llm_cache`.
old_llm_cache = langchain.llm_cache old_llm_cache = langchain.llm_cache
except ImportError: else:
old_llm_cache = None old_llm_cache = None
return _llm_cache or old_llm_cache return _llm_cache or old_llm_cache

View File

@@ -22,19 +22,31 @@ from typing_extensions import TypeAlias, TypedDict, override
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.caches import BaseCache from langchain_core.caches import BaseCache
from langchain_core.callbacks import Callbacks from langchain_core.callbacks import Callbacks
from langchain_core.globals import get_verbose
from langchain_core.messages import ( from langchain_core.messages import (
AnyMessage, AnyMessage,
BaseMessage, BaseMessage,
MessageLikeRepresentation, MessageLikeRepresentation,
get_buffer_string, get_buffer_string,
) )
from langchain_core.prompt_values import PromptValue from langchain_core.prompt_values import (
ChatPromptValueConcrete,
PromptValue,
StringPromptValue,
)
from langchain_core.runnables import Runnable, RunnableSerializable from langchain_core.runnables import Runnable, RunnableSerializable
from langchain_core.utils import get_pydantic_field_names from langchain_core.utils import get_pydantic_field_names
if TYPE_CHECKING: if TYPE_CHECKING:
from langchain_core.outputs import LLMResult from langchain_core.outputs import LLMResult
try:
from transformers import GPT2TokenizerFast # type: ignore[import-not-found]
_HAS_TRANSFORMERS = True
except ImportError:
_HAS_TRANSFORMERS = False
class LangSmithParams(TypedDict, total=False): class LangSmithParams(TypedDict, total=False):
"""LangSmith parameters for tracing.""" """LangSmith parameters for tracing."""
@@ -66,15 +78,13 @@ def get_tokenizer() -> Any:
The GPT-2 tokenizer instance. The GPT-2 tokenizer instance.
""" """
try: if not _HAS_TRANSFORMERS:
from transformers import GPT2TokenizerFast # type: ignore[import-not-found]
except ImportError as e:
msg = ( msg = (
"Could not import transformers python package. " "Could not import transformers python package. "
"This is needed in order to calculate get_token_ids. " "This is needed in order to calculate get_token_ids. "
"Please install it with `pip install transformers`." "Please install it with `pip install transformers`."
) )
raise ImportError(msg) from e raise ImportError(msg)
# create a GPT-2 tokenizer instance # create a GPT-2 tokenizer instance
return GPT2TokenizerFast.from_pretrained("gpt2") return GPT2TokenizerFast.from_pretrained("gpt2")
@@ -95,8 +105,6 @@ LanguageModelOutputVar = TypeVar("LanguageModelOutputVar", BaseMessage, str)
def _get_verbosity() -> bool: def _get_verbosity() -> bool:
from langchain_core.globals import get_verbose
return get_verbose() return get_verbose()
@@ -158,11 +166,6 @@ class BaseLanguageModel(
@override @override
def InputType(self) -> TypeAlias: def InputType(self) -> TypeAlias:
"""Get the input type for this runnable.""" """Get the input type for this runnable."""
from langchain_core.prompt_values import (
ChatPromptValueConcrete,
StringPromptValue,
)
# This is a version of LanguageModelInput which replaces the abstract # This is a version of LanguageModelInput which replaces the abstract
# base class BaseMessage with a union of its subclasses, which makes # base class BaseMessage with a union of its subclasses, which makes
# for a much better schema. # for a much better schema.

View File

@@ -46,6 +46,10 @@ from langchain_core.messages import (
message_chunk_to_message, message_chunk_to_message,
) )
from langchain_core.messages.ai import _LC_ID_PREFIX from langchain_core.messages.ai import _LC_ID_PREFIX
from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser,
PydanticToolsParser,
)
from langchain_core.outputs import ( from langchain_core.outputs import (
ChatGeneration, ChatGeneration,
ChatGenerationChunk, ChatGenerationChunk,
@@ -1590,11 +1594,6 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
msg = f"Received unsupported arguments {kwargs}" msg = f"Received unsupported arguments {kwargs}"
raise ValueError(msg) raise ValueError(msg)
from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser,
PydanticToolsParser,
)
if type(self).bind_tools is BaseChatModel.bind_tools: if type(self).bind_tools is BaseChatModel.bind_tools:
msg = "with_structured_output is not implemented for this model." msg = "with_structured_output is not implemented for this model."
raise NotImplementedError(msg) raise NotImplementedError(msg)

View File

@@ -6,6 +6,8 @@ from typing import Any
from pydantic import BaseModel from pydantic import BaseModel
from langchain_core.load.serializable import Serializable, to_json_not_implemented from langchain_core.load.serializable import Serializable, to_json_not_implemented
from langchain_core.messages import AIMessage
from langchain_core.outputs import ChatGeneration
def default(obj: Any) -> Any: def default(obj: Any) -> Any:
@@ -23,9 +25,6 @@ def default(obj: Any) -> Any:
def _dump_pydantic_models(obj: Any) -> Any: def _dump_pydantic_models(obj: Any) -> Any:
from langchain_core.messages import AIMessage
from langchain_core.outputs import ChatGeneration
if ( if (
isinstance(obj, ChatGeneration) isinstance(obj, ChatGeneration)
and isinstance(obj.message, AIMessage) and isinstance(obj.message, AIMessage)

View File

@@ -118,7 +118,8 @@ class BaseMessage(Serializable):
Returns: Returns:
A ChatPromptTemplate containing both messages. A ChatPromptTemplate containing both messages.
""" """
from langchain_core.prompts.chat import ChatPromptTemplate # Import locally to prevent circular imports.
from langchain_core.prompts.chat import ChatPromptTemplate # noqa: PLC0415
prompt = ChatPromptTemplate(messages=[self]) prompt = ChatPromptTemplate(messages=[self])
return prompt + other return prompt + other

View File

@@ -42,12 +42,17 @@ from langchain_core.messages.system import SystemMessage, SystemMessageChunk
from langchain_core.messages.tool import ToolCall, ToolMessage, ToolMessageChunk from langchain_core.messages.tool import ToolCall, ToolMessage, ToolMessageChunk
if TYPE_CHECKING: if TYPE_CHECKING:
from langchain_text_splitters import TextSplitter
from langchain_core.language_models import BaseLanguageModel from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompt_values import PromptValue from langchain_core.prompt_values import PromptValue
from langchain_core.runnables.base import Runnable from langchain_core.runnables.base import Runnable
try:
from langchain_text_splitters import TextSplitter
_HAS_LANGCHAIN_TEXT_SPLITTERS = True
except ImportError:
_HAS_LANGCHAIN_TEXT_SPLITTERS = False
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -361,7 +366,7 @@ def convert_to_messages(
list of messages (BaseMessages). list of messages (BaseMessages).
""" """
# Import here to avoid circular imports # Import here to avoid circular imports
from langchain_core.prompt_values import PromptValue from langchain_core.prompt_values import PromptValue # noqa: PLC0415
if isinstance(messages, PromptValue): if isinstance(messages, PromptValue):
return messages.to_messages() return messages.to_messages()
@@ -386,7 +391,8 @@ def _runnable_support(func: Callable) -> Callable:
list[BaseMessage], list[BaseMessage],
Runnable[Sequence[MessageLikeRepresentation], list[BaseMessage]], Runnable[Sequence[MessageLikeRepresentation], list[BaseMessage]],
]: ]:
from langchain_core.runnables.base import RunnableLambda # Import locally to prevent circular import.
from langchain_core.runnables.base import RunnableLambda # noqa: PLC0415
if messages is not None: if messages is not None:
return func(messages, **kwargs) return func(messages, **kwargs)
@@ -989,17 +995,12 @@ def trim_messages(
) )
raise ValueError(msg) raise ValueError(msg)
try: if _HAS_LANGCHAIN_TEXT_SPLITTERS and isinstance(text_splitter, TextSplitter):
from langchain_text_splitters import TextSplitter text_splitter_fn = text_splitter.split_text
except ImportError: elif text_splitter:
text_splitter_fn: Optional[Callable] = cast("Optional[Callable]", text_splitter) text_splitter_fn = cast("Callable", text_splitter)
else: else:
if isinstance(text_splitter, TextSplitter): text_splitter_fn = _default_text_splitter
text_splitter_fn = text_splitter.split_text
else:
text_splitter_fn = text_splitter
text_splitter_fn = text_splitter_fn or _default_text_splitter
if strategy == "first": if strategy == "first":
return _first_max_tokens( return _first_max_tokens(

View File

@@ -15,6 +15,14 @@ from langchain_core.messages import BaseMessage
from langchain_core.output_parsers.transform import BaseTransformOutputParser from langchain_core.output_parsers.transform import BaseTransformOutputParser
from langchain_core.runnables.utils import AddableDict from langchain_core.runnables.utils import AddableDict
try:
from defusedxml import ElementTree # type: ignore[import-untyped]
from defusedxml.ElementTree import XMLParser # type: ignore[import-untyped]
_HAS_DEFUSEDXML = True
except ImportError:
_HAS_DEFUSEDXML = False
XML_FORMAT_INSTRUCTIONS = """The output should be formatted as a XML file. XML_FORMAT_INSTRUCTIONS = """The output should be formatted as a XML file.
1. Output should conform to the tags below. 1. Output should conform to the tags below.
2. If tags are not given, make them on your own. 2. If tags are not given, make them on your own.
@@ -50,17 +58,13 @@ class _StreamingParser:
parser is requested. parser is requested.
""" """
if parser == "defusedxml": if parser == "defusedxml":
try: if not _HAS_DEFUSEDXML:
from defusedxml.ElementTree import ( # type: ignore[import-untyped]
XMLParser,
)
except ImportError as e:
msg = ( msg = (
"defusedxml is not installed. " "defusedxml is not installed. "
"Please install it to use the defusedxml parser." "Please install it to use the defusedxml parser."
"You can install it with `pip install defusedxml` " "You can install it with `pip install defusedxml` "
) )
raise ImportError(msg) from e raise ImportError(msg)
parser_ = XMLParser(target=TreeBuilder()) parser_ = XMLParser(target=TreeBuilder())
else: else:
parser_ = None parser_ = None
@@ -207,16 +211,14 @@ class XMLOutputParser(BaseTransformOutputParser):
# Imports are temporarily placed here to avoid issue with caching on CI # Imports are temporarily placed here to avoid issue with caching on CI
# likely if you're reading this you can move them to the top of the file # likely if you're reading this you can move them to the top of the file
if self.parser == "defusedxml": if self.parser == "defusedxml":
try: if not _HAS_DEFUSEDXML:
from defusedxml import ElementTree # type: ignore[import-untyped]
except ImportError as e:
msg = ( msg = (
"defusedxml is not installed. " "defusedxml is not installed. "
"Please install it to use the defusedxml parser." "Please install it to use the defusedxml parser."
"You can install it with `pip install defusedxml`" "You can install it with `pip install defusedxml`"
"See https://github.com/tiran/defusedxml for more details" "See https://github.com/tiran/defusedxml for more details"
) )
raise ImportError(msg) from e raise ImportError(msg)
et = ElementTree # Use the defusedxml parser et = ElementTree # Use the defusedxml parser
else: else:
et = ET # Use the standard library parser et = ET # Use the standard library parser

View File

@@ -88,7 +88,8 @@ class BaseMessagePromptTemplate(Serializable, ABC):
Returns: Returns:
Combined prompt template. Combined prompt template.
""" """
from langchain_core.prompts.chat import ChatPromptTemplate # Import locally to avoid circular import.
from langchain_core.prompts.chat import ChatPromptTemplate # noqa: PLC0415
prompt = ChatPromptTemplate(messages=[self]) prompt = ChatPromptTemplate(messages=[self])
return prompt + other return prompt + other

View File

@@ -15,6 +15,14 @@ from langchain_core.utils import get_colored_text, mustache
from langchain_core.utils.formatting import formatter from langchain_core.utils.formatting import formatter
from langchain_core.utils.interactive_env import is_interactive_env from langchain_core.utils.interactive_env import is_interactive_env
try:
from jinja2 import Environment, meta
from jinja2.sandbox import SandboxedEnvironment
_HAS_JINJA2 = True
except ImportError:
_HAS_JINJA2 = False
PromptTemplateFormat = Literal["f-string", "mustache", "jinja2"] PromptTemplateFormat = Literal["f-string", "mustache", "jinja2"]
@@ -40,9 +48,7 @@ def jinja2_formatter(template: str, /, **kwargs: Any) -> str:
Raises: Raises:
ImportError: If jinja2 is not installed. ImportError: If jinja2 is not installed.
""" """
try: if not _HAS_JINJA2:
from jinja2.sandbox import SandboxedEnvironment
except ImportError as e:
msg = ( msg = (
"jinja2 not installed, which is needed to use the jinja2_formatter. " "jinja2 not installed, which is needed to use the jinja2_formatter. "
"Please install it with `pip install jinja2`." "Please install it with `pip install jinja2`."
@@ -50,7 +56,7 @@ def jinja2_formatter(template: str, /, **kwargs: Any) -> str:
"Do not expand jinja2 templates using unverified or user-controlled " "Do not expand jinja2 templates using unverified or user-controlled "
"inputs as that can result in arbitrary Python code execution." "inputs as that can result in arbitrary Python code execution."
) )
raise ImportError(msg) from e raise ImportError(msg)
# This uses a sandboxed environment to prevent arbitrary code execution. # This uses a sandboxed environment to prevent arbitrary code execution.
# Jinja2 uses an opt-out rather than opt-in approach for sand-boxing. # Jinja2 uses an opt-out rather than opt-in approach for sand-boxing.
@@ -88,14 +94,12 @@ def validate_jinja2(template: str, input_variables: list[str]) -> None:
def _get_jinja2_variables_from_template(template: str) -> set[str]: def _get_jinja2_variables_from_template(template: str) -> set[str]:
try: if not _HAS_JINJA2:
from jinja2 import Environment, meta
except ImportError as e:
msg = ( msg = (
"jinja2 not installed, which is needed to use the jinja2_formatter. " "jinja2 not installed, which is needed to use the jinja2_formatter. "
"Please install it with `pip install jinja2`." "Please install it with `pip install jinja2`."
) )
raise ImportError(msg) from e raise ImportError(msg)
env = Environment() # noqa: S701 env = Environment() # noqa: S701
ast = env.parse(template) ast = env.parse(template)
return meta.find_undeclared_variables(ast) return meta.find_undeclared_variables(ast)

View File

@@ -31,6 +31,7 @@ from typing_extensions import Self, TypedDict, override
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.callbacks import Callbacks from langchain_core.callbacks import Callbacks
from langchain_core.callbacks.manager import AsyncCallbackManager, CallbackManager
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.runnables import ( from langchain_core.runnables import (
Runnable, Runnable,
@@ -236,8 +237,6 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
retriever.invoke("query") retriever.invoke("query")
""" """
from langchain_core.callbacks.manager import CallbackManager
config = ensure_config(config) config = ensure_config(config)
inheritable_metadata = { inheritable_metadata = {
**(config.get("metadata") or {}), **(config.get("metadata") or {}),
@@ -301,8 +300,6 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
await retriever.ainvoke("query") await retriever.ainvoke("query")
""" """
from langchain_core.callbacks.manager import AsyncCallbackManager
config = ensure_config(config) config = ensure_config(config)
inheritable_metadata = { inheritable_metadata = {
**(config.get("metadata") or {}), **(config.get("metadata") or {}),

View File

@@ -41,6 +41,7 @@ from pydantic import BaseModel, ConfigDict, Field, RootModel
from typing_extensions import Literal, get_args, override from typing_extensions import Literal, get_args, override
from langchain_core._api import beta_decorator from langchain_core._api import beta_decorator
from langchain_core.callbacks.manager import AsyncCallbackManager, CallbackManager
from langchain_core.load.serializable import ( from langchain_core.load.serializable import (
Serializable, Serializable,
SerializedConstructor, SerializedConstructor,
@@ -60,7 +61,6 @@ from langchain_core.runnables.config import (
run_in_executor, run_in_executor,
set_config_context, set_config_context,
) )
from langchain_core.runnables.graph import Graph
from langchain_core.runnables.utils import ( from langchain_core.runnables.utils import (
AddableDict, AddableDict,
AnyConfigurableField, AnyConfigurableField,
@@ -81,6 +81,19 @@ from langchain_core.runnables.utils import (
is_async_callable, is_async_callable,
is_async_generator, is_async_generator,
) )
from langchain_core.tracers._streaming import _StreamingCallbackHandler
from langchain_core.tracers.event_stream import (
_astream_events_implementation_v1,
_astream_events_implementation_v2,
)
from langchain_core.tracers.log_stream import (
LogStreamCallbackHandler,
_astream_log_implementation,
)
from langchain_core.tracers.root_listeners import (
AsyncRootListenersTracer,
RootListenersTracer,
)
from langchain_core.utils.aiter import aclosing, atee, py_anext from langchain_core.utils.aiter import aclosing, atee, py_anext
from langchain_core.utils.iter import safetee from langchain_core.utils.iter import safetee
from langchain_core.utils.pydantic import create_model_v2 from langchain_core.utils.pydantic import create_model_v2
@@ -94,6 +107,7 @@ if TYPE_CHECKING:
from langchain_core.runnables.fallbacks import ( from langchain_core.runnables.fallbacks import (
RunnableWithFallbacks as RunnableWithFallbacksT, RunnableWithFallbacks as RunnableWithFallbacksT,
) )
from langchain_core.runnables.graph import Graph
from langchain_core.runnables.retry import ExponentialJitterParams from langchain_core.runnables.retry import ExponentialJitterParams
from langchain_core.runnables.schema import StreamEvent from langchain_core.runnables.schema import StreamEvent
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
@@ -565,6 +579,9 @@ class Runnable(ABC, Generic[Input, Output]):
def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph: def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph:
"""Return a graph representation of this ``Runnable``.""" """Return a graph representation of this ``Runnable``."""
# Import locally to prevent circular import
from langchain_core.runnables.graph import Graph # noqa: PLC0415
graph = Graph() graph = Graph()
try: try:
input_node = graph.add_node(self.get_input_schema(config)) input_node = graph.add_node(self.get_input_schema(config))
@@ -585,7 +602,8 @@ class Runnable(ABC, Generic[Input, Output]):
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> list[BasePromptTemplate]: ) -> list[BasePromptTemplate]:
"""Return a list of prompts used by this ``Runnable``.""" """Return a list of prompts used by this ``Runnable``."""
from langchain_core.prompts.base import BasePromptTemplate # Import locally to prevent circular import
from langchain_core.prompts.base import BasePromptTemplate # noqa: PLC0415
return [ return [
node.data node.data
@@ -747,7 +765,8 @@ class Runnable(ABC, Generic[Input, Output]):
a new ``Runnable``. a new ``Runnable``.
""" """
from langchain_core.runnables.passthrough import RunnablePick # Import locally to prevent circular import
from langchain_core.runnables.passthrough import RunnablePick # noqa: PLC0415
return self | RunnablePick(keys) return self | RunnablePick(keys)
@@ -798,7 +817,8 @@ class Runnable(ABC, Generic[Input, Output]):
A new ``Runnable``. A new ``Runnable``.
""" """
from langchain_core.runnables.passthrough import RunnableAssign # Import locally to prevent circular import
from langchain_core.runnables.passthrough import RunnableAssign # noqa: PLC0415
return self | RunnableAssign(RunnableParallel[dict[str, Any]](kwargs)) return self | RunnableAssign(RunnableParallel[dict[str, Any]](kwargs))
@@ -1231,11 +1251,6 @@ class Runnable(ABC, Generic[Input, Output]):
A ``RunLogPatch`` or ``RunLog`` object. A ``RunLogPatch`` or ``RunLog`` object.
""" """
from langchain_core.tracers.log_stream import (
LogStreamCallbackHandler,
_astream_log_implementation,
)
stream = LogStreamCallbackHandler( stream = LogStreamCallbackHandler(
auto_close=False, auto_close=False,
include_names=include_names, include_names=include_names,
@@ -1489,11 +1504,6 @@ class Runnable(ABC, Generic[Input, Output]):
NotImplementedError: If the version is not ``'v1'`` or ``'v2'``. NotImplementedError: If the version is not ``'v1'`` or ``'v2'``.
""" # noqa: E501 """ # noqa: E501
from langchain_core.tracers.event_stream import (
_astream_events_implementation_v1,
_astream_events_implementation_v2,
)
if version == "v2": if version == "v2":
event_stream = _astream_events_implementation_v2( event_stream = _astream_events_implementation_v2(
self, self,
@@ -1740,8 +1750,6 @@ class Runnable(ABC, Generic[Input, Output]):
chain.invoke(2) chain.invoke(2)
""" """
from langchain_core.tracers.root_listeners import RootListenersTracer
return RunnableBinding( return RunnableBinding(
bound=self, bound=self,
config_factories=[ config_factories=[
@@ -1834,8 +1842,6 @@ class Runnable(ABC, Generic[Input, Output]):
on end callback ends at 2025-03-01T07:05:30.884831+00:00 on end callback ends at 2025-03-01T07:05:30.884831+00:00
""" """
from langchain_core.tracers.root_listeners import AsyncRootListenersTracer
return RunnableBinding( return RunnableBinding(
bound=self, bound=self,
config_factories=[ config_factories=[
@@ -1928,7 +1934,8 @@ class Runnable(ABC, Generic[Input, Output]):
assert count == 2 assert count == 2
""" """
from langchain_core.runnables.retry import RunnableRetry # Import locally to prevent circular import
from langchain_core.runnables.retry import RunnableRetry # noqa: PLC0415
return RunnableRetry( return RunnableRetry(
bound=self, bound=self,
@@ -2030,7 +2037,10 @@ class Runnable(ABC, Generic[Input, Output]):
fallback in order, upon failures. fallback in order, upon failures.
""" """
from langchain_core.runnables.fallbacks import RunnableWithFallbacks # Import locally to prevent circular import
from langchain_core.runnables.fallbacks import ( # noqa: PLC0415
RunnableWithFallbacks,
)
return RunnableWithFallbacks( return RunnableWithFallbacks(
runnable=self, runnable=self,
@@ -2316,9 +2326,6 @@ class Runnable(ABC, Generic[Input, Output]):
Use this to implement ``stream`` or ``transform`` in ``Runnable`` subclasses. Use this to implement ``stream`` or ``transform`` in ``Runnable`` subclasses.
""" """
# Mixin that is used by both astream log and astream events implementation
from langchain_core.tracers._streaming import _StreamingCallbackHandler
# tee the input so we can iterate over it twice # tee the input so we can iterate over it twice
input_for_tracing, input_for_transform = tee(inputs, 2) input_for_tracing, input_for_transform = tee(inputs, 2)
# Start the input iterator to ensure the input Runnable starts before this one # Start the input iterator to ensure the input Runnable starts before this one
@@ -2422,9 +2429,6 @@ class Runnable(ABC, Generic[Input, Output]):
Use this to implement ``astream`` or ``atransform`` in ``Runnable`` subclasses. Use this to implement ``astream`` or ``atransform`` in ``Runnable`` subclasses.
""" """
# Mixin that is used by both astream log and astream events implementation
from langchain_core.tracers._streaming import _StreamingCallbackHandler
# tee the input so we can iterate over it twice # tee the input so we can iterate over it twice
input_for_tracing, input_for_transform = atee(inputs, 2) input_for_tracing, input_for_transform = atee(inputs, 2)
# Start the input iterator to ensure the input Runnable starts before this one # Start the input iterator to ensure the input Runnable starts before this one
@@ -2614,7 +2618,7 @@ class Runnable(ABC, Generic[Input, Output]):
""" """
# Avoid circular import # Avoid circular import
from langchain_core.tools import convert_runnable_to_tool from langchain_core.tools import convert_runnable_to_tool # noqa: PLC0415
return convert_runnable_to_tool( return convert_runnable_to_tool(
self, self,
@@ -2690,7 +2694,10 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]):
) )
""" """
from langchain_core.runnables.configurable import RunnableConfigurableFields # Import locally to prevent circular import
from langchain_core.runnables.configurable import ( # noqa: PLC0415
RunnableConfigurableFields,
)
model_fields = type(self).model_fields model_fields = type(self).model_fields
for key in kwargs: for key in kwargs:
@@ -2751,7 +2758,8 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]):
) )
""" """
from langchain_core.runnables.configurable import ( # Import locally to prevent circular import
from langchain_core.runnables.configurable import ( # noqa: PLC0415
RunnableConfigurableAlternatives, RunnableConfigurableAlternatives,
) )
@@ -2767,7 +2775,11 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]):
def _seq_input_schema( def _seq_input_schema(
steps: list[Runnable[Any, Any]], config: Optional[RunnableConfig] steps: list[Runnable[Any, Any]], config: Optional[RunnableConfig]
) -> type[BaseModel]: ) -> type[BaseModel]:
from langchain_core.runnables.passthrough import RunnableAssign, RunnablePick # Import locally to prevent circular import
from langchain_core.runnables.passthrough import ( # noqa: PLC0415
RunnableAssign,
RunnablePick,
)
first = steps[0] first = steps[0]
if len(steps) == 1: if len(steps) == 1:
@@ -2793,7 +2805,11 @@ def _seq_input_schema(
def _seq_output_schema( def _seq_output_schema(
steps: list[Runnable[Any, Any]], config: Optional[RunnableConfig] steps: list[Runnable[Any, Any]], config: Optional[RunnableConfig]
) -> type[BaseModel]: ) -> type[BaseModel]:
from langchain_core.runnables.passthrough import RunnableAssign, RunnablePick # Import locally to prevent circular import
from langchain_core.runnables.passthrough import ( # noqa: PLC0415
RunnableAssign,
RunnablePick,
)
last = steps[-1] last = steps[-1]
if len(steps) == 1: if len(steps) == 1:
@@ -3050,7 +3066,8 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
The config specs of the ``Runnable``. The config specs of the ``Runnable``.
""" """
from langchain_core.beta.runnables.context import ( # Import locally to prevent circular import
from langchain_core.beta.runnables.context import ( # noqa: PLC0415
CONTEXT_CONFIG_PREFIX, CONTEXT_CONFIG_PREFIX,
_key_from_id, _key_from_id,
) )
@@ -3108,7 +3125,8 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
ValueError: If a ``Runnable`` has no first or last node. ValueError: If a ``Runnable`` has no first or last node.
""" """
from langchain_core.runnables.graph import Graph # Import locally to prevent circular import
from langchain_core.runnables.graph import Graph # noqa: PLC0415
graph = Graph() graph = Graph()
for step in self.steps: for step in self.steps:
@@ -3196,7 +3214,10 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
def invoke( def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output: ) -> Output:
from langchain_core.beta.runnables.context import config_with_context # Import locally to prevent circular import
from langchain_core.beta.runnables.context import ( # noqa: PLC0415
config_with_context,
)
# setup callbacks and context # setup callbacks and context
config = config_with_context(ensure_config(config), self.steps) config = config_with_context(ensure_config(config), self.steps)
@@ -3237,7 +3258,10 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> Output: ) -> Output:
from langchain_core.beta.runnables.context import aconfig_with_context # Import locally to prevent circular import
from langchain_core.beta.runnables.context import ( # noqa: PLC0415
aconfig_with_context,
)
# setup callbacks and context # setup callbacks and context
config = aconfig_with_context(ensure_config(config), self.steps) config = aconfig_with_context(ensure_config(config), self.steps)
@@ -3281,8 +3305,10 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
return_exceptions: bool = False, return_exceptions: bool = False,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> list[Output]: ) -> list[Output]:
from langchain_core.beta.runnables.context import config_with_context # Import locally to prevent circular import
from langchain_core.callbacks.manager import CallbackManager from langchain_core.beta.runnables.context import ( # noqa: PLC0415
config_with_context,
)
if not inputs: if not inputs:
return [] return []
@@ -3411,8 +3437,10 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
return_exceptions: bool = False, return_exceptions: bool = False,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> list[Output]: ) -> list[Output]:
from langchain_core.beta.runnables.context import aconfig_with_context # Import locally to prevent circular import
from langchain_core.callbacks.manager import AsyncCallbackManager from langchain_core.beta.runnables.context import ( # noqa: PLC0415
aconfig_with_context,
)
if not inputs: if not inputs:
return [] return []
@@ -3542,7 +3570,10 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
config: RunnableConfig, config: RunnableConfig,
**kwargs: Any, **kwargs: Any,
) -> Iterator[Output]: ) -> Iterator[Output]:
from langchain_core.beta.runnables.context import config_with_context # Import locally to prevent circular import
from langchain_core.beta.runnables.context import ( # noqa: PLC0415
config_with_context,
)
steps = [self.first, *self.middle, self.last] steps = [self.first, *self.middle, self.last]
config = config_with_context(config, self.steps) config = config_with_context(config, self.steps)
@@ -3569,7 +3600,10 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
config: RunnableConfig, config: RunnableConfig,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[Output]: ) -> AsyncIterator[Output]:
from langchain_core.beta.runnables.context import aconfig_with_context # Import locally to prevent circular import
from langchain_core.beta.runnables.context import ( # noqa: PLC0415
aconfig_with_context,
)
steps = [self.first, *self.middle, self.last] steps = [self.first, *self.middle, self.last]
config = aconfig_with_context(config, self.steps) config = aconfig_with_context(config, self.steps)
@@ -3882,7 +3916,8 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
ValueError: If a ``Runnable`` has no first or last node. ValueError: If a ``Runnable`` has no first or last node.
""" """
from langchain_core.runnables.graph import Graph # Import locally to prevent circular import
from langchain_core.runnables.graph import Graph # noqa: PLC0415
graph = Graph() graph = Graph()
input_node = graph.add_node(self.get_input_schema(config)) input_node = graph.add_node(self.get_input_schema(config))
@@ -3918,8 +3953,6 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
def invoke( def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> dict[str, Any]: ) -> dict[str, Any]:
from langchain_core.callbacks.manager import CallbackManager
# setup callbacks # setup callbacks
config = ensure_config(config) config = ensure_config(config)
callback_manager = CallbackManager.configure( callback_manager = CallbackManager.configure(
@@ -4767,6 +4800,9 @@ class RunnableLambda(Runnable[Input, Output]):
@override @override
def get_graph(self, config: RunnableConfig | None = None) -> Graph: def get_graph(self, config: RunnableConfig | None = None) -> Graph:
if deps := self.deps: if deps := self.deps:
# Import locally to prevent circular import
from langchain_core.runnables.graph import Graph # noqa: PLC0415
graph = Graph() graph = Graph()
input_node = graph.add_node(self.get_input_schema(config)) input_node = graph.add_node(self.get_input_schema(config))
output_node = graph.add_node(self.get_output_schema(config)) output_node = graph.add_node(self.get_output_schema(config))
@@ -6030,7 +6066,6 @@ class RunnableBinding(RunnableBindingBase[Input, Output]): # type: ignore[no-re
Returns: Returns:
A new ``Runnable`` with the listeners bound. A new ``Runnable`` with the listeners bound.
""" """
from langchain_core.tracers.root_listeners import RootListenersTracer
def listener_config_factory(config: RunnableConfig) -> RunnableConfig: def listener_config_factory(config: RunnableConfig) -> RunnableConfig:
return { return {

View File

@@ -12,6 +12,10 @@ from typing import (
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from typing_extensions import override from typing_extensions import override
from langchain_core.beta.runnables.context import (
CONTEXT_CONFIG_PREFIX,
CONTEXT_CONFIG_SUFFIX_SET,
)
from langchain_core.runnables.base import ( from langchain_core.runnables.base import (
Runnable, Runnable,
RunnableLike, RunnableLike,
@@ -177,11 +181,6 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
@property @property
@override @override
def config_specs(self) -> list[ConfigurableFieldSpec]: def config_specs(self) -> list[ConfigurableFieldSpec]:
from langchain_core.beta.runnables.context import (
CONTEXT_CONFIG_PREFIX,
CONTEXT_CONFIG_SUFFIX_SET,
)
specs = get_unique_config_specs( specs = get_unique_config_specs(
spec spec
for step in ( for step in (

View File

@@ -12,21 +12,22 @@ from contextvars import Context, ContextVar, Token, copy_context
from functools import partial from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union, cast from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union, cast
from langsmith.run_helpers import _set_tracing_context, get_tracing_context
from typing_extensions import ParamSpec, TypedDict from typing_extensions import ParamSpec, TypedDict
from langchain_core.callbacks.manager import AsyncCallbackManager, CallbackManager
from langchain_core.runnables.utils import ( from langchain_core.runnables.utils import (
Input, Input,
Output, Output,
accepts_config, accepts_config,
accepts_run_manager, accepts_run_manager,
) )
from langchain_core.tracers.langchain import LangChainTracer
if TYPE_CHECKING: if TYPE_CHECKING:
from langchain_core.callbacks.base import BaseCallbackManager, Callbacks from langchain_core.callbacks.base import BaseCallbackManager, Callbacks
from langchain_core.callbacks.manager import ( from langchain_core.callbacks.manager import (
AsyncCallbackManager,
AsyncCallbackManagerForChainRun, AsyncCallbackManagerForChainRun,
CallbackManager,
CallbackManagerForChainRun, CallbackManagerForChainRun,
) )
else: else:
@@ -129,8 +130,6 @@ def _set_config_context(
Returns: Returns:
The token to reset the config and the previous tracing context. The token to reset the config and the previous tracing context.
""" """
from langchain_core.tracers.langchain import LangChainTracer
config_token = var_child_runnable_config.set(config) config_token = var_child_runnable_config.set(config)
current_context = None current_context = None
if ( if (
@@ -150,8 +149,6 @@ def _set_config_context(
) )
and (run := tracer.run_map.get(str(parent_run_id))) and (run := tracer.run_map.get(str(parent_run_id)))
): ):
from langsmith.run_helpers import _set_tracing_context, get_tracing_context
current_context = get_tracing_context() current_context = get_tracing_context()
_set_tracing_context({"parent": run}) _set_tracing_context({"parent": run})
return config_token, current_context return config_token, current_context
@@ -167,8 +164,6 @@ def set_config_context(config: RunnableConfig) -> Generator[Context, None, None]
Yields: Yields:
The config context. The config context.
""" """
from langsmith.run_helpers import _set_tracing_context
ctx = copy_context() ctx = copy_context()
config_token, _ = ctx.run(_set_config_context, config) config_token, _ = ctx.run(_set_config_context, config)
try: try:
@@ -481,8 +476,6 @@ def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager:
Returns: Returns:
CallbackManager: The callback manager. CallbackManager: The callback manager.
""" """
from langchain_core.callbacks.manager import CallbackManager
return CallbackManager.configure( return CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"), inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"), inheritable_tags=config.get("tags"),
@@ -501,8 +494,6 @@ def get_async_callback_manager_for_config(
Returns: Returns:
AsyncCallbackManager: The async callback manager. AsyncCallbackManager: The async callback manager.
""" """
from langchain_core.callbacks.manager import AsyncCallbackManager
return AsyncCallbackManager.configure( return AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"), inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"), inheritable_tags=config.get("tags"),

View File

@@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, Any, Optional, Union, cast
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from typing_extensions import override from typing_extensions import override
from langchain_core.callbacks.manager import AsyncCallbackManager, CallbackManager
from langchain_core.runnables.base import Runnable, RunnableSerializable from langchain_core.runnables.base import Runnable, RunnableSerializable
from langchain_core.runnables.config import ( from langchain_core.runnables.config import (
RunnableConfig, RunnableConfig,
@@ -272,8 +273,6 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
return_exceptions: bool = False, return_exceptions: bool = False,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> list[Output]: ) -> list[Output]:
from langchain_core.callbacks.manager import CallbackManager
if self.exception_key is not None and not all( if self.exception_key is not None and not all(
isinstance(input_, dict) for input_ in inputs isinstance(input_, dict) for input_ in inputs
): ):
@@ -366,8 +365,6 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
return_exceptions: bool = False, return_exceptions: bool = False,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> list[Output]: ) -> list[Output]:
from langchain_core.callbacks.manager import AsyncCallbackManager
if self.exception_key is not None and not all( if self.exception_key is not None and not all(
isinstance(input_, dict) for input_ in inputs isinstance(input_, dict) for input_ in inputs
): ):

View File

@@ -19,6 +19,8 @@ from typing import (
) )
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from langchain_core.load.serializable import to_json_not_implemented
from langchain_core.runnables.base import Runnable, RunnableSerializable
from langchain_core.utils.pydantic import _IgnoreUnserializable, is_basemodel_subclass from langchain_core.utils.pydantic import _IgnoreUnserializable, is_basemodel_subclass
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -191,8 +193,6 @@ def node_data_str(
Returns: Returns:
A string representation of the data. A string representation of the data.
""" """
from langchain_core.runnables.base import Runnable
if not is_uuid(id) or data is None: if not is_uuid(id) or data is None:
return id return id
data_str = data.get_name() if isinstance(data, Runnable) else data.__name__ data_str = data.get_name() if isinstance(data, Runnable) else data.__name__
@@ -212,9 +212,6 @@ def node_data_json(
Returns: Returns:
A dictionary with the type of the data and the data itself. A dictionary with the type of the data and the data itself.
""" """
from langchain_core.load.serializable import to_json_not_implemented
from langchain_core.runnables.base import Runnable, RunnableSerializable
if node.data is None: if node.data is None:
json: dict[str, Any] = {} json: dict[str, Any] = {}
elif isinstance(node.data, RunnableSerializable): elif isinstance(node.data, RunnableSerializable):
@@ -518,7 +515,8 @@ class Graph:
Returns: Returns:
The ASCII art string. The ASCII art string.
""" """
from langchain_core.runnables.graph_ascii import draw_ascii # Import locally to prevent circular import
from langchain_core.runnables.graph_ascii import draw_ascii # noqa: PLC0415
return draw_ascii( return draw_ascii(
{node.id: node.name for node in self.nodes.values()}, {node.id: node.name for node in self.nodes.values()},
@@ -562,7 +560,8 @@ class Graph:
Returns: Returns:
The PNG image as bytes if output_file_path is None, None otherwise. The PNG image as bytes if output_file_path is None, None otherwise.
""" """
from langchain_core.runnables.graph_png import PngDrawer # Import locally to prevent circular import
from langchain_core.runnables.graph_png import PngDrawer # noqa: PLC0415
default_node_labels = {node.id: node.name for node in self.nodes.values()} default_node_labels = {node.id: node.name for node in self.nodes.values()}
@@ -617,7 +616,8 @@ class Graph:
The Mermaid syntax string. The Mermaid syntax string.
""" """
from langchain_core.runnables.graph_mermaid import draw_mermaid # Import locally to prevent circular import
from langchain_core.runnables.graph_mermaid import draw_mermaid # noqa: PLC0415
graph = self.reid() graph = self.reid()
first_node = graph.first_node() first_node = graph.first_node()
@@ -688,7 +688,10 @@ class Graph:
The PNG image as bytes. The PNG image as bytes.
""" """
from langchain_core.runnables.graph_mermaid import draw_mermaid_png # Import locally to prevent circular import
from langchain_core.runnables.graph_mermaid import ( # noqa: PLC0415
draw_mermaid_png,
)
mermaid_syntax = self.draw_mermaid( mermaid_syntax = self.draw_mermaid(
curve_style=curve_style, curve_style=curve_style,

View File

@@ -3,12 +3,24 @@
Adapted from https://github.com/iterative/dvc/blob/main/dvc/dagascii.py. Adapted from https://github.com/iterative/dvc/blob/main/dvc/dagascii.py.
""" """
from __future__ import annotations
import math import math
import os import os
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from typing import Any from typing import TYPE_CHECKING, Any
from langchain_core.runnables.graph import Edge as LangEdge try:
from grandalf.graphs import Edge, Graph, Vertex # type: ignore[import-untyped]
from grandalf.layouts import SugiyamaLayout # type: ignore[import-untyped]
from grandalf.routing import route_with_lines # type: ignore[import-untyped]
_HAS_GRANDALF = True
except ImportError:
_HAS_GRANDALF = False
if TYPE_CHECKING:
from langchain_core.runnables.graph import Edge as LangEdge
class VertexViewer: class VertexViewer:
@@ -185,13 +197,9 @@ class _EdgeViewer:
def _build_sugiyama_layout( def _build_sugiyama_layout(
vertices: Mapping[str, str], edges: Sequence[LangEdge] vertices: Mapping[str, str], edges: Sequence[LangEdge]
) -> Any: ) -> Any:
try: if not _HAS_GRANDALF:
from grandalf.graphs import Edge, Graph, Vertex # type: ignore[import-untyped]
from grandalf.layouts import SugiyamaLayout # type: ignore[import-untyped]
from grandalf.routing import route_with_lines # type: ignore[import-untyped]
except ImportError as exc:
msg = "Install grandalf to draw graphs: `pip install grandalf`." msg = "Install grandalf to draw graphs: `pip install grandalf`."
raise ImportError(msg) from exc raise ImportError(msg)
# #
# Just a reminder about naming conventions: # Just a reminder about naming conventions:

View File

@@ -1,5 +1,7 @@
"""Mermaid graph drawing utilities.""" """Mermaid graph drawing utilities."""
from __future__ import annotations
import asyncio import asyncio
import base64 import base64
import random import random
@@ -7,18 +9,34 @@ import re
import time import time
from dataclasses import asdict from dataclasses import asdict
from pathlib import Path from pathlib import Path
from typing import Any, Literal, Optional from typing import TYPE_CHECKING, Any, Literal, Optional
import yaml import yaml
from langchain_core.runnables.graph import ( from langchain_core.runnables.graph import (
CurveStyle, CurveStyle,
Edge,
MermaidDrawMethod, MermaidDrawMethod,
Node,
NodeStyles, NodeStyles,
) )
if TYPE_CHECKING:
from langchain_core.runnables.graph import Edge, Node
try:
import requests
_HAS_REQUESTS = True
except ImportError:
_HAS_REQUESTS = False
try:
from pyppeteer import launch # type: ignore[import-not-found]
_HAS_PYPPETEER = True
except ImportError:
_HAS_PYPPETEER = False
MARKDOWN_SPECIAL_CHARS = "*_`" MARKDOWN_SPECIAL_CHARS = "*_`"
@@ -283,8 +301,6 @@ def draw_mermaid_png(
ValueError: If an invalid draw method is provided. ValueError: If an invalid draw method is provided.
""" """
if draw_method == MermaidDrawMethod.PYPPETEER: if draw_method == MermaidDrawMethod.PYPPETEER:
import asyncio
img_bytes = asyncio.run( img_bytes = asyncio.run(
_render_mermaid_using_pyppeteer( _render_mermaid_using_pyppeteer(
mermaid_syntax, output_file_path, background_color, padding mermaid_syntax, output_file_path, background_color, padding
@@ -317,11 +333,9 @@ async def _render_mermaid_using_pyppeteer(
device_scale_factor: int = 3, device_scale_factor: int = 3,
) -> bytes: ) -> bytes:
"""Renders Mermaid graph using Pyppeteer.""" """Renders Mermaid graph using Pyppeteer."""
try: if not _HAS_PYPPETEER:
from pyppeteer import launch # type: ignore[import-not-found]
except ImportError as e:
msg = "Install Pyppeteer to use the Pyppeteer method: `pip install pyppeteer`." msg = "Install Pyppeteer to use the Pyppeteer method: `pip install pyppeteer`."
raise ImportError(msg) from e raise ImportError(msg)
browser = await launch() browser = await launch()
page = await browser.newPage() page = await browser.newPage()
@@ -392,14 +406,12 @@ def _render_mermaid_using_api(
retry_delay: float = 1.0, retry_delay: float = 1.0,
) -> bytes: ) -> bytes:
"""Renders Mermaid graph using the Mermaid.INK API.""" """Renders Mermaid graph using the Mermaid.INK API."""
try: if not _HAS_REQUESTS:
import requests
except ImportError as e:
msg = ( msg = (
"Install the `requests` module to use the Mermaid.INK API: " "Install the `requests` module to use the Mermaid.INK API: "
"`pip install requests`." "`pip install requests`."
) )
raise ImportError(msg) from e raise ImportError(msg)
# Use Mermaid API to render the image # Use Mermaid API to render the image
mermaid_syntax_encoded = base64.b64encode(mermaid_syntax.encode("utf8")).decode( mermaid_syntax_encoded = base64.b64encode(mermaid_syntax.encode("utf8")).decode(

View File

@@ -4,6 +4,13 @@ from typing import Any, Optional
from langchain_core.runnables.graph import Graph, LabelsDict from langchain_core.runnables.graph import Graph, LabelsDict
try:
import pygraphviz as pgv # type: ignore[import-not-found]
_HAS_PYGRAPHVIZ = True
except ImportError:
_HAS_PYGRAPHVIZ = False
class PngDrawer: class PngDrawer:
"""Helper class to draw a state graph into a PNG file. """Helper class to draw a state graph into a PNG file.
@@ -125,11 +132,9 @@ class PngDrawer:
Returns: Returns:
The PNG bytes if ``output_path`` is None, else None. The PNG bytes if ``output_path`` is None, else None.
""" """
try: if not _HAS_PYGRAPHVIZ:
import pygraphviz as pgv # type: ignore[import-not-found]
except ImportError as exc:
msg = "Install pygraphviz to draw graphs: `pip install pygraphviz`." msg = "Install pygraphviz to draw graphs: `pip install pygraphviz`."
raise ImportError(msg) from exc raise ImportError(msg)
# Create a directed graph # Create a directed graph
viz = pgv.AGraph(directed=True, nodesep=0.9, ranksep=1.0) viz = pgv.AGraph(directed=True, nodesep=0.9, ranksep=1.0)

View File

@@ -18,6 +18,7 @@ from typing_extensions import override
from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.load.load import load from langchain_core.load.load import load
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.runnables.base import Runnable, RunnableBindingBase, RunnableLambda from langchain_core.runnables.base import Runnable, RunnableBindingBase, RunnableLambda
from langchain_core.runnables.passthrough import RunnablePassthrough from langchain_core.runnables.passthrough import RunnablePassthrough
from langchain_core.runnables.utils import ( from langchain_core.runnables.utils import (
@@ -29,7 +30,6 @@ from langchain_core.utils.pydantic import create_model_v2
if TYPE_CHECKING: if TYPE_CHECKING:
from langchain_core.language_models.base import LanguageModelLike from langchain_core.language_models.base import LanguageModelLike
from langchain_core.messages.base import BaseMessage
from langchain_core.runnables.config import RunnableConfig from langchain_core.runnables.config import RunnableConfig
from langchain_core.tracers.schemas import Run from langchain_core.tracers.schemas import Run
@@ -384,8 +384,6 @@ class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef]
def get_input_schema( def get_input_schema(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> type[BaseModel]: ) -> type[BaseModel]:
from langchain_core.messages import BaseMessage
fields: dict = {} fields: dict = {}
if self.input_messages_key and self.history_messages_key: if self.input_messages_key and self.history_messages_key:
fields[self.input_messages_key] = ( fields[self.input_messages_key] = (
@@ -447,8 +445,6 @@ class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef]
def _get_input_messages( def _get_input_messages(
self, input_val: Union[str, BaseMessage, Sequence[BaseMessage], dict] self, input_val: Union[str, BaseMessage, Sequence[BaseMessage], dict]
) -> list[BaseMessage]: ) -> list[BaseMessage]:
from langchain_core.messages import BaseMessage
# If dictionary, try to pluck the single key representing messages # If dictionary, try to pluck the single key representing messages
if isinstance(input_val, dict): if isinstance(input_val, dict):
if self.input_messages_key: if self.input_messages_key:
@@ -461,8 +457,6 @@ class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef]
# If value is a string, convert to a human message # If value is a string, convert to a human message
if isinstance(input_val, str): if isinstance(input_val, str):
from langchain_core.messages import HumanMessage
return [HumanMessage(content=input_val)] return [HumanMessage(content=input_val)]
# If value is a single message, convert to a list # If value is a single message, convert to a list
if isinstance(input_val, BaseMessage): if isinstance(input_val, BaseMessage):
@@ -489,8 +483,6 @@ class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef]
def _get_output_messages( def _get_output_messages(
self, output_val: Union[str, BaseMessage, Sequence[BaseMessage], dict] self, output_val: Union[str, BaseMessage, Sequence[BaseMessage], dict]
) -> list[BaseMessage]: ) -> list[BaseMessage]:
from langchain_core.messages import BaseMessage
# If dictionary, try to pluck the single key representing messages # If dictionary, try to pluck the single key representing messages
if isinstance(output_val, dict): if isinstance(output_val, dict):
if self.output_messages_key: if self.output_messages_key:
@@ -507,8 +499,6 @@ class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef]
output_val = output_val[key] output_val = output_val[key]
if isinstance(output_val, str): if isinstance(output_val, str):
from langchain_core.messages import AIMessage
return [AIMessage(content=output_val)] return [AIMessage(content=output_val)]
# If value is a single message, convert to a list # If value is a single message, convert to a list
if isinstance(output_val, BaseMessage): if isinstance(output_val, BaseMessage):

View File

@@ -4,13 +4,15 @@ sys_info prints information about the system and langchain packages for
debugging purposes. debugging purposes.
""" """
import pkgutil
import platform
import sys
from collections.abc import Sequence from collections.abc import Sequence
from importlib import metadata, util
def _get_sub_deps(packages: Sequence[str]) -> list[str]: def _get_sub_deps(packages: Sequence[str]) -> list[str]:
"""Get any specified sub-dependencies.""" """Get any specified sub-dependencies."""
from importlib import metadata
sub_deps = set() sub_deps = set()
underscored_packages = {pkg.replace("-", "_") for pkg in packages} underscored_packages = {pkg.replace("-", "_") for pkg in packages}
@@ -37,11 +39,6 @@ def print_sys_info(*, additional_pkgs: Sequence[str] = ()) -> None:
Args: Args:
additional_pkgs: Additional packages to include in the output. additional_pkgs: Additional packages to include in the output.
""" """
import pkgutil
import platform
import sys
from importlib import metadata, util
# Packages that do not start with "langchain" prefix. # Packages that do not start with "langchain" prefix.
other_langchain_packages = [ other_langchain_packages = [
"langserve", "langserve",

View File

@@ -18,13 +18,14 @@ from uuid import UUID, uuid4
from typing_extensions import NotRequired, override from typing_extensions import NotRequired, override
from langchain_core.callbacks.base import AsyncCallbackHandler from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackManager
from langchain_core.messages import AIMessageChunk, BaseMessage, BaseMessageChunk from langchain_core.messages import AIMessageChunk, BaseMessage, BaseMessageChunk
from langchain_core.outputs import ( from langchain_core.outputs import (
ChatGenerationChunk, ChatGenerationChunk,
GenerationChunk, GenerationChunk,
LLMResult, LLMResult,
) )
from langchain_core.runnables import ensure_config
from langchain_core.runnables.schema import ( from langchain_core.runnables.schema import (
CustomStreamEvent, CustomStreamEvent,
EventData, EventData,
@@ -37,6 +38,11 @@ from langchain_core.runnables.utils import (
_RootEventFilter, _RootEventFilter,
) )
from langchain_core.tracers._streaming import _StreamingCallbackHandler from langchain_core.tracers._streaming import _StreamingCallbackHandler
from langchain_core.tracers.log_stream import (
LogStreamCallbackHandler,
RunLog,
_astream_log_implementation,
)
from langchain_core.tracers.memory_stream import _MemoryStream from langchain_core.tracers.memory_stream import _MemoryStream
from langchain_core.utils.aiter import aclosing, py_anext from langchain_core.utils.aiter import aclosing, py_anext
@@ -769,14 +775,6 @@ async def _astream_events_implementation_v1(
exclude_tags: Optional[Sequence[str]] = None, exclude_tags: Optional[Sequence[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[StandardStreamEvent]: ) -> AsyncIterator[StandardStreamEvent]:
from langchain_core.runnables import ensure_config
from langchain_core.runnables.utils import _RootEventFilter
from langchain_core.tracers.log_stream import (
LogStreamCallbackHandler,
RunLog,
_astream_log_implementation,
)
stream = LogStreamCallbackHandler( stream = LogStreamCallbackHandler(
auto_close=False, auto_close=False,
include_names=include_names, include_names=include_names,
@@ -954,9 +952,6 @@ async def _astream_events_implementation_v2(
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[StandardStreamEvent]: ) -> AsyncIterator[StandardStreamEvent]:
"""Implementation of the astream events API for V2 runnables.""" """Implementation of the astream events API for V2 runnables."""
from langchain_core.callbacks.base import BaseCallbackManager
from langchain_core.runnables import ensure_config
event_streamer = _AstreamEventsCallbackHandler( event_streamer = _AstreamEventsCallbackHandler(
include_names=include_names, include_names=include_names,
include_types=include_types, include_types=include_types,

View File

@@ -7,6 +7,7 @@ import contextlib
import copy import copy
import threading import threading
from collections import defaultdict from collections import defaultdict
from pprint import pformat
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
@@ -20,10 +21,11 @@ from typing import (
import jsonpatch # type: ignore[import-untyped] import jsonpatch # type: ignore[import-untyped]
from typing_extensions import NotRequired, TypedDict, override from typing_extensions import NotRequired, TypedDict, override
from langchain_core.callbacks.base import BaseCallbackManager
from langchain_core.load import dumps from langchain_core.load import dumps
from langchain_core.load.load import load from langchain_core.load.load import load
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
from langchain_core.runnables import Runnable, RunnableConfig, ensure_config from langchain_core.runnables import RunnableConfig, ensure_config
from langchain_core.tracers._streaming import _StreamingCallbackHandler from langchain_core.tracers._streaming import _StreamingCallbackHandler
from langchain_core.tracers.base import BaseTracer from langchain_core.tracers.base import BaseTracer
from langchain_core.tracers.memory_stream import _MemoryStream from langchain_core.tracers.memory_stream import _MemoryStream
@@ -32,6 +34,7 @@ if TYPE_CHECKING:
from collections.abc import AsyncIterator, Iterator, Sequence from collections.abc import AsyncIterator, Iterator, Sequence
from uuid import UUID from uuid import UUID
from langchain_core.runnables import Runnable
from langchain_core.runnables.utils import Input, Output from langchain_core.runnables.utils import Input, Output
from langchain_core.tracers.schemas import Run from langchain_core.tracers.schemas import Run
@@ -131,8 +134,6 @@ class RunLogPatch:
@override @override
def __repr__(self) -> str: def __repr__(self) -> str:
from pprint import pformat
# 1:-1 to get rid of the [] around the list # 1:-1 to get rid of the [] around the list
return f"RunLogPatch({pformat(self.ops)[1:-1]})" return f"RunLogPatch({pformat(self.ops)[1:-1]})"
@@ -181,8 +182,6 @@ class RunLog(RunLogPatch):
@override @override
def __repr__(self) -> str: def __repr__(self) -> str:
from pprint import pformat
return f"RunLog({pformat(self.state)})" return f"RunLog({pformat(self.state)})"
@override @override
@@ -672,14 +671,6 @@ async def _astream_log_implementation(
Yields: Yields:
The run log patches or states, depending on the value of ``diff``. The run log patches or states, depending on the value of ``diff``.
""" """
import jsonpatch
from langchain_core.callbacks.base import BaseCallbackManager
from langchain_core.tracers.log_stream import (
RunLog,
RunLogPatch,
)
# Assign the stream handler to the config # Assign the stream handler to the config
config = ensure_config(config) config = ensure_config(config)
callbacks = config.get("callbacks") callbacks = config.get("callbacks")

View File

@@ -21,8 +21,10 @@ from typing import (
from pydantic import BaseModel from pydantic import BaseModel
from pydantic.v1 import BaseModel as BaseModelV1 from pydantic.v1 import BaseModel as BaseModelV1
from pydantic.v1 import Field, create_model
from typing_extensions import TypedDict, get_args, get_origin, is_typeddict from typing_extensions import TypedDict, get_args, get_origin, is_typeddict
import langchain_core
from langchain_core._api import beta, deprecated from langchain_core._api import beta, deprecated
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
from langchain_core.utils.json_schema import dereference_refs from langchain_core.utils.json_schema import dereference_refs
@@ -220,10 +222,8 @@ def _convert_python_function_to_openai_function(
Returns: Returns:
The OpenAI function description. The OpenAI function description.
""" """
from langchain_core.tools.base import create_schema_from_function
func_name = _get_python_function_name(function) func_name = _get_python_function_name(function)
model = create_schema_from_function( model = langchain_core.tools.base.create_schema_from_function(
func_name, func_name,
function, function,
filter_args=(), filter_args=(),
@@ -264,9 +264,6 @@ def _convert_any_typed_dicts_to_pydantic(
visited: dict, visited: dict,
depth: int = 0, depth: int = 0,
) -> type: ) -> type:
from pydantic.v1 import Field as Field_v1
from pydantic.v1 import create_model as create_model_v1
if type_ in visited: if type_ in visited:
return visited[type_] return visited[type_]
if depth >= _MAX_TYPED_DICT_RECURSION: if depth >= _MAX_TYPED_DICT_RECURSION:
@@ -297,7 +294,7 @@ def _convert_any_typed_dicts_to_pydantic(
raise ValueError(msg) raise ValueError(msg)
if arg_desc := arg_descriptions.get(arg): if arg_desc := arg_descriptions.get(arg):
field_kwargs["description"] = arg_desc field_kwargs["description"] = arg_desc
fields[arg] = (new_arg_type, Field_v1(**field_kwargs)) fields[arg] = (new_arg_type, Field(**field_kwargs))
else: else:
new_arg_type = _convert_any_typed_dicts_to_pydantic( new_arg_type = _convert_any_typed_dicts_to_pydantic(
arg_type, depth=depth + 1, visited=visited arg_type, depth=depth + 1, visited=visited
@@ -305,8 +302,8 @@ def _convert_any_typed_dicts_to_pydantic(
field_kwargs = {"default": ...} field_kwargs = {"default": ...}
if arg_desc := arg_descriptions.get(arg): if arg_desc := arg_descriptions.get(arg):
field_kwargs["description"] = arg_desc field_kwargs["description"] = arg_desc
fields[arg] = (new_arg_type, Field_v1(**field_kwargs)) fields[arg] = (new_arg_type, Field(**field_kwargs))
model = create_model_v1(typed_dict.__name__, **fields) model = create_model(typed_dict.__name__, **fields)
model.__doc__ = description model.__doc__ = description
visited[typed_dict] = model visited[typed_dict] = model
return model return model
@@ -332,9 +329,9 @@ def _format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
Returns: Returns:
The function description. The function description.
""" """
from langchain_core.tools import simple is_simple_oai_tool = (
isinstance(tool, langchain_core.tools.simple.Tool) and not tool.args_schema
is_simple_oai_tool = isinstance(tool, simple.Tool) and not tool.args_schema )
if tool.tool_call_schema and not is_simple_oai_tool: if tool.tool_call_schema and not is_simple_oai_tool:
if isinstance(tool.tool_call_schema, dict): if isinstance(tool.tool_call_schema, dict):
return _convert_json_schema_to_openai_function( return _convert_json_schema_to_openai_function(
@@ -435,8 +432,6 @@ def convert_to_openai_function(
'description' and 'parameters' keys are now optional. Only 'name' is 'description' and 'parameters' keys are now optional. Only 'name' is
required and guaranteed to be part of the output. required and guaranteed to be part of the output.
""" """
from langchain_core.tools import BaseTool
# an Anthropic format tool # an Anthropic format tool
if isinstance(function, dict) and all( if isinstance(function, dict) and all(
k in function for k in ("name", "input_schema") k in function for k in ("name", "input_schema")
@@ -476,7 +471,7 @@ def convert_to_openai_function(
oai_function = cast( oai_function = cast(
"dict", _convert_typed_dict_to_openai_function(cast("type", function)) "dict", _convert_typed_dict_to_openai_function(cast("type", function))
) )
elif isinstance(function, BaseTool): elif isinstance(function, langchain_core.tools.base.BaseTool):
oai_function = cast("dict", _format_tool_to_openai_function(function)) oai_function = cast("dict", _format_tool_to_openai_function(function))
elif callable(function): elif callable(function):
oai_function = cast( oai_function = cast(
@@ -582,7 +577,8 @@ def convert_to_openai_tool(
Added support for OpenAI's image generation built-in tool. Added support for OpenAI's image generation built-in tool.
""" """
from langchain_core.tools import Tool # Import locally to prevent circular import
from langchain_core.tools import Tool # noqa: PLC0415
if isinstance(tool, dict): if isinstance(tool, dict):
if tool.get("type") in _WellKnownOpenAITools: if tool.get("type") in _WellKnownOpenAITools:

View File

@@ -1,5 +1,7 @@
"""Utilities for working with interactive environments.""" """Utilities for working with interactive environments."""
import sys
def is_interactive_env() -> bool: def is_interactive_env() -> bool:
"""Determine if running within IPython or Jupyter. """Determine if running within IPython or Jupyter.
@@ -7,6 +9,4 @@ def is_interactive_env() -> bool:
Returns: Returns:
True if running in an interactive environment, False otherwise. True if running in an interactive environment, False otherwise.
""" """
import sys
return hasattr(sys, "ps2") return hasattr(sys, "ps2")

View File

@@ -38,6 +38,7 @@ from typing import (
from pydantic import ConfigDict, Field, model_validator from pydantic import ConfigDict, Field, model_validator
from typing_extensions import Self, override from typing_extensions import Self, override
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.retrievers import BaseRetriever, LangSmithRetrieverParams from langchain_core.retrievers import BaseRetriever, LangSmithRetrieverParams
from langchain_core.runnables.config import run_in_executor from langchain_core.runnables.config import run_in_executor
@@ -49,7 +50,6 @@ if TYPE_CHECKING:
AsyncCallbackManagerForRetrieverRun, AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun, CallbackManagerForRetrieverRun,
) )
from langchain_core.documents import Document
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -85,9 +85,6 @@ class VectorStore(ABC):
ValueError: If the number of ids does not match the number of texts. ValueError: If the number of ids does not match the number of texts.
""" """
if type(self).add_documents != VectorStore.add_documents: if type(self).add_documents != VectorStore.add_documents:
# Import document in local scope to avoid circular imports
from langchain_core.documents import Document
# This condition is triggered if the subclass has provided # This condition is triggered if the subclass has provided
# an implementation of the upsert method. # an implementation of the upsert method.
# The existing add_texts # The existing add_texts
@@ -234,9 +231,6 @@ class VectorStore(ABC):
# For backward compatibility # For backward compatibility
kwargs["ids"] = ids kwargs["ids"] = ids
if type(self).aadd_documents != VectorStore.aadd_documents: if type(self).aadd_documents != VectorStore.aadd_documents:
# Import document in local scope to avoid circular imports
from langchain_core.documents import Document
# This condition is triggered if the subclass has provided # This condition is triggered if the subclass has provided
# an implementation of the upsert method. # an implementation of the upsert method.
# The existing add_texts # The existing add_texts

View File

@@ -27,6 +27,13 @@ if TYPE_CHECKING:
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.indexing import UpsertResponse from langchain_core.indexing import UpsertResponse
try:
import numpy as np
_HAS_NUMPY = True
except ImportError:
_HAS_NUMPY = False
class InMemoryVectorStore(VectorStore): class InMemoryVectorStore(VectorStore):
"""In-memory vector store implementation. """In-memory vector store implementation.
@@ -496,14 +503,12 @@ class InMemoryVectorStore(VectorStore):
filter=filter, filter=filter,
) )
try: if not _HAS_NUMPY:
import numpy as np
except ImportError as e:
msg = ( msg = (
"numpy must be installed to use max_marginal_relevance_search " "numpy must be installed to use max_marginal_relevance_search "
"pip install numpy" "pip install numpy"
) )
raise ImportError(msg) from e raise ImportError(msg)
mmr_chosen_indices = maximal_marginal_relevance( mmr_chosen_indices = maximal_marginal_relevance(
np.array(embedding, dtype=np.float32), np.array(embedding, dtype=np.float32),

View File

@@ -10,9 +10,21 @@ import logging
import warnings import warnings
from typing import TYPE_CHECKING, Union from typing import TYPE_CHECKING, Union
if TYPE_CHECKING: try:
import numpy as np import numpy as np
_HAS_NUMPY = True
except ImportError:
_HAS_NUMPY = False
try:
import simsimd as simd # type: ignore[import-not-found]
_HAS_SIMSIMD = True
except ImportError:
_HAS_SIMSIMD = False
if TYPE_CHECKING:
Matrix = Union[list[list[float]], list[np.ndarray], np.ndarray] Matrix = Union[list[list[float]], list[np.ndarray], np.ndarray]
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -33,14 +45,12 @@ def _cosine_similarity(x: Matrix, y: Matrix) -> np.ndarray:
ValueError: If the number of columns in X and Y are not the same. ValueError: If the number of columns in X and Y are not the same.
ImportError: If numpy is not installed. ImportError: If numpy is not installed.
""" """
try: if not _HAS_NUMPY:
import numpy as np
except ImportError as e:
msg = ( msg = (
"cosine_similarity requires numpy to be installed. " "cosine_similarity requires numpy to be installed. "
"Please install numpy with `pip install numpy`." "Please install numpy with `pip install numpy`."
) )
raise ImportError(msg) from e raise ImportError(msg)
if len(x) == 0 or len(y) == 0: if len(x) == 0 or len(y) == 0:
return np.array([[]]) return np.array([[]])
@@ -70,9 +80,7 @@ def _cosine_similarity(x: Matrix, y: Matrix) -> np.ndarray:
f"and Y has shape {y.shape}." f"and Y has shape {y.shape}."
) )
raise ValueError(msg) raise ValueError(msg)
try: if not _HAS_SIMSIMD:
import simsimd as simd # type: ignore[import-not-found]
except ImportError:
logger.debug( logger.debug(
"Unable to import simsimd, defaulting to NumPy implementation. If you want " "Unable to import simsimd, defaulting to NumPy implementation. If you want "
"to use simsimd please install with `pip install simsimd`." "to use simsimd please install with `pip install simsimd`."
@@ -113,14 +121,12 @@ def maximal_marginal_relevance(
Raises: Raises:
ImportError: If numpy is not installed. ImportError: If numpy is not installed.
""" """
try: if not _HAS_NUMPY:
import numpy as np
except ImportError as e:
msg = ( msg = (
"maximal_marginal_relevance requires numpy to be installed. " "maximal_marginal_relevance requires numpy to be installed. "
"Please install numpy with `pip install numpy`." "Please install numpy with `pip install numpy`."
) )
raise ImportError(msg) from e raise ImportError(msg)
if min(k, len(embedding_list)) <= 0: if min(k, len(embedding_list)) <= 0:
return [] return []

View File

@@ -114,7 +114,6 @@ ignore = [
"BLE", # Blind exceptions "BLE", # Blind exceptions
"DOC", # Docstrings (preview) "DOC", # Docstrings (preview)
"ERA", # No commented-out code "ERA", # No commented-out code
"PLC0415", # Imports outside top level
"PLR2004", # Comparison to magic number "PLR2004", # Comparison to magic number
] ]
unfixable = ["PLW1510",] unfixable = ["PLW1510",]

View File

@@ -12,6 +12,7 @@ from langchain_core.language_models.fake_chat_models import (
FakeListChatModel, FakeListChatModel,
GenericFakeChatModel, GenericFakeChatModel,
) )
from langchain_core.load import dumps
from langchain_core.messages import AIMessage from langchain_core.messages import AIMessage
from langchain_core.outputs import ChatGeneration, Generation from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.outputs.chat_result import ChatResult from langchain_core.outputs.chat_result import ChatResult
@@ -318,8 +319,6 @@ def test_cache_with_generation_objects() -> None:
cache = InMemoryCache() cache = InMemoryCache()
# Create a simple fake chat model that we can control # Create a simple fake chat model that we can control
from langchain_core.messages import AIMessage
class SimpleFakeChat: class SimpleFakeChat:
"""Simple fake chat model for testing.""" """Simple fake chat model for testing."""
@@ -332,8 +331,6 @@ def test_cache_with_generation_objects() -> None:
def generate_response(self, prompt: str) -> ChatResult: def generate_response(self, prompt: str) -> ChatResult:
"""Simulate the cache lookup and generation logic.""" """Simulate the cache lookup and generation logic."""
from langchain_core.load import dumps
llm_string = self._get_llm_string() llm_string = self._get_llm_string()
prompt_str = dumps([prompt]) prompt_str = dumps([prompt])

View File

@@ -5,6 +5,7 @@ from blockbuster import BlockBuster
from langchain_core.caches import InMemoryCache from langchain_core.caches import InMemoryCache
from langchain_core.language_models import GenericFakeChatModel from langchain_core.language_models import GenericFakeChatModel
from langchain_core.load import dumps
from langchain_core.rate_limiters import InMemoryRateLimiter from langchain_core.rate_limiters import InMemoryRateLimiter
@@ -229,8 +230,6 @@ class SerializableModel(GenericFakeChatModel):
def test_serialization_with_rate_limiter() -> None: def test_serialization_with_rate_limiter() -> None:
"""Test model serialization with rate limiter.""" """Test model serialization with rate limiter."""
from langchain_core.load import dumps
model = SerializableModel( model = SerializableModel(
messages=iter(["hello", "world", "!"]), messages=iter(["hello", "world", "!"]),
rate_limiter=InMemoryRateLimiter( rate_limiter=InMemoryRateLimiter(

View File

@@ -1,7 +1,7 @@
import json import json
import pytest import pytest
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field, SecretStr
from langchain_core.load import Serializable, dumpd, dumps, load from langchain_core.load import Serializable, dumpd, dumps, load
from langchain_core.load.serializable import _is_field_useful from langchain_core.load.serializable import _is_field_useful
@@ -62,9 +62,6 @@ def test_simple_serialization_is_serializable() -> None:
def test_simple_serialization_secret() -> None: def test_simple_serialization_secret() -> None:
"""Test handling of secrets.""" """Test handling of secrets."""
from pydantic import SecretStr
from langchain_core.load import Serializable
class Foo(Serializable): class Foo(Serializable):
bar: int bar: int

View File

@@ -1,6 +1,7 @@
from collections.abc import AsyncIterator, Iterator from collections.abc import AsyncIterator, Iterator
from typing import Any from typing import Any
import pydantic
import pytest import pytest
from pydantic import BaseModel, Field, ValidationError from pydantic import BaseModel, Field, ValidationError
@@ -802,7 +803,6 @@ async def test_partial_pydantic_output_parser_async() -> None:
def test_parse_with_different_pydantic_2_v1() -> None: def test_parse_with_different_pydantic_2_v1() -> None:
"""Test with pydantic.v1.BaseModel from pydantic 2.""" """Test with pydantic.v1.BaseModel from pydantic 2."""
import pydantic
class Forecast(pydantic.v1.BaseModel): class Forecast(pydantic.v1.BaseModel):
temperature: int temperature: int
@@ -836,9 +836,8 @@ def test_parse_with_different_pydantic_2_v1() -> None:
def test_parse_with_different_pydantic_2_proper() -> None: def test_parse_with_different_pydantic_2_proper() -> None:
"""Test with pydantic.BaseModel from pydantic 2.""" """Test with pydantic.BaseModel from pydantic 2."""
import pydantic
class Forecast(pydantic.BaseModel): class Forecast(BaseModel):
temperature: int temperature: int
forecast: str forecast: str

View File

@@ -189,8 +189,6 @@ def test_pydantic_output_parser_type_inference() -> None:
def test_format_instructions_preserves_language() -> None: def test_format_instructions_preserves_language() -> None:
"""Test format instructions does not attempt to encode into ascii.""" """Test format instructions does not attempt to encode into ascii."""
from pydantic import BaseModel, Field
description = ( description = (
"你好, こんにちは, नमस्ते, Bonjour, Hola, " "你好, こんにちは, नमस्ते, Bonjour, Hola, "
"Olá, 안녕하세요, Jambo, Merhaba, Γειά σου" # noqa: RUF001 "Olá, 안녕하세요, Jambo, Merhaba, Γειά σου" # noqa: RUF001

View File

@@ -1,6 +1,7 @@
"""Test functionality related to prompts.""" """Test functionality related to prompts."""
import re import re
from tempfile import NamedTemporaryFile
from typing import Any, Union from typing import Any, Union
from unittest import mock from unittest import mock
@@ -32,8 +33,6 @@ def test_from_file_encoding() -> None:
input_variables = ["foo"] input_variables = ["foo"]
# First write to a file using CP-1252 encoding. # First write to a file using CP-1252 encoding.
from tempfile import NamedTemporaryFile
with NamedTemporaryFile(delete=True, mode="w", encoding="cp1252") as f: with NamedTemporaryFile(delete=True, mode="w", encoding="cp1252") as f:
f.write(template) f.write(template)
f.flush() f.flush()
@@ -434,11 +433,9 @@ Will it get confused{ }?
assert prompt == expected_prompt assert prompt == expected_prompt
@pytest.mark.requires("jinja2")
def test_basic_sandboxing_with_jinja2() -> None: def test_basic_sandboxing_with_jinja2() -> None:
"""Test basic sandboxing with jinja2.""" """Test basic sandboxing with jinja2."""
import jinja2 jinja2 = pytest.importorskip("jinja2")
template = " {{''.__class__.__bases__[0] }} " # malicious code template = " {{''.__class__.__bases__[0] }} " # malicious code
prompt = PromptTemplate.from_template(template, template_format="jinja2") prompt = PromptTemplate.from_template(template, template_format="jinja2")
with pytest.raises(jinja2.exceptions.SecurityError): with pytest.raises(jinja2.exceptions.SecurityError):

View File

@@ -2,6 +2,7 @@
import asyncio import asyncio
import time import time
from threading import Lock
from typing import Any from typing import Any
import pytest import pytest
@@ -80,7 +81,6 @@ def test_batch_concurrency() -> None:
"""Test that batch respects max_concurrency.""" """Test that batch respects max_concurrency."""
running_tasks = 0 running_tasks = 0
max_running_tasks = 0 max_running_tasks = 0
from threading import Lock
lock = Lock() lock = Lock()
@@ -112,7 +112,6 @@ def test_batch_as_completed_concurrency() -> None:
"""Test that batch_as_completed respects max_concurrency.""" """Test that batch_as_completed respects max_concurrency."""
running_tasks = 0 running_tasks = 0
max_running_tasks = 0 max_running_tasks = 0
from threading import Lock
lock = Lock() lock = Lock()

View File

@@ -4,7 +4,7 @@ from typing import Any, Callable, Optional, Union
import pytest import pytest
from packaging import version from packaging import version
from pydantic import BaseModel from pydantic import BaseModel, RootModel
from typing_extensions import override from typing_extensions import override
from langchain_core.callbacks import ( from langchain_core.callbacks import (
@@ -20,6 +20,11 @@ from langchain_core.runnables.config import RunnableConfig
from langchain_core.runnables.history import RunnableWithMessageHistory from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.runnables.utils import ConfigurableFieldSpec, Input, Output from langchain_core.runnables.utils import ConfigurableFieldSpec, Input, Output
from langchain_core.tracers import Run from langchain_core.tracers import Run
from langchain_core.tracers.root_listeners import (
AsyncListener,
AsyncRootListenersTracer,
RootListenersTracer,
)
from langchain_core.utils.pydantic import PYDANTIC_VERSION from langchain_core.utils.pydantic import PYDANTIC_VERSION
from tests.unit_tests.pydantic_utils import _schema from tests.unit_tests.pydantic_utils import _schema
@@ -499,8 +504,6 @@ def test_get_output_schema() -> None:
def test_get_input_schema_input_messages() -> None: def test_get_input_schema_input_messages() -> None:
from pydantic import RootModel
runnable_with_message_history_input = RootModel[Sequence[BaseMessage]] runnable_with_message_history_input = RootModel[Sequence[BaseMessage]]
runnable = RunnableLambda( runnable = RunnableLambda(
@@ -776,8 +779,6 @@ def test_ignore_session_id() -> None:
class _RunnableLambdaWithRaiseError(RunnableLambda[Input, Output]): class _RunnableLambdaWithRaiseError(RunnableLambda[Input, Output]):
from langchain_core.tracers.root_listeners import AsyncListener
def with_listeners( def with_listeners(
self, self,
*, *,
@@ -791,8 +792,6 @@ class _RunnableLambdaWithRaiseError(RunnableLambda[Input, Output]):
Union[Callable[[Run], None], Callable[[Run, RunnableConfig], None]] Union[Callable[[Run], None], Callable[[Run, RunnableConfig], None]]
] = None, ] = None,
) -> Runnable[Input, Output]: ) -> Runnable[Input, Output]:
from langchain_core.tracers.root_listeners import RootListenersTracer
def create_tracer(config: RunnableConfig) -> RunnableConfig: def create_tracer(config: RunnableConfig) -> RunnableConfig:
tracer = RootListenersTracer( tracer = RootListenersTracer(
config=config, config=config,
@@ -817,8 +816,6 @@ class _RunnableLambdaWithRaiseError(RunnableLambda[Input, Output]):
on_end: Optional[AsyncListener] = None, on_end: Optional[AsyncListener] = None,
on_error: Optional[AsyncListener] = None, on_error: Optional[AsyncListener] = None,
) -> Runnable[Input, Output]: ) -> Runnable[Input, Output]:
from langchain_core.tracers.root_listeners import AsyncRootListenersTracer
def create_tracer(config: RunnableConfig) -> RunnableConfig: def create_tracer(config: RunnableConfig) -> RunnableConfig:
tracer = AsyncRootListenersTracer( tracer = AsyncRootListenersTracer(
config=config, config=config,

View File

@@ -40,6 +40,6 @@ def test_all_imports() -> None:
def test_imports_for_specific_funcs() -> None: def test_imports_for_specific_funcs() -> None:
"""Test that a few specific imports in more internal namespaces.""" """Test that a few specific imports in more internal namespaces."""
# create_model implementation has been moved to langchain_core.utils.pydantic # create_model implementation has been moved to langchain_core.utils.pydantic
from langchain_core.runnables.utils import ( # type: ignore[attr-defined] # noqa: F401 from langchain_core.runnables.utils import ( # type: ignore[attr-defined] # noqa: F401,PLC0415
create_model, create_model,
) )

View File

@@ -1,6 +1,7 @@
import asyncio import asyncio
import re import re
import sys import sys
import time
import uuid import uuid
import warnings import warnings
from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
@@ -17,6 +18,7 @@ from pytest_mock import MockerFixture
from syrupy.assertion import SnapshotAssertion from syrupy.assertion import SnapshotAssertion
from typing_extensions import TypedDict, override from typing_extensions import TypedDict, override
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.callbacks.manager import ( from langchain_core.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun, AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun, CallbackManagerForRetrieverRun,
@@ -29,6 +31,7 @@ from langchain_core.language_models import (
FakeListLLM, FakeListLLM,
FakeStreamingListLLM, FakeStreamingListLLM,
) )
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
from langchain_core.load import dumpd, dumps from langchain_core.load import dumpd, dumps
from langchain_core.load.load import loads from langchain_core.load.load import loads
from langchain_core.messages import AIMessageChunk, HumanMessage, SystemMessage from langchain_core.messages import AIMessageChunk, HumanMessage, SystemMessage
@@ -5516,9 +5519,6 @@ async def test_passthrough_atransform_with_dicts() -> None:
def test_listeners() -> None: def test_listeners() -> None:
from langchain_core.runnables import RunnableLambda
from langchain_core.tracers.schemas import Run
def fake_chain(inputs: dict) -> dict: def fake_chain(inputs: dict) -> dict:
return {**inputs, "key": "extra"} return {**inputs, "key": "extra"}
@@ -5546,9 +5546,6 @@ def test_listeners() -> None:
async def test_listeners_async() -> None: async def test_listeners_async() -> None:
from langchain_core.runnables import RunnableLambda
from langchain_core.tracers.schemas import Run
def fake_chain(inputs: dict) -> dict: def fake_chain(inputs: dict) -> dict:
return {**inputs, "key": "extra"} return {**inputs, "key": "extra"}
@@ -5578,12 +5575,6 @@ async def test_listeners_async() -> None:
def test_closing_iterator_doesnt_raise_error() -> None: def test_closing_iterator_doesnt_raise_error() -> None:
"""Test that closing an iterator calls on_chain_end rather than on_chain_error.""" """Test that closing an iterator calls on_chain_end rather than on_chain_error."""
import time
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
from langchain_core.output_parsers import StrOutputParser
on_chain_error_triggered = False on_chain_error_triggered = False
on_chain_end_triggered = False on_chain_end_triggered = False

View File

@@ -20,6 +20,7 @@ from pydantic import BaseModel
from typing_extensions import override from typing_extensions import override
from langchain_core.callbacks import CallbackManagerForRetrieverRun, Callbacks from langchain_core.callbacks import CallbackManagerForRetrieverRun, Callbacks
from langchain_core.callbacks.manager import adispatch_custom_event
from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.language_models import FakeStreamingListLLM, GenericFakeChatModel from langchain_core.language_models import FakeStreamingListLLM, GenericFakeChatModel
@@ -2548,7 +2549,6 @@ async def test_cancel_astream_events() -> None:
async def test_custom_event() -> None: async def test_custom_event() -> None:
"""Test adhoc event.""" """Test adhoc event."""
from langchain_core.callbacks.manager import adispatch_custom_event
# Ignoring type due to RunnableLamdba being dynamic when it comes to being # Ignoring type due to RunnableLamdba being dynamic when it comes to being
# applied as a decorator to async functions. # applied as a decorator to async functions.
@@ -2625,7 +2625,6 @@ async def test_custom_event() -> None:
async def test_custom_event_nested() -> None: async def test_custom_event_nested() -> None:
"""Test adhoc event in a nested chain.""" """Test adhoc event in a nested chain."""
from langchain_core.callbacks.manager import adispatch_custom_event
# Ignoring type due to RunnableLamdba being dynamic when it comes to being # Ignoring type due to RunnableLamdba being dynamic when it comes to being
# applied as a decorator to async functions. # applied as a decorator to async functions.
@@ -2736,7 +2735,6 @@ async def test_custom_event_root_dispatch() -> None:
# This just tests that nothing breaks on the path. # This just tests that nothing breaks on the path.
# It shouldn't do anything at the moment, since the tracer isn't configured # It shouldn't do anything at the moment, since the tracer isn't configured
# to handle adhoc events. # to handle adhoc events.
from langchain_core.callbacks.manager import adispatch_custom_event
# Expected behavior is that the event cannot be dispatched # Expected behavior is that the event cannot be dispatched
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
@@ -2750,8 +2748,6 @@ IS_GTE_3_11 = sys.version_info >= (3, 11)
@pytest.mark.skipif(not IS_GTE_3_11, reason="Requires Python >=3.11") @pytest.mark.skipif(not IS_GTE_3_11, reason="Requires Python >=3.11")
async def test_custom_event_root_dispatch_with_in_tool() -> None: async def test_custom_event_root_dispatch_with_in_tool() -> None:
"""Test adhoc event in a nested chain.""" """Test adhoc event in a nested chain."""
from langchain_core.callbacks.manager import adispatch_custom_event
from langchain_core.tools import tool
@tool @tool
async def foo(x: int) -> int: async def foo(x: int) -> int:

View File

@@ -1,18 +1,17 @@
import langchain_core
from langchain_core.callbacks.manager import _get_debug
from langchain_core.globals import get_debug, set_debug from langchain_core.globals import get_debug, set_debug
def test_debug_is_settable_via_setter() -> None: def test_debug_is_settable_via_setter() -> None:
from langchain_core import globals as globals_ previous_value = langchain_core.globals._debug
from langchain_core.callbacks.manager import _get_debug
previous_value = globals_._debug
previous_fn_reading = _get_debug() previous_fn_reading = _get_debug()
assert previous_value == previous_fn_reading assert previous_value == previous_fn_reading
# Flip the value of the flag. # Flip the value of the flag.
set_debug(not previous_value) set_debug(not previous_value)
new_value = globals_._debug new_value = langchain_core.globals._debug
new_fn_reading = _get_debug() new_fn_reading = _get_debug()
try: try:

View File

@@ -21,7 +21,7 @@ from typing import (
) )
import pytest import pytest
from pydantic import BaseModel, Field, ValidationError from pydantic import BaseModel, ConfigDict, Field, ValidationError
from pydantic.v1 import BaseModel as BaseModelV1 from pydantic.v1 import BaseModel as BaseModelV1
from pydantic.v1 import ValidationError as ValidationErrorV1 from pydantic.v1 import ValidationError as ValidationErrorV1
from typing_extensions import TypedDict, override from typing_extensions import TypedDict, override
@@ -1852,7 +1852,6 @@ def generate_models() -> list[Any]:
def generate_backwards_compatible_v1() -> list[Any]: def generate_backwards_compatible_v1() -> list[Any]:
"""Generate a model with pydantic 2 from the v1 namespace.""" """Generate a model with pydantic 2 from the v1 namespace."""
from pydantic.v1 import BaseModel as BaseModelV1
class FooV1Namespace(BaseModelV1): class FooV1Namespace(BaseModelV1):
a: int a: int
@@ -1920,8 +1919,6 @@ def test_args_schema_explicitly_typed() -> None:
is a pydantic 1 model! is a pydantic 1 model!
""" """
# Check with whatever pydantic model is passed in and not via v1 namespace
from pydantic import BaseModel
class Foo(BaseModel): class Foo(BaseModel):
a: int a: int
@@ -1964,7 +1961,6 @@ def test_args_schema_explicitly_typed() -> None:
@pytest.mark.parametrize("pydantic_model", TEST_MODELS) @pytest.mark.parametrize("pydantic_model", TEST_MODELS)
def test_structured_tool_with_different_pydantic_versions(pydantic_model: Any) -> None: def test_structured_tool_with_different_pydantic_versions(pydantic_model: Any) -> None:
"""This should test that one can type the args schema as a pydantic model.""" """This should test that one can type the args schema as a pydantic model."""
from langchain_core.tools import StructuredTool
def foo(a: int, b: str) -> str: def foo(a: int, b: str) -> str:
"""Hahaha.""" """Hahaha."""
@@ -2063,16 +2059,13 @@ def test__get_all_basemodel_annotations_v2(*, use_v1_namespace: bool) -> None:
A = TypeVar("A") A = TypeVar("A")
if use_v1_namespace: if use_v1_namespace:
from pydantic.v1 import BaseModel as BaseModel1
class ModelA(BaseModel1, Generic[A], extra="allow"): class ModelA(BaseModelV1, Generic[A], extra="allow"):
a: A a: A
else: else:
from pydantic import BaseModel as BaseModel2
from pydantic import ConfigDict
class ModelA(BaseModel2, Generic[A]): # type: ignore[no-redef] class ModelA(BaseModel, Generic[A]): # type: ignore[no-redef]
a: A a: A
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow") model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
@@ -2208,12 +2201,8 @@ def test_create_retriever_tool() -> None:
def test_tool_args_schema_pydantic_v2_with_metadata() -> None: def test_tool_args_schema_pydantic_v2_with_metadata() -> None:
from pydantic import BaseModel as BaseModelV2 class Foo(BaseModel):
from pydantic import Field as FieldV2 x: list[int] = Field(
from pydantic import ValidationError as ValidationErrorV2
class Foo(BaseModelV2):
x: list[int] = FieldV2(
description="List of integers", min_length=10, max_length=15 description="List of integers", min_length=10, max_length=15
) )
@@ -2240,7 +2229,7 @@ def test_tool_args_schema_pydantic_v2_with_metadata() -> None:
} }
assert foo.invoke({"x": [0] * 10}) assert foo.invoke({"x": [0] * 10})
with pytest.raises(ValidationErrorV2): with pytest.raises(ValidationError):
foo.invoke({"x": [0] * 9}) foo.invoke({"x": [0] * 9})
@@ -2576,8 +2565,6 @@ def test_title_property_preserved() -> None:
https://github.com/langchain-ai/langchain/issues/30456 https://github.com/langchain-ai/langchain/issues/30456
""" """
from langchain_core.tools import tool
schema_to_be_extracted = { schema_to_be_extracted = {
"type": "object", "type": "object",
"required": [], "required": [],

View File

@@ -3,7 +3,8 @@
import warnings import warnings
from typing import Any, Optional from typing import Any, Optional
from pydantic import ConfigDict from pydantic import BaseModel, ConfigDict, Field
from pydantic.v1 import BaseModel as BaseModelV1
from langchain_core.utils.pydantic import ( from langchain_core.utils.pydantic import (
_create_subset_model_v2, _create_subset_model_v2,
@@ -16,8 +17,6 @@ from langchain_core.utils.pydantic import (
def test_pre_init_decorator() -> None: def test_pre_init_decorator() -> None:
from pydantic import BaseModel
class Foo(BaseModel): class Foo(BaseModel):
x: int = 5 x: int = 5
y: int y: int
@@ -35,8 +34,6 @@ def test_pre_init_decorator() -> None:
def test_pre_init_decorator_with_more_defaults() -> None: def test_pre_init_decorator_with_more_defaults() -> None:
from pydantic import BaseModel, Field
class Foo(BaseModel): class Foo(BaseModel):
a: int = 1 a: int = 1
b: Optional[int] = None b: Optional[int] = None
@@ -56,8 +53,6 @@ def test_pre_init_decorator_with_more_defaults() -> None:
def test_with_aliases() -> None: def test_with_aliases() -> None:
from pydantic import BaseModel, Field
class Foo(BaseModel): class Foo(BaseModel):
x: int = Field(default=1, alias="y") x: int = Field(default=1, alias="y")
z: int z: int
@@ -92,19 +87,14 @@ def test_with_aliases() -> None:
def test_is_basemodel_subclass() -> None: def test_is_basemodel_subclass() -> None:
"""Test pydantic.""" """Test pydantic."""
from pydantic import BaseModel as BaseModelV2 assert is_basemodel_subclass(BaseModel)
from pydantic.v1 import BaseModel as BaseModelV1
assert is_basemodel_subclass(BaseModelV2)
assert is_basemodel_subclass(BaseModelV1) assert is_basemodel_subclass(BaseModelV1)
def test_is_basemodel_instance() -> None: def test_is_basemodel_instance() -> None:
"""Test pydantic.""" """Test pydantic."""
from pydantic import BaseModel as BaseModelV2
from pydantic.v1 import BaseModel as BaseModelV1
class Foo(BaseModelV2): class Foo(BaseModel):
x: int x: int
assert is_basemodel_instance(Foo(x=5)) assert is_basemodel_instance(Foo(x=5))
@@ -117,11 +107,9 @@ def test_is_basemodel_instance() -> None:
def test_with_field_metadata() -> None: def test_with_field_metadata() -> None:
"""Test pydantic with field metadata.""" """Test pydantic with field metadata."""
from pydantic import BaseModel as BaseModelV2
from pydantic import Field as FieldV2
class Foo(BaseModelV2): class Foo(BaseModel):
x: list[int] = FieldV2( x: list[int] = Field(
description="List of integers", min_length=10, max_length=15 description="List of integers", min_length=10, max_length=15
) )
@@ -144,8 +132,6 @@ def test_with_field_metadata() -> None:
def test_fields_pydantic_v2_proper() -> None: def test_fields_pydantic_v2_proper() -> None:
from pydantic import BaseModel
class Foo(BaseModel): class Foo(BaseModel):
x: int x: int
@@ -154,9 +140,7 @@ def test_fields_pydantic_v2_proper() -> None:
def test_fields_pydantic_v1_from_2() -> None: def test_fields_pydantic_v1_from_2() -> None:
from pydantic.v1 import BaseModel class Foo(BaseModelV1):
class Foo(BaseModel):
x: int x: int
fields = get_fields(Foo) fields = get_fields(Foo)

View File

@@ -6,7 +6,9 @@ from typing import Any, Callable, Optional, Union
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from pydantic import SecretStr from pydantic import BaseModel, Field, SecretStr
from pydantic.v1 import BaseModel as PydanticV1BaseModel
from pydantic.v1 import Field as PydanticV1Field
from langchain_core import utils from langchain_core import utils
from langchain_core.outputs import GenerationChunk from langchain_core.outputs import GenerationChunk
@@ -212,13 +214,10 @@ def test_guard_import_failure(
def test_get_pydantic_field_names_v1_in_2() -> None: def test_get_pydantic_field_names_v1_in_2() -> None:
from pydantic.v1 import BaseModel as PydanticV1BaseModel
from pydantic.v1 import Field
class PydanticV1Model(PydanticV1BaseModel): class PydanticV1Model(PydanticV1BaseModel):
field1: str field1: str
field2: int field2: int
alias_field: int = Field(alias="aliased_field") alias_field: int = PydanticV1Field(alias="aliased_field")
result = get_pydantic_field_names(PydanticV1Model) result = get_pydantic_field_names(PydanticV1Model)
expected = {"field1", "field2", "aliased_field", "alias_field"} expected = {"field1", "field2", "aliased_field", "alias_field"}
@@ -226,8 +225,6 @@ def test_get_pydantic_field_names_v1_in_2() -> None:
def test_get_pydantic_field_names_v2_in_2() -> None: def test_get_pydantic_field_names_v2_in_2() -> None:
from pydantic import BaseModel, Field
class PydanticModel(BaseModel): class PydanticModel(BaseModel):
field1: str field1: str
field2: int field2: int
@@ -341,8 +338,6 @@ def test_secret_from_env_with_custom_error_message(
def test_using_secret_from_env_as_default_factory( def test_using_secret_from_env_as_default_factory(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
) -> None: ) -> None:
from pydantic import BaseModel, Field
class Foo(BaseModel): class Foo(BaseModel):
secret: SecretStr = Field(default_factory=secret_from_env("TEST_KEY")) secret: SecretStr = Field(default_factory=secret_from_env("TEST_KEY"))