mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-14 14:05:37 +00:00
chore(langchain): drop Python 3.9 to prep for v1 (#32704)
Python 3.9 EOL is October 2025, so we're going to drop it for the v1 alpha release.
This commit is contained in:
@@ -32,9 +32,4 @@ def format_document_xml(doc: Document) -> str:
|
||||
if doc.metadata:
|
||||
metadata_items = [f"{k}: {v!s}" for k, v in doc.metadata.items()]
|
||||
metadata_str = f"<metadata>{', '.join(metadata_items)}</metadata>"
|
||||
return (
|
||||
f"<document>{id_str}"
|
||||
f"<content>{doc.page_content}</content>"
|
||||
f"{metadata_str}"
|
||||
f"</document>"
|
||||
)
|
||||
return f"<document>{id_str}<content>{doc.page_content}</content>{metadata_str}</document>"
|
||||
|
@@ -12,10 +12,10 @@ particularly for summarization chains and other document processing workflows.
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING, Callable, Union
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from langchain_core.messages import MessageLikeRepresentation
|
||||
from langgraph.runtime import Runtime
|
||||
@@ -92,9 +92,7 @@ async def aresolve_prompt(
|
||||
str,
|
||||
None,
|
||||
Callable[[StateT, Runtime[ContextT]], list[MessageLikeRepresentation]],
|
||||
Callable[
|
||||
[StateT, Runtime[ContextT]], Awaitable[list[MessageLikeRepresentation]]
|
||||
],
|
||||
Callable[[StateT, Runtime[ContextT]], Awaitable[list[MessageLikeRepresentation]]],
|
||||
],
|
||||
state: StateT,
|
||||
runtime: Runtime[ContextT],
|
||||
|
@@ -2,11 +2,10 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Protocol, TypeVar, Union
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Protocol, TypeAlias, TypeVar, Union
|
||||
|
||||
from langgraph.graph._node import StateNode
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dataclasses import Field
|
||||
|
@@ -7,10 +7,8 @@ from typing import (
|
||||
TYPE_CHECKING,
|
||||
Annotated,
|
||||
Any,
|
||||
Callable,
|
||||
Generic,
|
||||
Literal,
|
||||
Optional,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
@@ -26,6 +24,8 @@ from langchain._internal._utils import RunnableCallable
|
||||
from langchain.chat_models import init_chat_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
|
||||
@@ -154,7 +154,7 @@ class _MapReduceExtractor(Generic[ContextT]):
|
||||
StateNode,
|
||||
] = "default_reducer",
|
||||
context_schema: type[ContextT] | None = None,
|
||||
response_format: Optional[type[BaseModel]] = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
) -> None:
|
||||
"""Initialize the MapReduceExtractor.
|
||||
|
||||
@@ -190,9 +190,7 @@ class _MapReduceExtractor(Generic[ContextT]):
|
||||
if isinstance(model, str):
|
||||
model = init_chat_model(model)
|
||||
|
||||
self.model = (
|
||||
model.with_structured_output(response_format) if response_format else model
|
||||
)
|
||||
self.model = model.with_structured_output(response_format) if response_format else model
|
||||
self.map_prompt = map_prompt
|
||||
self.reduce_prompt = reduce_prompt
|
||||
self.reduce = reduce
|
||||
@@ -342,9 +340,7 @@ class _MapReduceExtractor(Generic[ContextT]):
|
||||
config: RunnableConfig,
|
||||
) -> dict[str, list[ExtractionResult]]:
|
||||
prompt = await self._aget_map_prompt(state, runtime)
|
||||
response = cast(
|
||||
"AIMessage", await self.model.ainvoke(prompt, config=config)
|
||||
)
|
||||
response = cast("AIMessage", await self.model.ainvoke(prompt, config=config))
|
||||
result = response if self.response_format else response.text()
|
||||
extraction_result: ExtractionResult = {
|
||||
"indexes": state["indexes"],
|
||||
@@ -375,9 +371,7 @@ class _MapReduceExtractor(Generic[ContextT]):
|
||||
config: RunnableConfig,
|
||||
) -> MapReduceNodeUpdate:
|
||||
prompt = await self._aget_reduce_prompt(state, runtime)
|
||||
response = cast(
|
||||
"AIMessage", await self.model.ainvoke(prompt, config=config)
|
||||
)
|
||||
response = cast("AIMessage", await self.model.ainvoke(prompt, config=config))
|
||||
result = response if self.response_format else response.text()
|
||||
return {"result": result}
|
||||
|
||||
@@ -411,7 +405,7 @@ class _MapReduceExtractor(Generic[ContextT]):
|
||||
# Add-conditional edges doesn't explicitly type Send
|
||||
builder.add_conditional_edges(
|
||||
"continue_to_map",
|
||||
self.continue_to_map, # type: ignore[arg-type]
|
||||
self.continue_to_map,
|
||||
["map_process"],
|
||||
)
|
||||
|
||||
@@ -422,7 +416,7 @@ class _MapReduceExtractor(Generic[ContextT]):
|
||||
builder.add_edge("map_process", "reduce_process")
|
||||
builder.add_edge("reduce_process", END)
|
||||
else:
|
||||
reduce_node = cast("StateNode", self.reduce)
|
||||
reduce_node = self.reduce
|
||||
# The type is ignored here. Requires parameterizing with generics.
|
||||
builder.add_node("reduce_process", reduce_node) # type: ignore[arg-type]
|
||||
builder.add_edge("map_process", "reduce_process")
|
||||
@@ -450,7 +444,7 @@ def create_map_reduce_chain(
|
||||
StateNode,
|
||||
] = "default_reducer",
|
||||
context_schema: type[ContextT] | None = None,
|
||||
response_format: Optional[type[BaseModel]] = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
) -> StateGraph[MapReduceState, ContextT, InputSchema, OutputSchema]:
|
||||
"""Create a map-reduce document extraction chain.
|
||||
|
||||
|
@@ -5,9 +5,7 @@ from __future__ import annotations
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Generic,
|
||||
Optional,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
@@ -24,6 +22,8 @@ from langchain._internal._utils import RunnableCallable
|
||||
from langchain.chat_models import init_chat_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
|
||||
# Used for type checking, but IDEs may not recognize it inside the cast.
|
||||
@@ -149,7 +149,7 @@ class _Extractor(Generic[ContextT]):
|
||||
],
|
||||
] = None,
|
||||
context_schema: type[ContextT] | None = None,
|
||||
response_format: Optional[type[BaseModel]] = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
) -> None:
|
||||
"""Initialize the Extractor.
|
||||
|
||||
@@ -173,9 +173,7 @@ class _Extractor(Generic[ContextT]):
|
||||
if isinstance(model, str):
|
||||
model = init_chat_model(model)
|
||||
|
||||
self.model = (
|
||||
model.with_structured_output(response_format) if response_format else model
|
||||
)
|
||||
self.model = model.with_structured_output(response_format) if response_format else model
|
||||
self.initial_prompt = prompt
|
||||
self.refine_prompt = refine_prompt
|
||||
self.context_schema = context_schema
|
||||
@@ -188,9 +186,7 @@ class _Extractor(Generic[ContextT]):
|
||||
|
||||
# Choose default prompt based on structured output format
|
||||
default_prompt = (
|
||||
DEFAULT_STRUCTURED_INIT_PROMPT
|
||||
if self.response_format
|
||||
else DEFAULT_INIT_PROMPT
|
||||
DEFAULT_STRUCTURED_INIT_PROMPT if self.response_format else DEFAULT_INIT_PROMPT
|
||||
)
|
||||
|
||||
return resolve_prompt(
|
||||
@@ -209,9 +205,7 @@ class _Extractor(Generic[ContextT]):
|
||||
|
||||
# Choose default prompt based on structured output format
|
||||
default_prompt = (
|
||||
DEFAULT_STRUCTURED_INIT_PROMPT
|
||||
if self.response_format
|
||||
else DEFAULT_INIT_PROMPT
|
||||
DEFAULT_STRUCTURED_INIT_PROMPT if self.response_format else DEFAULT_INIT_PROMPT
|
||||
)
|
||||
|
||||
return await aresolve_prompt(
|
||||
@@ -246,9 +240,7 @@ class _Extractor(Generic[ContextT]):
|
||||
|
||||
# Choose default prompt based on structured output format
|
||||
default_prompt = (
|
||||
DEFAULT_STRUCTURED_REFINE_PROMPT
|
||||
if self.response_format
|
||||
else DEFAULT_REFINE_PROMPT
|
||||
DEFAULT_STRUCTURED_REFINE_PROMPT if self.response_format else DEFAULT_REFINE_PROMPT
|
||||
)
|
||||
|
||||
return resolve_prompt(
|
||||
@@ -283,9 +275,7 @@ class _Extractor(Generic[ContextT]):
|
||||
|
||||
# Choose default prompt based on structured output format
|
||||
default_prompt = (
|
||||
DEFAULT_STRUCTURED_REFINE_PROMPT
|
||||
if self.response_format
|
||||
else DEFAULT_REFINE_PROMPT
|
||||
DEFAULT_STRUCTURED_REFINE_PROMPT if self.response_format else DEFAULT_REFINE_PROMPT
|
||||
)
|
||||
|
||||
return await aresolve_prompt(
|
||||
@@ -340,16 +330,12 @@ class _Extractor(Generic[ContextT]):
|
||||
if "result" not in state or state["result"] == "":
|
||||
# Initial processing
|
||||
prompt = await self._aget_initial_prompt(state, runtime)
|
||||
response = cast(
|
||||
"AIMessage", await self.model.ainvoke(prompt, config=config)
|
||||
)
|
||||
response = cast("AIMessage", await self.model.ainvoke(prompt, config=config))
|
||||
result = response if self.response_format else response.text()
|
||||
return {"result": result}
|
||||
# Refinement
|
||||
prompt = await self._aget_refine_prompt(state, runtime)
|
||||
response = cast(
|
||||
"AIMessage", await self.model.ainvoke(prompt, config=config)
|
||||
)
|
||||
response = cast("AIMessage", await self.model.ainvoke(prompt, config=config))
|
||||
result = response if self.response_format else response.text()
|
||||
return {"result": result}
|
||||
|
||||
@@ -388,7 +374,7 @@ def create_stuff_documents_chain(
|
||||
Callable[[ExtractionState, Runtime[ContextT]], list[MessageLikeRepresentation]],
|
||||
] = None,
|
||||
context_schema: type[ContextT] | None = None,
|
||||
response_format: Optional[type[BaseModel]] = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
) -> StateGraph[ExtractionState, ContextT, InputSchema, OutputSchema]:
|
||||
"""Create a stuff documents chain for processing documents.
|
||||
|
||||
|
@@ -5,9 +5,8 @@ from importlib import util
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Literal,
|
||||
Optional,
|
||||
TypeAlias,
|
||||
Union,
|
||||
cast,
|
||||
overload,
|
||||
@@ -16,10 +15,10 @@ from typing import (
|
||||
from langchain_core.language_models import BaseChatModel, LanguageModelInput
|
||||
from langchain_core.messages import AnyMessage, BaseMessage
|
||||
from langchain_core.runnables import Runnable, RunnableConfig, ensure_config
|
||||
from typing_extensions import TypeAlias, override
|
||||
from typing_extensions import override
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncIterator, Iterator, Sequence
|
||||
from collections.abc import AsyncIterator, Callable, Iterator, Sequence
|
||||
|
||||
from langchain_core.runnables.schema import StreamEvent
|
||||
from langchain_core.tools import BaseTool
|
||||
@@ -31,9 +30,9 @@ if TYPE_CHECKING:
|
||||
def init_chat_model(
|
||||
model: str,
|
||||
*,
|
||||
model_provider: Optional[str] = None,
|
||||
model_provider: str | None = None,
|
||||
configurable_fields: Literal[None] = None,
|
||||
config_prefix: Optional[str] = None,
|
||||
config_prefix: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseChatModel: ...
|
||||
|
||||
@@ -42,20 +41,20 @@ def init_chat_model(
|
||||
def init_chat_model(
|
||||
model: Literal[None] = None,
|
||||
*,
|
||||
model_provider: Optional[str] = None,
|
||||
model_provider: str | None = None,
|
||||
configurable_fields: Literal[None] = None,
|
||||
config_prefix: Optional[str] = None,
|
||||
config_prefix: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> _ConfigurableModel: ...
|
||||
|
||||
|
||||
@overload
|
||||
def init_chat_model(
|
||||
model: Optional[str] = None,
|
||||
model: str | None = None,
|
||||
*,
|
||||
model_provider: Optional[str] = None,
|
||||
model_provider: str | None = None,
|
||||
configurable_fields: Union[Literal["any"], list[str], tuple[str, ...]] = ...,
|
||||
config_prefix: Optional[str] = None,
|
||||
config_prefix: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> _ConfigurableModel: ...
|
||||
|
||||
@@ -64,13 +63,11 @@ def init_chat_model(
|
||||
# name to the supported list in the docstring below. Do *not* change the order of the
|
||||
# existing providers.
|
||||
def init_chat_model(
|
||||
model: Optional[str] = None,
|
||||
model: str | None = None,
|
||||
*,
|
||||
model_provider: Optional[str] = None,
|
||||
configurable_fields: Optional[
|
||||
Union[Literal["any"], list[str], tuple[str, ...]]
|
||||
] = None,
|
||||
config_prefix: Optional[str] = None,
|
||||
model_provider: str | None = None,
|
||||
configurable_fields: Union[Literal["any"], list[str], tuple[str, ...]] | None = None,
|
||||
config_prefix: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[BaseChatModel, _ConfigurableModel]:
|
||||
"""Initialize a ChatModel from the model name and provider.
|
||||
@@ -326,7 +323,7 @@ def init_chat_model(
|
||||
def _init_chat_model_helper(
|
||||
model: str,
|
||||
*,
|
||||
model_provider: Optional[str] = None,
|
||||
model_provider: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseChatModel:
|
||||
model, model_provider = _parse_model(model, model_provider)
|
||||
@@ -446,9 +443,7 @@ def _init_chat_model_helper(
|
||||
|
||||
return ChatPerplexity(model=model, **kwargs)
|
||||
supported = ", ".join(_SUPPORTED_PROVIDERS)
|
||||
msg = (
|
||||
f"Unsupported {model_provider=}.\n\nSupported model providers are: {supported}"
|
||||
)
|
||||
msg = f"Unsupported {model_provider=}.\n\nSupported model providers are: {supported}"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
@@ -476,7 +471,7 @@ _SUPPORTED_PROVIDERS = {
|
||||
}
|
||||
|
||||
|
||||
def _attempt_infer_model_provider(model_name: str) -> Optional[str]:
|
||||
def _attempt_infer_model_provider(model_name: str) -> str | None:
|
||||
if any(model_name.startswith(pre) for pre in ("gpt-3", "gpt-4", "o1", "o3")):
|
||||
return "openai"
|
||||
if model_name.startswith("claude"):
|
||||
@@ -500,31 +495,24 @@ def _attempt_infer_model_provider(model_name: str) -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
def _parse_model(model: str, model_provider: Optional[str]) -> tuple[str, str]:
|
||||
if (
|
||||
not model_provider
|
||||
and ":" in model
|
||||
and model.split(":")[0] in _SUPPORTED_PROVIDERS
|
||||
):
|
||||
def _parse_model(model: str, model_provider: str | None) -> tuple[str, str]:
|
||||
if not model_provider and ":" in model and model.split(":")[0] in _SUPPORTED_PROVIDERS:
|
||||
model_provider = model.split(":")[0]
|
||||
model = ":".join(model.split(":")[1:])
|
||||
model_provider = model_provider or _attempt_infer_model_provider(model)
|
||||
if not model_provider:
|
||||
msg = (
|
||||
f"Unable to infer model provider for {model=}, please specify "
|
||||
f"model_provider directly."
|
||||
f"Unable to infer model provider for {model=}, please specify model_provider directly."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
model_provider = model_provider.replace("-", "_").lower()
|
||||
return model, model_provider
|
||||
|
||||
|
||||
def _check_pkg(pkg: str, *, pkg_kebab: Optional[str] = None) -> None:
|
||||
def _check_pkg(pkg: str, *, pkg_kebab: str | None = None) -> None:
|
||||
if not util.find_spec(pkg):
|
||||
pkg_kebab = pkg_kebab if pkg_kebab is not None else pkg.replace("_", "-")
|
||||
msg = (
|
||||
f"Unable to import {pkg}. Please install with `pip install -U {pkg_kebab}`"
|
||||
)
|
||||
msg = f"Unable to import {pkg}. Please install with `pip install -U {pkg_kebab}`"
|
||||
raise ImportError(msg)
|
||||
|
||||
|
||||
@@ -539,16 +527,14 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
default_config: Optional[dict] = None,
|
||||
default_config: dict | None = None,
|
||||
configurable_fields: Union[Literal["any"], list[str], tuple[str, ...]] = "any",
|
||||
config_prefix: str = "",
|
||||
queued_declarative_operations: Sequence[tuple[str, tuple, dict]] = (),
|
||||
) -> None:
|
||||
self._default_config: dict = default_config or {}
|
||||
self._configurable_fields: Union[Literal["any"], list[str]] = (
|
||||
configurable_fields
|
||||
if configurable_fields == "any"
|
||||
else list(configurable_fields)
|
||||
configurable_fields if configurable_fields == "any" else list(configurable_fields)
|
||||
)
|
||||
self._config_prefix = (
|
||||
config_prefix + "_"
|
||||
@@ -589,14 +575,14 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
msg += "."
|
||||
raise AttributeError(msg)
|
||||
|
||||
def _model(self, config: Optional[RunnableConfig] = None) -> Runnable:
|
||||
def _model(self, config: RunnableConfig | None = None) -> Runnable:
|
||||
params = {**self._default_config, **self._model_params(config)}
|
||||
model = _init_chat_model_helper(**params)
|
||||
for name, args, kwargs in self._queued_declarative_operations:
|
||||
model = getattr(model, name)(*args, **kwargs)
|
||||
return model
|
||||
|
||||
def _model_params(self, config: Optional[RunnableConfig]) -> dict:
|
||||
def _model_params(self, config: RunnableConfig | None) -> dict:
|
||||
config = ensure_config(config)
|
||||
model_params = {
|
||||
_remove_prefix(k, self._config_prefix): v
|
||||
@@ -604,14 +590,12 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
if k.startswith(self._config_prefix)
|
||||
}
|
||||
if self._configurable_fields != "any":
|
||||
model_params = {
|
||||
k: v for k, v in model_params.items() if k in self._configurable_fields
|
||||
}
|
||||
model_params = {k: v for k, v in model_params.items() if k in self._configurable_fields}
|
||||
return model_params
|
||||
|
||||
def with_config(
|
||||
self,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
config: RunnableConfig | None = None,
|
||||
**kwargs: Any,
|
||||
) -> _ConfigurableModel:
|
||||
"""Bind config to a Runnable, returning a new Runnable."""
|
||||
@@ -662,7 +646,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
def invoke(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
config: RunnableConfig | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
return self._model(config).invoke(input, config=config, **kwargs)
|
||||
@@ -671,7 +655,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
config: RunnableConfig | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
return await self._model(config).ainvoke(input, config=config, **kwargs)
|
||||
@@ -680,8 +664,8 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
def stream(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
config: RunnableConfig | None = None,
|
||||
**kwargs: Any | None,
|
||||
) -> Iterator[Any]:
|
||||
yield from self._model(config).stream(input, config=config, **kwargs)
|
||||
|
||||
@@ -689,8 +673,8 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
async def astream(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
config: RunnableConfig | None = None,
|
||||
**kwargs: Any | None,
|
||||
) -> AsyncIterator[Any]:
|
||||
async for x in self._model(config).astream(input, config=config, **kwargs):
|
||||
yield x
|
||||
@@ -698,10 +682,10 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
def batch(
|
||||
self,
|
||||
inputs: list[LanguageModelInput],
|
||||
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
|
||||
config: Union[RunnableConfig, list[RunnableConfig]] | None = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
**kwargs: Any | None,
|
||||
) -> list[Any]:
|
||||
config = config or None
|
||||
# If <= 1 config use the underlying models batch implementation.
|
||||
@@ -726,10 +710,10 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
async def abatch(
|
||||
self,
|
||||
inputs: list[LanguageModelInput],
|
||||
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
|
||||
config: Union[RunnableConfig, list[RunnableConfig]] | None = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
**kwargs: Any | None,
|
||||
) -> list[Any]:
|
||||
config = config or None
|
||||
# If <= 1 config use the underlying models batch implementation.
|
||||
@@ -754,7 +738,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
def batch_as_completed(
|
||||
self,
|
||||
inputs: Sequence[LanguageModelInput],
|
||||
config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None,
|
||||
config: Union[RunnableConfig, Sequence[RunnableConfig]] | None = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Any,
|
||||
@@ -783,7 +767,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
async def abatch_as_completed(
|
||||
self,
|
||||
inputs: Sequence[LanguageModelInput],
|
||||
config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None,
|
||||
config: Union[RunnableConfig, Sequence[RunnableConfig]] | None = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Any,
|
||||
@@ -817,8 +801,8 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
def transform(
|
||||
self,
|
||||
input: Iterator[LanguageModelInput],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
config: RunnableConfig | None = None,
|
||||
**kwargs: Any | None,
|
||||
) -> Iterator[Any]:
|
||||
yield from self._model(config).transform(input, config=config, **kwargs)
|
||||
|
||||
@@ -826,8 +810,8 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
async def atransform(
|
||||
self,
|
||||
input: AsyncIterator[LanguageModelInput],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
config: RunnableConfig | None = None,
|
||||
**kwargs: Any | None,
|
||||
) -> AsyncIterator[Any]:
|
||||
async for x in self._model(config).atransform(input, config=config, **kwargs):
|
||||
yield x
|
||||
@@ -836,16 +820,16 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
def astream_log(
|
||||
self,
|
||||
input: Any,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
config: RunnableConfig | None = None,
|
||||
*,
|
||||
diff: Literal[True] = True,
|
||||
with_streamed_output_list: bool = True,
|
||||
include_names: Optional[Sequence[str]] = None,
|
||||
include_types: Optional[Sequence[str]] = None,
|
||||
include_tags: Optional[Sequence[str]] = None,
|
||||
exclude_names: Optional[Sequence[str]] = None,
|
||||
exclude_types: Optional[Sequence[str]] = None,
|
||||
exclude_tags: Optional[Sequence[str]] = None,
|
||||
include_names: Sequence[str] | None = None,
|
||||
include_types: Sequence[str] | None = None,
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_names: Sequence[str] | None = None,
|
||||
exclude_types: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[RunLogPatch]: ...
|
||||
|
||||
@@ -853,16 +837,16 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
def astream_log(
|
||||
self,
|
||||
input: Any,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
config: RunnableConfig | None = None,
|
||||
*,
|
||||
diff: Literal[False],
|
||||
with_streamed_output_list: bool = True,
|
||||
include_names: Optional[Sequence[str]] = None,
|
||||
include_types: Optional[Sequence[str]] = None,
|
||||
include_tags: Optional[Sequence[str]] = None,
|
||||
exclude_names: Optional[Sequence[str]] = None,
|
||||
exclude_types: Optional[Sequence[str]] = None,
|
||||
exclude_tags: Optional[Sequence[str]] = None,
|
||||
include_names: Sequence[str] | None = None,
|
||||
include_types: Sequence[str] | None = None,
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_names: Sequence[str] | None = None,
|
||||
exclude_types: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[RunLog]: ...
|
||||
|
||||
@@ -870,16 +854,16 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
async def astream_log(
|
||||
self,
|
||||
input: Any,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
config: RunnableConfig | None = None,
|
||||
*,
|
||||
diff: bool = True,
|
||||
with_streamed_output_list: bool = True,
|
||||
include_names: Optional[Sequence[str]] = None,
|
||||
include_types: Optional[Sequence[str]] = None,
|
||||
include_tags: Optional[Sequence[str]] = None,
|
||||
exclude_names: Optional[Sequence[str]] = None,
|
||||
exclude_types: Optional[Sequence[str]] = None,
|
||||
exclude_tags: Optional[Sequence[str]] = None,
|
||||
include_names: Sequence[str] | None = None,
|
||||
include_types: Sequence[str] | None = None,
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_names: Sequence[str] | None = None,
|
||||
exclude_types: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AsyncIterator[RunLogPatch], AsyncIterator[RunLog]]:
|
||||
async for x in self._model(config).astream_log( # type: ignore[call-overload, misc]
|
||||
@@ -901,15 +885,15 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
async def astream_events(
|
||||
self,
|
||||
input: Any,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
config: RunnableConfig | None = None,
|
||||
*,
|
||||
version: Literal["v1", "v2"] = "v2",
|
||||
include_names: Optional[Sequence[str]] = None,
|
||||
include_types: Optional[Sequence[str]] = None,
|
||||
include_tags: Optional[Sequence[str]] = None,
|
||||
exclude_names: Optional[Sequence[str]] = None,
|
||||
exclude_types: Optional[Sequence[str]] = None,
|
||||
exclude_tags: Optional[Sequence[str]] = None,
|
||||
include_names: Sequence[str] | None = None,
|
||||
include_types: Sequence[str] | None = None,
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_names: Sequence[str] | None = None,
|
||||
exclude_types: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
async for x in self._model(config).astream_events(
|
||||
|
@@ -1,6 +1,6 @@
|
||||
import functools
|
||||
from importlib import util
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Union
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.runnables import Runnable
|
||||
@@ -19,9 +19,7 @@ _SUPPORTED_PROVIDERS = {
|
||||
|
||||
def _get_provider_list() -> str:
|
||||
"""Get formatted list of providers and their packages."""
|
||||
return "\n".join(
|
||||
f" - {p}: {pkg.replace('_', '-')}" for p, pkg in _SUPPORTED_PROVIDERS.items()
|
||||
)
|
||||
return "\n".join(f" - {p}: {pkg.replace('_', '-')}" for p, pkg in _SUPPORTED_PROVIDERS.items())
|
||||
|
||||
|
||||
def _parse_model_string(model_name: str) -> tuple[str, str]:
|
||||
@@ -82,7 +80,7 @@ def _parse_model_string(model_name: str) -> tuple[str, str]:
|
||||
def _infer_model_and_provider(
|
||||
model: str,
|
||||
*,
|
||||
provider: Optional[str] = None,
|
||||
provider: str | None = None,
|
||||
) -> tuple[str, str]:
|
||||
if not model.strip():
|
||||
msg = "Model name cannot be empty"
|
||||
@@ -117,17 +115,14 @@ def _infer_model_and_provider(
|
||||
def _check_pkg(pkg: str) -> None:
|
||||
"""Check if a package is installed."""
|
||||
if not util.find_spec(pkg):
|
||||
msg = (
|
||||
f"Could not import {pkg} python package. "
|
||||
f"Please install it with `pip install {pkg}`"
|
||||
)
|
||||
msg = f"Could not import {pkg} python package. Please install it with `pip install {pkg}`"
|
||||
raise ImportError(msg)
|
||||
|
||||
|
||||
def init_embeddings(
|
||||
model: str,
|
||||
*,
|
||||
provider: Optional[str] = None,
|
||||
provider: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[Embeddings, Runnable[Any, list[float]]]:
|
||||
"""Initialize an embeddings model from a model name and optional provider.
|
||||
@@ -182,9 +177,7 @@ def init_embeddings(
|
||||
"""
|
||||
if not model:
|
||||
providers = _SUPPORTED_PROVIDERS.keys()
|
||||
msg = (
|
||||
f"Must specify model name. Supported providers are: {', '.join(providers)}"
|
||||
)
|
||||
msg = f"Must specify model name. Supported providers are: {', '.join(providers)}"
|
||||
raise ValueError(msg)
|
||||
|
||||
provider, model_name = _infer_model_and_provider(model, provider=provider)
|
||||
|
@@ -13,7 +13,7 @@ import hashlib
|
||||
import json
|
||||
import uuid
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Callable, Literal, Optional, Union, cast
|
||||
from typing import TYPE_CHECKING, Literal, Union, cast
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.utils.iter import batch_iterate
|
||||
@@ -21,7 +21,7 @@ from langchain_core.utils.iter import batch_iterate
|
||||
from langchain.storage.encoder_backed import EncoderBackedStore
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
|
||||
from langchain_core.stores import BaseStore, ByteStore
|
||||
|
||||
@@ -147,8 +147,8 @@ class CacheBackedEmbeddings(Embeddings):
|
||||
underlying_embeddings: Embeddings,
|
||||
document_embedding_store: BaseStore[str, list[float]],
|
||||
*,
|
||||
batch_size: Optional[int] = None,
|
||||
query_embedding_store: Optional[BaseStore[str, list[float]]] = None,
|
||||
batch_size: int | None = None,
|
||||
query_embedding_store: BaseStore[str, list[float]] | None = None,
|
||||
) -> None:
|
||||
"""Initialize the embedder.
|
||||
|
||||
@@ -181,17 +181,15 @@ class CacheBackedEmbeddings(Embeddings):
|
||||
vectors: list[Union[list[float], None]] = self.document_embedding_store.mget(
|
||||
texts,
|
||||
)
|
||||
all_missing_indices: list[int] = [
|
||||
i for i, vector in enumerate(vectors) if vector is None
|
||||
]
|
||||
all_missing_indices: list[int] = [i for i, vector in enumerate(vectors) if vector is None]
|
||||
|
||||
for missing_indices in batch_iterate(self.batch_size, all_missing_indices):
|
||||
missing_texts = [texts[i] for i in missing_indices]
|
||||
missing_vectors = self.underlying_embeddings.embed_documents(missing_texts)
|
||||
self.document_embedding_store.mset(
|
||||
list(zip(missing_texts, missing_vectors)),
|
||||
list(zip(missing_texts, missing_vectors, strict=False)),
|
||||
)
|
||||
for index, updated_vector in zip(missing_indices, missing_vectors):
|
||||
for index, updated_vector in zip(missing_indices, missing_vectors, strict=False):
|
||||
vectors[index] = updated_vector
|
||||
|
||||
return cast(
|
||||
@@ -212,12 +210,8 @@ class CacheBackedEmbeddings(Embeddings):
|
||||
Returns:
|
||||
A list of embeddings for the given texts.
|
||||
"""
|
||||
vectors: list[
|
||||
Union[list[float], None]
|
||||
] = await self.document_embedding_store.amget(texts)
|
||||
all_missing_indices: list[int] = [
|
||||
i for i, vector in enumerate(vectors) if vector is None
|
||||
]
|
||||
vectors: list[Union[list[float], None]] = await self.document_embedding_store.amget(texts)
|
||||
all_missing_indices: list[int] = [i for i, vector in enumerate(vectors) if vector is None]
|
||||
|
||||
# batch_iterate supports None batch_size which returns all elements at once
|
||||
# as a single batch.
|
||||
@@ -227,9 +221,9 @@ class CacheBackedEmbeddings(Embeddings):
|
||||
missing_texts,
|
||||
)
|
||||
await self.document_embedding_store.amset(
|
||||
list(zip(missing_texts, missing_vectors)),
|
||||
list(zip(missing_texts, missing_vectors, strict=False)),
|
||||
)
|
||||
for index, updated_vector in zip(missing_indices, missing_vectors):
|
||||
for index, updated_vector in zip(missing_indices, missing_vectors, strict=False):
|
||||
vectors[index] = updated_vector
|
||||
|
||||
return cast(
|
||||
@@ -290,7 +284,7 @@ class CacheBackedEmbeddings(Embeddings):
|
||||
document_embedding_cache: ByteStore,
|
||||
*,
|
||||
namespace: str = "",
|
||||
batch_size: Optional[int] = None,
|
||||
batch_size: int | None = None,
|
||||
query_embedding_cache: Union[bool, ByteStore] = False,
|
||||
key_encoder: Union[
|
||||
Callable[[str], str],
|
||||
|
@@ -1,8 +1,6 @@
|
||||
from collections.abc import AsyncIterator, Iterator, Sequence
|
||||
from collections.abc import AsyncIterator, Callable, Iterator, Sequence
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Optional,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
@@ -62,37 +60,29 @@ class EncoderBackedStore(BaseStore[K, V]):
|
||||
self.value_serializer = value_serializer
|
||||
self.value_deserializer = value_deserializer
|
||||
|
||||
def mget(self, keys: Sequence[K]) -> list[Optional[V]]:
|
||||
def mget(self, keys: Sequence[K]) -> list[V | None]:
|
||||
"""Get the values associated with the given keys."""
|
||||
encoded_keys: list[str] = [self.key_encoder(key) for key in keys]
|
||||
values = self.store.mget(encoded_keys)
|
||||
return [
|
||||
self.value_deserializer(value) if value is not None else value
|
||||
for value in values
|
||||
]
|
||||
return [self.value_deserializer(value) if value is not None else value for value in values]
|
||||
|
||||
async def amget(self, keys: Sequence[K]) -> list[Optional[V]]:
|
||||
async def amget(self, keys: Sequence[K]) -> list[V | None]:
|
||||
"""Get the values associated with the given keys."""
|
||||
encoded_keys: list[str] = [self.key_encoder(key) for key in keys]
|
||||
values = await self.store.amget(encoded_keys)
|
||||
return [
|
||||
self.value_deserializer(value) if value is not None else value
|
||||
for value in values
|
||||
]
|
||||
return [self.value_deserializer(value) if value is not None else value for value in values]
|
||||
|
||||
def mset(self, key_value_pairs: Sequence[tuple[K, V]]) -> None:
|
||||
"""Set the values for the given keys."""
|
||||
encoded_pairs = [
|
||||
(self.key_encoder(key), self.value_serializer(value))
|
||||
for key, value in key_value_pairs
|
||||
(self.key_encoder(key), self.value_serializer(value)) for key, value in key_value_pairs
|
||||
]
|
||||
self.store.mset(encoded_pairs)
|
||||
|
||||
async def amset(self, key_value_pairs: Sequence[tuple[K, V]]) -> None:
|
||||
"""Set the values for the given keys."""
|
||||
encoded_pairs = [
|
||||
(self.key_encoder(key), self.value_serializer(value))
|
||||
for key, value in key_value_pairs
|
||||
(self.key_encoder(key), self.value_serializer(value)) for key, value in key_value_pairs
|
||||
]
|
||||
await self.store.amset(encoded_pairs)
|
||||
|
||||
@@ -109,7 +99,7 @@ class EncoderBackedStore(BaseStore[K, V]):
|
||||
def yield_keys(
|
||||
self,
|
||||
*,
|
||||
prefix: Optional[str] = None,
|
||||
prefix: str | None = None,
|
||||
) -> Union[Iterator[K], Iterator[str]]:
|
||||
"""Get an iterator over keys that match the given prefix."""
|
||||
# For the time being this does not return K, but str
|
||||
@@ -119,7 +109,7 @@ class EncoderBackedStore(BaseStore[K, V]):
|
||||
async def ayield_keys(
|
||||
self,
|
||||
*,
|
||||
prefix: Optional[str] = None,
|
||||
prefix: str | None = None,
|
||||
) -> Union[AsyncIterator[K], AsyncIterator[str]]:
|
||||
"""Get an iterator over keys that match the given prefix."""
|
||||
# For the time being this does not return K, but str
|
||||
|
@@ -5,7 +5,7 @@ build-backend = "pdm.backend"
|
||||
[project]
|
||||
authors = []
|
||||
license = { text = "MIT" }
|
||||
requires-python = ">=3.9, <4.0"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"langchain-core<1.0.0,>=0.3.66",
|
||||
"langchain-text-splitters<1.0.0,>=0.3.8",
|
||||
@@ -46,23 +46,22 @@ repository = "https://github.com/langchain-ai/langchain"
|
||||
[dependency-groups]
|
||||
test = [
|
||||
"pytest<9,>=8",
|
||||
"pytest-cov<5.0.0,>=4.0.0",
|
||||
"pytest-watcher<1.0.0,>=0.2.6",
|
||||
"pytest-asyncio<1.0.0,>=0.23.2",
|
||||
"pytest-socket<1.0.0,>=0.6.0",
|
||||
"syrupy<5.0.0,>=4.0.2",
|
||||
"pytest-xdist<4.0.0,>=3.6.1",
|
||||
"blockbuster<1.6,>=1.5.18",
|
||||
"pytest-cov>=4.0.0",
|
||||
"pytest-watcher>=0.2.6",
|
||||
"pytest-asyncio>=0.23.2",
|
||||
"pytest-socket>=0.6.0",
|
||||
"syrupy>=4.0.2",
|
||||
"pytest-xdist>=3.6.1",
|
||||
"blockbuster>=1.5.18",
|
||||
"langchain-tests",
|
||||
"langchain-core",
|
||||
"langchain-text-splitters",
|
||||
"langchain-openai",
|
||||
"toml>=0.10.2",
|
||||
]
|
||||
codespell = ["codespell<3.0.0,>=2.2.0"]
|
||||
lint = [
|
||||
"ruff<0.13,>=0.12.2",
|
||||
"mypy<1.16,>=1.15",
|
||||
"ruff>=0.12.2",
|
||||
"mypy>=1.15",
|
||||
]
|
||||
typing = [
|
||||
"types-toml>=0.10.8.20240310",
|
||||
@@ -70,11 +69,10 @@ typing = [
|
||||
|
||||
test_integration = [
|
||||
"vcrpy>=7.0",
|
||||
"urllib3<2; python_version < \"3.10\"",
|
||||
"wrapt<2.0.0,>=1.15.0",
|
||||
"python-dotenv<2.0.0,>=1.0.0",
|
||||
"cassio<1.0.0,>=0.1.0",
|
||||
"langchainhub<1.0.0,>=0.1.16",
|
||||
"wrapt>=1.15.0",
|
||||
"python-dotenv>=1.0.0",
|
||||
"cassio>=0.1.0",
|
||||
"langchainhub>=0.1.16",
|
||||
"langchain-core",
|
||||
"langchain-text-splitters",
|
||||
]
|
||||
@@ -86,8 +84,9 @@ langchain-text-splitters = { path = "../text-splitters", editable = true }
|
||||
langchain-openai = { path = "../partners/openai", editable = true }
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py39"
|
||||
target-version = "py310"
|
||||
exclude = ["tests/integration_tests/examples/non-utf8-encoding.py"]
|
||||
line-length = 100
|
||||
|
||||
[tool.mypy]
|
||||
strict = "True"
|
||||
|
@@ -31,9 +31,7 @@ async def test_init_chat_model_chain() -> None:
|
||||
chain = prompt | model_with_config
|
||||
output = chain.invoke({"input": "bar"})
|
||||
assert isinstance(output, AIMessage)
|
||||
events = [
|
||||
event async for event in chain.astream_events({"input": "bar"}, version="v2")
|
||||
]
|
||||
events = [event async for event in chain.astream_events({"input": "bar"}, version="v2")]
|
||||
assert events
|
||||
|
||||
|
||||
|
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
@@ -38,7 +38,7 @@ def test_all_imports() -> None:
|
||||
("mixtral-8x7b-32768", "groq"),
|
||||
],
|
||||
)
|
||||
def test_init_chat_model(model_name: str, model_provider: Optional[str]) -> None:
|
||||
def test_init_chat_model(model_name: str, model_provider: str | None) -> None:
|
||||
llm1: BaseChatModel = init_chat_model(
|
||||
model_name,
|
||||
model_provider=model_provider,
|
||||
@@ -222,7 +222,7 @@ def test_configurable_with_default() -> None:
|
||||
config={"configurable": {"my_model_model": "claude-3-sonnet-20240229"}}
|
||||
)
|
||||
|
||||
""" # noqa: E501
|
||||
"""
|
||||
model = init_chat_model("gpt-4o", configurable_fields="any", config_prefix="bar")
|
||||
for method in (
|
||||
"invoke",
|
||||
|
@@ -53,9 +53,7 @@ def pytest_addoption(parser: pytest.Parser) -> None:
|
||||
)
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(
|
||||
config: pytest.Config, items: Sequence[pytest.Function]
|
||||
) -> None:
|
||||
def pytest_collection_modifyitems(config: pytest.Config, items: Sequence[pytest.Function]) -> None:
|
||||
"""Add implementations for handling custom markers.
|
||||
|
||||
At the moment, this adds support for a custom `requires` marker.
|
||||
|
@@ -113,9 +113,7 @@ async def test_aembed_documents(cache_embeddings: CacheBackedEmbeddings) -> None
|
||||
vectors = await cache_embeddings.aembed_documents(texts)
|
||||
expected_vectors: list[list[float]] = [[1, 2.0], [2.0, 3.0], [1.0, 2.0], [3.0, 4.0]]
|
||||
assert vectors == expected_vectors
|
||||
keys = [
|
||||
key async for key in cache_embeddings.document_embedding_store.ayield_keys()
|
||||
]
|
||||
keys = [key async for key in cache_embeddings.document_embedding_store.ayield_keys()]
|
||||
assert len(keys) == 4
|
||||
# UUID is expected to be the same for the same text
|
||||
assert keys[0] == "test_namespace812b86c1-8ebf-5483-95c6-c95cf2b52d12"
|
||||
@@ -128,10 +126,7 @@ async def test_aembed_documents_batch(
|
||||
texts = ["1", "22", "a", "333", "RAISE_EXCEPTION"]
|
||||
with contextlib.suppress(ValueError):
|
||||
await cache_embeddings_batch.aembed_documents(texts)
|
||||
keys = [
|
||||
key
|
||||
async for key in cache_embeddings_batch.document_embedding_store.ayield_keys()
|
||||
]
|
||||
keys = [key async for key in cache_embeddings_batch.document_embedding_store.ayield_keys()]
|
||||
# only the first batch of three embeddings should exist
|
||||
assert len(keys) == 3
|
||||
# UUID is expected to be the same for the same text
|
||||
|
@@ -13,9 +13,7 @@ def test_import_all() -> None:
|
||||
library_code = PKG_ROOT / "langchain"
|
||||
for path in library_code.rglob("*.py"):
|
||||
# Calculate the relative path to the module
|
||||
module_name = (
|
||||
path.relative_to(PKG_ROOT).with_suffix("").as_posix().replace("/", ".")
|
||||
)
|
||||
module_name = path.relative_to(PKG_ROOT).with_suffix("").as_posix().replace("/", ".")
|
||||
if module_name.endswith("__init__"):
|
||||
# Without init
|
||||
module_name = module_name.rsplit(".", 1)[0]
|
||||
@@ -39,9 +37,7 @@ def test_import_all_using_dir() -> None:
|
||||
library_code = PKG_ROOT / "langchain"
|
||||
for path in library_code.rglob("*.py"):
|
||||
# Calculate the relative path to the module
|
||||
module_name = (
|
||||
path.relative_to(PKG_ROOT).with_suffix("").as_posix().replace("/", ".")
|
||||
)
|
||||
module_name = path.relative_to(PKG_ROOT).with_suffix("").as_posix().replace("/", ".")
|
||||
if module_name.endswith("__init__"):
|
||||
# Without init
|
||||
module_name = module_name.rsplit(".", 1)[0]
|
||||
|
2415
libs/langchain_v1/uv.lock
generated
2415
libs/langchain_v1/uv.lock
generated
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user