chore(langchain): drop Python 3.9 to prep for v1 (#32704)

Python 3.9 EOL is October 2025, so we're going to drop it for the v1
alpha release.
This commit is contained in:
Sydney Runkle
2025-08-26 19:16:42 -04:00
committed by GitHub
parent 3d08b6bd11
commit c6c7fce6c9
17 changed files with 1168 additions and 1628 deletions

View File

@@ -32,9 +32,4 @@ def format_document_xml(doc: Document) -> str:
if doc.metadata:
metadata_items = [f"{k}: {v!s}" for k, v in doc.metadata.items()]
metadata_str = f"<metadata>{', '.join(metadata_items)}</metadata>"
return (
f"<document>{id_str}"
f"<content>{doc.page_content}</content>"
f"{metadata_str}"
f"</document>"
)
return f"<document>{id_str}<content>{doc.page_content}</content>{metadata_str}</document>"

View File

@@ -12,10 +12,10 @@ particularly for summarization chains and other document processing workflows.
from __future__ import annotations
import inspect
from typing import TYPE_CHECKING, Callable, Union
from typing import TYPE_CHECKING, Union
if TYPE_CHECKING:
from collections.abc import Awaitable
from collections.abc import Awaitable, Callable
from langchain_core.messages import MessageLikeRepresentation
from langgraph.runtime import Runtime
@@ -92,9 +92,7 @@ async def aresolve_prompt(
str,
None,
Callable[[StateT, Runtime[ContextT]], list[MessageLikeRepresentation]],
Callable[
[StateT, Runtime[ContextT]], Awaitable[list[MessageLikeRepresentation]]
],
Callable[[StateT, Runtime[ContextT]], Awaitable[list[MessageLikeRepresentation]]],
],
state: StateT,
runtime: Runtime[ContextT],

View File

@@ -2,11 +2,10 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, ClassVar, Protocol, TypeVar, Union
from typing import TYPE_CHECKING, Any, ClassVar, Protocol, TypeAlias, TypeVar, Union
from langgraph.graph._node import StateNode
from pydantic import BaseModel
from typing_extensions import TypeAlias
if TYPE_CHECKING:
from dataclasses import Field

View File

@@ -7,10 +7,8 @@ from typing import (
TYPE_CHECKING,
Annotated,
Any,
Callable,
Generic,
Literal,
Optional,
Union,
cast,
)
@@ -26,6 +24,8 @@ from langchain._internal._utils import RunnableCallable
from langchain.chat_models import init_chat_model
if TYPE_CHECKING:
from collections.abc import Callable
from langchain_core.documents import Document
from langchain_core.language_models.chat_models import BaseChatModel
@@ -154,7 +154,7 @@ class _MapReduceExtractor(Generic[ContextT]):
StateNode,
] = "default_reducer",
context_schema: type[ContextT] | None = None,
response_format: Optional[type[BaseModel]] = None,
response_format: type[BaseModel] | None = None,
) -> None:
"""Initialize the MapReduceExtractor.
@@ -190,9 +190,7 @@ class _MapReduceExtractor(Generic[ContextT]):
if isinstance(model, str):
model = init_chat_model(model)
self.model = (
model.with_structured_output(response_format) if response_format else model
)
self.model = model.with_structured_output(response_format) if response_format else model
self.map_prompt = map_prompt
self.reduce_prompt = reduce_prompt
self.reduce = reduce
@@ -342,9 +340,7 @@ class _MapReduceExtractor(Generic[ContextT]):
config: RunnableConfig,
) -> dict[str, list[ExtractionResult]]:
prompt = await self._aget_map_prompt(state, runtime)
response = cast(
"AIMessage", await self.model.ainvoke(prompt, config=config)
)
response = cast("AIMessage", await self.model.ainvoke(prompt, config=config))
result = response if self.response_format else response.text()
extraction_result: ExtractionResult = {
"indexes": state["indexes"],
@@ -375,9 +371,7 @@ class _MapReduceExtractor(Generic[ContextT]):
config: RunnableConfig,
) -> MapReduceNodeUpdate:
prompt = await self._aget_reduce_prompt(state, runtime)
response = cast(
"AIMessage", await self.model.ainvoke(prompt, config=config)
)
response = cast("AIMessage", await self.model.ainvoke(prompt, config=config))
result = response if self.response_format else response.text()
return {"result": result}
@@ -411,7 +405,7 @@ class _MapReduceExtractor(Generic[ContextT]):
# Add-conditional edges doesn't explicitly type Send
builder.add_conditional_edges(
"continue_to_map",
self.continue_to_map, # type: ignore[arg-type]
self.continue_to_map,
["map_process"],
)
@@ -422,7 +416,7 @@ class _MapReduceExtractor(Generic[ContextT]):
builder.add_edge("map_process", "reduce_process")
builder.add_edge("reduce_process", END)
else:
reduce_node = cast("StateNode", self.reduce)
reduce_node = self.reduce
# The type is ignored here. Requires parameterizing with generics.
builder.add_node("reduce_process", reduce_node) # type: ignore[arg-type]
builder.add_edge("map_process", "reduce_process")
@@ -450,7 +444,7 @@ def create_map_reduce_chain(
StateNode,
] = "default_reducer",
context_schema: type[ContextT] | None = None,
response_format: Optional[type[BaseModel]] = None,
response_format: type[BaseModel] | None = None,
) -> StateGraph[MapReduceState, ContextT, InputSchema, OutputSchema]:
"""Create a map-reduce document extraction chain.

