chore(langchain): drop Python 3.9 to prep for v1 (#32704)

Python 3.9 EOL is October 2025, so we're going to drop it for the v1
alpha release.
This commit is contained in:
Sydney Runkle
2025-08-26 19:16:42 -04:00
committed by GitHub
parent 3d08b6bd11
commit c6c7fce6c9
17 changed files with 1168 additions and 1628 deletions

View File

@@ -132,6 +132,8 @@ def _get_configs_for_single_dir(job: str, dir_: str) -> List[Dict[str, str]]:
elif dir_ == "libs/langchain" and job == "extended-tests": elif dir_ == "libs/langchain" and job == "extended-tests":
py_versions = ["3.9", "3.13"] py_versions = ["3.9", "3.13"]
elif dir_ == "libs/langchain_v1":
py_versions = ["3.10", "3.13"]
elif dir_ == ".": elif dir_ == ".":
# unable to install with 3.13 because tokenizers doesn't support 3.13 yet # unable to install with 3.13 because tokenizers doesn't support 3.13 yet

View File

@@ -32,9 +32,4 @@ def format_document_xml(doc: Document) -> str:
if doc.metadata: if doc.metadata:
metadata_items = [f"{k}: {v!s}" for k, v in doc.metadata.items()] metadata_items = [f"{k}: {v!s}" for k, v in doc.metadata.items()]
metadata_str = f"<metadata>{', '.join(metadata_items)}</metadata>" metadata_str = f"<metadata>{', '.join(metadata_items)}</metadata>"
return ( return f"<document>{id_str}<content>{doc.page_content}</content>{metadata_str}</document>"
f"<document>{id_str}"
f"<content>{doc.page_content}</content>"
f"{metadata_str}"
f"</document>"
)

View File

@@ -12,10 +12,10 @@ particularly for summarization chains and other document processing workflows.
from __future__ import annotations from __future__ import annotations
import inspect import inspect
from typing import TYPE_CHECKING, Callable, Union from typing import TYPE_CHECKING, Union
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Awaitable from collections.abc import Awaitable, Callable
from langchain_core.messages import MessageLikeRepresentation from langchain_core.messages import MessageLikeRepresentation
from langgraph.runtime import Runtime from langgraph.runtime import Runtime
@@ -92,9 +92,7 @@ async def aresolve_prompt(
str, str,
None, None,
Callable[[StateT, Runtime[ContextT]], list[MessageLikeRepresentation]], Callable[[StateT, Runtime[ContextT]], list[MessageLikeRepresentation]],
Callable[ Callable[[StateT, Runtime[ContextT]], Awaitable[list[MessageLikeRepresentation]]],
[StateT, Runtime[ContextT]], Awaitable[list[MessageLikeRepresentation]]
],
], ],
state: StateT, state: StateT,
runtime: Runtime[ContextT], runtime: Runtime[ContextT],

View File

@@ -2,11 +2,10 @@
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, Any, ClassVar, Protocol, TypeVar, Union from typing import TYPE_CHECKING, Any, ClassVar, Protocol, TypeAlias, TypeVar, Union
from langgraph.graph._node import StateNode from langgraph.graph._node import StateNode
from pydantic import BaseModel from pydantic import BaseModel
from typing_extensions import TypeAlias
if TYPE_CHECKING: if TYPE_CHECKING:
from dataclasses import Field from dataclasses import Field

View File

@@ -7,10 +7,8 @@ from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Annotated, Annotated,
Any, Any,
Callable,
Generic, Generic,
Literal, Literal,
Optional,
Union, Union,
cast, cast,
) )
@@ -26,6 +24,8 @@ from langchain._internal._utils import RunnableCallable
from langchain.chat_models import init_chat_model from langchain.chat_models import init_chat_model
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Callable
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.language_models.chat_models import BaseChatModel
@@ -154,7 +154,7 @@ class _MapReduceExtractor(Generic[ContextT]):
StateNode, StateNode,
] = "default_reducer", ] = "default_reducer",
context_schema: type[ContextT] | None = None, context_schema: type[ContextT] | None = None,
response_format: Optional[type[BaseModel]] = None, response_format: type[BaseModel] | None = None,
) -> None: ) -> None:
"""Initialize the MapReduceExtractor. """Initialize the MapReduceExtractor.
@@ -190,9 +190,7 @@ class _MapReduceExtractor(Generic[ContextT]):
if isinstance(model, str): if isinstance(model, str):
model = init_chat_model(model) model = init_chat_model(model)
self.model = ( self.model = model.with_structured_output(response_format) if response_format else model
model.with_structured_output(response_format) if response_format else model
)
self.map_prompt = map_prompt self.map_prompt = map_prompt
self.reduce_prompt = reduce_prompt self.reduce_prompt = reduce_prompt
self.reduce = reduce self.reduce = reduce
@@ -342,9 +340,7 @@ class _MapReduceExtractor(Generic[ContextT]):
config: RunnableConfig, config: RunnableConfig,
) -> dict[str, list[ExtractionResult]]: ) -> dict[str, list[ExtractionResult]]:
prompt = await self._aget_map_prompt(state, runtime) prompt = await self._aget_map_prompt(state, runtime)
response = cast( response = cast("AIMessage", await self.model.ainvoke(prompt, config=config))
"AIMessage", await self.model.ainvoke(prompt, config=config)
)
result = response if self.response_format else response.text() result = response if self.response_format else response.text()
extraction_result: ExtractionResult = { extraction_result: ExtractionResult = {
"indexes": state["indexes"], "indexes": state["indexes"],
@@ -375,9 +371,7 @@ class _MapReduceExtractor(Generic[ContextT]):
config: RunnableConfig, config: RunnableConfig,
) -> MapReduceNodeUpdate: ) -> MapReduceNodeUpdate:
prompt = await self._aget_reduce_prompt(state, runtime) prompt = await self._aget_reduce_prompt(state, runtime)
response = cast( response = cast("AIMessage", await self.model.ainvoke(prompt, config=config))
"AIMessage", await self.model.ainvoke(prompt, config=config)
)
result = response if self.response_format else response.text() result = response if self.response_format else response.text()
return {"result": result} return {"result": result}
@@ -411,7 +405,7 @@ class _MapReduceExtractor(Generic[ContextT]):
# Add-conditional edges doesn't explicitly type Send # Add-conditional edges doesn't explicitly type Send
builder.add_conditional_edges( builder.add_conditional_edges(
"continue_to_map", "continue_to_map",
self.continue_to_map, # type: ignore[arg-type] self.continue_to_map,
["map_process"], ["map_process"],
) )
@@ -422,7 +416,7 @@ class _MapReduceExtractor(Generic[ContextT]):
builder.add_edge("map_process", "reduce_process") builder.add_edge("map_process", "reduce_process")
builder.add_edge("reduce_process", END) builder.add_edge("reduce_process", END)
else: else:
reduce_node = cast("StateNode", self.reduce) reduce_node = self.reduce
# The type is ignored here. Requires parameterizing with generics. # The type is ignored here. Requires parameterizing with generics.
builder.add_node("reduce_process", reduce_node) # type: ignore[arg-type] builder.add_node("reduce_process", reduce_node) # type: ignore[arg-type]
builder.add_edge("map_process", "reduce_process") builder.add_edge("map_process", "reduce_process")
@@ -450,7 +444,7 @@ def create_map_reduce_chain(
StateNode, StateNode,
] = "default_reducer", ] = "default_reducer",
context_schema: type[ContextT] | None = None, context_schema: type[ContextT] | None = None,
response_format: Optional[type[BaseModel]] = None, response_format: type[BaseModel] | None = None,
) -> StateGraph[MapReduceState, ContextT, InputSchema, OutputSchema]: ) -> StateGraph[MapReduceState, ContextT, InputSchema, OutputSchema]:
"""Create a map-reduce document extraction chain. """Create a map-reduce document extraction chain.

