mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 18:50:33 +00:00
chore(core): add ruff rule PLC0415 (#32351)
See https://docs.astral.sh/ruff/rules/import-outside-top-level/ Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
committed by
GitHub
parent
16420cad71
commit
cc98fb9bee
@@ -28,8 +28,18 @@ from langchain_core.callbacks.base import (
|
|||||||
ToolManagerMixin,
|
ToolManagerMixin,
|
||||||
)
|
)
|
||||||
from langchain_core.callbacks.stdout import StdOutCallbackHandler
|
from langchain_core.callbacks.stdout import StdOutCallbackHandler
|
||||||
|
from langchain_core.globals import get_debug
|
||||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||||
|
from langchain_core.tracers.context import (
|
||||||
|
_configure_hooks,
|
||||||
|
_get_trace_callbacks,
|
||||||
|
_get_tracer_project,
|
||||||
|
_tracing_v2_is_enabled,
|
||||||
|
tracing_v2_callback_var,
|
||||||
|
)
|
||||||
|
from langchain_core.tracers.langchain import LangChainTracer
|
||||||
from langchain_core.tracers.schemas import Run
|
from langchain_core.tracers.schemas import Run
|
||||||
|
from langchain_core.tracers.stdout import ConsoleCallbackHandler
|
||||||
from langchain_core.utils.env import env_var_is_set
|
from langchain_core.utils.env import env_var_is_set
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -46,8 +56,6 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
def _get_debug() -> bool:
|
def _get_debug() -> bool:
|
||||||
from langchain_core.globals import get_debug
|
|
||||||
|
|
||||||
return get_debug()
|
return get_debug()
|
||||||
|
|
||||||
|
|
||||||
@@ -103,8 +111,6 @@ def trace_as_chain_group(
|
|||||||
manager.on_chain_end({"output": res})
|
manager.on_chain_end({"output": res})
|
||||||
|
|
||||||
"""
|
"""
|
||||||
from langchain_core.tracers.context import _get_trace_callbacks
|
|
||||||
|
|
||||||
cb = _get_trace_callbacks(
|
cb = _get_trace_callbacks(
|
||||||
project_name, example_id, callback_manager=callback_manager
|
project_name, example_id, callback_manager=callback_manager
|
||||||
)
|
)
|
||||||
@@ -189,8 +195,6 @@ async def atrace_as_chain_group(
|
|||||||
await manager.on_chain_end({"output": res})
|
await manager.on_chain_end({"output": res})
|
||||||
|
|
||||||
"""
|
"""
|
||||||
from langchain_core.tracers.context import _get_trace_callbacks
|
|
||||||
|
|
||||||
cb = _get_trace_callbacks(
|
cb = _get_trace_callbacks(
|
||||||
project_name, example_id, callback_manager=callback_manager
|
project_name, example_id, callback_manager=callback_manager
|
||||||
)
|
)
|
||||||
@@ -2376,13 +2380,6 @@ def _configure(
|
|||||||
Returns:
|
Returns:
|
||||||
T: The configured callback manager.
|
T: The configured callback manager.
|
||||||
"""
|
"""
|
||||||
from langchain_core.tracers.context import (
|
|
||||||
_configure_hooks,
|
|
||||||
_get_tracer_project,
|
|
||||||
_tracing_v2_is_enabled,
|
|
||||||
tracing_v2_callback_var,
|
|
||||||
)
|
|
||||||
|
|
||||||
tracing_context = get_tracing_context()
|
tracing_context = get_tracing_context()
|
||||||
tracing_metadata = tracing_context["metadata"]
|
tracing_metadata = tracing_context["metadata"]
|
||||||
tracing_tags = tracing_context["tags"]
|
tracing_tags = tracing_context["tags"]
|
||||||
@@ -2459,9 +2456,6 @@ def _configure(
|
|||||||
tracer_project = _get_tracer_project()
|
tracer_project = _get_tracer_project()
|
||||||
debug = _get_debug()
|
debug = _get_debug()
|
||||||
if verbose or debug or tracing_v2_enabled_:
|
if verbose or debug or tracing_v2_enabled_:
|
||||||
from langchain_core.tracers.langchain import LangChainTracer
|
|
||||||
from langchain_core.tracers.stdout import ConsoleCallbackHandler
|
|
||||||
|
|
||||||
if verbose and not any(
|
if verbose and not any(
|
||||||
isinstance(handler, StdOutCallbackHandler)
|
isinstance(handler, StdOutCallbackHandler)
|
||||||
for handler in callback_manager.handlers
|
for handler in callback_manager.handlers
|
||||||
@@ -2630,7 +2624,8 @@ async def adispatch_custom_event(
|
|||||||
.. versionadded:: 0.2.15
|
.. versionadded:: 0.2.15
|
||||||
|
|
||||||
"""
|
"""
|
||||||
from langchain_core.runnables.config import (
|
# Import locally to prevent circular imports.
|
||||||
|
from langchain_core.runnables.config import ( # noqa: PLC0415
|
||||||
ensure_config,
|
ensure_config,
|
||||||
get_async_callback_manager_for_config,
|
get_async_callback_manager_for_config,
|
||||||
)
|
)
|
||||||
@@ -2705,7 +2700,8 @@ def dispatch_custom_event(
|
|||||||
.. versionadded:: 0.2.15
|
.. versionadded:: 0.2.15
|
||||||
|
|
||||||
"""
|
"""
|
||||||
from langchain_core.runnables.config import (
|
# Import locally to prevent circular imports.
|
||||||
|
from langchain_core.runnables.config import ( # noqa: PLC0415
|
||||||
ensure_config,
|
ensure_config,
|
||||||
get_callback_manager_for_config,
|
get_callback_manager_for_config,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from langchain_core.callbacks import BaseCallbackHandler
|
|||||||
from langchain_core.messages import AIMessage
|
from langchain_core.messages import AIMessage
|
||||||
from langchain_core.messages.ai import UsageMetadata, add_usage
|
from langchain_core.messages.ai import UsageMetadata, add_usage
|
||||||
from langchain_core.outputs import ChatGeneration, LLMResult
|
from langchain_core.outputs import ChatGeneration, LLMResult
|
||||||
|
from langchain_core.tracers.context import register_configure_hook
|
||||||
|
|
||||||
|
|
||||||
class UsageMetadataCallbackHandler(BaseCallbackHandler):
|
class UsageMetadataCallbackHandler(BaseCallbackHandler):
|
||||||
@@ -133,8 +134,6 @@ def get_usage_metadata_callback(
|
|||||||
.. versionadded:: 0.3.49
|
.. versionadded:: 0.3.49
|
||||||
|
|
||||||
"""
|
"""
|
||||||
from langchain_core.tracers.context import register_configure_hook
|
|
||||||
|
|
||||||
usage_metadata_callback_var: ContextVar[Optional[UsageMetadataCallbackHandler]] = (
|
usage_metadata_callback_var: ContextVar[Optional[UsageMetadataCallbackHandler]] = (
|
||||||
ContextVar(name, default=None)
|
ContextVar(name, default=None)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ from langchain_core.messages import (
|
|||||||
HumanMessage,
|
HumanMessage,
|
||||||
get_buffer_string,
|
get_buffer_string,
|
||||||
)
|
)
|
||||||
|
from langchain_core.runnables.config import run_in_executor
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
@@ -113,8 +114,6 @@ class BaseChatMessageHistory(ABC):
|
|||||||
Returns:
|
Returns:
|
||||||
The messages.
|
The messages.
|
||||||
"""
|
"""
|
||||||
from langchain_core.runnables.config import run_in_executor
|
|
||||||
|
|
||||||
return await run_in_executor(None, lambda: self.messages)
|
return await run_in_executor(None, lambda: self.messages)
|
||||||
|
|
||||||
def add_user_message(self, message: Union[HumanMessage, str]) -> None:
|
def add_user_message(self, message: Union[HumanMessage, str]) -> None:
|
||||||
@@ -190,8 +189,6 @@ class BaseChatMessageHistory(ABC):
|
|||||||
Args:
|
Args:
|
||||||
messages: A sequence of BaseMessage objects to store.
|
messages: A sequence of BaseMessage objects to store.
|
||||||
"""
|
"""
|
||||||
from langchain_core.runnables.config import run_in_executor
|
|
||||||
|
|
||||||
await run_in_executor(None, self.add_messages, messages)
|
await run_in_executor(None, self.add_messages, messages)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@@ -200,8 +197,6 @@ class BaseChatMessageHistory(ABC):
|
|||||||
|
|
||||||
async def aclear(self) -> None:
|
async def aclear(self) -> None:
|
||||||
"""Async remove all messages from the store."""
|
"""Async remove all messages from the store."""
|
||||||
from langchain_core.runnables.config import run_in_executor
|
|
||||||
|
|
||||||
await run_in_executor(None, self.clear)
|
await run_in_executor(None, self.clear)
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
|
|||||||
@@ -15,6 +15,13 @@ if TYPE_CHECKING:
|
|||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.documents.base import Blob
|
from langchain_core.documents.base import Blob
|
||||||
|
|
||||||
|
try:
|
||||||
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||||
|
|
||||||
|
_HAS_TEXT_SPLITTERS = True
|
||||||
|
except ImportError:
|
||||||
|
_HAS_TEXT_SPLITTERS = False
|
||||||
|
|
||||||
|
|
||||||
class BaseLoader(ABC): # noqa: B024
|
class BaseLoader(ABC): # noqa: B024
|
||||||
"""Interface for Document Loader.
|
"""Interface for Document Loader.
|
||||||
@@ -62,15 +69,13 @@ class BaseLoader(ABC): # noqa: B024
|
|||||||
List of Documents.
|
List of Documents.
|
||||||
"""
|
"""
|
||||||
if text_splitter is None:
|
if text_splitter is None:
|
||||||
try:
|
if not _HAS_TEXT_SPLITTERS:
|
||||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
|
||||||
except ImportError as e:
|
|
||||||
msg = (
|
msg = (
|
||||||
"Unable to import from langchain_text_splitters. Please specify "
|
"Unable to import from langchain_text_splitters. Please specify "
|
||||||
"text_splitter or install langchain_text_splitters with "
|
"text_splitter or install langchain_text_splitters with "
|
||||||
"`pip install -U langchain-text-splitters`."
|
"`pip install -U langchain-text-splitters`."
|
||||||
)
|
)
|
||||||
raise ImportError(msg) from e
|
raise ImportError(msg)
|
||||||
|
|
||||||
text_splitter_: TextSplitter = RecursiveCharacterTextSplitter()
|
text_splitter_: TextSplitter = RecursiveCharacterTextSplitter()
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Module contains a few fake embedding models for testing purposes."""
|
"""Module contains a few fake embedding models for testing purposes."""
|
||||||
|
|
||||||
# Please do not add additional fake embedding model implementations here.
|
# Please do not add additional fake embedding model implementations here.
|
||||||
|
import contextlib
|
||||||
import hashlib
|
import hashlib
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@@ -8,6 +9,9 @@ from typing_extensions import override
|
|||||||
|
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
|
|
||||||
|
with contextlib.suppress(ImportError):
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
class FakeEmbeddings(Embeddings, BaseModel):
|
class FakeEmbeddings(Embeddings, BaseModel):
|
||||||
"""Fake embedding model for unit testing purposes.
|
"""Fake embedding model for unit testing purposes.
|
||||||
@@ -54,8 +58,6 @@ class FakeEmbeddings(Embeddings, BaseModel):
|
|||||||
"""The size of the embedding vector."""
|
"""The size of the embedding vector."""
|
||||||
|
|
||||||
def _get_embedding(self) -> list[float]:
|
def _get_embedding(self) -> list[float]:
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
return list(np.random.default_rng().normal(size=self.size))
|
return list(np.random.default_rng().normal(size=self.size))
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@@ -113,8 +115,6 @@ class DeterministicFakeEmbedding(Embeddings, BaseModel):
|
|||||||
"""The size of the embedding vector."""
|
"""The size of the embedding vector."""
|
||||||
|
|
||||||
def _get_embedding(self, seed: int) -> list[float]:
|
def _get_embedding(self, seed: int) -> list[float]:
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
# set the seed for the random generator
|
# set the seed for the random generator
|
||||||
rng = np.random.default_rng(seed)
|
rng = np.random.default_rng(seed)
|
||||||
return list(rng.normal(size=self.size))
|
return list(rng.normal(size=self.size))
|
||||||
|
|||||||
@@ -3,6 +3,8 @@
|
|||||||
import platform
|
import platform
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
|
||||||
|
from langchain_core import __version__
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=1)
|
@lru_cache(maxsize=1)
|
||||||
def get_runtime_environment() -> dict:
|
def get_runtime_environment() -> dict:
|
||||||
@@ -11,9 +13,6 @@ def get_runtime_environment() -> dict:
|
|||||||
Returns:
|
Returns:
|
||||||
A dictionary with information about the runtime environment.
|
A dictionary with information about the runtime environment.
|
||||||
"""
|
"""
|
||||||
# Lazy import to avoid circular imports
|
|
||||||
from langchain_core import __version__
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"library_version": __version__,
|
"library_version": __version__,
|
||||||
"library": "langchain-core",
|
"library": "langchain-core",
|
||||||
|
|||||||
@@ -6,6 +6,13 @@ from typing import TYPE_CHECKING, Optional
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langchain_core.caches import BaseCache
|
from langchain_core.caches import BaseCache
|
||||||
|
|
||||||
|
try:
|
||||||
|
import langchain # type: ignore[import-not-found]
|
||||||
|
|
||||||
|
_HAS_LANGCHAIN = True
|
||||||
|
except ImportError:
|
||||||
|
_HAS_LANGCHAIN = False
|
||||||
|
|
||||||
|
|
||||||
# DO NOT USE THESE VALUES DIRECTLY!
|
# DO NOT USE THESE VALUES DIRECTLY!
|
||||||
# Use them only via `get_<X>()` and `set_<X>()` below,
|
# Use them only via `get_<X>()` and `set_<X>()` below,
|
||||||
@@ -22,9 +29,7 @@ def set_verbose(value: bool) -> None: # noqa: FBT001
|
|||||||
Args:
|
Args:
|
||||||
value: The new value for the `verbose` global setting.
|
value: The new value for the `verbose` global setting.
|
||||||
"""
|
"""
|
||||||
try:
|
if _HAS_LANGCHAIN:
|
||||||
import langchain # type: ignore[import-not-found]
|
|
||||||
|
|
||||||
# We're about to run some deprecated code, don't report warnings from it.
|
# We're about to run some deprecated code, don't report warnings from it.
|
||||||
# The user called the correct (non-deprecated) code path and shouldn't get
|
# The user called the correct (non-deprecated) code path and shouldn't get
|
||||||
# warnings.
|
# warnings.
|
||||||
@@ -43,8 +48,6 @@ def set_verbose(value: bool) -> None: # noqa: FBT001
|
|||||||
# Remove it once `langchain.verbose` is no longer supported, and once all
|
# Remove it once `langchain.verbose` is no longer supported, and once all
|
||||||
# users have migrated to using `set_verbose()` here.
|
# users have migrated to using `set_verbose()` here.
|
||||||
langchain.verbose = value
|
langchain.verbose = value
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
global _verbose # noqa: PLW0603
|
global _verbose # noqa: PLW0603
|
||||||
_verbose = value
|
_verbose = value
|
||||||
@@ -56,9 +59,7 @@ def get_verbose() -> bool:
|
|||||||
Returns:
|
Returns:
|
||||||
The value of the `verbose` global setting.
|
The value of the `verbose` global setting.
|
||||||
"""
|
"""
|
||||||
try:
|
if _HAS_LANGCHAIN:
|
||||||
import langchain
|
|
||||||
|
|
||||||
# We're about to run some deprecated code, don't report warnings from it.
|
# We're about to run some deprecated code, don't report warnings from it.
|
||||||
# The user called the correct (non-deprecated) code path and shouldn't get
|
# The user called the correct (non-deprecated) code path and shouldn't get
|
||||||
# warnings.
|
# warnings.
|
||||||
@@ -83,7 +84,7 @@ def get_verbose() -> bool:
|
|||||||
# deprecation warnings directing them to use `set_verbose()` when they
|
# deprecation warnings directing them to use `set_verbose()` when they
|
||||||
# import `langchain.verbose`.
|
# import `langchain.verbose`.
|
||||||
old_verbose = langchain.verbose
|
old_verbose = langchain.verbose
|
||||||
except ImportError:
|
else:
|
||||||
old_verbose = False
|
old_verbose = False
|
||||||
|
|
||||||
return _verbose or old_verbose
|
return _verbose or old_verbose
|
||||||
@@ -95,9 +96,7 @@ def set_debug(value: bool) -> None: # noqa: FBT001
|
|||||||
Args:
|
Args:
|
||||||
value: The new value for the `debug` global setting.
|
value: The new value for the `debug` global setting.
|
||||||
"""
|
"""
|
||||||
try:
|
if _HAS_LANGCHAIN:
|
||||||
import langchain
|
|
||||||
|
|
||||||
# We're about to run some deprecated code, don't report warnings from it.
|
# We're about to run some deprecated code, don't report warnings from it.
|
||||||
# The user called the correct (non-deprecated) code path and shouldn't get
|
# The user called the correct (non-deprecated) code path and shouldn't get
|
||||||
# warnings.
|
# warnings.
|
||||||
@@ -114,8 +113,6 @@ def set_debug(value: bool) -> None: # noqa: FBT001
|
|||||||
# Remove it once `langchain.debug` is no longer supported, and once all
|
# Remove it once `langchain.debug` is no longer supported, and once all
|
||||||
# users have migrated to using `set_debug()` here.
|
# users have migrated to using `set_debug()` here.
|
||||||
langchain.debug = value
|
langchain.debug = value
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
global _debug # noqa: PLW0603
|
global _debug # noqa: PLW0603
|
||||||
_debug = value
|
_debug = value
|
||||||
@@ -127,9 +124,7 @@ def get_debug() -> bool:
|
|||||||
Returns:
|
Returns:
|
||||||
The value of the `debug` global setting.
|
The value of the `debug` global setting.
|
||||||
"""
|
"""
|
||||||
try:
|
if _HAS_LANGCHAIN:
|
||||||
import langchain
|
|
||||||
|
|
||||||
# We're about to run some deprecated code, don't report warnings from it.
|
# We're about to run some deprecated code, don't report warnings from it.
|
||||||
# The user called the correct (non-deprecated) code path and shouldn't get
|
# The user called the correct (non-deprecated) code path and shouldn't get
|
||||||
# warnings.
|
# warnings.
|
||||||
@@ -151,7 +146,7 @@ def get_debug() -> bool:
|
|||||||
# to using `set_debug()` yet. Those users are getting deprecation warnings
|
# to using `set_debug()` yet. Those users are getting deprecation warnings
|
||||||
# directing them to use `set_debug()` when they import `langchain.debug`.
|
# directing them to use `set_debug()` when they import `langchain.debug`.
|
||||||
old_debug = langchain.debug
|
old_debug = langchain.debug
|
||||||
except ImportError:
|
else:
|
||||||
old_debug = False
|
old_debug = False
|
||||||
|
|
||||||
return _debug or old_debug
|
return _debug or old_debug
|
||||||
@@ -163,9 +158,7 @@ def set_llm_cache(value: Optional["BaseCache"]) -> None:
|
|||||||
Args:
|
Args:
|
||||||
value: The new LLM cache to use. If `None`, the LLM cache is disabled.
|
value: The new LLM cache to use. If `None`, the LLM cache is disabled.
|
||||||
"""
|
"""
|
||||||
try:
|
if _HAS_LANGCHAIN:
|
||||||
import langchain
|
|
||||||
|
|
||||||
# We're about to run some deprecated code, don't report warnings from it.
|
# We're about to run some deprecated code, don't report warnings from it.
|
||||||
# The user called the correct (non-deprecated) code path and shouldn't get
|
# The user called the correct (non-deprecated) code path and shouldn't get
|
||||||
# warnings.
|
# warnings.
|
||||||
@@ -184,8 +177,6 @@ def set_llm_cache(value: Optional["BaseCache"]) -> None:
|
|||||||
# Remove it once `langchain.llm_cache` is no longer supported, and
|
# Remove it once `langchain.llm_cache` is no longer supported, and
|
||||||
# once all users have migrated to using `set_llm_cache()` here.
|
# once all users have migrated to using `set_llm_cache()` here.
|
||||||
langchain.llm_cache = value
|
langchain.llm_cache = value
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
global _llm_cache # noqa: PLW0603
|
global _llm_cache # noqa: PLW0603
|
||||||
_llm_cache = value
|
_llm_cache = value
|
||||||
@@ -197,9 +188,7 @@ def get_llm_cache() -> Optional["BaseCache"]:
|
|||||||
Returns:
|
Returns:
|
||||||
The value of the `llm_cache` global setting.
|
The value of the `llm_cache` global setting.
|
||||||
"""
|
"""
|
||||||
try:
|
if _HAS_LANGCHAIN:
|
||||||
import langchain
|
|
||||||
|
|
||||||
# We're about to run some deprecated code, don't report warnings from it.
|
# We're about to run some deprecated code, don't report warnings from it.
|
||||||
# The user called the correct (non-deprecated) code path and shouldn't get
|
# The user called the correct (non-deprecated) code path and shouldn't get
|
||||||
# warnings.
|
# warnings.
|
||||||
@@ -225,7 +214,7 @@ def get_llm_cache() -> Optional["BaseCache"]:
|
|||||||
# Those users are getting deprecation warnings directing them
|
# Those users are getting deprecation warnings directing them
|
||||||
# to use `set_llm_cache()` when they import `langchain.llm_cache`.
|
# to use `set_llm_cache()` when they import `langchain.llm_cache`.
|
||||||
old_llm_cache = langchain.llm_cache
|
old_llm_cache = langchain.llm_cache
|
||||||
except ImportError:
|
else:
|
||||||
old_llm_cache = None
|
old_llm_cache = None
|
||||||
|
|
||||||
return _llm_cache or old_llm_cache
|
return _llm_cache or old_llm_cache
|
||||||
|
|||||||
@@ -22,19 +22,31 @@ from typing_extensions import TypeAlias, TypedDict, override
|
|||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.caches import BaseCache
|
from langchain_core.caches import BaseCache
|
||||||
from langchain_core.callbacks import Callbacks
|
from langchain_core.callbacks import Callbacks
|
||||||
|
from langchain_core.globals import get_verbose
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
AnyMessage,
|
AnyMessage,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
MessageLikeRepresentation,
|
MessageLikeRepresentation,
|
||||||
get_buffer_string,
|
get_buffer_string,
|
||||||
)
|
)
|
||||||
from langchain_core.prompt_values import PromptValue
|
from langchain_core.prompt_values import (
|
||||||
|
ChatPromptValueConcrete,
|
||||||
|
PromptValue,
|
||||||
|
StringPromptValue,
|
||||||
|
)
|
||||||
from langchain_core.runnables import Runnable, RunnableSerializable
|
from langchain_core.runnables import Runnable, RunnableSerializable
|
||||||
from langchain_core.utils import get_pydantic_field_names
|
from langchain_core.utils import get_pydantic_field_names
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langchain_core.outputs import LLMResult
|
from langchain_core.outputs import LLMResult
|
||||||
|
|
||||||
|
try:
|
||||||
|
from transformers import GPT2TokenizerFast # type: ignore[import-not-found]
|
||||||
|
|
||||||
|
_HAS_TRANSFORMERS = True
|
||||||
|
except ImportError:
|
||||||
|
_HAS_TRANSFORMERS = False
|
||||||
|
|
||||||
|
|
||||||
class LangSmithParams(TypedDict, total=False):
|
class LangSmithParams(TypedDict, total=False):
|
||||||
"""LangSmith parameters for tracing."""
|
"""LangSmith parameters for tracing."""
|
||||||
@@ -66,15 +78,13 @@ def get_tokenizer() -> Any:
|
|||||||
The GPT-2 tokenizer instance.
|
The GPT-2 tokenizer instance.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
try:
|
if not _HAS_TRANSFORMERS:
|
||||||
from transformers import GPT2TokenizerFast # type: ignore[import-not-found]
|
|
||||||
except ImportError as e:
|
|
||||||
msg = (
|
msg = (
|
||||||
"Could not import transformers python package. "
|
"Could not import transformers python package. "
|
||||||
"This is needed in order to calculate get_token_ids. "
|
"This is needed in order to calculate get_token_ids. "
|
||||||
"Please install it with `pip install transformers`."
|
"Please install it with `pip install transformers`."
|
||||||
)
|
)
|
||||||
raise ImportError(msg) from e
|
raise ImportError(msg)
|
||||||
# create a GPT-2 tokenizer instance
|
# create a GPT-2 tokenizer instance
|
||||||
return GPT2TokenizerFast.from_pretrained("gpt2")
|
return GPT2TokenizerFast.from_pretrained("gpt2")
|
||||||
|
|
||||||
@@ -95,8 +105,6 @@ LanguageModelOutputVar = TypeVar("LanguageModelOutputVar", BaseMessage, str)
|
|||||||
|
|
||||||
|
|
||||||
def _get_verbosity() -> bool:
|
def _get_verbosity() -> bool:
|
||||||
from langchain_core.globals import get_verbose
|
|
||||||
|
|
||||||
return get_verbose()
|
return get_verbose()
|
||||||
|
|
||||||
|
|
||||||
@@ -158,11 +166,6 @@ class BaseLanguageModel(
|
|||||||
@override
|
@override
|
||||||
def InputType(self) -> TypeAlias:
|
def InputType(self) -> TypeAlias:
|
||||||
"""Get the input type for this runnable."""
|
"""Get the input type for this runnable."""
|
||||||
from langchain_core.prompt_values import (
|
|
||||||
ChatPromptValueConcrete,
|
|
||||||
StringPromptValue,
|
|
||||||
)
|
|
||||||
|
|
||||||
# This is a version of LanguageModelInput which replaces the abstract
|
# This is a version of LanguageModelInput which replaces the abstract
|
||||||
# base class BaseMessage with a union of its subclasses, which makes
|
# base class BaseMessage with a union of its subclasses, which makes
|
||||||
# for a much better schema.
|
# for a much better schema.
|
||||||
|
|||||||
@@ -46,6 +46,10 @@ from langchain_core.messages import (
|
|||||||
message_chunk_to_message,
|
message_chunk_to_message,
|
||||||
)
|
)
|
||||||
from langchain_core.messages.ai import _LC_ID_PREFIX
|
from langchain_core.messages.ai import _LC_ID_PREFIX
|
||||||
|
from langchain_core.output_parsers.openai_tools import (
|
||||||
|
JsonOutputKeyToolsParser,
|
||||||
|
PydanticToolsParser,
|
||||||
|
)
|
||||||
from langchain_core.outputs import (
|
from langchain_core.outputs import (
|
||||||
ChatGeneration,
|
ChatGeneration,
|
||||||
ChatGenerationChunk,
|
ChatGenerationChunk,
|
||||||
@@ -1590,11 +1594,6 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
msg = f"Received unsupported arguments {kwargs}"
|
msg = f"Received unsupported arguments {kwargs}"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
from langchain_core.output_parsers.openai_tools import (
|
|
||||||
JsonOutputKeyToolsParser,
|
|
||||||
PydanticToolsParser,
|
|
||||||
)
|
|
||||||
|
|
||||||
if type(self).bind_tools is BaseChatModel.bind_tools:
|
if type(self).bind_tools is BaseChatModel.bind_tools:
|
||||||
msg = "with_structured_output is not implemented for this model."
|
msg = "with_structured_output is not implemented for this model."
|
||||||
raise NotImplementedError(msg)
|
raise NotImplementedError(msg)
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ from typing import Any
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from langchain_core.load.serializable import Serializable, to_json_not_implemented
|
from langchain_core.load.serializable import Serializable, to_json_not_implemented
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
|
from langchain_core.outputs import ChatGeneration
|
||||||
|
|
||||||
|
|
||||||
def default(obj: Any) -> Any:
|
def default(obj: Any) -> Any:
|
||||||
@@ -23,9 +25,6 @@ def default(obj: Any) -> Any:
|
|||||||
|
|
||||||
|
|
||||||
def _dump_pydantic_models(obj: Any) -> Any:
|
def _dump_pydantic_models(obj: Any) -> Any:
|
||||||
from langchain_core.messages import AIMessage
|
|
||||||
from langchain_core.outputs import ChatGeneration
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
isinstance(obj, ChatGeneration)
|
isinstance(obj, ChatGeneration)
|
||||||
and isinstance(obj.message, AIMessage)
|
and isinstance(obj.message, AIMessage)
|
||||||
|
|||||||
@@ -118,7 +118,8 @@ class BaseMessage(Serializable):
|
|||||||
Returns:
|
Returns:
|
||||||
A ChatPromptTemplate containing both messages.
|
A ChatPromptTemplate containing both messages.
|
||||||
"""
|
"""
|
||||||
from langchain_core.prompts.chat import ChatPromptTemplate
|
# Import locally to prevent circular imports.
|
||||||
|
from langchain_core.prompts.chat import ChatPromptTemplate # noqa: PLC0415
|
||||||
|
|
||||||
prompt = ChatPromptTemplate(messages=[self])
|
prompt = ChatPromptTemplate(messages=[self])
|
||||||
return prompt + other
|
return prompt + other
|
||||||
|
|||||||
@@ -42,12 +42,17 @@ from langchain_core.messages.system import SystemMessage, SystemMessageChunk
|
|||||||
from langchain_core.messages.tool import ToolCall, ToolMessage, ToolMessageChunk
|
from langchain_core.messages.tool import ToolCall, ToolMessage, ToolMessageChunk
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langchain_text_splitters import TextSplitter
|
|
||||||
|
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
from langchain_core.prompt_values import PromptValue
|
from langchain_core.prompt_values import PromptValue
|
||||||
from langchain_core.runnables.base import Runnable
|
from langchain_core.runnables.base import Runnable
|
||||||
|
|
||||||
|
try:
|
||||||
|
from langchain_text_splitters import TextSplitter
|
||||||
|
|
||||||
|
_HAS_LANGCHAIN_TEXT_SPLITTERS = True
|
||||||
|
except ImportError:
|
||||||
|
_HAS_LANGCHAIN_TEXT_SPLITTERS = False
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -361,7 +366,7 @@ def convert_to_messages(
|
|||||||
list of messages (BaseMessages).
|
list of messages (BaseMessages).
|
||||||
"""
|
"""
|
||||||
# Import here to avoid circular imports
|
# Import here to avoid circular imports
|
||||||
from langchain_core.prompt_values import PromptValue
|
from langchain_core.prompt_values import PromptValue # noqa: PLC0415
|
||||||
|
|
||||||
if isinstance(messages, PromptValue):
|
if isinstance(messages, PromptValue):
|
||||||
return messages.to_messages()
|
return messages.to_messages()
|
||||||
@@ -386,7 +391,8 @@ def _runnable_support(func: Callable) -> Callable:
|
|||||||
list[BaseMessage],
|
list[BaseMessage],
|
||||||
Runnable[Sequence[MessageLikeRepresentation], list[BaseMessage]],
|
Runnable[Sequence[MessageLikeRepresentation], list[BaseMessage]],
|
||||||
]:
|
]:
|
||||||
from langchain_core.runnables.base import RunnableLambda
|
# Import locally to prevent circular import.
|
||||||
|
from langchain_core.runnables.base import RunnableLambda # noqa: PLC0415
|
||||||
|
|
||||||
if messages is not None:
|
if messages is not None:
|
||||||
return func(messages, **kwargs)
|
return func(messages, **kwargs)
|
||||||
@@ -989,17 +995,12 @@ def trim_messages(
|
|||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
try:
|
if _HAS_LANGCHAIN_TEXT_SPLITTERS and isinstance(text_splitter, TextSplitter):
|
||||||
from langchain_text_splitters import TextSplitter
|
text_splitter_fn = text_splitter.split_text
|
||||||
except ImportError:
|
elif text_splitter:
|
||||||
text_splitter_fn: Optional[Callable] = cast("Optional[Callable]", text_splitter)
|
text_splitter_fn = cast("Callable", text_splitter)
|
||||||
else:
|
else:
|
||||||
if isinstance(text_splitter, TextSplitter):
|
text_splitter_fn = _default_text_splitter
|
||||||
text_splitter_fn = text_splitter.split_text
|
|
||||||
else:
|
|
||||||
text_splitter_fn = text_splitter
|
|
||||||
|
|
||||||
text_splitter_fn = text_splitter_fn or _default_text_splitter
|
|
||||||
|
|
||||||
if strategy == "first":
|
if strategy == "first":
|
||||||
return _first_max_tokens(
|
return _first_max_tokens(
|
||||||
|
|||||||
@@ -15,6 +15,14 @@ from langchain_core.messages import BaseMessage
|
|||||||
from langchain_core.output_parsers.transform import BaseTransformOutputParser
|
from langchain_core.output_parsers.transform import BaseTransformOutputParser
|
||||||
from langchain_core.runnables.utils import AddableDict
|
from langchain_core.runnables.utils import AddableDict
|
||||||
|
|
||||||
|
try:
|
||||||
|
from defusedxml import ElementTree # type: ignore[import-untyped]
|
||||||
|
from defusedxml.ElementTree import XMLParser # type: ignore[import-untyped]
|
||||||
|
|
||||||
|
_HAS_DEFUSEDXML = True
|
||||||
|
except ImportError:
|
||||||
|
_HAS_DEFUSEDXML = False
|
||||||
|
|
||||||
XML_FORMAT_INSTRUCTIONS = """The output should be formatted as a XML file.
|
XML_FORMAT_INSTRUCTIONS = """The output should be formatted as a XML file.
|
||||||
1. Output should conform to the tags below.
|
1. Output should conform to the tags below.
|
||||||
2. If tags are not given, make them on your own.
|
2. If tags are not given, make them on your own.
|
||||||
@@ -50,17 +58,13 @@ class _StreamingParser:
|
|||||||
parser is requested.
|
parser is requested.
|
||||||
"""
|
"""
|
||||||
if parser == "defusedxml":
|
if parser == "defusedxml":
|
||||||
try:
|
if not _HAS_DEFUSEDXML:
|
||||||
from defusedxml.ElementTree import ( # type: ignore[import-untyped]
|
|
||||||
XMLParser,
|
|
||||||
)
|
|
||||||
except ImportError as e:
|
|
||||||
msg = (
|
msg = (
|
||||||
"defusedxml is not installed. "
|
"defusedxml is not installed. "
|
||||||
"Please install it to use the defusedxml parser."
|
"Please install it to use the defusedxml parser."
|
||||||
"You can install it with `pip install defusedxml` "
|
"You can install it with `pip install defusedxml` "
|
||||||
)
|
)
|
||||||
raise ImportError(msg) from e
|
raise ImportError(msg)
|
||||||
parser_ = XMLParser(target=TreeBuilder())
|
parser_ = XMLParser(target=TreeBuilder())
|
||||||
else:
|
else:
|
||||||
parser_ = None
|
parser_ = None
|
||||||
@@ -207,16 +211,14 @@ class XMLOutputParser(BaseTransformOutputParser):
|
|||||||
# Imports are temporarily placed here to avoid issue with caching on CI
|
# Imports are temporarily placed here to avoid issue with caching on CI
|
||||||
# likely if you're reading this you can move them to the top of the file
|
# likely if you're reading this you can move them to the top of the file
|
||||||
if self.parser == "defusedxml":
|
if self.parser == "defusedxml":
|
||||||
try:
|
if not _HAS_DEFUSEDXML:
|
||||||
from defusedxml import ElementTree # type: ignore[import-untyped]
|
|
||||||
except ImportError as e:
|
|
||||||
msg = (
|
msg = (
|
||||||
"defusedxml is not installed. "
|
"defusedxml is not installed. "
|
||||||
"Please install it to use the defusedxml parser."
|
"Please install it to use the defusedxml parser."
|
||||||
"You can install it with `pip install defusedxml`"
|
"You can install it with `pip install defusedxml`"
|
||||||
"See https://github.com/tiran/defusedxml for more details"
|
"See https://github.com/tiran/defusedxml for more details"
|
||||||
)
|
)
|
||||||
raise ImportError(msg) from e
|
raise ImportError(msg)
|
||||||
et = ElementTree # Use the defusedxml parser
|
et = ElementTree # Use the defusedxml parser
|
||||||
else:
|
else:
|
||||||
et = ET # Use the standard library parser
|
et = ET # Use the standard library parser
|
||||||
|
|||||||
@@ -88,7 +88,8 @@ class BaseMessagePromptTemplate(Serializable, ABC):
|
|||||||
Returns:
|
Returns:
|
||||||
Combined prompt template.
|
Combined prompt template.
|
||||||
"""
|
"""
|
||||||
from langchain_core.prompts.chat import ChatPromptTemplate
|
# Import locally to avoid circular import.
|
||||||
|
from langchain_core.prompts.chat import ChatPromptTemplate # noqa: PLC0415
|
||||||
|
|
||||||
prompt = ChatPromptTemplate(messages=[self])
|
prompt = ChatPromptTemplate(messages=[self])
|
||||||
return prompt + other
|
return prompt + other
|
||||||
|
|||||||
@@ -15,6 +15,14 @@ from langchain_core.utils import get_colored_text, mustache
|
|||||||
from langchain_core.utils.formatting import formatter
|
from langchain_core.utils.formatting import formatter
|
||||||
from langchain_core.utils.interactive_env import is_interactive_env
|
from langchain_core.utils.interactive_env import is_interactive_env
|
||||||
|
|
||||||
|
try:
|
||||||
|
from jinja2 import Environment, meta
|
||||||
|
from jinja2.sandbox import SandboxedEnvironment
|
||||||
|
|
||||||
|
_HAS_JINJA2 = True
|
||||||
|
except ImportError:
|
||||||
|
_HAS_JINJA2 = False
|
||||||
|
|
||||||
PromptTemplateFormat = Literal["f-string", "mustache", "jinja2"]
|
PromptTemplateFormat = Literal["f-string", "mustache", "jinja2"]
|
||||||
|
|
||||||
|
|
||||||
@@ -40,9 +48,7 @@ def jinja2_formatter(template: str, /, **kwargs: Any) -> str:
|
|||||||
Raises:
|
Raises:
|
||||||
ImportError: If jinja2 is not installed.
|
ImportError: If jinja2 is not installed.
|
||||||
"""
|
"""
|
||||||
try:
|
if not _HAS_JINJA2:
|
||||||
from jinja2.sandbox import SandboxedEnvironment
|
|
||||||
except ImportError as e:
|
|
||||||
msg = (
|
msg = (
|
||||||
"jinja2 not installed, which is needed to use the jinja2_formatter. "
|
"jinja2 not installed, which is needed to use the jinja2_formatter. "
|
||||||
"Please install it with `pip install jinja2`."
|
"Please install it with `pip install jinja2`."
|
||||||
@@ -50,7 +56,7 @@ def jinja2_formatter(template: str, /, **kwargs: Any) -> str:
|
|||||||
"Do not expand jinja2 templates using unverified or user-controlled "
|
"Do not expand jinja2 templates using unverified or user-controlled "
|
||||||
"inputs as that can result in arbitrary Python code execution."
|
"inputs as that can result in arbitrary Python code execution."
|
||||||
)
|
)
|
||||||
raise ImportError(msg) from e
|
raise ImportError(msg)
|
||||||
|
|
||||||
# This uses a sandboxed environment to prevent arbitrary code execution.
|
# This uses a sandboxed environment to prevent arbitrary code execution.
|
||||||
# Jinja2 uses an opt-out rather than opt-in approach for sand-boxing.
|
# Jinja2 uses an opt-out rather than opt-in approach for sand-boxing.
|
||||||
@@ -88,14 +94,12 @@ def validate_jinja2(template: str, input_variables: list[str]) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def _get_jinja2_variables_from_template(template: str) -> set[str]:
|
def _get_jinja2_variables_from_template(template: str) -> set[str]:
|
||||||
try:
|
if not _HAS_JINJA2:
|
||||||
from jinja2 import Environment, meta
|
|
||||||
except ImportError as e:
|
|
||||||
msg = (
|
msg = (
|
||||||
"jinja2 not installed, which is needed to use the jinja2_formatter. "
|
"jinja2 not installed, which is needed to use the jinja2_formatter. "
|
||||||
"Please install it with `pip install jinja2`."
|
"Please install it with `pip install jinja2`."
|
||||||
)
|
)
|
||||||
raise ImportError(msg) from e
|
raise ImportError(msg)
|
||||||
env = Environment() # noqa: S701
|
env = Environment() # noqa: S701
|
||||||
ast = env.parse(template)
|
ast = env.parse(template)
|
||||||
return meta.find_undeclared_variables(ast)
|
return meta.find_undeclared_variables(ast)
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ from typing_extensions import Self, TypedDict, override
|
|||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.callbacks import Callbacks
|
from langchain_core.callbacks import Callbacks
|
||||||
|
from langchain_core.callbacks.manager import AsyncCallbackManager, CallbackManager
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.runnables import (
|
from langchain_core.runnables import (
|
||||||
Runnable,
|
Runnable,
|
||||||
@@ -236,8 +237,6 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
|||||||
retriever.invoke("query")
|
retriever.invoke("query")
|
||||||
|
|
||||||
"""
|
"""
|
||||||
from langchain_core.callbacks.manager import CallbackManager
|
|
||||||
|
|
||||||
config = ensure_config(config)
|
config = ensure_config(config)
|
||||||
inheritable_metadata = {
|
inheritable_metadata = {
|
||||||
**(config.get("metadata") or {}),
|
**(config.get("metadata") or {}),
|
||||||
@@ -301,8 +300,6 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
|||||||
await retriever.ainvoke("query")
|
await retriever.ainvoke("query")
|
||||||
|
|
||||||
"""
|
"""
|
||||||
from langchain_core.callbacks.manager import AsyncCallbackManager
|
|
||||||
|
|
||||||
config = ensure_config(config)
|
config = ensure_config(config)
|
||||||
inheritable_metadata = {
|
inheritable_metadata = {
|
||||||
**(config.get("metadata") or {}),
|
**(config.get("metadata") or {}),
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ from pydantic import BaseModel, ConfigDict, Field, RootModel
|
|||||||
from typing_extensions import Literal, get_args, override
|
from typing_extensions import Literal, get_args, override
|
||||||
|
|
||||||
from langchain_core._api import beta_decorator
|
from langchain_core._api import beta_decorator
|
||||||
|
from langchain_core.callbacks.manager import AsyncCallbackManager, CallbackManager
|
||||||
from langchain_core.load.serializable import (
|
from langchain_core.load.serializable import (
|
||||||
Serializable,
|
Serializable,
|
||||||
SerializedConstructor,
|
SerializedConstructor,
|
||||||
@@ -60,7 +61,6 @@ from langchain_core.runnables.config import (
|
|||||||
run_in_executor,
|
run_in_executor,
|
||||||
set_config_context,
|
set_config_context,
|
||||||
)
|
)
|
||||||
from langchain_core.runnables.graph import Graph
|
|
||||||
from langchain_core.runnables.utils import (
|
from langchain_core.runnables.utils import (
|
||||||
AddableDict,
|
AddableDict,
|
||||||
AnyConfigurableField,
|
AnyConfigurableField,
|
||||||
@@ -81,6 +81,19 @@ from langchain_core.runnables.utils import (
|
|||||||
is_async_callable,
|
is_async_callable,
|
||||||
is_async_generator,
|
is_async_generator,
|
||||||
)
|
)
|
||||||
|
from langchain_core.tracers._streaming import _StreamingCallbackHandler
|
||||||
|
from langchain_core.tracers.event_stream import (
|
||||||
|
_astream_events_implementation_v1,
|
||||||
|
_astream_events_implementation_v2,
|
||||||
|
)
|
||||||
|
from langchain_core.tracers.log_stream import (
|
||||||
|
LogStreamCallbackHandler,
|
||||||
|
_astream_log_implementation,
|
||||||
|
)
|
||||||
|
from langchain_core.tracers.root_listeners import (
|
||||||
|
AsyncRootListenersTracer,
|
||||||
|
RootListenersTracer,
|
||||||
|
)
|
||||||
from langchain_core.utils.aiter import aclosing, atee, py_anext
|
from langchain_core.utils.aiter import aclosing, atee, py_anext
|
||||||
from langchain_core.utils.iter import safetee
|
from langchain_core.utils.iter import safetee
|
||||||
from langchain_core.utils.pydantic import create_model_v2
|
from langchain_core.utils.pydantic import create_model_v2
|
||||||
@@ -94,6 +107,7 @@ if TYPE_CHECKING:
|
|||||||
from langchain_core.runnables.fallbacks import (
|
from langchain_core.runnables.fallbacks import (
|
||||||
RunnableWithFallbacks as RunnableWithFallbacksT,
|
RunnableWithFallbacks as RunnableWithFallbacksT,
|
||||||
)
|
)
|
||||||
|
from langchain_core.runnables.graph import Graph
|
||||||
from langchain_core.runnables.retry import ExponentialJitterParams
|
from langchain_core.runnables.retry import ExponentialJitterParams
|
||||||
from langchain_core.runnables.schema import StreamEvent
|
from langchain_core.runnables.schema import StreamEvent
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
@@ -565,6 +579,9 @@ class Runnable(ABC, Generic[Input, Output]):
|
|||||||
|
|
||||||
def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph:
|
def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph:
|
||||||
"""Return a graph representation of this ``Runnable``."""
|
"""Return a graph representation of this ``Runnable``."""
|
||||||
|
# Import locally to prevent circular import
|
||||||
|
from langchain_core.runnables.graph import Graph # noqa: PLC0415
|
||||||
|
|
||||||
graph = Graph()
|
graph = Graph()
|
||||||
try:
|
try:
|
||||||
input_node = graph.add_node(self.get_input_schema(config))
|
input_node = graph.add_node(self.get_input_schema(config))
|
||||||
@@ -585,7 +602,8 @@ class Runnable(ABC, Generic[Input, Output]):
|
|||||||
self, config: Optional[RunnableConfig] = None
|
self, config: Optional[RunnableConfig] = None
|
||||||
) -> list[BasePromptTemplate]:
|
) -> list[BasePromptTemplate]:
|
||||||
"""Return a list of prompts used by this ``Runnable``."""
|
"""Return a list of prompts used by this ``Runnable``."""
|
||||||
from langchain_core.prompts.base import BasePromptTemplate
|
# Import locally to prevent circular import
|
||||||
|
from langchain_core.prompts.base import BasePromptTemplate # noqa: PLC0415
|
||||||
|
|
||||||
return [
|
return [
|
||||||
node.data
|
node.data
|
||||||
@@ -747,7 +765,8 @@ class Runnable(ABC, Generic[Input, Output]):
|
|||||||
a new ``Runnable``.
|
a new ``Runnable``.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
from langchain_core.runnables.passthrough import RunnablePick
|
# Import locally to prevent circular import
|
||||||
|
from langchain_core.runnables.passthrough import RunnablePick # noqa: PLC0415
|
||||||
|
|
||||||
return self | RunnablePick(keys)
|
return self | RunnablePick(keys)
|
||||||
|
|
||||||
@@ -798,7 +817,8 @@ class Runnable(ABC, Generic[Input, Output]):
|
|||||||
A new ``Runnable``.
|
A new ``Runnable``.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
from langchain_core.runnables.passthrough import RunnableAssign
|
# Import locally to prevent circular import
|
||||||
|
from langchain_core.runnables.passthrough import RunnableAssign # noqa: PLC0415
|
||||||
|
|
||||||
return self | RunnableAssign(RunnableParallel[dict[str, Any]](kwargs))
|
return self | RunnableAssign(RunnableParallel[dict[str, Any]](kwargs))
|
||||||
|
|
||||||
@@ -1231,11 +1251,6 @@ class Runnable(ABC, Generic[Input, Output]):
|
|||||||
A ``RunLogPatch`` or ``RunLog`` object.
|
A ``RunLogPatch`` or ``RunLog`` object.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
from langchain_core.tracers.log_stream import (
|
|
||||||
LogStreamCallbackHandler,
|
|
||||||
_astream_log_implementation,
|
|
||||||
)
|
|
||||||
|
|
||||||
stream = LogStreamCallbackHandler(
|
stream = LogStreamCallbackHandler(
|
||||||
auto_close=False,
|
auto_close=False,
|
||||||
include_names=include_names,
|
include_names=include_names,
|
||||||
@@ -1489,11 +1504,6 @@ class Runnable(ABC, Generic[Input, Output]):
|
|||||||
NotImplementedError: If the version is not ``'v1'`` or ``'v2'``.
|
NotImplementedError: If the version is not ``'v1'`` or ``'v2'``.
|
||||||
|
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
from langchain_core.tracers.event_stream import (
|
|
||||||
_astream_events_implementation_v1,
|
|
||||||
_astream_events_implementation_v2,
|
|
||||||
)
|
|
||||||
|
|
||||||
if version == "v2":
|
if version == "v2":
|
||||||
event_stream = _astream_events_implementation_v2(
|
event_stream = _astream_events_implementation_v2(
|
||||||
self,
|
self,
|
||||||
@@ -1740,8 +1750,6 @@ class Runnable(ABC, Generic[Input, Output]):
|
|||||||
chain.invoke(2)
|
chain.invoke(2)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
from langchain_core.tracers.root_listeners import RootListenersTracer
|
|
||||||
|
|
||||||
return RunnableBinding(
|
return RunnableBinding(
|
||||||
bound=self,
|
bound=self,
|
||||||
config_factories=[
|
config_factories=[
|
||||||
@@ -1834,8 +1842,6 @@ class Runnable(ABC, Generic[Input, Output]):
|
|||||||
on end callback ends at 2025-03-01T07:05:30.884831+00:00
|
on end callback ends at 2025-03-01T07:05:30.884831+00:00
|
||||||
|
|
||||||
"""
|
"""
|
||||||
from langchain_core.tracers.root_listeners import AsyncRootListenersTracer
|
|
||||||
|
|
||||||
return RunnableBinding(
|
return RunnableBinding(
|
||||||
bound=self,
|
bound=self,
|
||||||
config_factories=[
|
config_factories=[
|
||||||
@@ -1928,7 +1934,8 @@ class Runnable(ABC, Generic[Input, Output]):
|
|||||||
assert count == 2
|
assert count == 2
|
||||||
|
|
||||||
"""
|
"""
|
||||||
from langchain_core.runnables.retry import RunnableRetry
|
# Import locally to prevent circular import
|
||||||
|
from langchain_core.runnables.retry import RunnableRetry # noqa: PLC0415
|
||||||
|
|
||||||
return RunnableRetry(
|
return RunnableRetry(
|
||||||
bound=self,
|
bound=self,
|
||||||
@@ -2030,7 +2037,10 @@ class Runnable(ABC, Generic[Input, Output]):
|
|||||||
fallback in order, upon failures.
|
fallback in order, upon failures.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
from langchain_core.runnables.fallbacks import RunnableWithFallbacks
|
# Import locally to prevent circular import
|
||||||
|
from langchain_core.runnables.fallbacks import ( # noqa: PLC0415
|
||||||
|
RunnableWithFallbacks,
|
||||||
|
)
|
||||||
|
|
||||||
return RunnableWithFallbacks(
|
return RunnableWithFallbacks(
|
||||||
runnable=self,
|
runnable=self,
|
||||||
@@ -2316,9 +2326,6 @@ class Runnable(ABC, Generic[Input, Output]):
|
|||||||
Use this to implement ``stream`` or ``transform`` in ``Runnable`` subclasses.
|
Use this to implement ``stream`` or ``transform`` in ``Runnable`` subclasses.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# Mixin that is used by both astream log and astream events implementation
|
|
||||||
from langchain_core.tracers._streaming import _StreamingCallbackHandler
|
|
||||||
|
|
||||||
# tee the input so we can iterate over it twice
|
# tee the input so we can iterate over it twice
|
||||||
input_for_tracing, input_for_transform = tee(inputs, 2)
|
input_for_tracing, input_for_transform = tee(inputs, 2)
|
||||||
# Start the input iterator to ensure the input Runnable starts before this one
|
# Start the input iterator to ensure the input Runnable starts before this one
|
||||||
@@ -2422,9 +2429,6 @@ class Runnable(ABC, Generic[Input, Output]):
|
|||||||
Use this to implement ``astream`` or ``atransform`` in ``Runnable`` subclasses.
|
Use this to implement ``astream`` or ``atransform`` in ``Runnable`` subclasses.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# Mixin that is used by both astream log and astream events implementation
|
|
||||||
from langchain_core.tracers._streaming import _StreamingCallbackHandler
|
|
||||||
|
|
||||||
# tee the input so we can iterate over it twice
|
# tee the input so we can iterate over it twice
|
||||||
input_for_tracing, input_for_transform = atee(inputs, 2)
|
input_for_tracing, input_for_transform = atee(inputs, 2)
|
||||||
# Start the input iterator to ensure the input Runnable starts before this one
|
# Start the input iterator to ensure the input Runnable starts before this one
|
||||||
@@ -2614,7 +2618,7 @@ class Runnable(ABC, Generic[Input, Output]):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
# Avoid circular import
|
# Avoid circular import
|
||||||
from langchain_core.tools import convert_runnable_to_tool
|
from langchain_core.tools import convert_runnable_to_tool # noqa: PLC0415
|
||||||
|
|
||||||
return convert_runnable_to_tool(
|
return convert_runnable_to_tool(
|
||||||
self,
|
self,
|
||||||
@@ -2690,7 +2694,10 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
from langchain_core.runnables.configurable import RunnableConfigurableFields
|
# Import locally to prevent circular import
|
||||||
|
from langchain_core.runnables.configurable import ( # noqa: PLC0415
|
||||||
|
RunnableConfigurableFields,
|
||||||
|
)
|
||||||
|
|
||||||
model_fields = type(self).model_fields
|
model_fields = type(self).model_fields
|
||||||
for key in kwargs:
|
for key in kwargs:
|
||||||
@@ -2751,7 +2758,8 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
from langchain_core.runnables.configurable import (
|
# Import locally to prevent circular import
|
||||||
|
from langchain_core.runnables.configurable import ( # noqa: PLC0415
|
||||||
RunnableConfigurableAlternatives,
|
RunnableConfigurableAlternatives,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -2767,7 +2775,11 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]):
|
|||||||
def _seq_input_schema(
|
def _seq_input_schema(
|
||||||
steps: list[Runnable[Any, Any]], config: Optional[RunnableConfig]
|
steps: list[Runnable[Any, Any]], config: Optional[RunnableConfig]
|
||||||
) -> type[BaseModel]:
|
) -> type[BaseModel]:
|
||||||
from langchain_core.runnables.passthrough import RunnableAssign, RunnablePick
|
# Import locally to prevent circular import
|
||||||
|
from langchain_core.runnables.passthrough import ( # noqa: PLC0415
|
||||||
|
RunnableAssign,
|
||||||
|
RunnablePick,
|
||||||
|
)
|
||||||
|
|
||||||
first = steps[0]
|
first = steps[0]
|
||||||
if len(steps) == 1:
|
if len(steps) == 1:
|
||||||
@@ -2793,7 +2805,11 @@ def _seq_input_schema(
|
|||||||
def _seq_output_schema(
|
def _seq_output_schema(
|
||||||
steps: list[Runnable[Any, Any]], config: Optional[RunnableConfig]
|
steps: list[Runnable[Any, Any]], config: Optional[RunnableConfig]
|
||||||
) -> type[BaseModel]:
|
) -> type[BaseModel]:
|
||||||
from langchain_core.runnables.passthrough import RunnableAssign, RunnablePick
|
# Import locally to prevent circular import
|
||||||
|
from langchain_core.runnables.passthrough import ( # noqa: PLC0415
|
||||||
|
RunnableAssign,
|
||||||
|
RunnablePick,
|
||||||
|
)
|
||||||
|
|
||||||
last = steps[-1]
|
last = steps[-1]
|
||||||
if len(steps) == 1:
|
if len(steps) == 1:
|
||||||
@@ -3050,7 +3066,8 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
The config specs of the ``Runnable``.
|
The config specs of the ``Runnable``.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
from langchain_core.beta.runnables.context import (
|
# Import locally to prevent circular import
|
||||||
|
from langchain_core.beta.runnables.context import ( # noqa: PLC0415
|
||||||
CONTEXT_CONFIG_PREFIX,
|
CONTEXT_CONFIG_PREFIX,
|
||||||
_key_from_id,
|
_key_from_id,
|
||||||
)
|
)
|
||||||
@@ -3108,7 +3125,8 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
ValueError: If a ``Runnable`` has no first or last node.
|
ValueError: If a ``Runnable`` has no first or last node.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
from langchain_core.runnables.graph import Graph
|
# Import locally to prevent circular import
|
||||||
|
from langchain_core.runnables.graph import Graph # noqa: PLC0415
|
||||||
|
|
||||||
graph = Graph()
|
graph = Graph()
|
||||||
for step in self.steps:
|
for step in self.steps:
|
||||||
@@ -3196,7 +3214,10 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
def invoke(
|
def invoke(
|
||||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||||
) -> Output:
|
) -> Output:
|
||||||
from langchain_core.beta.runnables.context import config_with_context
|
# Import locally to prevent circular import
|
||||||
|
from langchain_core.beta.runnables.context import ( # noqa: PLC0415
|
||||||
|
config_with_context,
|
||||||
|
)
|
||||||
|
|
||||||
# setup callbacks and context
|
# setup callbacks and context
|
||||||
config = config_with_context(ensure_config(config), self.steps)
|
config = config_with_context(ensure_config(config), self.steps)
|
||||||
@@ -3237,7 +3258,10 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> Output:
|
) -> Output:
|
||||||
from langchain_core.beta.runnables.context import aconfig_with_context
|
# Import locally to prevent circular import
|
||||||
|
from langchain_core.beta.runnables.context import ( # noqa: PLC0415
|
||||||
|
aconfig_with_context,
|
||||||
|
)
|
||||||
|
|
||||||
# setup callbacks and context
|
# setup callbacks and context
|
||||||
config = aconfig_with_context(ensure_config(config), self.steps)
|
config = aconfig_with_context(ensure_config(config), self.steps)
|
||||||
@@ -3281,8 +3305,10 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
return_exceptions: bool = False,
|
return_exceptions: bool = False,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> list[Output]:
|
) -> list[Output]:
|
||||||
from langchain_core.beta.runnables.context import config_with_context
|
# Import locally to prevent circular import
|
||||||
from langchain_core.callbacks.manager import CallbackManager
|
from langchain_core.beta.runnables.context import ( # noqa: PLC0415
|
||||||
|
config_with_context,
|
||||||
|
)
|
||||||
|
|
||||||
if not inputs:
|
if not inputs:
|
||||||
return []
|
return []
|
||||||
@@ -3411,8 +3437,10 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
return_exceptions: bool = False,
|
return_exceptions: bool = False,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> list[Output]:
|
) -> list[Output]:
|
||||||
from langchain_core.beta.runnables.context import aconfig_with_context
|
# Import locally to prevent circular import
|
||||||
from langchain_core.callbacks.manager import AsyncCallbackManager
|
from langchain_core.beta.runnables.context import ( # noqa: PLC0415
|
||||||
|
aconfig_with_context,
|
||||||
|
)
|
||||||
|
|
||||||
if not inputs:
|
if not inputs:
|
||||||
return []
|
return []
|
||||||
@@ -3542,7 +3570,10 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
config: RunnableConfig,
|
config: RunnableConfig,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Iterator[Output]:
|
) -> Iterator[Output]:
|
||||||
from langchain_core.beta.runnables.context import config_with_context
|
# Import locally to prevent circular import
|
||||||
|
from langchain_core.beta.runnables.context import ( # noqa: PLC0415
|
||||||
|
config_with_context,
|
||||||
|
)
|
||||||
|
|
||||||
steps = [self.first, *self.middle, self.last]
|
steps = [self.first, *self.middle, self.last]
|
||||||
config = config_with_context(config, self.steps)
|
config = config_with_context(config, self.steps)
|
||||||
@@ -3569,7 +3600,10 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
config: RunnableConfig,
|
config: RunnableConfig,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AsyncIterator[Output]:
|
) -> AsyncIterator[Output]:
|
||||||
from langchain_core.beta.runnables.context import aconfig_with_context
|
# Import locally to prevent circular import
|
||||||
|
from langchain_core.beta.runnables.context import ( # noqa: PLC0415
|
||||||
|
aconfig_with_context,
|
||||||
|
)
|
||||||
|
|
||||||
steps = [self.first, *self.middle, self.last]
|
steps = [self.first, *self.middle, self.last]
|
||||||
config = aconfig_with_context(config, self.steps)
|
config = aconfig_with_context(config, self.steps)
|
||||||
@@ -3882,7 +3916,8 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
|
|||||||
ValueError: If a ``Runnable`` has no first or last node.
|
ValueError: If a ``Runnable`` has no first or last node.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
from langchain_core.runnables.graph import Graph
|
# Import locally to prevent circular import
|
||||||
|
from langchain_core.runnables.graph import Graph # noqa: PLC0415
|
||||||
|
|
||||||
graph = Graph()
|
graph = Graph()
|
||||||
input_node = graph.add_node(self.get_input_schema(config))
|
input_node = graph.add_node(self.get_input_schema(config))
|
||||||
@@ -3918,8 +3953,6 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
|
|||||||
def invoke(
|
def invoke(
|
||||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
from langchain_core.callbacks.manager import CallbackManager
|
|
||||||
|
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
config = ensure_config(config)
|
config = ensure_config(config)
|
||||||
callback_manager = CallbackManager.configure(
|
callback_manager = CallbackManager.configure(
|
||||||
@@ -4767,6 +4800,9 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
@override
|
@override
|
||||||
def get_graph(self, config: RunnableConfig | None = None) -> Graph:
|
def get_graph(self, config: RunnableConfig | None = None) -> Graph:
|
||||||
if deps := self.deps:
|
if deps := self.deps:
|
||||||
|
# Import locally to prevent circular import
|
||||||
|
from langchain_core.runnables.graph import Graph # noqa: PLC0415
|
||||||
|
|
||||||
graph = Graph()
|
graph = Graph()
|
||||||
input_node = graph.add_node(self.get_input_schema(config))
|
input_node = graph.add_node(self.get_input_schema(config))
|
||||||
output_node = graph.add_node(self.get_output_schema(config))
|
output_node = graph.add_node(self.get_output_schema(config))
|
||||||
@@ -6030,7 +6066,6 @@ class RunnableBinding(RunnableBindingBase[Input, Output]): # type: ignore[no-re
|
|||||||
Returns:
|
Returns:
|
||||||
A new ``Runnable`` with the listeners bound.
|
A new ``Runnable`` with the listeners bound.
|
||||||
"""
|
"""
|
||||||
from langchain_core.tracers.root_listeners import RootListenersTracer
|
|
||||||
|
|
||||||
def listener_config_factory(config: RunnableConfig) -> RunnableConfig:
|
def listener_config_factory(config: RunnableConfig) -> RunnableConfig:
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -12,6 +12,10 @@ from typing import (
|
|||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
|
from langchain_core.beta.runnables.context import (
|
||||||
|
CONTEXT_CONFIG_PREFIX,
|
||||||
|
CONTEXT_CONFIG_SUFFIX_SET,
|
||||||
|
)
|
||||||
from langchain_core.runnables.base import (
|
from langchain_core.runnables.base import (
|
||||||
Runnable,
|
Runnable,
|
||||||
RunnableLike,
|
RunnableLike,
|
||||||
@@ -177,11 +181,6 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
|||||||
@property
|
@property
|
||||||
@override
|
@override
|
||||||
def config_specs(self) -> list[ConfigurableFieldSpec]:
|
def config_specs(self) -> list[ConfigurableFieldSpec]:
|
||||||
from langchain_core.beta.runnables.context import (
|
|
||||||
CONTEXT_CONFIG_PREFIX,
|
|
||||||
CONTEXT_CONFIG_SUFFIX_SET,
|
|
||||||
)
|
|
||||||
|
|
||||||
specs = get_unique_config_specs(
|
specs = get_unique_config_specs(
|
||||||
spec
|
spec
|
||||||
for step in (
|
for step in (
|
||||||
|
|||||||
@@ -12,21 +12,22 @@ from contextvars import Context, ContextVar, Token, copy_context
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union, cast
|
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union, cast
|
||||||
|
|
||||||
|
from langsmith.run_helpers import _set_tracing_context, get_tracing_context
|
||||||
from typing_extensions import ParamSpec, TypedDict
|
from typing_extensions import ParamSpec, TypedDict
|
||||||
|
|
||||||
|
from langchain_core.callbacks.manager import AsyncCallbackManager, CallbackManager
|
||||||
from langchain_core.runnables.utils import (
|
from langchain_core.runnables.utils import (
|
||||||
Input,
|
Input,
|
||||||
Output,
|
Output,
|
||||||
accepts_config,
|
accepts_config,
|
||||||
accepts_run_manager,
|
accepts_run_manager,
|
||||||
)
|
)
|
||||||
|
from langchain_core.tracers.langchain import LangChainTracer
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langchain_core.callbacks.base import BaseCallbackManager, Callbacks
|
from langchain_core.callbacks.base import BaseCallbackManager, Callbacks
|
||||||
from langchain_core.callbacks.manager import (
|
from langchain_core.callbacks.manager import (
|
||||||
AsyncCallbackManager,
|
|
||||||
AsyncCallbackManagerForChainRun,
|
AsyncCallbackManagerForChainRun,
|
||||||
CallbackManager,
|
|
||||||
CallbackManagerForChainRun,
|
CallbackManagerForChainRun,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -129,8 +130,6 @@ def _set_config_context(
|
|||||||
Returns:
|
Returns:
|
||||||
The token to reset the config and the previous tracing context.
|
The token to reset the config and the previous tracing context.
|
||||||
"""
|
"""
|
||||||
from langchain_core.tracers.langchain import LangChainTracer
|
|
||||||
|
|
||||||
config_token = var_child_runnable_config.set(config)
|
config_token = var_child_runnable_config.set(config)
|
||||||
current_context = None
|
current_context = None
|
||||||
if (
|
if (
|
||||||
@@ -150,8 +149,6 @@ def _set_config_context(
|
|||||||
)
|
)
|
||||||
and (run := tracer.run_map.get(str(parent_run_id)))
|
and (run := tracer.run_map.get(str(parent_run_id)))
|
||||||
):
|
):
|
||||||
from langsmith.run_helpers import _set_tracing_context, get_tracing_context
|
|
||||||
|
|
||||||
current_context = get_tracing_context()
|
current_context = get_tracing_context()
|
||||||
_set_tracing_context({"parent": run})
|
_set_tracing_context({"parent": run})
|
||||||
return config_token, current_context
|
return config_token, current_context
|
||||||
@@ -167,8 +164,6 @@ def set_config_context(config: RunnableConfig) -> Generator[Context, None, None]
|
|||||||
Yields:
|
Yields:
|
||||||
The config context.
|
The config context.
|
||||||
"""
|
"""
|
||||||
from langsmith.run_helpers import _set_tracing_context
|
|
||||||
|
|
||||||
ctx = copy_context()
|
ctx = copy_context()
|
||||||
config_token, _ = ctx.run(_set_config_context, config)
|
config_token, _ = ctx.run(_set_config_context, config)
|
||||||
try:
|
try:
|
||||||
@@ -481,8 +476,6 @@ def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager:
|
|||||||
Returns:
|
Returns:
|
||||||
CallbackManager: The callback manager.
|
CallbackManager: The callback manager.
|
||||||
"""
|
"""
|
||||||
from langchain_core.callbacks.manager import CallbackManager
|
|
||||||
|
|
||||||
return CallbackManager.configure(
|
return CallbackManager.configure(
|
||||||
inheritable_callbacks=config.get("callbacks"),
|
inheritable_callbacks=config.get("callbacks"),
|
||||||
inheritable_tags=config.get("tags"),
|
inheritable_tags=config.get("tags"),
|
||||||
@@ -501,8 +494,6 @@ def get_async_callback_manager_for_config(
|
|||||||
Returns:
|
Returns:
|
||||||
AsyncCallbackManager: The async callback manager.
|
AsyncCallbackManager: The async callback manager.
|
||||||
"""
|
"""
|
||||||
from langchain_core.callbacks.manager import AsyncCallbackManager
|
|
||||||
|
|
||||||
return AsyncCallbackManager.configure(
|
return AsyncCallbackManager.configure(
|
||||||
inheritable_callbacks=config.get("callbacks"),
|
inheritable_callbacks=config.get("callbacks"),
|
||||||
inheritable_tags=config.get("tags"),
|
inheritable_tags=config.get("tags"),
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
|||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
|
from langchain_core.callbacks.manager import AsyncCallbackManager, CallbackManager
|
||||||
from langchain_core.runnables.base import Runnable, RunnableSerializable
|
from langchain_core.runnables.base import Runnable, RunnableSerializable
|
||||||
from langchain_core.runnables.config import (
|
from langchain_core.runnables.config import (
|
||||||
RunnableConfig,
|
RunnableConfig,
|
||||||
@@ -272,8 +273,6 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
|||||||
return_exceptions: bool = False,
|
return_exceptions: bool = False,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> list[Output]:
|
) -> list[Output]:
|
||||||
from langchain_core.callbacks.manager import CallbackManager
|
|
||||||
|
|
||||||
if self.exception_key is not None and not all(
|
if self.exception_key is not None and not all(
|
||||||
isinstance(input_, dict) for input_ in inputs
|
isinstance(input_, dict) for input_ in inputs
|
||||||
):
|
):
|
||||||
@@ -366,8 +365,6 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
|||||||
return_exceptions: bool = False,
|
return_exceptions: bool = False,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> list[Output]:
|
) -> list[Output]:
|
||||||
from langchain_core.callbacks.manager import AsyncCallbackManager
|
|
||||||
|
|
||||||
if self.exception_key is not None and not all(
|
if self.exception_key is not None and not all(
|
||||||
isinstance(input_, dict) for input_ in inputs
|
isinstance(input_, dict) for input_ in inputs
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -19,6 +19,8 @@ from typing import (
|
|||||||
)
|
)
|
||||||
from uuid import UUID, uuid4
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
|
from langchain_core.load.serializable import to_json_not_implemented
|
||||||
|
from langchain_core.runnables.base import Runnable, RunnableSerializable
|
||||||
from langchain_core.utils.pydantic import _IgnoreUnserializable, is_basemodel_subclass
|
from langchain_core.utils.pydantic import _IgnoreUnserializable, is_basemodel_subclass
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -191,8 +193,6 @@ def node_data_str(
|
|||||||
Returns:
|
Returns:
|
||||||
A string representation of the data.
|
A string representation of the data.
|
||||||
"""
|
"""
|
||||||
from langchain_core.runnables.base import Runnable
|
|
||||||
|
|
||||||
if not is_uuid(id) or data is None:
|
if not is_uuid(id) or data is None:
|
||||||
return id
|
return id
|
||||||
data_str = data.get_name() if isinstance(data, Runnable) else data.__name__
|
data_str = data.get_name() if isinstance(data, Runnable) else data.__name__
|
||||||
@@ -212,9 +212,6 @@ def node_data_json(
|
|||||||
Returns:
|
Returns:
|
||||||
A dictionary with the type of the data and the data itself.
|
A dictionary with the type of the data and the data itself.
|
||||||
"""
|
"""
|
||||||
from langchain_core.load.serializable import to_json_not_implemented
|
|
||||||
from langchain_core.runnables.base import Runnable, RunnableSerializable
|
|
||||||
|
|
||||||
if node.data is None:
|
if node.data is None:
|
||||||
json: dict[str, Any] = {}
|
json: dict[str, Any] = {}
|
||||||
elif isinstance(node.data, RunnableSerializable):
|
elif isinstance(node.data, RunnableSerializable):
|
||||||
@@ -518,7 +515,8 @@ class Graph:
|
|||||||
Returns:
|
Returns:
|
||||||
The ASCII art string.
|
The ASCII art string.
|
||||||
"""
|
"""
|
||||||
from langchain_core.runnables.graph_ascii import draw_ascii
|
# Import locally to prevent circular import
|
||||||
|
from langchain_core.runnables.graph_ascii import draw_ascii # noqa: PLC0415
|
||||||
|
|
||||||
return draw_ascii(
|
return draw_ascii(
|
||||||
{node.id: node.name for node in self.nodes.values()},
|
{node.id: node.name for node in self.nodes.values()},
|
||||||
@@ -562,7 +560,8 @@ class Graph:
|
|||||||
Returns:
|
Returns:
|
||||||
The PNG image as bytes if output_file_path is None, None otherwise.
|
The PNG image as bytes if output_file_path is None, None otherwise.
|
||||||
"""
|
"""
|
||||||
from langchain_core.runnables.graph_png import PngDrawer
|
# Import locally to prevent circular import
|
||||||
|
from langchain_core.runnables.graph_png import PngDrawer # noqa: PLC0415
|
||||||
|
|
||||||
default_node_labels = {node.id: node.name for node in self.nodes.values()}
|
default_node_labels = {node.id: node.name for node in self.nodes.values()}
|
||||||
|
|
||||||
@@ -617,7 +616,8 @@ class Graph:
|
|||||||
The Mermaid syntax string.
|
The Mermaid syntax string.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
from langchain_core.runnables.graph_mermaid import draw_mermaid
|
# Import locally to prevent circular import
|
||||||
|
from langchain_core.runnables.graph_mermaid import draw_mermaid # noqa: PLC0415
|
||||||
|
|
||||||
graph = self.reid()
|
graph = self.reid()
|
||||||
first_node = graph.first_node()
|
first_node = graph.first_node()
|
||||||
@@ -688,7 +688,10 @@ class Graph:
|
|||||||
The PNG image as bytes.
|
The PNG image as bytes.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
from langchain_core.runnables.graph_mermaid import draw_mermaid_png
|
# Import locally to prevent circular import
|
||||||
|
from langchain_core.runnables.graph_mermaid import ( # noqa: PLC0415
|
||||||
|
draw_mermaid_png,
|
||||||
|
)
|
||||||
|
|
||||||
mermaid_syntax = self.draw_mermaid(
|
mermaid_syntax = self.draw_mermaid(
|
||||||
curve_style=curve_style,
|
curve_style=curve_style,
|
||||||
|
|||||||
@@ -3,12 +3,24 @@
|
|||||||
Adapted from https://github.com/iterative/dvc/blob/main/dvc/dagascii.py.
|
Adapted from https://github.com/iterative/dvc/blob/main/dvc/dagascii.py.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from typing import Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from langchain_core.runnables.graph import Edge as LangEdge
|
try:
|
||||||
|
from grandalf.graphs import Edge, Graph, Vertex # type: ignore[import-untyped]
|
||||||
|
from grandalf.layouts import SugiyamaLayout # type: ignore[import-untyped]
|
||||||
|
from grandalf.routing import route_with_lines # type: ignore[import-untyped]
|
||||||
|
|
||||||
|
_HAS_GRANDALF = True
|
||||||
|
except ImportError:
|
||||||
|
_HAS_GRANDALF = False
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from langchain_core.runnables.graph import Edge as LangEdge
|
||||||
|
|
||||||
|
|
||||||
class VertexViewer:
|
class VertexViewer:
|
||||||
@@ -185,13 +197,9 @@ class _EdgeViewer:
|
|||||||
def _build_sugiyama_layout(
|
def _build_sugiyama_layout(
|
||||||
vertices: Mapping[str, str], edges: Sequence[LangEdge]
|
vertices: Mapping[str, str], edges: Sequence[LangEdge]
|
||||||
) -> Any:
|
) -> Any:
|
||||||
try:
|
if not _HAS_GRANDALF:
|
||||||
from grandalf.graphs import Edge, Graph, Vertex # type: ignore[import-untyped]
|
|
||||||
from grandalf.layouts import SugiyamaLayout # type: ignore[import-untyped]
|
|
||||||
from grandalf.routing import route_with_lines # type: ignore[import-untyped]
|
|
||||||
except ImportError as exc:
|
|
||||||
msg = "Install grandalf to draw graphs: `pip install grandalf`."
|
msg = "Install grandalf to draw graphs: `pip install grandalf`."
|
||||||
raise ImportError(msg) from exc
|
raise ImportError(msg)
|
||||||
|
|
||||||
#
|
#
|
||||||
# Just a reminder about naming conventions:
|
# Just a reminder about naming conventions:
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
"""Mermaid graph drawing utilities."""
|
"""Mermaid graph drawing utilities."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
import random
|
import random
|
||||||
@@ -7,18 +9,34 @@ import re
|
|||||||
import time
|
import time
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Literal, Optional
|
from typing import TYPE_CHECKING, Any, Literal, Optional
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from langchain_core.runnables.graph import (
|
from langchain_core.runnables.graph import (
|
||||||
CurveStyle,
|
CurveStyle,
|
||||||
Edge,
|
|
||||||
MermaidDrawMethod,
|
MermaidDrawMethod,
|
||||||
Node,
|
|
||||||
NodeStyles,
|
NodeStyles,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from langchain_core.runnables.graph import Edge, Node
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
import requests
|
||||||
|
|
||||||
|
_HAS_REQUESTS = True
|
||||||
|
except ImportError:
|
||||||
|
_HAS_REQUESTS = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
from pyppeteer import launch # type: ignore[import-not-found]
|
||||||
|
|
||||||
|
_HAS_PYPPETEER = True
|
||||||
|
except ImportError:
|
||||||
|
_HAS_PYPPETEER = False
|
||||||
|
|
||||||
MARKDOWN_SPECIAL_CHARS = "*_`"
|
MARKDOWN_SPECIAL_CHARS = "*_`"
|
||||||
|
|
||||||
|
|
||||||
@@ -283,8 +301,6 @@ def draw_mermaid_png(
|
|||||||
ValueError: If an invalid draw method is provided.
|
ValueError: If an invalid draw method is provided.
|
||||||
"""
|
"""
|
||||||
if draw_method == MermaidDrawMethod.PYPPETEER:
|
if draw_method == MermaidDrawMethod.PYPPETEER:
|
||||||
import asyncio
|
|
||||||
|
|
||||||
img_bytes = asyncio.run(
|
img_bytes = asyncio.run(
|
||||||
_render_mermaid_using_pyppeteer(
|
_render_mermaid_using_pyppeteer(
|
||||||
mermaid_syntax, output_file_path, background_color, padding
|
mermaid_syntax, output_file_path, background_color, padding
|
||||||
@@ -317,11 +333,9 @@ async def _render_mermaid_using_pyppeteer(
|
|||||||
device_scale_factor: int = 3,
|
device_scale_factor: int = 3,
|
||||||
) -> bytes:
|
) -> bytes:
|
||||||
"""Renders Mermaid graph using Pyppeteer."""
|
"""Renders Mermaid graph using Pyppeteer."""
|
||||||
try:
|
if not _HAS_PYPPETEER:
|
||||||
from pyppeteer import launch # type: ignore[import-not-found]
|
|
||||||
except ImportError as e:
|
|
||||||
msg = "Install Pyppeteer to use the Pyppeteer method: `pip install pyppeteer`."
|
msg = "Install Pyppeteer to use the Pyppeteer method: `pip install pyppeteer`."
|
||||||
raise ImportError(msg) from e
|
raise ImportError(msg)
|
||||||
|
|
||||||
browser = await launch()
|
browser = await launch()
|
||||||
page = await browser.newPage()
|
page = await browser.newPage()
|
||||||
@@ -392,14 +406,12 @@ def _render_mermaid_using_api(
|
|||||||
retry_delay: float = 1.0,
|
retry_delay: float = 1.0,
|
||||||
) -> bytes:
|
) -> bytes:
|
||||||
"""Renders Mermaid graph using the Mermaid.INK API."""
|
"""Renders Mermaid graph using the Mermaid.INK API."""
|
||||||
try:
|
if not _HAS_REQUESTS:
|
||||||
import requests
|
|
||||||
except ImportError as e:
|
|
||||||
msg = (
|
msg = (
|
||||||
"Install the `requests` module to use the Mermaid.INK API: "
|
"Install the `requests` module to use the Mermaid.INK API: "
|
||||||
"`pip install requests`."
|
"`pip install requests`."
|
||||||
)
|
)
|
||||||
raise ImportError(msg) from e
|
raise ImportError(msg)
|
||||||
|
|
||||||
# Use Mermaid API to render the image
|
# Use Mermaid API to render the image
|
||||||
mermaid_syntax_encoded = base64.b64encode(mermaid_syntax.encode("utf8")).decode(
|
mermaid_syntax_encoded = base64.b64encode(mermaid_syntax.encode("utf8")).decode(
|
||||||
|
|||||||
@@ -4,6 +4,13 @@ from typing import Any, Optional
|
|||||||
|
|
||||||
from langchain_core.runnables.graph import Graph, LabelsDict
|
from langchain_core.runnables.graph import Graph, LabelsDict
|
||||||
|
|
||||||
|
try:
|
||||||
|
import pygraphviz as pgv # type: ignore[import-not-found]
|
||||||
|
|
||||||
|
_HAS_PYGRAPHVIZ = True
|
||||||
|
except ImportError:
|
||||||
|
_HAS_PYGRAPHVIZ = False
|
||||||
|
|
||||||
|
|
||||||
class PngDrawer:
|
class PngDrawer:
|
||||||
"""Helper class to draw a state graph into a PNG file.
|
"""Helper class to draw a state graph into a PNG file.
|
||||||
@@ -125,11 +132,9 @@ class PngDrawer:
|
|||||||
Returns:
|
Returns:
|
||||||
The PNG bytes if ``output_path`` is None, else None.
|
The PNG bytes if ``output_path`` is None, else None.
|
||||||
"""
|
"""
|
||||||
try:
|
if not _HAS_PYGRAPHVIZ:
|
||||||
import pygraphviz as pgv # type: ignore[import-not-found]
|
|
||||||
except ImportError as exc:
|
|
||||||
msg = "Install pygraphviz to draw graphs: `pip install pygraphviz`."
|
msg = "Install pygraphviz to draw graphs: `pip install pygraphviz`."
|
||||||
raise ImportError(msg) from exc
|
raise ImportError(msg)
|
||||||
|
|
||||||
# Create a directed graph
|
# Create a directed graph
|
||||||
viz = pgv.AGraph(directed=True, nodesep=0.9, ranksep=1.0)
|
viz = pgv.AGraph(directed=True, nodesep=0.9, ranksep=1.0)
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from typing_extensions import override
|
|||||||
|
|
||||||
from langchain_core.chat_history import BaseChatMessageHistory
|
from langchain_core.chat_history import BaseChatMessageHistory
|
||||||
from langchain_core.load.load import load
|
from langchain_core.load.load import load
|
||||||
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||||
from langchain_core.runnables.base import Runnable, RunnableBindingBase, RunnableLambda
|
from langchain_core.runnables.base import Runnable, RunnableBindingBase, RunnableLambda
|
||||||
from langchain_core.runnables.passthrough import RunnablePassthrough
|
from langchain_core.runnables.passthrough import RunnablePassthrough
|
||||||
from langchain_core.runnables.utils import (
|
from langchain_core.runnables.utils import (
|
||||||
@@ -29,7 +30,6 @@ from langchain_core.utils.pydantic import create_model_v2
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langchain_core.language_models.base import LanguageModelLike
|
from langchain_core.language_models.base import LanguageModelLike
|
||||||
from langchain_core.messages.base import BaseMessage
|
|
||||||
from langchain_core.runnables.config import RunnableConfig
|
from langchain_core.runnables.config import RunnableConfig
|
||||||
from langchain_core.tracers.schemas import Run
|
from langchain_core.tracers.schemas import Run
|
||||||
|
|
||||||
@@ -384,8 +384,6 @@ class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef]
|
|||||||
def get_input_schema(
|
def get_input_schema(
|
||||||
self, config: Optional[RunnableConfig] = None
|
self, config: Optional[RunnableConfig] = None
|
||||||
) -> type[BaseModel]:
|
) -> type[BaseModel]:
|
||||||
from langchain_core.messages import BaseMessage
|
|
||||||
|
|
||||||
fields: dict = {}
|
fields: dict = {}
|
||||||
if self.input_messages_key and self.history_messages_key:
|
if self.input_messages_key and self.history_messages_key:
|
||||||
fields[self.input_messages_key] = (
|
fields[self.input_messages_key] = (
|
||||||
@@ -447,8 +445,6 @@ class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef]
|
|||||||
def _get_input_messages(
|
def _get_input_messages(
|
||||||
self, input_val: Union[str, BaseMessage, Sequence[BaseMessage], dict]
|
self, input_val: Union[str, BaseMessage, Sequence[BaseMessage], dict]
|
||||||
) -> list[BaseMessage]:
|
) -> list[BaseMessage]:
|
||||||
from langchain_core.messages import BaseMessage
|
|
||||||
|
|
||||||
# If dictionary, try to pluck the single key representing messages
|
# If dictionary, try to pluck the single key representing messages
|
||||||
if isinstance(input_val, dict):
|
if isinstance(input_val, dict):
|
||||||
if self.input_messages_key:
|
if self.input_messages_key:
|
||||||
@@ -461,8 +457,6 @@ class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef]
|
|||||||
|
|
||||||
# If value is a string, convert to a human message
|
# If value is a string, convert to a human message
|
||||||
if isinstance(input_val, str):
|
if isinstance(input_val, str):
|
||||||
from langchain_core.messages import HumanMessage
|
|
||||||
|
|
||||||
return [HumanMessage(content=input_val)]
|
return [HumanMessage(content=input_val)]
|
||||||
# If value is a single message, convert to a list
|
# If value is a single message, convert to a list
|
||||||
if isinstance(input_val, BaseMessage):
|
if isinstance(input_val, BaseMessage):
|
||||||
@@ -489,8 +483,6 @@ class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef]
|
|||||||
def _get_output_messages(
|
def _get_output_messages(
|
||||||
self, output_val: Union[str, BaseMessage, Sequence[BaseMessage], dict]
|
self, output_val: Union[str, BaseMessage, Sequence[BaseMessage], dict]
|
||||||
) -> list[BaseMessage]:
|
) -> list[BaseMessage]:
|
||||||
from langchain_core.messages import BaseMessage
|
|
||||||
|
|
||||||
# If dictionary, try to pluck the single key representing messages
|
# If dictionary, try to pluck the single key representing messages
|
||||||
if isinstance(output_val, dict):
|
if isinstance(output_val, dict):
|
||||||
if self.output_messages_key:
|
if self.output_messages_key:
|
||||||
@@ -507,8 +499,6 @@ class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef]
|
|||||||
output_val = output_val[key]
|
output_val = output_val[key]
|
||||||
|
|
||||||
if isinstance(output_val, str):
|
if isinstance(output_val, str):
|
||||||
from langchain_core.messages import AIMessage
|
|
||||||
|
|
||||||
return [AIMessage(content=output_val)]
|
return [AIMessage(content=output_val)]
|
||||||
# If value is a single message, convert to a list
|
# If value is a single message, convert to a list
|
||||||
if isinstance(output_val, BaseMessage):
|
if isinstance(output_val, BaseMessage):
|
||||||
|
|||||||
@@ -4,13 +4,15 @@ sys_info prints information about the system and langchain packages for
|
|||||||
debugging purposes.
|
debugging purposes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import pkgutil
|
||||||
|
import platform
|
||||||
|
import sys
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
from importlib import metadata, util
|
||||||
|
|
||||||
|
|
||||||
def _get_sub_deps(packages: Sequence[str]) -> list[str]:
|
def _get_sub_deps(packages: Sequence[str]) -> list[str]:
|
||||||
"""Get any specified sub-dependencies."""
|
"""Get any specified sub-dependencies."""
|
||||||
from importlib import metadata
|
|
||||||
|
|
||||||
sub_deps = set()
|
sub_deps = set()
|
||||||
underscored_packages = {pkg.replace("-", "_") for pkg in packages}
|
underscored_packages = {pkg.replace("-", "_") for pkg in packages}
|
||||||
|
|
||||||
@@ -37,11 +39,6 @@ def print_sys_info(*, additional_pkgs: Sequence[str] = ()) -> None:
|
|||||||
Args:
|
Args:
|
||||||
additional_pkgs: Additional packages to include in the output.
|
additional_pkgs: Additional packages to include in the output.
|
||||||
"""
|
"""
|
||||||
import pkgutil
|
|
||||||
import platform
|
|
||||||
import sys
|
|
||||||
from importlib import metadata, util
|
|
||||||
|
|
||||||
# Packages that do not start with "langchain" prefix.
|
# Packages that do not start with "langchain" prefix.
|
||||||
other_langchain_packages = [
|
other_langchain_packages = [
|
||||||
"langserve",
|
"langserve",
|
||||||
|
|||||||
@@ -18,13 +18,14 @@ from uuid import UUID, uuid4
|
|||||||
|
|
||||||
from typing_extensions import NotRequired, override
|
from typing_extensions import NotRequired, override
|
||||||
|
|
||||||
from langchain_core.callbacks.base import AsyncCallbackHandler
|
from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackManager
|
||||||
from langchain_core.messages import AIMessageChunk, BaseMessage, BaseMessageChunk
|
from langchain_core.messages import AIMessageChunk, BaseMessage, BaseMessageChunk
|
||||||
from langchain_core.outputs import (
|
from langchain_core.outputs import (
|
||||||
ChatGenerationChunk,
|
ChatGenerationChunk,
|
||||||
GenerationChunk,
|
GenerationChunk,
|
||||||
LLMResult,
|
LLMResult,
|
||||||
)
|
)
|
||||||
|
from langchain_core.runnables import ensure_config
|
||||||
from langchain_core.runnables.schema import (
|
from langchain_core.runnables.schema import (
|
||||||
CustomStreamEvent,
|
CustomStreamEvent,
|
||||||
EventData,
|
EventData,
|
||||||
@@ -37,6 +38,11 @@ from langchain_core.runnables.utils import (
|
|||||||
_RootEventFilter,
|
_RootEventFilter,
|
||||||
)
|
)
|
||||||
from langchain_core.tracers._streaming import _StreamingCallbackHandler
|
from langchain_core.tracers._streaming import _StreamingCallbackHandler
|
||||||
|
from langchain_core.tracers.log_stream import (
|
||||||
|
LogStreamCallbackHandler,
|
||||||
|
RunLog,
|
||||||
|
_astream_log_implementation,
|
||||||
|
)
|
||||||
from langchain_core.tracers.memory_stream import _MemoryStream
|
from langchain_core.tracers.memory_stream import _MemoryStream
|
||||||
from langchain_core.utils.aiter import aclosing, py_anext
|
from langchain_core.utils.aiter import aclosing, py_anext
|
||||||
|
|
||||||
@@ -769,14 +775,6 @@ async def _astream_events_implementation_v1(
|
|||||||
exclude_tags: Optional[Sequence[str]] = None,
|
exclude_tags: Optional[Sequence[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AsyncIterator[StandardStreamEvent]:
|
) -> AsyncIterator[StandardStreamEvent]:
|
||||||
from langchain_core.runnables import ensure_config
|
|
||||||
from langchain_core.runnables.utils import _RootEventFilter
|
|
||||||
from langchain_core.tracers.log_stream import (
|
|
||||||
LogStreamCallbackHandler,
|
|
||||||
RunLog,
|
|
||||||
_astream_log_implementation,
|
|
||||||
)
|
|
||||||
|
|
||||||
stream = LogStreamCallbackHandler(
|
stream = LogStreamCallbackHandler(
|
||||||
auto_close=False,
|
auto_close=False,
|
||||||
include_names=include_names,
|
include_names=include_names,
|
||||||
@@ -954,9 +952,6 @@ async def _astream_events_implementation_v2(
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AsyncIterator[StandardStreamEvent]:
|
) -> AsyncIterator[StandardStreamEvent]:
|
||||||
"""Implementation of the astream events API for V2 runnables."""
|
"""Implementation of the astream events API for V2 runnables."""
|
||||||
from langchain_core.callbacks.base import BaseCallbackManager
|
|
||||||
from langchain_core.runnables import ensure_config
|
|
||||||
|
|
||||||
event_streamer = _AstreamEventsCallbackHandler(
|
event_streamer = _AstreamEventsCallbackHandler(
|
||||||
include_names=include_names,
|
include_names=include_names,
|
||||||
include_types=include_types,
|
include_types=include_types,
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import contextlib
|
|||||||
import copy
|
import copy
|
||||||
import threading
|
import threading
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from pprint import pformat
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
@@ -20,10 +21,11 @@ from typing import (
|
|||||||
import jsonpatch # type: ignore[import-untyped]
|
import jsonpatch # type: ignore[import-untyped]
|
||||||
from typing_extensions import NotRequired, TypedDict, override
|
from typing_extensions import NotRequired, TypedDict, override
|
||||||
|
|
||||||
|
from langchain_core.callbacks.base import BaseCallbackManager
|
||||||
from langchain_core.load import dumps
|
from langchain_core.load import dumps
|
||||||
from langchain_core.load.load import load
|
from langchain_core.load.load import load
|
||||||
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
|
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
|
||||||
from langchain_core.runnables import Runnable, RunnableConfig, ensure_config
|
from langchain_core.runnables import RunnableConfig, ensure_config
|
||||||
from langchain_core.tracers._streaming import _StreamingCallbackHandler
|
from langchain_core.tracers._streaming import _StreamingCallbackHandler
|
||||||
from langchain_core.tracers.base import BaseTracer
|
from langchain_core.tracers.base import BaseTracer
|
||||||
from langchain_core.tracers.memory_stream import _MemoryStream
|
from langchain_core.tracers.memory_stream import _MemoryStream
|
||||||
@@ -32,6 +34,7 @@ if TYPE_CHECKING:
|
|||||||
from collections.abc import AsyncIterator, Iterator, Sequence
|
from collections.abc import AsyncIterator, Iterator, Sequence
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
|
from langchain_core.runnables import Runnable
|
||||||
from langchain_core.runnables.utils import Input, Output
|
from langchain_core.runnables.utils import Input, Output
|
||||||
from langchain_core.tracers.schemas import Run
|
from langchain_core.tracers.schemas import Run
|
||||||
|
|
||||||
@@ -131,8 +134,6 @@ class RunLogPatch:
|
|||||||
|
|
||||||
@override
|
@override
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
from pprint import pformat
|
|
||||||
|
|
||||||
# 1:-1 to get rid of the [] around the list
|
# 1:-1 to get rid of the [] around the list
|
||||||
return f"RunLogPatch({pformat(self.ops)[1:-1]})"
|
return f"RunLogPatch({pformat(self.ops)[1:-1]})"
|
||||||
|
|
||||||
@@ -181,8 +182,6 @@ class RunLog(RunLogPatch):
|
|||||||
|
|
||||||
@override
|
@override
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
from pprint import pformat
|
|
||||||
|
|
||||||
return f"RunLog({pformat(self.state)})"
|
return f"RunLog({pformat(self.state)})"
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@@ -672,14 +671,6 @@ async def _astream_log_implementation(
|
|||||||
Yields:
|
Yields:
|
||||||
The run log patches or states, depending on the value of ``diff``.
|
The run log patches or states, depending on the value of ``diff``.
|
||||||
"""
|
"""
|
||||||
import jsonpatch
|
|
||||||
|
|
||||||
from langchain_core.callbacks.base import BaseCallbackManager
|
|
||||||
from langchain_core.tracers.log_stream import (
|
|
||||||
RunLog,
|
|
||||||
RunLogPatch,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assign the stream handler to the config
|
# Assign the stream handler to the config
|
||||||
config = ensure_config(config)
|
config = ensure_config(config)
|
||||||
callbacks = config.get("callbacks")
|
callbacks = config.get("callbacks")
|
||||||
|
|||||||
@@ -21,8 +21,10 @@ from typing import (
|
|||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from pydantic.v1 import BaseModel as BaseModelV1
|
from pydantic.v1 import BaseModel as BaseModelV1
|
||||||
|
from pydantic.v1 import Field, create_model
|
||||||
from typing_extensions import TypedDict, get_args, get_origin, is_typeddict
|
from typing_extensions import TypedDict, get_args, get_origin, is_typeddict
|
||||||
|
|
||||||
|
import langchain_core
|
||||||
from langchain_core._api import beta, deprecated
|
from langchain_core._api import beta, deprecated
|
||||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
|
||||||
from langchain_core.utils.json_schema import dereference_refs
|
from langchain_core.utils.json_schema import dereference_refs
|
||||||
@@ -220,10 +222,8 @@ def _convert_python_function_to_openai_function(
|
|||||||
Returns:
|
Returns:
|
||||||
The OpenAI function description.
|
The OpenAI function description.
|
||||||
"""
|
"""
|
||||||
from langchain_core.tools.base import create_schema_from_function
|
|
||||||
|
|
||||||
func_name = _get_python_function_name(function)
|
func_name = _get_python_function_name(function)
|
||||||
model = create_schema_from_function(
|
model = langchain_core.tools.base.create_schema_from_function(
|
||||||
func_name,
|
func_name,
|
||||||
function,
|
function,
|
||||||
filter_args=(),
|
filter_args=(),
|
||||||
@@ -264,9 +264,6 @@ def _convert_any_typed_dicts_to_pydantic(
|
|||||||
visited: dict,
|
visited: dict,
|
||||||
depth: int = 0,
|
depth: int = 0,
|
||||||
) -> type:
|
) -> type:
|
||||||
from pydantic.v1 import Field as Field_v1
|
|
||||||
from pydantic.v1 import create_model as create_model_v1
|
|
||||||
|
|
||||||
if type_ in visited:
|
if type_ in visited:
|
||||||
return visited[type_]
|
return visited[type_]
|
||||||
if depth >= _MAX_TYPED_DICT_RECURSION:
|
if depth >= _MAX_TYPED_DICT_RECURSION:
|
||||||
@@ -297,7 +294,7 @@ def _convert_any_typed_dicts_to_pydantic(
|
|||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
if arg_desc := arg_descriptions.get(arg):
|
if arg_desc := arg_descriptions.get(arg):
|
||||||
field_kwargs["description"] = arg_desc
|
field_kwargs["description"] = arg_desc
|
||||||
fields[arg] = (new_arg_type, Field_v1(**field_kwargs))
|
fields[arg] = (new_arg_type, Field(**field_kwargs))
|
||||||
else:
|
else:
|
||||||
new_arg_type = _convert_any_typed_dicts_to_pydantic(
|
new_arg_type = _convert_any_typed_dicts_to_pydantic(
|
||||||
arg_type, depth=depth + 1, visited=visited
|
arg_type, depth=depth + 1, visited=visited
|
||||||
@@ -305,8 +302,8 @@ def _convert_any_typed_dicts_to_pydantic(
|
|||||||
field_kwargs = {"default": ...}
|
field_kwargs = {"default": ...}
|
||||||
if arg_desc := arg_descriptions.get(arg):
|
if arg_desc := arg_descriptions.get(arg):
|
||||||
field_kwargs["description"] = arg_desc
|
field_kwargs["description"] = arg_desc
|
||||||
fields[arg] = (new_arg_type, Field_v1(**field_kwargs))
|
fields[arg] = (new_arg_type, Field(**field_kwargs))
|
||||||
model = create_model_v1(typed_dict.__name__, **fields)
|
model = create_model(typed_dict.__name__, **fields)
|
||||||
model.__doc__ = description
|
model.__doc__ = description
|
||||||
visited[typed_dict] = model
|
visited[typed_dict] = model
|
||||||
return model
|
return model
|
||||||
@@ -332,9 +329,9 @@ def _format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
|
|||||||
Returns:
|
Returns:
|
||||||
The function description.
|
The function description.
|
||||||
"""
|
"""
|
||||||
from langchain_core.tools import simple
|
is_simple_oai_tool = (
|
||||||
|
isinstance(tool, langchain_core.tools.simple.Tool) and not tool.args_schema
|
||||||
is_simple_oai_tool = isinstance(tool, simple.Tool) and not tool.args_schema
|
)
|
||||||
if tool.tool_call_schema and not is_simple_oai_tool:
|
if tool.tool_call_schema and not is_simple_oai_tool:
|
||||||
if isinstance(tool.tool_call_schema, dict):
|
if isinstance(tool.tool_call_schema, dict):
|
||||||
return _convert_json_schema_to_openai_function(
|
return _convert_json_schema_to_openai_function(
|
||||||
@@ -435,8 +432,6 @@ def convert_to_openai_function(
|
|||||||
'description' and 'parameters' keys are now optional. Only 'name' is
|
'description' and 'parameters' keys are now optional. Only 'name' is
|
||||||
required and guaranteed to be part of the output.
|
required and guaranteed to be part of the output.
|
||||||
"""
|
"""
|
||||||
from langchain_core.tools import BaseTool
|
|
||||||
|
|
||||||
# an Anthropic format tool
|
# an Anthropic format tool
|
||||||
if isinstance(function, dict) and all(
|
if isinstance(function, dict) and all(
|
||||||
k in function for k in ("name", "input_schema")
|
k in function for k in ("name", "input_schema")
|
||||||
@@ -476,7 +471,7 @@ def convert_to_openai_function(
|
|||||||
oai_function = cast(
|
oai_function = cast(
|
||||||
"dict", _convert_typed_dict_to_openai_function(cast("type", function))
|
"dict", _convert_typed_dict_to_openai_function(cast("type", function))
|
||||||
)
|
)
|
||||||
elif isinstance(function, BaseTool):
|
elif isinstance(function, langchain_core.tools.base.BaseTool):
|
||||||
oai_function = cast("dict", _format_tool_to_openai_function(function))
|
oai_function = cast("dict", _format_tool_to_openai_function(function))
|
||||||
elif callable(function):
|
elif callable(function):
|
||||||
oai_function = cast(
|
oai_function = cast(
|
||||||
@@ -582,7 +577,8 @@ def convert_to_openai_tool(
|
|||||||
|
|
||||||
Added support for OpenAI's image generation built-in tool.
|
Added support for OpenAI's image generation built-in tool.
|
||||||
"""
|
"""
|
||||||
from langchain_core.tools import Tool
|
# Import locally to prevent circular import
|
||||||
|
from langchain_core.tools import Tool # noqa: PLC0415
|
||||||
|
|
||||||
if isinstance(tool, dict):
|
if isinstance(tool, dict):
|
||||||
if tool.get("type") in _WellKnownOpenAITools:
|
if tool.get("type") in _WellKnownOpenAITools:
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
"""Utilities for working with interactive environments."""
|
"""Utilities for working with interactive environments."""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
def is_interactive_env() -> bool:
|
def is_interactive_env() -> bool:
|
||||||
"""Determine if running within IPython or Jupyter.
|
"""Determine if running within IPython or Jupyter.
|
||||||
@@ -7,6 +9,4 @@ def is_interactive_env() -> bool:
|
|||||||
Returns:
|
Returns:
|
||||||
True if running in an interactive environment, False otherwise.
|
True if running in an interactive environment, False otherwise.
|
||||||
"""
|
"""
|
||||||
import sys
|
|
||||||
|
|
||||||
return hasattr(sys, "ps2")
|
return hasattr(sys, "ps2")
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ from typing import (
|
|||||||
from pydantic import ConfigDict, Field, model_validator
|
from pydantic import ConfigDict, Field, model_validator
|
||||||
from typing_extensions import Self, override
|
from typing_extensions import Self, override
|
||||||
|
|
||||||
|
from langchain_core.documents import Document
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.retrievers import BaseRetriever, LangSmithRetrieverParams
|
from langchain_core.retrievers import BaseRetriever, LangSmithRetrieverParams
|
||||||
from langchain_core.runnables.config import run_in_executor
|
from langchain_core.runnables.config import run_in_executor
|
||||||
@@ -49,7 +50,6 @@ if TYPE_CHECKING:
|
|||||||
AsyncCallbackManagerForRetrieverRun,
|
AsyncCallbackManagerForRetrieverRun,
|
||||||
CallbackManagerForRetrieverRun,
|
CallbackManagerForRetrieverRun,
|
||||||
)
|
)
|
||||||
from langchain_core.documents import Document
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -85,9 +85,6 @@ class VectorStore(ABC):
|
|||||||
ValueError: If the number of ids does not match the number of texts.
|
ValueError: If the number of ids does not match the number of texts.
|
||||||
"""
|
"""
|
||||||
if type(self).add_documents != VectorStore.add_documents:
|
if type(self).add_documents != VectorStore.add_documents:
|
||||||
# Import document in local scope to avoid circular imports
|
|
||||||
from langchain_core.documents import Document
|
|
||||||
|
|
||||||
# This condition is triggered if the subclass has provided
|
# This condition is triggered if the subclass has provided
|
||||||
# an implementation of the upsert method.
|
# an implementation of the upsert method.
|
||||||
# The existing add_texts
|
# The existing add_texts
|
||||||
@@ -234,9 +231,6 @@ class VectorStore(ABC):
|
|||||||
# For backward compatibility
|
# For backward compatibility
|
||||||
kwargs["ids"] = ids
|
kwargs["ids"] = ids
|
||||||
if type(self).aadd_documents != VectorStore.aadd_documents:
|
if type(self).aadd_documents != VectorStore.aadd_documents:
|
||||||
# Import document in local scope to avoid circular imports
|
|
||||||
from langchain_core.documents import Document
|
|
||||||
|
|
||||||
# This condition is triggered if the subclass has provided
|
# This condition is triggered if the subclass has provided
|
||||||
# an implementation of the upsert method.
|
# an implementation of the upsert method.
|
||||||
# The existing add_texts
|
# The existing add_texts
|
||||||
|
|||||||
@@ -27,6 +27,13 @@ if TYPE_CHECKING:
|
|||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.indexing import UpsertResponse
|
from langchain_core.indexing import UpsertResponse
|
||||||
|
|
||||||
|
try:
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
_HAS_NUMPY = True
|
||||||
|
except ImportError:
|
||||||
|
_HAS_NUMPY = False
|
||||||
|
|
||||||
|
|
||||||
class InMemoryVectorStore(VectorStore):
|
class InMemoryVectorStore(VectorStore):
|
||||||
"""In-memory vector store implementation.
|
"""In-memory vector store implementation.
|
||||||
@@ -496,14 +503,12 @@ class InMemoryVectorStore(VectorStore):
|
|||||||
filter=filter,
|
filter=filter,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
if not _HAS_NUMPY:
|
||||||
import numpy as np
|
|
||||||
except ImportError as e:
|
|
||||||
msg = (
|
msg = (
|
||||||
"numpy must be installed to use max_marginal_relevance_search "
|
"numpy must be installed to use max_marginal_relevance_search "
|
||||||
"pip install numpy"
|
"pip install numpy"
|
||||||
)
|
)
|
||||||
raise ImportError(msg) from e
|
raise ImportError(msg)
|
||||||
|
|
||||||
mmr_chosen_indices = maximal_marginal_relevance(
|
mmr_chosen_indices = maximal_marginal_relevance(
|
||||||
np.array(embedding, dtype=np.float32),
|
np.array(embedding, dtype=np.float32),
|
||||||
|
|||||||
@@ -10,9 +10,21 @@ import logging
|
|||||||
import warnings
|
import warnings
|
||||||
from typing import TYPE_CHECKING, Union
|
from typing import TYPE_CHECKING, Union
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
try:
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
_HAS_NUMPY = True
|
||||||
|
except ImportError:
|
||||||
|
_HAS_NUMPY = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
import simsimd as simd # type: ignore[import-not-found]
|
||||||
|
|
||||||
|
_HAS_SIMSIMD = True
|
||||||
|
except ImportError:
|
||||||
|
_HAS_SIMSIMD = False
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
Matrix = Union[list[list[float]], list[np.ndarray], np.ndarray]
|
Matrix = Union[list[list[float]], list[np.ndarray], np.ndarray]
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -33,14 +45,12 @@ def _cosine_similarity(x: Matrix, y: Matrix) -> np.ndarray:
|
|||||||
ValueError: If the number of columns in X and Y are not the same.
|
ValueError: If the number of columns in X and Y are not the same.
|
||||||
ImportError: If numpy is not installed.
|
ImportError: If numpy is not installed.
|
||||||
"""
|
"""
|
||||||
try:
|
if not _HAS_NUMPY:
|
||||||
import numpy as np
|
|
||||||
except ImportError as e:
|
|
||||||
msg = (
|
msg = (
|
||||||
"cosine_similarity requires numpy to be installed. "
|
"cosine_similarity requires numpy to be installed. "
|
||||||
"Please install numpy with `pip install numpy`."
|
"Please install numpy with `pip install numpy`."
|
||||||
)
|
)
|
||||||
raise ImportError(msg) from e
|
raise ImportError(msg)
|
||||||
|
|
||||||
if len(x) == 0 or len(y) == 0:
|
if len(x) == 0 or len(y) == 0:
|
||||||
return np.array([[]])
|
return np.array([[]])
|
||||||
@@ -70,9 +80,7 @@ def _cosine_similarity(x: Matrix, y: Matrix) -> np.ndarray:
|
|||||||
f"and Y has shape {y.shape}."
|
f"and Y has shape {y.shape}."
|
||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
try:
|
if not _HAS_SIMSIMD:
|
||||||
import simsimd as simd # type: ignore[import-not-found]
|
|
||||||
except ImportError:
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Unable to import simsimd, defaulting to NumPy implementation. If you want "
|
"Unable to import simsimd, defaulting to NumPy implementation. If you want "
|
||||||
"to use simsimd please install with `pip install simsimd`."
|
"to use simsimd please install with `pip install simsimd`."
|
||||||
@@ -113,14 +121,12 @@ def maximal_marginal_relevance(
|
|||||||
Raises:
|
Raises:
|
||||||
ImportError: If numpy is not installed.
|
ImportError: If numpy is not installed.
|
||||||
"""
|
"""
|
||||||
try:
|
if not _HAS_NUMPY:
|
||||||
import numpy as np
|
|
||||||
except ImportError as e:
|
|
||||||
msg = (
|
msg = (
|
||||||
"maximal_marginal_relevance requires numpy to be installed. "
|
"maximal_marginal_relevance requires numpy to be installed. "
|
||||||
"Please install numpy with `pip install numpy`."
|
"Please install numpy with `pip install numpy`."
|
||||||
)
|
)
|
||||||
raise ImportError(msg) from e
|
raise ImportError(msg)
|
||||||
|
|
||||||
if min(k, len(embedding_list)) <= 0:
|
if min(k, len(embedding_list)) <= 0:
|
||||||
return []
|
return []
|
||||||
|
|||||||
@@ -114,7 +114,6 @@ ignore = [
|
|||||||
"BLE", # Blind exceptions
|
"BLE", # Blind exceptions
|
||||||
"DOC", # Docstrings (preview)
|
"DOC", # Docstrings (preview)
|
||||||
"ERA", # No commented-out code
|
"ERA", # No commented-out code
|
||||||
"PLC0415", # Imports outside top level
|
|
||||||
"PLR2004", # Comparison to magic number
|
"PLR2004", # Comparison to magic number
|
||||||
]
|
]
|
||||||
unfixable = ["PLW1510",]
|
unfixable = ["PLW1510",]
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from langchain_core.language_models.fake_chat_models import (
|
|||||||
FakeListChatModel,
|
FakeListChatModel,
|
||||||
GenericFakeChatModel,
|
GenericFakeChatModel,
|
||||||
)
|
)
|
||||||
|
from langchain_core.load import dumps
|
||||||
from langchain_core.messages import AIMessage
|
from langchain_core.messages import AIMessage
|
||||||
from langchain_core.outputs import ChatGeneration, Generation
|
from langchain_core.outputs import ChatGeneration, Generation
|
||||||
from langchain_core.outputs.chat_result import ChatResult
|
from langchain_core.outputs.chat_result import ChatResult
|
||||||
@@ -318,8 +319,6 @@ def test_cache_with_generation_objects() -> None:
|
|||||||
cache = InMemoryCache()
|
cache = InMemoryCache()
|
||||||
|
|
||||||
# Create a simple fake chat model that we can control
|
# Create a simple fake chat model that we can control
|
||||||
from langchain_core.messages import AIMessage
|
|
||||||
|
|
||||||
class SimpleFakeChat:
|
class SimpleFakeChat:
|
||||||
"""Simple fake chat model for testing."""
|
"""Simple fake chat model for testing."""
|
||||||
|
|
||||||
@@ -332,8 +331,6 @@ def test_cache_with_generation_objects() -> None:
|
|||||||
|
|
||||||
def generate_response(self, prompt: str) -> ChatResult:
|
def generate_response(self, prompt: str) -> ChatResult:
|
||||||
"""Simulate the cache lookup and generation logic."""
|
"""Simulate the cache lookup and generation logic."""
|
||||||
from langchain_core.load import dumps
|
|
||||||
|
|
||||||
llm_string = self._get_llm_string()
|
llm_string = self._get_llm_string()
|
||||||
prompt_str = dumps([prompt])
|
prompt_str = dumps([prompt])
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from blockbuster import BlockBuster
|
|||||||
|
|
||||||
from langchain_core.caches import InMemoryCache
|
from langchain_core.caches import InMemoryCache
|
||||||
from langchain_core.language_models import GenericFakeChatModel
|
from langchain_core.language_models import GenericFakeChatModel
|
||||||
|
from langchain_core.load import dumps
|
||||||
from langchain_core.rate_limiters import InMemoryRateLimiter
|
from langchain_core.rate_limiters import InMemoryRateLimiter
|
||||||
|
|
||||||
|
|
||||||
@@ -229,8 +230,6 @@ class SerializableModel(GenericFakeChatModel):
|
|||||||
|
|
||||||
def test_serialization_with_rate_limiter() -> None:
|
def test_serialization_with_rate_limiter() -> None:
|
||||||
"""Test model serialization with rate limiter."""
|
"""Test model serialization with rate limiter."""
|
||||||
from langchain_core.load import dumps
|
|
||||||
|
|
||||||
model = SerializableModel(
|
model = SerializableModel(
|
||||||
messages=iter(["hello", "world", "!"]),
|
messages=iter(["hello", "world", "!"]),
|
||||||
rate_limiter=InMemoryRateLimiter(
|
rate_limiter=InMemoryRateLimiter(
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field, SecretStr
|
||||||
|
|
||||||
from langchain_core.load import Serializable, dumpd, dumps, load
|
from langchain_core.load import Serializable, dumpd, dumps, load
|
||||||
from langchain_core.load.serializable import _is_field_useful
|
from langchain_core.load.serializable import _is_field_useful
|
||||||
@@ -62,9 +62,6 @@ def test_simple_serialization_is_serializable() -> None:
|
|||||||
|
|
||||||
def test_simple_serialization_secret() -> None:
|
def test_simple_serialization_secret() -> None:
|
||||||
"""Test handling of secrets."""
|
"""Test handling of secrets."""
|
||||||
from pydantic import SecretStr
|
|
||||||
|
|
||||||
from langchain_core.load import Serializable
|
|
||||||
|
|
||||||
class Foo(Serializable):
|
class Foo(Serializable):
|
||||||
bar: int
|
bar: int
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from collections.abc import AsyncIterator, Iterator
|
from collections.abc import AsyncIterator, Iterator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import pydantic
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel, Field, ValidationError
|
from pydantic import BaseModel, Field, ValidationError
|
||||||
|
|
||||||
@@ -802,7 +803,6 @@ async def test_partial_pydantic_output_parser_async() -> None:
|
|||||||
|
|
||||||
def test_parse_with_different_pydantic_2_v1() -> None:
|
def test_parse_with_different_pydantic_2_v1() -> None:
|
||||||
"""Test with pydantic.v1.BaseModel from pydantic 2."""
|
"""Test with pydantic.v1.BaseModel from pydantic 2."""
|
||||||
import pydantic
|
|
||||||
|
|
||||||
class Forecast(pydantic.v1.BaseModel):
|
class Forecast(pydantic.v1.BaseModel):
|
||||||
temperature: int
|
temperature: int
|
||||||
@@ -836,9 +836,8 @@ def test_parse_with_different_pydantic_2_v1() -> None:
|
|||||||
|
|
||||||
def test_parse_with_different_pydantic_2_proper() -> None:
|
def test_parse_with_different_pydantic_2_proper() -> None:
|
||||||
"""Test with pydantic.BaseModel from pydantic 2."""
|
"""Test with pydantic.BaseModel from pydantic 2."""
|
||||||
import pydantic
|
|
||||||
|
|
||||||
class Forecast(pydantic.BaseModel):
|
class Forecast(BaseModel):
|
||||||
temperature: int
|
temperature: int
|
||||||
forecast: str
|
forecast: str
|
||||||
|
|
||||||
|
|||||||
@@ -189,8 +189,6 @@ def test_pydantic_output_parser_type_inference() -> None:
|
|||||||
|
|
||||||
def test_format_instructions_preserves_language() -> None:
|
def test_format_instructions_preserves_language() -> None:
|
||||||
"""Test format instructions does not attempt to encode into ascii."""
|
"""Test format instructions does not attempt to encode into ascii."""
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
description = (
|
description = (
|
||||||
"你好, こんにちは, नमस्ते, Bonjour, Hola, "
|
"你好, こんにちは, नमस्ते, Bonjour, Hola, "
|
||||||
"Olá, 안녕하세요, Jambo, Merhaba, Γειά σου" # noqa: RUF001
|
"Olá, 안녕하세요, Jambo, Merhaba, Γειά σου" # noqa: RUF001
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Test functionality related to prompts."""
|
"""Test functionality related to prompts."""
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
from tempfile import NamedTemporaryFile
|
||||||
from typing import Any, Union
|
from typing import Any, Union
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
@@ -32,8 +33,6 @@ def test_from_file_encoding() -> None:
|
|||||||
input_variables = ["foo"]
|
input_variables = ["foo"]
|
||||||
|
|
||||||
# First write to a file using CP-1252 encoding.
|
# First write to a file using CP-1252 encoding.
|
||||||
from tempfile import NamedTemporaryFile
|
|
||||||
|
|
||||||
with NamedTemporaryFile(delete=True, mode="w", encoding="cp1252") as f:
|
with NamedTemporaryFile(delete=True, mode="w", encoding="cp1252") as f:
|
||||||
f.write(template)
|
f.write(template)
|
||||||
f.flush()
|
f.flush()
|
||||||
@@ -434,11 +433,9 @@ Will it get confused{ }?
|
|||||||
assert prompt == expected_prompt
|
assert prompt == expected_prompt
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("jinja2")
|
|
||||||
def test_basic_sandboxing_with_jinja2() -> None:
|
def test_basic_sandboxing_with_jinja2() -> None:
|
||||||
"""Test basic sandboxing with jinja2."""
|
"""Test basic sandboxing with jinja2."""
|
||||||
import jinja2
|
jinja2 = pytest.importorskip("jinja2")
|
||||||
|
|
||||||
template = " {{''.__class__.__bases__[0] }} " # malicious code
|
template = " {{''.__class__.__bases__[0] }} " # malicious code
|
||||||
prompt = PromptTemplate.from_template(template, template_format="jinja2")
|
prompt = PromptTemplate.from_template(template, template_format="jinja2")
|
||||||
with pytest.raises(jinja2.exceptions.SecurityError):
|
with pytest.raises(jinja2.exceptions.SecurityError):
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
|
from threading import Lock
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -80,7 +81,6 @@ def test_batch_concurrency() -> None:
|
|||||||
"""Test that batch respects max_concurrency."""
|
"""Test that batch respects max_concurrency."""
|
||||||
running_tasks = 0
|
running_tasks = 0
|
||||||
max_running_tasks = 0
|
max_running_tasks = 0
|
||||||
from threading import Lock
|
|
||||||
|
|
||||||
lock = Lock()
|
lock = Lock()
|
||||||
|
|
||||||
@@ -112,7 +112,6 @@ def test_batch_as_completed_concurrency() -> None:
|
|||||||
"""Test that batch_as_completed respects max_concurrency."""
|
"""Test that batch_as_completed respects max_concurrency."""
|
||||||
running_tasks = 0
|
running_tasks = 0
|
||||||
max_running_tasks = 0
|
max_running_tasks = 0
|
||||||
from threading import Lock
|
|
||||||
|
|
||||||
lock = Lock()
|
lock = Lock()
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from typing import Any, Callable, Optional, Union
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, RootModel
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
@@ -20,6 +20,11 @@ from langchain_core.runnables.config import RunnableConfig
|
|||||||
from langchain_core.runnables.history import RunnableWithMessageHistory
|
from langchain_core.runnables.history import RunnableWithMessageHistory
|
||||||
from langchain_core.runnables.utils import ConfigurableFieldSpec, Input, Output
|
from langchain_core.runnables.utils import ConfigurableFieldSpec, Input, Output
|
||||||
from langchain_core.tracers import Run
|
from langchain_core.tracers import Run
|
||||||
|
from langchain_core.tracers.root_listeners import (
|
||||||
|
AsyncListener,
|
||||||
|
AsyncRootListenersTracer,
|
||||||
|
RootListenersTracer,
|
||||||
|
)
|
||||||
from langchain_core.utils.pydantic import PYDANTIC_VERSION
|
from langchain_core.utils.pydantic import PYDANTIC_VERSION
|
||||||
from tests.unit_tests.pydantic_utils import _schema
|
from tests.unit_tests.pydantic_utils import _schema
|
||||||
|
|
||||||
@@ -499,8 +504,6 @@ def test_get_output_schema() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_get_input_schema_input_messages() -> None:
|
def test_get_input_schema_input_messages() -> None:
|
||||||
from pydantic import RootModel
|
|
||||||
|
|
||||||
runnable_with_message_history_input = RootModel[Sequence[BaseMessage]]
|
runnable_with_message_history_input = RootModel[Sequence[BaseMessage]]
|
||||||
|
|
||||||
runnable = RunnableLambda(
|
runnable = RunnableLambda(
|
||||||
@@ -776,8 +779,6 @@ def test_ignore_session_id() -> None:
|
|||||||
|
|
||||||
|
|
||||||
class _RunnableLambdaWithRaiseError(RunnableLambda[Input, Output]):
|
class _RunnableLambdaWithRaiseError(RunnableLambda[Input, Output]):
|
||||||
from langchain_core.tracers.root_listeners import AsyncListener
|
|
||||||
|
|
||||||
def with_listeners(
|
def with_listeners(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
@@ -791,8 +792,6 @@ class _RunnableLambdaWithRaiseError(RunnableLambda[Input, Output]):
|
|||||||
Union[Callable[[Run], None], Callable[[Run, RunnableConfig], None]]
|
Union[Callable[[Run], None], Callable[[Run, RunnableConfig], None]]
|
||||||
] = None,
|
] = None,
|
||||||
) -> Runnable[Input, Output]:
|
) -> Runnable[Input, Output]:
|
||||||
from langchain_core.tracers.root_listeners import RootListenersTracer
|
|
||||||
|
|
||||||
def create_tracer(config: RunnableConfig) -> RunnableConfig:
|
def create_tracer(config: RunnableConfig) -> RunnableConfig:
|
||||||
tracer = RootListenersTracer(
|
tracer = RootListenersTracer(
|
||||||
config=config,
|
config=config,
|
||||||
@@ -817,8 +816,6 @@ class _RunnableLambdaWithRaiseError(RunnableLambda[Input, Output]):
|
|||||||
on_end: Optional[AsyncListener] = None,
|
on_end: Optional[AsyncListener] = None,
|
||||||
on_error: Optional[AsyncListener] = None,
|
on_error: Optional[AsyncListener] = None,
|
||||||
) -> Runnable[Input, Output]:
|
) -> Runnable[Input, Output]:
|
||||||
from langchain_core.tracers.root_listeners import AsyncRootListenersTracer
|
|
||||||
|
|
||||||
def create_tracer(config: RunnableConfig) -> RunnableConfig:
|
def create_tracer(config: RunnableConfig) -> RunnableConfig:
|
||||||
tracer = AsyncRootListenersTracer(
|
tracer = AsyncRootListenersTracer(
|
||||||
config=config,
|
config=config,
|
||||||
|
|||||||
@@ -40,6 +40,6 @@ def test_all_imports() -> None:
|
|||||||
def test_imports_for_specific_funcs() -> None:
|
def test_imports_for_specific_funcs() -> None:
|
||||||
"""Test that a few specific imports in more internal namespaces."""
|
"""Test that a few specific imports in more internal namespaces."""
|
||||||
# create_model implementation has been moved to langchain_core.utils.pydantic
|
# create_model implementation has been moved to langchain_core.utils.pydantic
|
||||||
from langchain_core.runnables.utils import ( # type: ignore[attr-defined] # noqa: F401
|
from langchain_core.runnables.utils import ( # type: ignore[attr-defined] # noqa: F401,PLC0415
|
||||||
create_model,
|
create_model,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
import warnings
|
import warnings
|
||||||
from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
|
from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
|
||||||
@@ -17,6 +18,7 @@ from pytest_mock import MockerFixture
|
|||||||
from syrupy.assertion import SnapshotAssertion
|
from syrupy.assertion import SnapshotAssertion
|
||||||
from typing_extensions import TypedDict, override
|
from typing_extensions import TypedDict, override
|
||||||
|
|
||||||
|
from langchain_core.callbacks import BaseCallbackHandler
|
||||||
from langchain_core.callbacks.manager import (
|
from langchain_core.callbacks.manager import (
|
||||||
AsyncCallbackManagerForRetrieverRun,
|
AsyncCallbackManagerForRetrieverRun,
|
||||||
CallbackManagerForRetrieverRun,
|
CallbackManagerForRetrieverRun,
|
||||||
@@ -29,6 +31,7 @@ from langchain_core.language_models import (
|
|||||||
FakeListLLM,
|
FakeListLLM,
|
||||||
FakeStreamingListLLM,
|
FakeStreamingListLLM,
|
||||||
)
|
)
|
||||||
|
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
|
||||||
from langchain_core.load import dumpd, dumps
|
from langchain_core.load import dumpd, dumps
|
||||||
from langchain_core.load.load import loads
|
from langchain_core.load.load import loads
|
||||||
from langchain_core.messages import AIMessageChunk, HumanMessage, SystemMessage
|
from langchain_core.messages import AIMessageChunk, HumanMessage, SystemMessage
|
||||||
@@ -5516,9 +5519,6 @@ async def test_passthrough_atransform_with_dicts() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_listeners() -> None:
|
def test_listeners() -> None:
|
||||||
from langchain_core.runnables import RunnableLambda
|
|
||||||
from langchain_core.tracers.schemas import Run
|
|
||||||
|
|
||||||
def fake_chain(inputs: dict) -> dict:
|
def fake_chain(inputs: dict) -> dict:
|
||||||
return {**inputs, "key": "extra"}
|
return {**inputs, "key": "extra"}
|
||||||
|
|
||||||
@@ -5546,9 +5546,6 @@ def test_listeners() -> None:
|
|||||||
|
|
||||||
|
|
||||||
async def test_listeners_async() -> None:
|
async def test_listeners_async() -> None:
|
||||||
from langchain_core.runnables import RunnableLambda
|
|
||||||
from langchain_core.tracers.schemas import Run
|
|
||||||
|
|
||||||
def fake_chain(inputs: dict) -> dict:
|
def fake_chain(inputs: dict) -> dict:
|
||||||
return {**inputs, "key": "extra"}
|
return {**inputs, "key": "extra"}
|
||||||
|
|
||||||
@@ -5578,12 +5575,6 @@ async def test_listeners_async() -> None:
|
|||||||
|
|
||||||
def test_closing_iterator_doesnt_raise_error() -> None:
|
def test_closing_iterator_doesnt_raise_error() -> None:
|
||||||
"""Test that closing an iterator calls on_chain_end rather than on_chain_error."""
|
"""Test that closing an iterator calls on_chain_end rather than on_chain_error."""
|
||||||
import time
|
|
||||||
|
|
||||||
from langchain_core.callbacks import BaseCallbackHandler
|
|
||||||
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
|
|
||||||
from langchain_core.output_parsers import StrOutputParser
|
|
||||||
|
|
||||||
on_chain_error_triggered = False
|
on_chain_error_triggered = False
|
||||||
on_chain_end_triggered = False
|
on_chain_end_triggered = False
|
||||||
|
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from pydantic import BaseModel
|
|||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun, Callbacks
|
from langchain_core.callbacks import CallbackManagerForRetrieverRun, Callbacks
|
||||||
|
from langchain_core.callbacks.manager import adispatch_custom_event
|
||||||
from langchain_core.chat_history import BaseChatMessageHistory
|
from langchain_core.chat_history import BaseChatMessageHistory
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.language_models import FakeStreamingListLLM, GenericFakeChatModel
|
from langchain_core.language_models import FakeStreamingListLLM, GenericFakeChatModel
|
||||||
@@ -2548,7 +2549,6 @@ async def test_cancel_astream_events() -> None:
|
|||||||
|
|
||||||
async def test_custom_event() -> None:
|
async def test_custom_event() -> None:
|
||||||
"""Test adhoc event."""
|
"""Test adhoc event."""
|
||||||
from langchain_core.callbacks.manager import adispatch_custom_event
|
|
||||||
|
|
||||||
# Ignoring type due to RunnableLamdba being dynamic when it comes to being
|
# Ignoring type due to RunnableLamdba being dynamic when it comes to being
|
||||||
# applied as a decorator to async functions.
|
# applied as a decorator to async functions.
|
||||||
@@ -2625,7 +2625,6 @@ async def test_custom_event() -> None:
|
|||||||
|
|
||||||
async def test_custom_event_nested() -> None:
|
async def test_custom_event_nested() -> None:
|
||||||
"""Test adhoc event in a nested chain."""
|
"""Test adhoc event in a nested chain."""
|
||||||
from langchain_core.callbacks.manager import adispatch_custom_event
|
|
||||||
|
|
||||||
# Ignoring type due to RunnableLamdba being dynamic when it comes to being
|
# Ignoring type due to RunnableLamdba being dynamic when it comes to being
|
||||||
# applied as a decorator to async functions.
|
# applied as a decorator to async functions.
|
||||||
@@ -2736,7 +2735,6 @@ async def test_custom_event_root_dispatch() -> None:
|
|||||||
# This just tests that nothing breaks on the path.
|
# This just tests that nothing breaks on the path.
|
||||||
# It shouldn't do anything at the moment, since the tracer isn't configured
|
# It shouldn't do anything at the moment, since the tracer isn't configured
|
||||||
# to handle adhoc events.
|
# to handle adhoc events.
|
||||||
from langchain_core.callbacks.manager import adispatch_custom_event
|
|
||||||
|
|
||||||
# Expected behavior is that the event cannot be dispatched
|
# Expected behavior is that the event cannot be dispatched
|
||||||
with pytest.raises(RuntimeError):
|
with pytest.raises(RuntimeError):
|
||||||
@@ -2750,8 +2748,6 @@ IS_GTE_3_11 = sys.version_info >= (3, 11)
|
|||||||
@pytest.mark.skipif(not IS_GTE_3_11, reason="Requires Python >=3.11")
|
@pytest.mark.skipif(not IS_GTE_3_11, reason="Requires Python >=3.11")
|
||||||
async def test_custom_event_root_dispatch_with_in_tool() -> None:
|
async def test_custom_event_root_dispatch_with_in_tool() -> None:
|
||||||
"""Test adhoc event in a nested chain."""
|
"""Test adhoc event in a nested chain."""
|
||||||
from langchain_core.callbacks.manager import adispatch_custom_event
|
|
||||||
from langchain_core.tools import tool
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def foo(x: int) -> int:
|
async def foo(x: int) -> int:
|
||||||
|
|||||||
@@ -1,18 +1,17 @@
|
|||||||
|
import langchain_core
|
||||||
|
from langchain_core.callbacks.manager import _get_debug
|
||||||
from langchain_core.globals import get_debug, set_debug
|
from langchain_core.globals import get_debug, set_debug
|
||||||
|
|
||||||
|
|
||||||
def test_debug_is_settable_via_setter() -> None:
|
def test_debug_is_settable_via_setter() -> None:
|
||||||
from langchain_core import globals as globals_
|
previous_value = langchain_core.globals._debug
|
||||||
from langchain_core.callbacks.manager import _get_debug
|
|
||||||
|
|
||||||
previous_value = globals_._debug
|
|
||||||
previous_fn_reading = _get_debug()
|
previous_fn_reading = _get_debug()
|
||||||
assert previous_value == previous_fn_reading
|
assert previous_value == previous_fn_reading
|
||||||
|
|
||||||
# Flip the value of the flag.
|
# Flip the value of the flag.
|
||||||
set_debug(not previous_value)
|
set_debug(not previous_value)
|
||||||
|
|
||||||
new_value = globals_._debug
|
new_value = langchain_core.globals._debug
|
||||||
new_fn_reading = _get_debug()
|
new_fn_reading = _get_debug()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ from typing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel, Field, ValidationError
|
from pydantic import BaseModel, ConfigDict, Field, ValidationError
|
||||||
from pydantic.v1 import BaseModel as BaseModelV1
|
from pydantic.v1 import BaseModel as BaseModelV1
|
||||||
from pydantic.v1 import ValidationError as ValidationErrorV1
|
from pydantic.v1 import ValidationError as ValidationErrorV1
|
||||||
from typing_extensions import TypedDict, override
|
from typing_extensions import TypedDict, override
|
||||||
@@ -1852,7 +1852,6 @@ def generate_models() -> list[Any]:
|
|||||||
|
|
||||||
def generate_backwards_compatible_v1() -> list[Any]:
|
def generate_backwards_compatible_v1() -> list[Any]:
|
||||||
"""Generate a model with pydantic 2 from the v1 namespace."""
|
"""Generate a model with pydantic 2 from the v1 namespace."""
|
||||||
from pydantic.v1 import BaseModel as BaseModelV1
|
|
||||||
|
|
||||||
class FooV1Namespace(BaseModelV1):
|
class FooV1Namespace(BaseModelV1):
|
||||||
a: int
|
a: int
|
||||||
@@ -1920,8 +1919,6 @@ def test_args_schema_explicitly_typed() -> None:
|
|||||||
is a pydantic 1 model!
|
is a pydantic 1 model!
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# Check with whatever pydantic model is passed in and not via v1 namespace
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
class Foo(BaseModel):
|
class Foo(BaseModel):
|
||||||
a: int
|
a: int
|
||||||
@@ -1964,7 +1961,6 @@ def test_args_schema_explicitly_typed() -> None:
|
|||||||
@pytest.mark.parametrize("pydantic_model", TEST_MODELS)
|
@pytest.mark.parametrize("pydantic_model", TEST_MODELS)
|
||||||
def test_structured_tool_with_different_pydantic_versions(pydantic_model: Any) -> None:
|
def test_structured_tool_with_different_pydantic_versions(pydantic_model: Any) -> None:
|
||||||
"""This should test that one can type the args schema as a pydantic model."""
|
"""This should test that one can type the args schema as a pydantic model."""
|
||||||
from langchain_core.tools import StructuredTool
|
|
||||||
|
|
||||||
def foo(a: int, b: str) -> str:
|
def foo(a: int, b: str) -> str:
|
||||||
"""Hahaha."""
|
"""Hahaha."""
|
||||||
@@ -2063,16 +2059,13 @@ def test__get_all_basemodel_annotations_v2(*, use_v1_namespace: bool) -> None:
|
|||||||
A = TypeVar("A")
|
A = TypeVar("A")
|
||||||
|
|
||||||
if use_v1_namespace:
|
if use_v1_namespace:
|
||||||
from pydantic.v1 import BaseModel as BaseModel1
|
|
||||||
|
|
||||||
class ModelA(BaseModel1, Generic[A], extra="allow"):
|
class ModelA(BaseModelV1, Generic[A], extra="allow"):
|
||||||
a: A
|
a: A
|
||||||
|
|
||||||
else:
|
else:
|
||||||
from pydantic import BaseModel as BaseModel2
|
|
||||||
from pydantic import ConfigDict
|
|
||||||
|
|
||||||
class ModelA(BaseModel2, Generic[A]): # type: ignore[no-redef]
|
class ModelA(BaseModel, Generic[A]): # type: ignore[no-redef]
|
||||||
a: A
|
a: A
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
|
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
|
||||||
|
|
||||||
@@ -2208,12 +2201,8 @@ def test_create_retriever_tool() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_tool_args_schema_pydantic_v2_with_metadata() -> None:
|
def test_tool_args_schema_pydantic_v2_with_metadata() -> None:
|
||||||
from pydantic import BaseModel as BaseModelV2
|
class Foo(BaseModel):
|
||||||
from pydantic import Field as FieldV2
|
x: list[int] = Field(
|
||||||
from pydantic import ValidationError as ValidationErrorV2
|
|
||||||
|
|
||||||
class Foo(BaseModelV2):
|
|
||||||
x: list[int] = FieldV2(
|
|
||||||
description="List of integers", min_length=10, max_length=15
|
description="List of integers", min_length=10, max_length=15
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -2240,7 +2229,7 @@ def test_tool_args_schema_pydantic_v2_with_metadata() -> None:
|
|||||||
}
|
}
|
||||||
|
|
||||||
assert foo.invoke({"x": [0] * 10})
|
assert foo.invoke({"x": [0] * 10})
|
||||||
with pytest.raises(ValidationErrorV2):
|
with pytest.raises(ValidationError):
|
||||||
foo.invoke({"x": [0] * 9})
|
foo.invoke({"x": [0] * 9})
|
||||||
|
|
||||||
|
|
||||||
@@ -2576,8 +2565,6 @@ def test_title_property_preserved() -> None:
|
|||||||
|
|
||||||
https://github.com/langchain-ai/langchain/issues/30456
|
https://github.com/langchain-ai/langchain/issues/30456
|
||||||
"""
|
"""
|
||||||
from langchain_core.tools import tool
|
|
||||||
|
|
||||||
schema_to_be_extracted = {
|
schema_to_be_extracted = {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": [],
|
"required": [],
|
||||||
|
|||||||
@@ -3,7 +3,8 @@
|
|||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from pydantic import ConfigDict
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
from pydantic.v1 import BaseModel as BaseModelV1
|
||||||
|
|
||||||
from langchain_core.utils.pydantic import (
|
from langchain_core.utils.pydantic import (
|
||||||
_create_subset_model_v2,
|
_create_subset_model_v2,
|
||||||
@@ -16,8 +17,6 @@ from langchain_core.utils.pydantic import (
|
|||||||
|
|
||||||
|
|
||||||
def test_pre_init_decorator() -> None:
|
def test_pre_init_decorator() -> None:
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
class Foo(BaseModel):
|
class Foo(BaseModel):
|
||||||
x: int = 5
|
x: int = 5
|
||||||
y: int
|
y: int
|
||||||
@@ -35,8 +34,6 @@ def test_pre_init_decorator() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_pre_init_decorator_with_more_defaults() -> None:
|
def test_pre_init_decorator_with_more_defaults() -> None:
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
class Foo(BaseModel):
|
class Foo(BaseModel):
|
||||||
a: int = 1
|
a: int = 1
|
||||||
b: Optional[int] = None
|
b: Optional[int] = None
|
||||||
@@ -56,8 +53,6 @@ def test_pre_init_decorator_with_more_defaults() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_with_aliases() -> None:
|
def test_with_aliases() -> None:
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
class Foo(BaseModel):
|
class Foo(BaseModel):
|
||||||
x: int = Field(default=1, alias="y")
|
x: int = Field(default=1, alias="y")
|
||||||
z: int
|
z: int
|
||||||
@@ -92,19 +87,14 @@ def test_with_aliases() -> None:
|
|||||||
|
|
||||||
def test_is_basemodel_subclass() -> None:
|
def test_is_basemodel_subclass() -> None:
|
||||||
"""Test pydantic."""
|
"""Test pydantic."""
|
||||||
from pydantic import BaseModel as BaseModelV2
|
assert is_basemodel_subclass(BaseModel)
|
||||||
from pydantic.v1 import BaseModel as BaseModelV1
|
|
||||||
|
|
||||||
assert is_basemodel_subclass(BaseModelV2)
|
|
||||||
assert is_basemodel_subclass(BaseModelV1)
|
assert is_basemodel_subclass(BaseModelV1)
|
||||||
|
|
||||||
|
|
||||||
def test_is_basemodel_instance() -> None:
|
def test_is_basemodel_instance() -> None:
|
||||||
"""Test pydantic."""
|
"""Test pydantic."""
|
||||||
from pydantic import BaseModel as BaseModelV2
|
|
||||||
from pydantic.v1 import BaseModel as BaseModelV1
|
|
||||||
|
|
||||||
class Foo(BaseModelV2):
|
class Foo(BaseModel):
|
||||||
x: int
|
x: int
|
||||||
|
|
||||||
assert is_basemodel_instance(Foo(x=5))
|
assert is_basemodel_instance(Foo(x=5))
|
||||||
@@ -117,11 +107,9 @@ def test_is_basemodel_instance() -> None:
|
|||||||
|
|
||||||
def test_with_field_metadata() -> None:
|
def test_with_field_metadata() -> None:
|
||||||
"""Test pydantic with field metadata."""
|
"""Test pydantic with field metadata."""
|
||||||
from pydantic import BaseModel as BaseModelV2
|
|
||||||
from pydantic import Field as FieldV2
|
|
||||||
|
|
||||||
class Foo(BaseModelV2):
|
class Foo(BaseModel):
|
||||||
x: list[int] = FieldV2(
|
x: list[int] = Field(
|
||||||
description="List of integers", min_length=10, max_length=15
|
description="List of integers", min_length=10, max_length=15
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -144,8 +132,6 @@ def test_with_field_metadata() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_fields_pydantic_v2_proper() -> None:
|
def test_fields_pydantic_v2_proper() -> None:
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
class Foo(BaseModel):
|
class Foo(BaseModel):
|
||||||
x: int
|
x: int
|
||||||
|
|
||||||
@@ -154,9 +140,7 @@ def test_fields_pydantic_v2_proper() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_fields_pydantic_v1_from_2() -> None:
|
def test_fields_pydantic_v1_from_2() -> None:
|
||||||
from pydantic.v1 import BaseModel
|
class Foo(BaseModelV1):
|
||||||
|
|
||||||
class Foo(BaseModel):
|
|
||||||
x: int
|
x: int
|
||||||
|
|
||||||
fields = get_fields(Foo)
|
fields = get_fields(Foo)
|
||||||
|
|||||||
@@ -6,7 +6,9 @@ from typing import Any, Callable, Optional, Union
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import SecretStr
|
from pydantic import BaseModel, Field, SecretStr
|
||||||
|
from pydantic.v1 import BaseModel as PydanticV1BaseModel
|
||||||
|
from pydantic.v1 import Field as PydanticV1Field
|
||||||
|
|
||||||
from langchain_core import utils
|
from langchain_core import utils
|
||||||
from langchain_core.outputs import GenerationChunk
|
from langchain_core.outputs import GenerationChunk
|
||||||
@@ -212,13 +214,10 @@ def test_guard_import_failure(
|
|||||||
|
|
||||||
|
|
||||||
def test_get_pydantic_field_names_v1_in_2() -> None:
|
def test_get_pydantic_field_names_v1_in_2() -> None:
|
||||||
from pydantic.v1 import BaseModel as PydanticV1BaseModel
|
|
||||||
from pydantic.v1 import Field
|
|
||||||
|
|
||||||
class PydanticV1Model(PydanticV1BaseModel):
|
class PydanticV1Model(PydanticV1BaseModel):
|
||||||
field1: str
|
field1: str
|
||||||
field2: int
|
field2: int
|
||||||
alias_field: int = Field(alias="aliased_field")
|
alias_field: int = PydanticV1Field(alias="aliased_field")
|
||||||
|
|
||||||
result = get_pydantic_field_names(PydanticV1Model)
|
result = get_pydantic_field_names(PydanticV1Model)
|
||||||
expected = {"field1", "field2", "aliased_field", "alias_field"}
|
expected = {"field1", "field2", "aliased_field", "alias_field"}
|
||||||
@@ -226,8 +225,6 @@ def test_get_pydantic_field_names_v1_in_2() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_get_pydantic_field_names_v2_in_2() -> None:
|
def test_get_pydantic_field_names_v2_in_2() -> None:
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
class PydanticModel(BaseModel):
|
class PydanticModel(BaseModel):
|
||||||
field1: str
|
field1: str
|
||||||
field2: int
|
field2: int
|
||||||
@@ -341,8 +338,6 @@ def test_secret_from_env_with_custom_error_message(
|
|||||||
def test_using_secret_from_env_as_default_factory(
|
def test_using_secret_from_env_as_default_factory(
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
) -> None:
|
) -> None:
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
class Foo(BaseModel):
|
class Foo(BaseModel):
|
||||||
secret: SecretStr = Field(default_factory=secret_from_env("TEST_KEY"))
|
secret: SecretStr = Field(default_factory=secret_from_env("TEST_KEY"))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user