View File

@@ -5,9 +5,7 @@ from __future__ import annotations
from typing import (
TYPE_CHECKING,
Any,
Callable,
Generic,
Optional,
Union,
cast,
)
@@ -24,6 +22,8 @@ from langchain._internal._utils import RunnableCallable
from langchain.chat_models import init_chat_model
if TYPE_CHECKING:
from collections.abc import Callable
from langchain_core.language_models.chat_models import BaseChatModel
# Used for type checking, but IDEs may not recognize it inside the cast.
@@ -149,7 +149,7 @@ class _Extractor(Generic[ContextT]):
],
] = None,
context_schema: type[ContextT] | None = None,
response_format: Optional[type[BaseModel]] = None,
response_format: type[BaseModel] | None = None,
) -> None:
"""Initialize the Extractor.
@@ -173,9 +173,7 @@ class _Extractor(Generic[ContextT]):
if isinstance(model, str):
model = init_chat_model(model)
self.model = (
model.with_structured_output(response_format) if response_format else model
)
self.model = model.with_structured_output(response_format) if response_format else model
self.initial_prompt = prompt
self.refine_prompt = refine_prompt
self.context_schema = context_schema
@@ -188,9 +186,7 @@ class _Extractor(Generic[ContextT]):
# Choose default prompt based on structured output format
default_prompt = (
DEFAULT_STRUCTURED_INIT_PROMPT
if self.response_format
else DEFAULT_INIT_PROMPT
DEFAULT_STRUCTURED_INIT_PROMPT if self.response_format else DEFAULT_INIT_PROMPT
)
return resolve_prompt(
@@ -209,9 +205,7 @@ class _Extractor(Generic[ContextT]):
# Choose default prompt based on structured output format
default_prompt = (
DEFAULT_STRUCTURED_INIT_PROMPT
if self.response_format
else DEFAULT_INIT_PROMPT
DEFAULT_STRUCTURED_INIT_PROMPT if self.response_format else DEFAULT_INIT_PROMPT
)
return await aresolve_prompt(
@@ -246,9 +240,7 @@ class _Extractor(Generic[ContextT]):
# Choose default prompt based on structured output format
default_prompt = (
DEFAULT_STRUCTURED_REFINE_PROMPT
if self.response_format
else DEFAULT_REFINE_PROMPT
DEFAULT_STRUCTURED_REFINE_PROMPT if self.response_format else DEFAULT_REFINE_PROMPT
)
return resolve_prompt(
@@ -283,9 +275,7 @@ class _Extractor(Generic[ContextT]):
# Choose default prompt based on structured output format
default_prompt = (
DEFAULT_STRUCTURED_REFINE_PROMPT
if self.response_format
else DEFAULT_REFINE_PROMPT
DEFAULT_STRUCTURED_REFINE_PROMPT if self.response_format else DEFAULT_REFINE_PROMPT
)
return await aresolve_prompt(
@@ -340,16 +330,12 @@ class _Extractor(Generic[ContextT]):
if "result" not in state or state["result"] == "":
# Initial processing
prompt = await self._aget_initial_prompt(state, runtime)
response = cast(
"AIMessage", await self.model.ainvoke(prompt, config=config)
)
response = cast("AIMessage", await self.model.ainvoke(prompt, config=config))
result = response if self.response_format else response.text()
return {"result": result}
# Refinement
prompt = await self._aget_refine_prompt(state, runtime)
response = cast(
"AIMessage", await self.model.ainvoke(prompt, config=config)
)
response = cast("AIMessage", await self.model.ainvoke(prompt, config=config))
result = response if self.response_format else response.text()
return {"result": result}
@@ -388,7 +374,7 @@ def create_stuff_documents_chain(
Callable[[ExtractionState, Runtime[ContextT]], list[MessageLikeRepresentation]],
] = None,
context_schema: type[ContextT] | None = None,
response_format: Optional[type[BaseModel]] = None,
response_format: type[BaseModel] | None = None,
) -> StateGraph[ExtractionState, ContextT, InputSchema, OutputSchema]:
"""Create a stuff documents chain for processing documents.