View File

@@ -5,9 +5,7 @@ from __future__ import annotations
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Callable,
Generic, Generic,
Optional,
Union, Union,
cast, cast,
) )
@@ -24,6 +22,8 @@ from langchain._internal._utils import RunnableCallable
from langchain.chat_models import init_chat_model from langchain.chat_models import init_chat_model
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Callable
from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.language_models.chat_models import BaseChatModel
# Used for type checking, but IDEs may not recognize it inside the cast. # Used for type checking, but IDEs may not recognize it inside the cast.
@@ -149,7 +149,7 @@ class _Extractor(Generic[ContextT]):
], ],
] = None, ] = None,
context_schema: type[ContextT] | None = None, context_schema: type[ContextT] | None = None,
response_format: Optional[type[BaseModel]] = None, response_format: type[BaseModel] | None = None,
) -> None: ) -> None:
"""Initialize the Extractor. """Initialize the Extractor.
@@ -173,9 +173,7 @@ class _Extractor(Generic[ContextT]):
if isinstance(model, str): if isinstance(model, str):
model = init_chat_model(model) model = init_chat_model(model)
self.model = ( self.model = model.with_structured_output(response_format) if response_format else model
model.with_structured_output(response_format) if response_format else model
)
self.initial_prompt = prompt self.initial_prompt = prompt
self.refine_prompt = refine_prompt self.refine_prompt = refine_prompt
self.context_schema = context_schema self.context_schema = context_schema
@@ -188,9 +186,7 @@ class _Extractor(Generic[ContextT]):
# Choose default prompt based on structured output format # Choose default prompt based on structured output format
default_prompt = ( default_prompt = (
DEFAULT_STRUCTURED_INIT_PROMPT DEFAULT_STRUCTURED_INIT_PROMPT if self.response_format else DEFAULT_INIT_PROMPT
if self.response_format
else DEFAULT_INIT_PROMPT
) )
return resolve_prompt( return resolve_prompt(
@@ -209,9 +205,7 @@ class _Extractor(Generic[ContextT]):
# Choose default prompt based on structured output format # Choose default prompt based on structured output format
default_prompt = ( default_prompt = (
DEFAULT_STRUCTURED_INIT_PROMPT DEFAULT_STRUCTURED_INIT_PROMPT if self.response_format else DEFAULT_INIT_PROMPT
if self.response_format
else DEFAULT_INIT_PROMPT
) )
return await aresolve_prompt( return await aresolve_prompt(
@@ -246,9 +240,7 @@ class _Extractor(Generic[ContextT]):
# Choose default prompt based on structured output format # Choose default prompt based on structured output format
default_prompt = ( default_prompt = (
DEFAULT_STRUCTURED_REFINE_PROMPT DEFAULT_STRUCTURED_REFINE_PROMPT if self.response_format else DEFAULT_REFINE_PROMPT
if self.response_format
else DEFAULT_REFINE_PROMPT
) )
return resolve_prompt( return resolve_prompt(
@@ -283,9 +275,7 @@ class _Extractor(Generic[ContextT]):
# Choose default prompt based on structured output format # Choose default prompt based on structured output format
default_prompt = ( default_prompt = (
DEFAULT_STRUCTURED_REFINE_PROMPT DEFAULT_STRUCTURED_REFINE_PROMPT if self.response_format else DEFAULT_REFINE_PROMPT
if self.response_format
else DEFAULT_REFINE_PROMPT
) )
return await aresolve_prompt( return await aresolve_prompt(
@@ -340,16 +330,12 @@ class _Extractor(Generic[ContextT]):
if "result" not in state or state["result"] == "": if "result" not in state or state["result"] == "":
# Initial processing # Initial processing
prompt = await self._aget_initial_prompt(state, runtime) prompt = await self._aget_initial_prompt(state, runtime)
response = cast( response = cast("AIMessage", await self.model.ainvoke(prompt, config=config))
"AIMessage", await self.model.ainvoke(prompt, config=config)
)
result = response if self.response_format else response.text() result = response if self.response_format else response.text()
return {"result": result} return {"result": result}
# Refinement # Refinement
prompt = await self._aget_refine_prompt(state, runtime) prompt = await self._aget_refine_prompt(state, runtime)
response = cast( response = cast("AIMessage", await self.model.ainvoke(prompt, config=config))
"AIMessage", await self.model.ainvoke(prompt, config=config)
)
result = response if self.response_format else response.text() result = response if self.response_format else response.text()
return {"result": result} return {"result": result}
@@ -388,7 +374,7 @@ def create_stuff_documents_chain(
Callable[[ExtractionState, Runtime[ContextT]], list[MessageLikeRepresentation]], Callable[[ExtractionState, Runtime[ContextT]], list[MessageLikeRepresentation]],
] = None, ] = None,
context_schema: type[ContextT] | None = None, context_schema: type[ContextT] | None = None,
response_format: Optional[type[BaseModel]] = None, response_format: type[BaseModel] | None = None,
) -> StateGraph[ExtractionState, ContextT, InputSchema, OutputSchema]: ) -> StateGraph[ExtractionState, ContextT, InputSchema, OutputSchema]:
"""Create a stuff documents chain for processing documents. """Create a stuff documents chain for processing documents.

View File

@@ -5,9 +5,8 @@ from importlib import util
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Callable,
Literal, Literal,
Optional, TypeAlias,
Union, Union,
cast, cast,
overload, overload,
@@ -16,10 +15,10 @@ from typing import (
from langchain_core.language_models import BaseChatModel, LanguageModelInput from langchain_core.language_models import BaseChatModel, LanguageModelInput
from langchain_core.messages import AnyMessage, BaseMessage from langchain_core.messages import AnyMessage, BaseMessage
from langchain_core.runnables import Runnable, RunnableConfig, ensure_config from langchain_core.runnables import Runnable, RunnableConfig, ensure_config
from typing_extensions import TypeAlias, override from typing_extensions import override
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import AsyncIterator, Iterator, Sequence from collections.abc import AsyncIterator, Callable, Iterator, Sequence
from langchain_core.runnables.schema import StreamEvent from langchain_core.runnables.schema import StreamEvent
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
@@ -31,9 +30,9 @@ if TYPE_CHECKING:
def init_chat_model( def init_chat_model(
model: str, model: str,
*, *,
model_provider: Optional[str] = None, model_provider: str | None = None,
configurable_fields: Literal[None] = None, configurable_fields: Literal[None] = None,
config_prefix: Optional[str] = None, config_prefix: str | None = None,
**kwargs: Any, **kwargs: Any,
) -> BaseChatModel: ... ) -> BaseChatModel: ...
@@ -42,20 +41,20 @@ def init_chat_model(
def init_chat_model( def init_chat_model(
model: Literal[None] = None, model: Literal[None] = None,
*, *,
model_provider: Optional[str] = None, model_provider: str | None = None,
configurable_fields: Literal[None] = None, configurable_fields: Literal[None] = None,
config_prefix: Optional[str] = None, config_prefix: str | None = None,
**kwargs: Any, **kwargs: Any,
) -> _ConfigurableModel: ... ) -> _ConfigurableModel: ...
@overload @overload
def init_chat_model( def init_chat_model(
model: Optional[str] = None, model: str | None = None,
*, *,
model_provider: Optional[str] = None, model_provider: str | None = None,
configurable_fields: Union[Literal["any"], list[str], tuple[str, ...]] = ..., configurable_fields: Union[Literal["any"], list[str], tuple[str, ...]] = ...,
config_prefix: Optional[str] = None, config_prefix: str | None = None,
**kwargs: Any, **kwargs: Any,
) -> _ConfigurableModel: ... ) -> _ConfigurableModel: ...
@@ -64,13 +63,11 @@ def init_chat_model(
# name to the supported list in the docstring below. Do *not* change the order of the # name to the supported list in the docstring below. Do *not* change the order of the
# existing providers. # existing providers.
def init_chat_model( def init_chat_model(
model: Optional[str] = None, model: str | None = None,
*, *,
model_provider: Optional[str] = None, model_provider: str | None = None,
configurable_fields: Optional[ configurable_fields: Union[Literal["any"], list[str], tuple[str, ...]] | None = None,
Union[Literal["any"], list[str], tuple[str, ...]] config_prefix: str | None = None,
] = None,
config_prefix: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
) -> Union[BaseChatModel, _ConfigurableModel]: ) -> Union[BaseChatModel, _ConfigurableModel]:
"""Initialize a ChatModel from the model name and provider. """Initialize a ChatModel from the model name and provider.
@@ -326,7 +323,7 @@ def init_chat_model(
def _init_chat_model_helper( def _init_chat_model_helper(
model: str, model: str,
*, *,
model_provider: Optional[str] = None, model_provider: str | None = None,
**kwargs: Any, **kwargs: Any,
) -> BaseChatModel: ) -> BaseChatModel:
model, model_provider = _parse_model(model, model_provider) model, model_provider = _parse_model(model, model_provider)
@@ -446,9 +443,7 @@ def _init_chat_model_helper(
return ChatPerplexity(model=model, **kwargs) return ChatPerplexity(model=model, **kwargs)
supported = ", ".join(_SUPPORTED_PROVIDERS) supported = ", ".join(_SUPPORTED_PROVIDERS)
msg = ( msg = f"Unsupported {model_provider=}.\n\nSupported model providers are: {supported}"
f"Unsupported {model_provider=}.\n\nSupported model providers are: {supported}"
)
raise ValueError(msg) raise ValueError(msg)
@@ -476,7 +471,7 @@ _SUPPORTED_PROVIDERS = {
} }
def _attempt_infer_model_provider(model_name: str) -> Optional[str]: def _attempt_infer_model_provider(model_name: str) -> str | None:
if any(model_name.startswith(pre) for pre in ("gpt-3", "gpt-4", "o1", "o3")): if any(model_name.startswith(pre) for pre in ("gpt-3", "gpt-4", "o1", "o3")):
return "openai" return "openai"
if model_name.startswith("claude"): if model_name.startswith("claude"):
@@ -500,31 +495,24 @@ def _attempt_infer_model_provider(model_name: str) -> Optional[str]:
return None return None
def _parse_model(model: str, model_provider: Optional[str]) -> tuple[str, str]: def _parse_model(model: str, model_provider: str | None) -> tuple[str, str]:
if ( if not model_provider and ":" in model and model.split(":")[0] in _SUPPORTED_PROVIDERS:
not model_provider
and ":" in model
and model.split(":")[0] in _SUPPORTED_PROVIDERS
):
model_provider = model.split(":")[0] model_provider = model.split(":")[0]
model = ":".join(model.split(":")[1:]) model = ":".join(model.split(":")[1:])
model_provider = model_provider or _attempt_infer_model_provider(model) model_provider = model_provider or _attempt_infer_model_provider(model)
if not model_provider: if not model_provider:
msg = ( msg = (
f"Unable to infer model provider for {model=}, please specify " f"Unable to infer model provider for {model=}, please specify model_provider directly."
f"model_provider directly."
) )
raise ValueError(msg) raise ValueError(msg)
model_provider = model_provider.replace("-", "_").lower() model_provider = model_provider.replace("-", "_").lower()
return model, model_provider return model, model_provider
def _check_pkg(pkg: str, *, pkg_kebab: Optional[str] = None) -> None: def _check_pkg(pkg: str, *, pkg_kebab: str | None = None) -> None:
if not util.find_spec(pkg): if not util.find_spec(pkg):
pkg_kebab = pkg_kebab if pkg_kebab is not None else pkg.replace("_", "-") pkg_kebab = pkg_kebab if pkg_kebab is not None else pkg.replace("_", "-")
msg = ( msg = f"Unable to import {pkg}. Please install with `pip install -U {pkg_kebab}`"
f"Unable to import {pkg}. Please install with `pip install -U {pkg_kebab}`"
)
raise ImportError(msg) raise ImportError(msg)
@@ -539,16 +527,14 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
def __init__( def __init__(
self, self,
*, *,
default_config: Optional[dict] = None, default_config: dict | None = None,
configurable_fields: Union[Literal["any"], list[str], tuple[str, ...]] = "any", configurable_fields: Union[Literal["any"], list[str], tuple[str, ...]] = "any",
config_prefix: str = "", config_prefix: str = "",
queued_declarative_operations: Sequence[tuple[str, tuple, dict]] = (), queued_declarative_operations: Sequence[tuple[str, tuple, dict]] = (),
) -> None: ) -> None:
self._default_config: dict = default_config or {} self._default_config: dict = default_config or {}
self._configurable_fields: Union[Literal["any"], list[str]] = ( self._configurable_fields: Union[Literal["any"], list[str]] = (
configurable_fields configurable_fields if configurable_fields == "any" else list(configurable_fields)
if configurable_fields == "any"
else list(configurable_fields)
) )
self._config_prefix = ( self._config_prefix = (
config_prefix + "_" config_prefix + "_"
@@ -589,14 +575,14 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
msg += "." msg += "."
raise AttributeError(msg) raise AttributeError(msg)
def _model(self, config: Optional[RunnableConfig] = None) -> Runnable: def _model(self, config: RunnableConfig | None = None) -> Runnable:
params = {**self._default_config, **self._model_params(config)} params = {**self._default_config, **self._model_params(config)}
model = _init_chat_model_helper(**params) model = _init_chat_model_helper(**params)
for name, args, kwargs in self._queued_declarative_operations: for name, args, kwargs in self._queued_declarative_operations:
model = getattr(model, name)(*args, **kwargs) model = getattr(model, name)(*args, **kwargs)
return model return model
def _model_params(self, config: Optional[RunnableConfig]) -> dict: def _model_params(self, config: RunnableConfig | None) -> dict:
config = ensure_config(config) config = ensure_config(config)
model_params = { model_params = {
_remove_prefix(k, self._config_prefix): v _remove_prefix(k, self._config_prefix): v
@@ -604,14 +590,12 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
if k.startswith(self._config_prefix) if k.startswith(self._config_prefix)
} }
if self._configurable_fields != "any": if self._configurable_fields != "any":
model_params = { model_params = {k: v for k, v in model_params.items() if k in self._configurable_fields}
k: v for k, v in model_params.items() if k in self._configurable_fields
}
return model_params return model_params
def with_config( def with_config(
self, self,
config: Optional[RunnableConfig] = None, config: RunnableConfig | None = None,
**kwargs: Any, **kwargs: Any,
) -> _ConfigurableModel: ) -> _ConfigurableModel:
"""Bind config to a Runnable, returning a new Runnable.""" """Bind config to a Runnable, returning a new Runnable."""
@@ -662,7 +646,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
def invoke( def invoke(
self, self,
input: LanguageModelInput, input: LanguageModelInput,
config: Optional[RunnableConfig] = None, config: RunnableConfig | None = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
return self._model(config).invoke(input, config=config, **kwargs) return self._model(config).invoke(input, config=config, **kwargs)
@@ -671,7 +655,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
async def ainvoke( async def ainvoke(
self, self,
input: LanguageModelInput, input: LanguageModelInput,
config: Optional[RunnableConfig] = None, config: RunnableConfig | None = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
return await self._model(config).ainvoke(input, config=config, **kwargs) return await self._model(config).ainvoke(input, config=config, **kwargs)
@@ -680,8 +664,8 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
def stream( def stream(
self, self,
input: LanguageModelInput, input: LanguageModelInput,
config: Optional[RunnableConfig] = None, config: RunnableConfig | None = None,
**kwargs: Optional[Any], **kwargs: Any | None,
) -> Iterator[Any]: ) -> Iterator[Any]:
yield from self._model(config).stream(input, config=config, **kwargs) yield from self._model(config).stream(input, config=config, **kwargs)
@@ -689,8 +673,8 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
async def astream( async def astream(
self, self,
input: LanguageModelInput, input: LanguageModelInput,
config: Optional[RunnableConfig] = None, config: RunnableConfig | None = None,
**kwargs: Optional[Any], **kwargs: Any | None,
) -> AsyncIterator[Any]: ) -> AsyncIterator[Any]:
async for x in self._model(config).astream(input, config=config, **kwargs): async for x in self._model(config).astream(input, config=config, **kwargs):
yield x yield x
@@ -698,10 +682,10 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
def batch( def batch(
self, self,
inputs: list[LanguageModelInput], inputs: list[LanguageModelInput],
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, config: Union[RunnableConfig, list[RunnableConfig]] | None = None,
*, *,
return_exceptions: bool = False, return_exceptions: bool = False,
**kwargs: Optional[Any], **kwargs: Any | None,
) -> list[Any]: ) -> list[Any]:
config = config or None config = config or None
# If <= 1 config use the underlying models batch implementation. # If <= 1 config use the underlying models batch implementation.
@@ -726,10 +710,10 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
async def abatch( async def abatch(
self, self,
inputs: list[LanguageModelInput], inputs: list[LanguageModelInput],
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, config: Union[RunnableConfig, list[RunnableConfig]] | None = None,
*, *,
return_exceptions: bool = False, return_exceptions: bool = False,
**kwargs: Optional[Any], **kwargs: Any | None,
) -> list[Any]: ) -> list[Any]:
config = config or None config = config or None
# If <= 1 config use the underlying models batch implementation. # If <= 1 config use the underlying models batch implementation.
@@ -754,7 +738,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
def batch_as_completed( def batch_as_completed(
self, self,
inputs: Sequence[LanguageModelInput], inputs: Sequence[LanguageModelInput],
config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None, config: Union[RunnableConfig, Sequence[RunnableConfig]] | None = None,
*, *,
return_exceptions: bool = False, return_exceptions: bool = False,
**kwargs: Any, **kwargs: Any,
@@ -783,7 +767,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
async def abatch_as_completed( async def abatch_as_completed(
self, self,
inputs: Sequence[LanguageModelInput], inputs: Sequence[LanguageModelInput],
config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None, config: Union[RunnableConfig, Sequence[RunnableConfig]] | None = None,
*, *,
return_exceptions: bool = False, return_exceptions: bool = False,
**kwargs: Any, **kwargs: Any,
@@ -817,8 +801,8 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
def transform( def transform(
self, self,
input: Iterator[LanguageModelInput], input: Iterator[LanguageModelInput],
config: Optional[RunnableConfig] = None, config: RunnableConfig | None = None,
**kwargs: Optional[Any], **kwargs: Any | None,
) -> Iterator[Any]: ) -> Iterator[Any]:
yield from self._model(config).transform(input, config=config, **kwargs) yield from self._model(config).transform(input, config=config, **kwargs)
@@ -826,8 +810,8 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
async def atransform( async def atransform(
self, self,
input: AsyncIterator[LanguageModelInput], input: AsyncIterator[LanguageModelInput],
config: Optional[RunnableConfig] = None, config: RunnableConfig | None = None,
**kwargs: Optional[Any], **kwargs: Any | None,
) -> AsyncIterator[Any]: ) -> AsyncIterator[Any]:
async for x in self._model(config).atransform(input, config=config, **kwargs): async for x in self._model(config).atransform(input, config=config, **kwargs):
yield x yield x
@@ -836,16 +820,16 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
def astream_log( def astream_log(
self, self,
input: Any, input: Any,
config: Optional[RunnableConfig] = None, config: RunnableConfig | None = None,
*, *,
diff: Literal[True] = True, diff: Literal[True] = True,
with_streamed_output_list: bool = True, with_streamed_output_list: bool = True,
include_names: Optional[Sequence[str]] = None, include_names: Sequence[str] | None = None,
include_types: Optional[Sequence[str]] = None, include_types: Sequence[str] | None = None,
include_tags: Optional[Sequence[str]] = None, include_tags: Sequence[str] | None = None,
exclude_names: Optional[Sequence[str]] = None, exclude_names: Sequence[str] | None = None,
exclude_types: Optional[Sequence[str]] = None, exclude_types: Sequence[str] | None = None,
exclude_tags: Optional[Sequence[str]] = None, exclude_tags: Sequence[str] | None = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[RunLogPatch]: ... ) -> AsyncIterator[RunLogPatch]: ...
@@ -853,16 +837,16 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
def astream_log( def astream_log(
self, self,
input: Any, input: Any,
config: Optional[RunnableConfig] = None, config: RunnableConfig | None = None,
*, *,
diff: Literal[False], diff: Literal[False],
with_streamed_output_list: bool = True, with_streamed_output_list: bool = True,
include_names: Optional[Sequence[str]] = None, include_names: Sequence[str] | None = None,
include_types: Optional[Sequence[str]] = None, include_types: Sequence[str] | None = None,
include_tags: Optional[Sequence[str]] = None, include_tags: Sequence[str] | None = None,
exclude_names: Optional[Sequence[str]] = None, exclude_names: Sequence[str] | None = None,
exclude_types: Optional[Sequence[str]] = None, exclude_types: Sequence[str] | None = None,
exclude_tags: Optional[Sequence[str]] = None, exclude_tags: Sequence[str] | None = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[RunLog]: ... ) -> AsyncIterator[RunLog]: ...
@@ -870,16 +854,16 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
async def astream_log( async def astream_log(
self, self,
input: Any, input: Any,
config: Optional[RunnableConfig] = None, config: RunnableConfig | None = None,
*, *,
diff: bool = True, diff: bool = True,
with_streamed_output_list: bool = True, with_streamed_output_list: bool = True,
include_names: Optional[Sequence[str]] = None, include_names: Sequence[str] | None = None,
include_types: Optional[Sequence[str]] = None, include_types: Sequence[str] | None = None,
include_tags: Optional[Sequence[str]] = None, include_tags: Sequence[str] | None = None,
exclude_names: Optional[Sequence[str]] = None, exclude_names: Sequence[str] | None = None,
exclude_types: Optional[Sequence[str]] = None, exclude_types: Sequence[str] | None = None,
exclude_tags: Optional[Sequence[str]] = None, exclude_tags: Sequence[str] | None = None,
**kwargs: Any, **kwargs: Any,
) -> Union[AsyncIterator[RunLogPatch], AsyncIterator[RunLog]]: ) -> Union[AsyncIterator[RunLogPatch], AsyncIterator[RunLog]]:
async for x in self._model(config).astream_log( # type: ignore[call-overload, misc] async for x in self._model(config).astream_log( # type: ignore[call-overload, misc]
@@ -901,15 +885,15 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
async def astream_events( async def astream_events(
self, self,
input: Any, input: Any,
config: Optional[RunnableConfig] = None, config: RunnableConfig | None = None,
*, *,
version: Literal["v1", "v2"] = "v2", version: Literal["v1", "v2"] = "v2",
include_names: Optional[Sequence[str]] = None, include_names: Sequence[str] | None = None,
include_types: Optional[Sequence[str]] = None, include_types: Sequence[str] | None = None,
include_tags: Optional[Sequence[str]] = None, include_tags: Sequence[str] | None = None,
exclude_names: Optional[Sequence[str]] = None, exclude_names: Sequence[str] | None = None,
exclude_types: Optional[Sequence[str]] = None, exclude_types: Sequence[str] | None = None,
exclude_tags: Optional[Sequence[str]] = None, exclude_tags: Sequence[str] | None = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[StreamEvent]: ) -> AsyncIterator[StreamEvent]:
async for x in self._model(config).astream_events( async for x in self._model(config).astream_events(

View File

@@ -1,6 +1,6 @@
import functools import functools
from importlib import util from importlib import util
from typing import Any, Optional, Union from typing import Any, Union
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.runnables import Runnable from langchain_core.runnables import Runnable
@@ -19,9 +19,7 @@ _SUPPORTED_PROVIDERS = {
def _get_provider_list() -> str: def _get_provider_list() -> str:
"""Get formatted list of providers and their packages.""" """Get formatted list of providers and their packages."""
return "\n".join( return "\n".join(f" - {p}: {pkg.replace('_', '-')}" for p, pkg in _SUPPORTED_PROVIDERS.items())
f" - {p}: {pkg.replace('_', '-')}" for p, pkg in _SUPPORTED_PROVIDERS.items()
)
def _parse_model_string(model_name: str) -> tuple[str, str]: def _parse_model_string(model_name: str) -> tuple[str, str]:
@@ -82,7 +80,7 @@ def _parse_model_string(model_name: str) -> tuple[str, str]:
def _infer_model_and_provider( def _infer_model_and_provider(
model: str, model: str,
*, *,
provider: Optional[str] = None, provider: str | None = None,
) -> tuple[str, str]: ) -> tuple[str, str]:
if not model.strip(): if not model.strip():
msg = "Model name cannot be empty" msg = "Model name cannot be empty"
@@ -117,17 +115,14 @@ def _infer_model_and_provider(
def _check_pkg(pkg: str) -> None: def _check_pkg(pkg: str) -> None:
"""Check if a package is installed.""" """Check if a package is installed."""
if not util.find_spec(pkg): if not util.find_spec(pkg):
msg = ( msg = f"Could not import {pkg} python package. Please install it with `pip install {pkg}`"
f"Could not import {pkg} python package. "
f"Please install it with `pip install {pkg}`"
)
raise ImportError(msg) raise ImportError(msg)
def init_embeddings( def init_embeddings(
model: str, model: str,
*, *,
provider: Optional[str] = None, provider: str | None = None,
**kwargs: Any, **kwargs: Any,
) -> Union[Embeddings, Runnable[Any, list[float]]]: ) -> Union[Embeddings, Runnable[Any, list[float]]]:
"""Initialize an embeddings model from a model name and optional provider. """Initialize an embeddings model from a model name and optional provider.
@@ -182,9 +177,7 @@ def init_embeddings(
""" """
if not model: if not model:
providers = _SUPPORTED_PROVIDERS.keys() providers = _SUPPORTED_PROVIDERS.keys()
msg = ( msg = f"Must specify model name. Supported providers are: {', '.join(providers)}"
f"Must specify model name. Supported providers are: {', '.join(providers)}"
)
raise ValueError(msg) raise ValueError(msg)
provider, model_name = _infer_model_and_provider(model, provider=provider) provider, model_name = _infer_model_and_provider(model, provider=provider)

View File

@@ -13,7 +13,7 @@ import hashlib
import json import json
import uuid import uuid
import warnings import warnings
from typing import TYPE_CHECKING, Callable, Literal, Optional, Union, cast from typing import TYPE_CHECKING, Literal, Union, cast
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.utils.iter import batch_iterate from langchain_core.utils.iter import batch_iterate
@@ -21,7 +21,7 @@ from langchain_core.utils.iter import batch_iterate
from langchain.storage.encoder_backed import EncoderBackedStore from langchain.storage.encoder_backed import EncoderBackedStore
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Sequence from collections.abc import Callable, Sequence
from langchain_core.stores import BaseStore, ByteStore from langchain_core.stores import BaseStore, ByteStore
@@ -147,8 +147,8 @@ class CacheBackedEmbeddings(Embeddings):
underlying_embeddings: Embeddings, underlying_embeddings: Embeddings,
document_embedding_store: BaseStore[str, list[float]], document_embedding_store: BaseStore[str, list[float]],
*, *,
batch_size: Optional[int] = None, batch_size: int | None = None,
query_embedding_store: Optional[BaseStore[str, list[float]]] = None, query_embedding_store: BaseStore[str, list[float]] | None = None,
) -> None: ) -> None:
"""Initialize the embedder. """Initialize the embedder.
@@ -181,17 +181,15 @@ class CacheBackedEmbeddings(Embeddings):
vectors: list[Union[list[float], None]] = self.document_embedding_store.mget( vectors: list[Union[list[float], None]] = self.document_embedding_store.mget(
texts, texts,
) )
all_missing_indices: list[int] = [ all_missing_indices: list[int] = [i for i, vector in enumerate(vectors) if vector is None]
i for i, vector in enumerate(vectors) if vector is None
]
for missing_indices in batch_iterate(self.batch_size, all_missing_indices): for missing_indices in batch_iterate(self.batch_size, all_missing_indices):
missing_texts = [texts[i] for i in missing_indices] missing_texts = [texts[i] for i in missing_indices]
missing_vectors = self.underlying_embeddings.embed_documents(missing_texts) missing_vectors = self.underlying_embeddings.embed_documents(missing_texts)
self.document_embedding_store.mset( self.document_embedding_store.mset(
list(zip(missing_texts, missing_vectors)), list(zip(missing_texts, missing_vectors, strict=False)),
) )
for index, updated_vector in zip(missing_indices, missing_vectors): for index, updated_vector in zip(missing_indices, missing_vectors, strict=False):
vectors[index] = updated_vector vectors[index] = updated_vector
return cast( return cast(
@@ -212,12 +210,8 @@ class CacheBackedEmbeddings(Embeddings):
Returns: Returns:
A list of embeddings for the given texts. A list of embeddings for the given texts.
""" """
vectors: list[ vectors: list[Union[list[float], None]] = await self.document_embedding_store.amget(texts)
Union[list[float], None] all_missing_indices: list[int] = [i for i, vector in enumerate(vectors) if vector is None]
] = await self.document_embedding_store.amget(texts)
all_missing_indices: list[int] = [
i for i, vector in enumerate(vectors) if vector is None
]
# batch_iterate supports None batch_size which returns all elements at once # batch_iterate supports None batch_size which returns all elements at once
# as a single batch. # as a single batch.
@@ -227,9 +221,9 @@ class CacheBackedEmbeddings(Embeddings):
missing_texts, missing_texts,
) )
await self.document_embedding_store.amset( await self.document_embedding_store.amset(
list(zip(missing_texts, missing_vectors)), list(zip(missing_texts, missing_vectors, strict=False)),
) )
for index, updated_vector in zip(missing_indices, missing_vectors): for index, updated_vector in zip(missing_indices, missing_vectors, strict=False):
vectors[index] = updated_vector vectors[index] = updated_vector
return cast( return cast(
@@ -290,7 +284,7 @@ class CacheBackedEmbeddings(Embeddings):
document_embedding_cache: ByteStore, document_embedding_cache: ByteStore,
*, *,
namespace: str = "", namespace: str = "",
batch_size: Optional[int] = None, batch_size: int | None = None,
query_embedding_cache: Union[bool, ByteStore] = False, query_embedding_cache: Union[bool, ByteStore] = False,
key_encoder: Union[ key_encoder: Union[
Callable[[str], str], Callable[[str], str],

View File

@@ -1,8 +1,6 @@
from collections.abc import AsyncIterator, Iterator, Sequence from collections.abc import AsyncIterator, Callable, Iterator, Sequence
from typing import ( from typing import (
Any, Any,
Callable,
Optional,
TypeVar, TypeVar,
Union, Union,
) )
@@ -62,37 +60,29 @@ class EncoderBackedStore(BaseStore[K, V]):
self.value_serializer = value_serializer self.value_serializer = value_serializer
self.value_deserializer = value_deserializer self.value_deserializer = value_deserializer
def mget(self, keys: Sequence[K]) -> list[Optional[V]]: def mget(self, keys: Sequence[K]) -> list[V | None]:
"""Get the values associated with the given keys.""" """Get the values associated with the given keys."""
encoded_keys: list[str] = [self.key_encoder(key) for key in keys] encoded_keys: list[str] = [self.key_encoder(key) for key in keys]
values = self.store.mget(encoded_keys) values = self.store.mget(encoded_keys)
return [ return [self.value_deserializer(value) if value is not None else value for value in values]
self.value_deserializer(value) if value is not None else value
for value in values
]
async def amget(self, keys: Sequence[K]) -> list[Optional[V]]: async def amget(self, keys: Sequence[K]) -> list[V | None]:
"""Get the values associated with the given keys.""" """Get the values associated with the given keys."""
encoded_keys: list[str] = [self.key_encoder(key) for key in keys] encoded_keys: list[str] = [self.key_encoder(key) for key in keys]
values = await self.store.amget(encoded_keys) values = await self.store.amget(encoded_keys)
return [ return [self.value_deserializer(value) if value is not None else value for value in values]
self.value_deserializer(value) if value is not None else value
for value in values
]
def mset(self, key_value_pairs: Sequence[tuple[K, V]]) -> None: def mset(self, key_value_pairs: Sequence[tuple[K, V]]) -> None:
"""Set the values for the given keys.""" """Set the values for the given keys."""
encoded_pairs = [ encoded_pairs = [
(self.key_encoder(key), self.value_serializer(value)) (self.key_encoder(key), self.value_serializer(value)) for key, value in key_value_pairs
for key, value in key_value_pairs
] ]
self.store.mset(encoded_pairs) self.store.mset(encoded_pairs)
async def amset(self, key_value_pairs: Sequence[tuple[K, V]]) -> None: async def amset(self, key_value_pairs: Sequence[tuple[K, V]]) -> None:
"""Set the values for the given keys.""" """Set the values for the given keys."""
encoded_pairs = [ encoded_pairs = [
(self.key_encoder(key), self.value_serializer(value)) (self.key_encoder(key), self.value_serializer(value)) for key, value in key_value_pairs
for key, value in key_value_pairs
] ]
await self.store.amset(encoded_pairs) await self.store.amset(encoded_pairs)
@@ -109,7 +99,7 @@ class EncoderBackedStore(BaseStore[K, V]):
def yield_keys( def yield_keys(
self, self,
*, *,
prefix: Optional[str] = None, prefix: str | None = None,
) -> Union[Iterator[K], Iterator[str]]: ) -> Union[Iterator[K], Iterator[str]]:
"""Get an iterator over keys that match the given prefix.""" """Get an iterator over keys that match the given prefix."""
# For the time being this does not return K, but str # For the time being this does not return K, but str
@@ -119,7 +109,7 @@ class EncoderBackedStore(BaseStore[K, V]):
async def ayield_keys( async def ayield_keys(
self, self,
*, *,
prefix: Optional[str] = None, prefix: str | None = None,
) -> Union[AsyncIterator[K], AsyncIterator[str]]: ) -> Union[AsyncIterator[K], AsyncIterator[str]]:
"""Get an iterator over keys that match the given prefix.""" """Get an iterator over keys that match the given prefix."""
# For the time being this does not return K, but str # For the time being this does not return K, but str

View File

@@ -5,7 +5,7 @@ build-backend = "pdm.backend"
[project] [project]
authors = [] authors = []
license = { text = "MIT" } license = { text = "MIT" }
requires-python = ">=3.9, <4.0" requires-python = ">=3.10"
dependencies = [ dependencies = [
"langchain-core<1.0.0,>=0.3.66", "langchain-core<1.0.0,>=0.3.66",
"langchain-text-splitters<1.0.0,>=0.3.8", "langchain-text-splitters<1.0.0,>=0.3.8",
@@ -46,23 +46,22 @@ repository = "https://github.com/langchain-ai/langchain"
[dependency-groups] [dependency-groups]
test = [ test = [
"pytest<9,>=8", "pytest<9,>=8",
"pytest-cov<5.0.0,>=4.0.0", "pytest-cov>=4.0.0",
"pytest-watcher<1.0.0,>=0.2.6", "pytest-watcher>=0.2.6",
"pytest-asyncio<1.0.0,>=0.23.2", "pytest-asyncio>=0.23.2",
"pytest-socket<1.0.0,>=0.6.0", "pytest-socket>=0.6.0",
"syrupy<5.0.0,>=4.0.2", "syrupy>=4.0.2",
"pytest-xdist<4.0.0,>=3.6.1", "pytest-xdist>=3.6.1",
"blockbuster<1.6,>=1.5.18", "blockbuster>=1.5.18",
"langchain-tests", "langchain-tests",
"langchain-core",
"langchain-text-splitters", "langchain-text-splitters",
"langchain-openai", "langchain-openai",
"toml>=0.10.2", "toml>=0.10.2",
] ]
codespell = ["codespell<3.0.0,>=2.2.0"] codespell = ["codespell<3.0.0,>=2.2.0"]
lint = [ lint = [
"ruff<0.13,>=0.12.2", "ruff>=0.12.2",
"mypy<1.16,>=1.15", "mypy>=1.15",
] ]
typing = [ typing = [
"types-toml>=0.10.8.20240310", "types-toml>=0.10.8.20240310",
@@ -70,11 +69,10 @@ typing = [
test_integration = [ test_integration = [
"vcrpy>=7.0", "vcrpy>=7.0",
"urllib3<2; python_version < \"3.10\"", "wrapt>=1.15.0",
"wrapt<2.0.0,>=1.15.0", "python-dotenv>=1.0.0",
"python-dotenv<2.0.0,>=1.0.0", "cassio>=0.1.0",
"cassio<1.0.0,>=0.1.0", "langchainhub>=0.1.16",
"langchainhub<1.0.0,>=0.1.16",
"langchain-core", "langchain-core",
"langchain-text-splitters", "langchain-text-splitters",
] ]
@@ -86,8 +84,9 @@ langchain-text-splitters = { path = "../text-splitters", editable = true }
langchain-openai = { path = "../partners/openai", editable = true } langchain-openai = { path = "../partners/openai", editable = true }
[tool.ruff] [tool.ruff]
target-version = "py39" target-version = "py310"
exclude = ["tests/integration_tests/examples/non-utf8-encoding.py"] exclude = ["tests/integration_tests/examples/non-utf8-encoding.py"]
line-length = 100
[tool.mypy] [tool.mypy]
strict = "True" strict = "True"

View File

@@ -31,9 +31,7 @@ async def test_init_chat_model_chain() -> None:
chain = prompt | model_with_config chain = prompt | model_with_config
output = chain.invoke({"input": "bar"}) output = chain.invoke({"input": "bar"})
assert isinstance(output, AIMessage) assert isinstance(output, AIMessage)
events = [ events = [event async for event in chain.astream_events({"input": "bar"}, version="v2")]
event async for event in chain.astream_events({"input": "bar"}, version="v2")
]
assert events assert events

View File

@@ -1,5 +1,5 @@
import os import os
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING
from unittest import mock from unittest import mock
import pytest import pytest
@@ -38,7 +38,7 @@ def test_all_imports() -> None:
("mixtral-8x7b-32768", "groq"), ("mixtral-8x7b-32768", "groq"),
], ],
) )
def test_init_chat_model(model_name: str, model_provider: Optional[str]) -> None: def test_init_chat_model(model_name: str, model_provider: str | None) -> None:
llm1: BaseChatModel = init_chat_model( llm1: BaseChatModel = init_chat_model(
model_name, model_name,
model_provider=model_provider, model_provider=model_provider,
@@ -222,7 +222,7 @@ def test_configurable_with_default() -> None:
config={"configurable": {"my_model_model": "claude-3-sonnet-20240229"}} config={"configurable": {"my_model_model": "claude-3-sonnet-20240229"}}
) )
""" # noqa: E501 """
model = init_chat_model("gpt-4o", configurable_fields="any", config_prefix="bar") model = init_chat_model("gpt-4o", configurable_fields="any", config_prefix="bar")
for method in ( for method in (
"invoke", "invoke",

View File

@@ -53,9 +53,7 @@ def pytest_addoption(parser: pytest.Parser) -> None:
) )
def pytest_collection_modifyitems( def pytest_collection_modifyitems(config: pytest.Config, items: Sequence[pytest.Function]) -> None:
config: pytest.Config, items: Sequence[pytest.Function]
) -> None:
"""Add implementations for handling custom markers. """Add implementations for handling custom markers.
At the moment, this adds support for a custom `requires` marker. At the moment, this adds support for a custom `requires` marker.

View File

@@ -113,9 +113,7 @@ async def test_aembed_documents(cache_embeddings: CacheBackedEmbeddings) -> None
vectors = await cache_embeddings.aembed_documents(texts) vectors = await cache_embeddings.aembed_documents(texts)
expected_vectors: list[list[float]] = [[1, 2.0], [2.0, 3.0], [1.0, 2.0], [3.0, 4.0]] expected_vectors: list[list[float]] = [[1, 2.0], [2.0, 3.0], [1.0, 2.0], [3.0, 4.0]]
assert vectors == expected_vectors assert vectors == expected_vectors
keys = [ keys = [key async for key in cache_embeddings.document_embedding_store.ayield_keys()]
key async for key in cache_embeddings.document_embedding_store.ayield_keys()
]
assert len(keys) == 4 assert len(keys) == 4
# UUID is expected to be the same for the same text # UUID is expected to be the same for the same text
assert keys[0] == "test_namespace812b86c1-8ebf-5483-95c6-c95cf2b52d12" assert keys[0] == "test_namespace812b86c1-8ebf-5483-95c6-c95cf2b52d12"
@@ -128,10 +126,7 @@ async def test_aembed_documents_batch(
texts = ["1", "22", "a", "333", "RAISE_EXCEPTION"] texts = ["1", "22", "a", "333", "RAISE_EXCEPTION"]
with contextlib.suppress(ValueError): with contextlib.suppress(ValueError):
await cache_embeddings_batch.aembed_documents(texts) await cache_embeddings_batch.aembed_documents(texts)
keys = [ keys = [key async for key in cache_embeddings_batch.document_embedding_store.ayield_keys()]
key
async for key in cache_embeddings_batch.document_embedding_store.ayield_keys()
]
# only the first batch of three embeddings should exist # only the first batch of three embeddings should exist
assert len(keys) == 3 assert len(keys) == 3
# UUID is expected to be the same for the same text # UUID is expected to be the same for the same text

View File

@@ -13,9 +13,7 @@ def test_import_all() -> None:
library_code = PKG_ROOT / "langchain" library_code = PKG_ROOT / "langchain"
for path in library_code.rglob("*.py"): for path in library_code.rglob("*.py"):
# Calculate the relative path to the module # Calculate the relative path to the module
module_name = ( module_name = path.relative_to(PKG_ROOT).with_suffix("").as_posix().replace("/", ".")
path.relative_to(PKG_ROOT).with_suffix("").as_posix().replace("/", ".")
)
if module_name.endswith("__init__"): if module_name.endswith("__init__"):
# Without init # Without init
module_name = module_name.rsplit(".", 1)[0] module_name = module_name.rsplit(".", 1)[0]
@@ -39,9 +37,7 @@ def test_import_all_using_dir() -> None:
library_code = PKG_ROOT / "langchain" library_code = PKG_ROOT / "langchain"
for path in library_code.rglob("*.py"): for path in library_code.rglob("*.py"):
# Calculate the relative path to the module # Calculate the relative path to the module
module_name = ( module_name = path.relative_to(PKG_ROOT).with_suffix("").as_posix().replace("/", ".")
path.relative_to(PKG_ROOT).with_suffix("").as_posix().replace("/", ".")
)
if module_name.endswith("__init__"): if module_name.endswith("__init__"):
# Without init # Without init
module_name = module_name.rsplit(".", 1)[0] module_name = module_name.rsplit(".", 1)[0]

2415
libs/langchain_v1/uv.lock generated

File diff suppressed because it is too large Load Diff