View File

@@ -5,9 +5,8 @@ from importlib import util
from typing import (
TYPE_CHECKING,
Any,
Callable,
Literal,
Optional,
TypeAlias,
Union,
cast,
overload,
@@ -16,10 +15,10 @@ from typing import (
from langchain_core.language_models import BaseChatModel, LanguageModelInput
from langchain_core.messages import AnyMessage, BaseMessage
from langchain_core.runnables import Runnable, RunnableConfig, ensure_config
from typing_extensions import TypeAlias, override
from typing_extensions import override
if TYPE_CHECKING:
from collections.abc import AsyncIterator, Iterator, Sequence
from collections.abc import AsyncIterator, Callable, Iterator, Sequence
from langchain_core.runnables.schema import StreamEvent
from langchain_core.tools import BaseTool
@@ -31,9 +30,9 @@ if TYPE_CHECKING:
def init_chat_model(
model: str,
*,
model_provider: Optional[str] = None,
model_provider: str | None = None,
configurable_fields: Literal[None] = None,
config_prefix: Optional[str] = None,
config_prefix: str | None = None,
**kwargs: Any,
) -> BaseChatModel: ...
@@ -42,20 +41,20 @@ def init_chat_model(
def init_chat_model(
model: Literal[None] = None,
*,
model_provider: Optional[str] = None,
model_provider: str | None = None,
configurable_fields: Literal[None] = None,
config_prefix: Optional[str] = None,
config_prefix: str | None = None,
**kwargs: Any,
) -> _ConfigurableModel: ...
@overload
def init_chat_model(
model: Optional[str] = None,
model: str | None = None,
*,
model_provider: Optional[str] = None,
model_provider: str | None = None,
configurable_fields: Union[Literal["any"], list[str], tuple[str, ...]] = ...,
config_prefix: Optional[str] = None,
config_prefix: str | None = None,
**kwargs: Any,
) -> _ConfigurableModel: ...
@@ -64,13 +63,11 @@ def init_chat_model(
# name to the supported list in the docstring below. Do *not* change the order of the
# existing providers.
def init_chat_model(
model: Optional[str] = None,
model: str | None = None,
*,
model_provider: Optional[str] = None,
configurable_fields: Optional[
Union[Literal["any"], list[str], tuple[str, ...]]
] = None,
config_prefix: Optional[str] = None,
model_provider: str | None = None,
configurable_fields: Union[Literal["any"], list[str], tuple[str, ...]] | None = None,
config_prefix: str | None = None,
**kwargs: Any,
) -> Union[BaseChatModel, _ConfigurableModel]:
"""Initialize a ChatModel from the model name and provider.
@@ -326,7 +323,7 @@ def init_chat_model(
def _init_chat_model_helper(
model: str,
*,
model_provider: Optional[str] = None,
model_provider: str | None = None,
**kwargs: Any,
) -> BaseChatModel:
model, model_provider = _parse_model(model, model_provider)
@@ -446,9 +443,7 @@ def _init_chat_model_helper(
return ChatPerplexity(model=model, **kwargs)
supported = ", ".join(_SUPPORTED_PROVIDERS)
msg = (
f"Unsupported {model_provider=}.\n\nSupported model providers are: {supported}"
)
msg = f"Unsupported {model_provider=}.\n\nSupported model providers are: {supported}"
raise ValueError(msg)
@@ -476,7 +471,7 @@ _SUPPORTED_PROVIDERS = {
}
def _attempt_infer_model_provider(model_name: str) -> Optional[str]:
def _attempt_infer_model_provider(model_name: str) -> str | None:
if any(model_name.startswith(pre) for pre in ("gpt-3", "gpt-4", "o1", "o3")):
return "openai"
if model_name.startswith("claude"):
@@ -500,31 +495,24 @@ def _attempt_infer_model_provider(model_name: str) -> Optional[str]:
return None
def _parse_model(model: str, model_provider: Optional[str]) -> tuple[str, str]:
if (
not model_provider
and ":" in model
and model.split(":")[0] in _SUPPORTED_PROVIDERS
):
def _parse_model(model: str, model_provider: str | None) -> tuple[str, str]:
if not model_provider and ":" in model and model.split(":")[0] in _SUPPORTED_PROVIDERS:
model_provider = model.split(":")[0]
model = ":".join(model.split(":")[1:])
model_provider = model_provider or _attempt_infer_model_provider(model)
if not model_provider:
msg = (
f"Unable to infer model provider for {model=}, please specify "
f"model_provider directly."
f"Unable to infer model provider for {model=}, please specify model_provider directly."
)
raise ValueError(msg)
model_provider = model_provider.replace("-", "_").lower()
return model, model_provider
def _check_pkg(pkg: str, *, pkg_kebab: Optional[str] = None) -> None:
def _check_pkg(pkg: str, *, pkg_kebab: str | None = None) -> None:
if not util.find_spec(pkg):
pkg_kebab = pkg_kebab if pkg_kebab is not None else pkg.replace("_", "-")
msg = (
f"Unable to import {pkg}. Please install with `pip install -U {pkg_kebab}`"
)
msg = f"Unable to import {pkg}. Please install with `pip install -U {pkg_kebab}`"
raise ImportError(msg)
@@ -539,16 +527,14 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
def __init__(
self,
*,
default_config: Optional[dict] = None,
default_config: dict | None = None,
configurable_fields: Union[Literal["any"], list[str], tuple[str, ...]] = "any",
config_prefix: str = "",
queued_declarative_operations: Sequence[tuple[str, tuple, dict]] = (),
) -> None:
self._default_config: dict = default_config or {}
self._configurable_fields: Union[Literal["any"], list[str]] = (
configurable_fields
if configurable_fields == "any"
else list(configurable_fields)
configurable_fields if configurable_fields == "any" else list(configurable_fields)
)
self._config_prefix = (
config_prefix + "_"
@@ -589,14 +575,14 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
msg += "."
raise AttributeError(msg)
def _model(self, config: Optional[RunnableConfig] = None) -> Runnable:
def _model(self, config: RunnableConfig | None = None) -> Runnable:
params = {**self._default_config, **self._model_params(config)}
model = _init_chat_model_helper(**params)
for name, args, kwargs in self._queued_declarative_operations:
model = getattr(model, name)(*args, **kwargs)
return model
def _model_params(self, config: Optional[RunnableConfig]) -> dict:
def _model_params(self, config: RunnableConfig | None) -> dict:
config = ensure_config(config)
model_params = {
_remove_prefix(k, self._config_prefix): v
@@ -604,14 +590,12 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
if k.startswith(self._config_prefix)
}
if self._configurable_fields != "any":
model_params = {
k: v for k, v in model_params.items() if k in self._configurable_fields
}
model_params = {k: v for k, v in model_params.items() if k in self._configurable_fields}
return model_params
def with_config(
self,
config: Optional[RunnableConfig] = None,
config: RunnableConfig | None = None,
**kwargs: Any,
) -> _ConfigurableModel:
"""Bind config to a Runnable, returning a new Runnable."""
@@ -662,7 +646,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
def invoke(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
config: RunnableConfig | None = None,
**kwargs: Any,
) -> Any:
return self._model(config).invoke(input, config=config, **kwargs)
@@ -671,7 +655,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
async def ainvoke(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
config: RunnableConfig | None = None,
**kwargs: Any,
) -> Any:
return await self._model(config).ainvoke(input, config=config, **kwargs)
@@ -680,8 +664,8 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
def stream(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
config: RunnableConfig | None = None,
**kwargs: Any | None,
) -> Iterator[Any]:
yield from self._model(config).stream(input, config=config, **kwargs)
@@ -689,8 +673,8 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
async def astream(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
config: RunnableConfig | None = None,
**kwargs: Any | None,
) -> AsyncIterator[Any]:
async for x in self._model(config).astream(input, config=config, **kwargs):
yield x
@@ -698,10 +682,10 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
def batch(
self,
inputs: list[LanguageModelInput],
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
config: Union[RunnableConfig, list[RunnableConfig]] | None = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
**kwargs: Any | None,
) -> list[Any]:
config = config or None
# If <= 1 config use the underlying models batch implementation.
@@ -726,10 +710,10 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
async def abatch(
self,
inputs: list[LanguageModelInput],
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
config: Union[RunnableConfig, list[RunnableConfig]] | None = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
**kwargs: Any | None,
) -> list[Any]:
config = config or None
# If <= 1 config use the underlying models batch implementation.
@@ -754,7 +738,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
def batch_as_completed(
self,
inputs: Sequence[LanguageModelInput],
config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None,
config: Union[RunnableConfig, Sequence[RunnableConfig]] | None = None,
*,
return_exceptions: bool = False,
**kwargs: Any,
@@ -783,7 +767,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
async def abatch_as_completed(
self,
inputs: Sequence[LanguageModelInput],
config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None,
config: Union[RunnableConfig, Sequence[RunnableConfig]] | None = None,
*,
return_exceptions: bool = False,
**kwargs: Any,
@@ -817,8 +801,8 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
def transform(
self,
input: Iterator[LanguageModelInput],
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
config: RunnableConfig | None = None,
**kwargs: Any | None,
) -> Iterator[Any]:
yield from self._model(config).transform(input, config=config, **kwargs)
@@ -826,8 +810,8 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
async def atransform(
self,
input: AsyncIterator[LanguageModelInput],
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
config: RunnableConfig | None = None,
**kwargs: Any | None,
) -> AsyncIterator[Any]:
async for x in self._model(config).atransform(input, config=config, **kwargs):
yield x
@@ -836,16 +820,16 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
def astream_log(
self,
input: Any,
config: Optional[RunnableConfig] = None,
config: RunnableConfig | None = None,
*,
diff: Literal[True] = True,
with_streamed_output_list: bool = True,
include_names: Optional[Sequence[str]] = None,
include_types: Optional[Sequence[str]] = None,
include_tags: Optional[Sequence[str]] = None,
exclude_names: Optional[Sequence[str]] = None,
exclude_types: Optional[Sequence[str]] = None,
exclude_tags: Optional[Sequence[str]] = None,
include_names: Sequence[str] | None = None,
include_types: Sequence[str] | None = None,
include_tags: Sequence[str] | None = None,
exclude_names: Sequence[str] | None = None,
exclude_types: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
**kwargs: Any,
) -> AsyncIterator[RunLogPatch]: ...
@@ -853,16 +837,16 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
def astream_log(
self,
input: Any,
config: Optional[RunnableConfig] = None,
config: RunnableConfig | None = None,
*,
diff: Literal[False],
with_streamed_output_list: bool = True,
include_names: Optional[Sequence[str]] = None,
include_types: Optional[Sequence[str]] = None,
include_tags: Optional[Sequence[str]] = None,
exclude_names: Optional[Sequence[str]] = None,
exclude_types: Optional[Sequence[str]] = None,
exclude_tags: Optional[Sequence[str]] = None,
include_names: Sequence[str] | None = None,
include_types: Sequence[str] | None = None,
include_tags: Sequence[str] | None = None,
exclude_names: Sequence[str] | None = None,
exclude_types: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
**kwargs: Any,
) -> AsyncIterator[RunLog]: ...
@@ -870,16 +854,16 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
async def astream_log(
self,
input: Any,
config: Optional[RunnableConfig] = None,
config: RunnableConfig | None = None,
*,
diff: bool = True,
with_streamed_output_list: bool = True,
include_names: Optional[Sequence[str]] = None,
include_types: Optional[Sequence[str]] = None,
include_tags: Optional[Sequence[str]] = None,
exclude_names: Optional[Sequence[str]] = None,
exclude_types: Optional[Sequence[str]] = None,
exclude_tags: Optional[Sequence[str]] = None,
include_names: Sequence[str] | None = None,
include_types: Sequence[str] | None = None,
include_tags: Sequence[str] | None = None,
exclude_names: Sequence[str] | None = None,
exclude_types: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
**kwargs: Any,
) -> Union[AsyncIterator[RunLogPatch], AsyncIterator[RunLog]]:
async for x in self._model(config).astream_log( # type: ignore[call-overload, misc]
@@ -901,15 +885,15 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
async def astream_events(
self,
input: Any,
config: Optional[RunnableConfig] = None,
config: RunnableConfig | None = None,
*,
version: Literal["v1", "v2"] = "v2",
include_names: Optional[Sequence[str]] = None,
include_types: Optional[Sequence[str]] = None,
include_tags: Optional[Sequence[str]] = None,
exclude_names: Optional[Sequence[str]] = None,
exclude_types: Optional[Sequence[str]] = None,
exclude_tags: Optional[Sequence[str]] = None,
include_names: Sequence[str] | None = None,
include_types: Sequence[str] | None = None,
include_tags: Sequence[str] | None = None,
exclude_names: Sequence[str] | None = None,
exclude_types: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
**kwargs: Any,
) -> AsyncIterator[StreamEvent]:
async for x in self._model(config).astream_events(

View File

@@ -1,6 +1,6 @@
import functools
from importlib import util
from typing import Any, Optional, Union
from typing import Any, Union
from langchain_core.embeddings import Embeddings
from langchain_core.runnables import Runnable
@@ -19,9 +19,7 @@ _SUPPORTED_PROVIDERS = {
def _get_provider_list() -> str:
"""Get formatted list of providers and their packages."""
return "\n".join(
f" - {p}: {pkg.replace('_', '-')}" for p, pkg in _SUPPORTED_PROVIDERS.items()
)
return "\n".join(f" - {p}: {pkg.replace('_', '-')}" for p, pkg in _SUPPORTED_PROVIDERS.items())
def _parse_model_string(model_name: str) -> tuple[str, str]:
@@ -82,7 +80,7 @@ def _parse_model_string(model_name: str) -> tuple[str, str]:
def _infer_model_and_provider(
model: str,
*,
provider: Optional[str] = None,
provider: str | None = None,
) -> tuple[str, str]:
if not model.strip():
msg = "Model name cannot be empty"
@@ -117,17 +115,14 @@ def _infer_model_and_provider(
def _check_pkg(pkg: str) -> None:
"""Check if a package is installed."""
if not util.find_spec(pkg):
msg = (
f"Could not import {pkg} python package. "
f"Please install it with `pip install {pkg}`"
)
msg = f"Could not import {pkg} python package. Please install it with `pip install {pkg}`"
raise ImportError(msg)
def init_embeddings(
model: str,
*,
provider: Optional[str] = None,
provider: str | None = None,
**kwargs: Any,
) -> Union[Embeddings, Runnable[Any, list[float]]]:
"""Initialize an embeddings model from a model name and optional provider.
@@ -182,9 +177,7 @@ def init_embeddings(
"""
if not model:
providers = _SUPPORTED_PROVIDERS.keys()
msg = (
f"Must specify model name. Supported providers are: {', '.join(providers)}"
)
msg = f"Must specify model name. Supported providers are: {', '.join(providers)}"
raise ValueError(msg)
provider, model_name = _infer_model_and_provider(model, provider=provider)

View File

@@ -13,7 +13,7 @@ import hashlib
import json
import uuid
import warnings
from typing import TYPE_CHECKING, Callable, Literal, Optional, Union, cast
from typing import TYPE_CHECKING, Literal, Union, cast
from langchain_core.embeddings import Embeddings
from langchain_core.utils.iter import batch_iterate
@@ -21,7 +21,7 @@ from langchain_core.utils.iter import batch_iterate
from langchain.storage.encoder_backed import EncoderBackedStore
if TYPE_CHECKING:
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from langchain_core.stores import BaseStore, ByteStore
@@ -147,8 +147,8 @@ class CacheBackedEmbeddings(Embeddings):
underlying_embeddings: Embeddings,
document_embedding_store: BaseStore[str, list[float]],
*,
batch_size: Optional[int] = None,
query_embedding_store: Optional[BaseStore[str, list[float]]] = None,
batch_size: int | None = None,
query_embedding_store: BaseStore[str, list[float]] | None = None,
) -> None:
"""Initialize the embedder.
@@ -181,17 +181,15 @@ class CacheBackedEmbeddings(Embeddings):
vectors: list[Union[list[float], None]] = self.document_embedding_store.mget(
texts,
)
all_missing_indices: list[int] = [
i for i, vector in enumerate(vectors) if vector is None
]
all_missing_indices: list[int] = [i for i, vector in enumerate(vectors) if vector is None]
for missing_indices in batch_iterate(self.batch_size, all_missing_indices):
missing_texts = [texts[i] for i in missing_indices]
missing_vectors = self.underlying_embeddings.embed_documents(missing_texts)
self.document_embedding_store.mset(
list(zip(missing_texts, missing_vectors)),
list(zip(missing_texts, missing_vectors, strict=False)),
)
for index, updated_vector in zip(missing_indices, missing_vectors):
for index, updated_vector in zip(missing_indices, missing_vectors, strict=False):
vectors[index] = updated_vector
return cast(
@@ -212,12 +210,8 @@ class CacheBackedEmbeddings(Embeddings):
Returns:
A list of embeddings for the given texts.
"""
vectors: list[
Union[list[float], None]
] = await self.document_embedding_store.amget(texts)
all_missing_indices: list[int] = [
i for i, vector in enumerate(vectors) if vector is None
]
vectors: list[Union[list[float], None]] = await self.document_embedding_store.amget(texts)
all_missing_indices: list[int] = [i for i, vector in enumerate(vectors) if vector is None]
# batch_iterate supports None batch_size which returns all elements at once
# as a single batch.
@@ -227,9 +221,9 @@ class CacheBackedEmbeddings(Embeddings):
missing_texts,
)
await self.document_embedding_store.amset(
list(zip(missing_texts, missing_vectors)),
list(zip(missing_texts, missing_vectors, strict=False)),
)
for index, updated_vector in zip(missing_indices, missing_vectors):
for index, updated_vector in zip(missing_indices, missing_vectors, strict=False):
vectors[index] = updated_vector
return cast(
@@ -290,7 +284,7 @@ class CacheBackedEmbeddings(Embeddings):
document_embedding_cache: ByteStore,
*,
namespace: str = "",
batch_size: Optional[int] = None,
batch_size: int | None = None,
query_embedding_cache: Union[bool, ByteStore] = False,
key_encoder: Union[
Callable[[str], str],

View File

@@ -1,8 +1,6 @@
from collections.abc import AsyncIterator, Iterator, Sequence
from collections.abc import AsyncIterator, Callable, Iterator, Sequence
from typing import (
Any,
Callable,
Optional,
TypeVar,
Union,
)
@@ -62,37 +60,29 @@ class EncoderBackedStore(BaseStore[K, V]):
self.value_serializer = value_serializer
self.value_deserializer = value_deserializer
def mget(self, keys: Sequence[K]) -> list[Optional[V]]:
def mget(self, keys: Sequence[K]) -> list[V | None]:
"""Get the values associated with the given keys."""
encoded_keys: list[str] = [self.key_encoder(key) for key in keys]
values = self.store.mget(encoded_keys)
return [
self.value_deserializer(value) if value is not None else value
for value in values
]
return [self.value_deserializer(value) if value is not None else value for value in values]
async def amget(self, keys: Sequence[K]) -> list[Optional[V]]:
async def amget(self, keys: Sequence[K]) -> list[V | None]:
"""Get the values associated with the given keys."""
encoded_keys: list[str] = [self.key_encoder(key) for key in keys]
values = await self.store.amget(encoded_keys)
return [
self.value_deserializer(value) if value is not None else value
for value in values
]
return [self.value_deserializer(value) if value is not None else value for value in values]
def mset(self, key_value_pairs: Sequence[tuple[K, V]]) -> None:
"""Set the values for the given keys."""
encoded_pairs = [
(self.key_encoder(key), self.value_serializer(value))
for key, value in key_value_pairs
(self.key_encoder(key), self.value_serializer(value)) for key, value in key_value_pairs
]
self.store.mset(encoded_pairs)
async def amset(self, key_value_pairs: Sequence[tuple[K, V]]) -> None:
"""Set the values for the given keys."""
encoded_pairs = [
(self.key_encoder(key), self.value_serializer(value))
for key, value in key_value_pairs
(self.key_encoder(key), self.value_serializer(value)) for key, value in key_value_pairs
]
await self.store.amset(encoded_pairs)
@@ -109,7 +99,7 @@ class EncoderBackedStore(BaseStore[K, V]):
def yield_keys(
self,
*,
prefix: Optional[str] = None,
prefix: str | None = None,
) -> Union[Iterator[K], Iterator[str]]:
"""Get an iterator over keys that match the given prefix."""
# For the time being this does not return K, but str
@@ -119,7 +109,7 @@ class EncoderBackedStore(BaseStore[K, V]):
async def ayield_keys(
self,
*,
prefix: Optional[str] = None,
prefix: str | None = None,
) -> Union[AsyncIterator[K], AsyncIterator[str]]:
"""Get an iterator over keys that match the given prefix."""
# For the time being this does not return K, but str

View File

@@ -5,7 +5,7 @@ build-backend = "pdm.backend"
[project]
authors = []
license = { text = "MIT" }
requires-python = ">=3.9, <4.0"
requires-python = ">=3.10"
dependencies = [
"langchain-core<1.0.0,>=0.3.66",
"langchain-text-splitters<1.0.0,>=0.3.8",
@@ -46,23 +46,22 @@ repository = "https://github.com/langchain-ai/langchain"
[dependency-groups]
test = [
"pytest<9,>=8",
"pytest-cov<5.0.0,>=4.0.0",
"pytest-watcher<1.0.0,>=0.2.6",
"pytest-asyncio<1.0.0,>=0.23.2",
"pytest-socket<1.0.0,>=0.6.0",
"syrupy<5.0.0,>=4.0.2",
"pytest-xdist<4.0.0,>=3.6.1",
"blockbuster<1.6,>=1.5.18",
"pytest-cov>=4.0.0",
"pytest-watcher>=0.2.6",
"pytest-asyncio>=0.23.2",
"pytest-socket>=0.6.0",
"syrupy>=4.0.2",
"pytest-xdist>=3.6.1",
"blockbuster>=1.5.18",
"langchain-tests",
"langchain-core",
"langchain-text-splitters",
"langchain-openai",
"toml>=0.10.2",
]
codespell = ["codespell<3.0.0,>=2.2.0"]
lint = [
"ruff<0.13,>=0.12.2",
"mypy<1.16,>=1.15",
"ruff>=0.12.2",
"mypy>=1.15",
]
typing = [
"types-toml>=0.10.8.20240310",
@@ -70,11 +69,10 @@ typing = [
test_integration = [
"vcrpy>=7.0",
"urllib3<2; python_version < \"3.10\"",
"wrapt<2.0.0,>=1.15.0",
"python-dotenv<2.0.0,>=1.0.0",
"cassio<1.0.0,>=0.1.0",
"langchainhub<1.0.0,>=0.1.16",
"wrapt>=1.15.0",
"python-dotenv>=1.0.0",
"cassio>=0.1.0",
"langchainhub>=0.1.16",
"langchain-core",
"langchain-text-splitters",
]
@@ -86,8 +84,9 @@ langchain-text-splitters = { path = "../text-splitters", editable = true }
langchain-openai = { path = "../partners/openai", editable = true }
[tool.ruff]
target-version = "py39"
target-version = "py310"
exclude = ["tests/integration_tests/examples/non-utf8-encoding.py"]
line-length = 100
[tool.mypy]
strict = "True"

View File

@@ -31,9 +31,7 @@ async def test_init_chat_model_chain() -> None:
chain = prompt | model_with_config
output = chain.invoke({"input": "bar"})
assert isinstance(output, AIMessage)
events = [
event async for event in chain.astream_events({"input": "bar"}, version="v2")
]
events = [event async for event in chain.astream_events({"input": "bar"}, version="v2")]
assert events

View File

@@ -1,5 +1,5 @@
import os
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING
from unittest import mock
import pytest
@@ -38,7 +38,7 @@ def test_all_imports() -> None:
("mixtral-8x7b-32768", "groq"),
],
)
def test_init_chat_model(model_name: str, model_provider: Optional[str]) -> None:
def test_init_chat_model(model_name: str, model_provider: str | None) -> None:
llm1: BaseChatModel = init_chat_model(
model_name,
model_provider=model_provider,
@@ -222,7 +222,7 @@ def test_configurable_with_default() -> None:
config={"configurable": {"my_model_model": "claude-3-sonnet-20240229"}}
)
""" # noqa: E501
"""
model = init_chat_model("gpt-4o", configurable_fields="any", config_prefix="bar")
for method in (
"invoke",

View File

@@ -53,9 +53,7 @@ def pytest_addoption(parser: pytest.Parser) -> None:
)
def pytest_collection_modifyitems(
config: pytest.Config, items: Sequence[pytest.Function]
) -> None:
def pytest_collection_modifyitems(config: pytest.Config, items: Sequence[pytest.Function]) -> None:
"""Add implementations for handling custom markers.
At the moment, this adds support for a custom `requires` marker.

View File

@@ -113,9 +113,7 @@ async def test_aembed_documents(cache_embeddings: CacheBackedEmbeddings) -> None
vectors = await cache_embeddings.aembed_documents(texts)
expected_vectors: list[list[float]] = [[1, 2.0], [2.0, 3.0], [1.0, 2.0], [3.0, 4.0]]
assert vectors == expected_vectors
keys = [
key async for key in cache_embeddings.document_embedding_store.ayield_keys()
]
keys = [key async for key in cache_embeddings.document_embedding_store.ayield_keys()]
assert len(keys) == 4
# UUID is expected to be the same for the same text
assert keys[0] == "test_namespace812b86c1-8ebf-5483-95c6-c95cf2b52d12"
@@ -128,10 +126,7 @@ async def test_aembed_documents_batch(
texts = ["1", "22", "a", "333", "RAISE_EXCEPTION"]
with contextlib.suppress(ValueError):
await cache_embeddings_batch.aembed_documents(texts)
keys = [
key
async for key in cache_embeddings_batch.document_embedding_store.ayield_keys()
]
keys = [key async for key in cache_embeddings_batch.document_embedding_store.ayield_keys()]
# only the first batch of three embeddings should exist
assert len(keys) == 3
# UUID is expected to be the same for the same text

View File

@@ -13,9 +13,7 @@ def test_import_all() -> None:
library_code = PKG_ROOT / "langchain"
for path in library_code.rglob("*.py"):
# Calculate the relative path to the module
module_name = (
path.relative_to(PKG_ROOT).with_suffix("").as_posix().replace("/", ".")
)
module_name = path.relative_to(PKG_ROOT).with_suffix("").as_posix().replace("/", ".")
if module_name.endswith("__init__"):
# Without init
module_name = module_name.rsplit(".", 1)[0]
@@ -39,9 +37,7 @@ def test_import_all_using_dir() -> None:
library_code = PKG_ROOT / "langchain"
for path in library_code.rglob("*.py"):
# Calculate the relative path to the module
module_name = (
path.relative_to(PKG_ROOT).with_suffix("").as_posix().replace("/", ".")
)
module_name = path.relative_to(PKG_ROOT).with_suffix("").as_posix().replace("/", ".")
if module_name.endswith("__init__"):
# Without init
module_name = module_name.rsplit(".", 1)[0]

2415
libs/langchain_v1/uv.lock generated

File diff suppressed because it is too large Load Diff