mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-15 06:26:12 +00:00
chore(langchain): drop Python 3.9 to prep for v1 (#32704)
Python 3.9 EOL is October 2025, so we're going to drop it for the v1 alpha release.
This commit is contained in:
2
.github/scripts/check_diff.py
vendored
2
.github/scripts/check_diff.py
vendored
@@ -132,6 +132,8 @@ def _get_configs_for_single_dir(job: str, dir_: str) -> List[Dict[str, str]]:
|
|||||||
|
|
||||||
elif dir_ == "libs/langchain" and job == "extended-tests":
|
elif dir_ == "libs/langchain" and job == "extended-tests":
|
||||||
py_versions = ["3.9", "3.13"]
|
py_versions = ["3.9", "3.13"]
|
||||||
|
elif dir_ == "libs/langchain_v1":
|
||||||
|
py_versions = ["3.10", "3.13"]
|
||||||
|
|
||||||
elif dir_ == ".":
|
elif dir_ == ".":
|
||||||
# unable to install with 3.13 because tokenizers doesn't support 3.13 yet
|
# unable to install with 3.13 because tokenizers doesn't support 3.13 yet
|
||||||
|
@@ -32,9 +32,4 @@ def format_document_xml(doc: Document) -> str:
|
|||||||
if doc.metadata:
|
if doc.metadata:
|
||||||
metadata_items = [f"{k}: {v!s}" for k, v in doc.metadata.items()]
|
metadata_items = [f"{k}: {v!s}" for k, v in doc.metadata.items()]
|
||||||
metadata_str = f"<metadata>{', '.join(metadata_items)}</metadata>"
|
metadata_str = f"<metadata>{', '.join(metadata_items)}</metadata>"
|
||||||
return (
|
return f"<document>{id_str}<content>{doc.page_content}</content>{metadata_str}</document>"
|
||||||
f"<document>{id_str}"
|
|
||||||
f"<content>{doc.page_content}</content>"
|
|
||||||
f"{metadata_str}"
|
|
||||||
f"</document>"
|
|
||||||
)
|
|
||||||
|
@@ -12,10 +12,10 @@ particularly for summarization chains and other document processing workflows.
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
from typing import TYPE_CHECKING, Callable, Union
|
from typing import TYPE_CHECKING, Union
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from collections.abc import Awaitable
|
from collections.abc import Awaitable, Callable
|
||||||
|
|
||||||
from langchain_core.messages import MessageLikeRepresentation
|
from langchain_core.messages import MessageLikeRepresentation
|
||||||
from langgraph.runtime import Runtime
|
from langgraph.runtime import Runtime
|
||||||
@@ -92,9 +92,7 @@ async def aresolve_prompt(
|
|||||||
str,
|
str,
|
||||||
None,
|
None,
|
||||||
Callable[[StateT, Runtime[ContextT]], list[MessageLikeRepresentation]],
|
Callable[[StateT, Runtime[ContextT]], list[MessageLikeRepresentation]],
|
||||||
Callable[
|
Callable[[StateT, Runtime[ContextT]], Awaitable[list[MessageLikeRepresentation]]],
|
||||||
[StateT, Runtime[ContextT]], Awaitable[list[MessageLikeRepresentation]]
|
|
||||||
],
|
|
||||||
],
|
],
|
||||||
state: StateT,
|
state: StateT,
|
||||||
runtime: Runtime[ContextT],
|
runtime: Runtime[ContextT],
|
||||||
|
@@ -2,11 +2,10 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Any, ClassVar, Protocol, TypeVar, Union
|
from typing import TYPE_CHECKING, Any, ClassVar, Protocol, TypeAlias, TypeVar, Union
|
||||||
|
|
||||||
from langgraph.graph._node import StateNode
|
from langgraph.graph._node import StateNode
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing_extensions import TypeAlias
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from dataclasses import Field
|
from dataclasses import Field
|
||||||
|
@@ -7,10 +7,8 @@ from typing import (
|
|||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Annotated,
|
Annotated,
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
|
||||||
Generic,
|
Generic,
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
|
||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
@@ -26,6 +24,8 @@ from langchain._internal._utils import RunnableCallable
|
|||||||
from langchain.chat_models import init_chat_model
|
from langchain.chat_models import init_chat_model
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.language_models.chat_models import BaseChatModel
|
from langchain_core.language_models.chat_models import BaseChatModel
|
||||||
|
|
||||||
@@ -154,7 +154,7 @@ class _MapReduceExtractor(Generic[ContextT]):
|
|||||||
StateNode,
|
StateNode,
|
||||||
] = "default_reducer",
|
] = "default_reducer",
|
||||||
context_schema: type[ContextT] | None = None,
|
context_schema: type[ContextT] | None = None,
|
||||||
response_format: Optional[type[BaseModel]] = None,
|
response_format: type[BaseModel] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the MapReduceExtractor.
|
"""Initialize the MapReduceExtractor.
|
||||||
|
|
||||||
@@ -190,9 +190,7 @@ class _MapReduceExtractor(Generic[ContextT]):
|
|||||||
if isinstance(model, str):
|
if isinstance(model, str):
|
||||||
model = init_chat_model(model)
|
model = init_chat_model(model)
|
||||||
|
|
||||||
self.model = (
|
self.model = model.with_structured_output(response_format) if response_format else model
|
||||||
model.with_structured_output(response_format) if response_format else model
|
|
||||||
)
|
|
||||||
self.map_prompt = map_prompt
|
self.map_prompt = map_prompt
|
||||||
self.reduce_prompt = reduce_prompt
|
self.reduce_prompt = reduce_prompt
|
||||||
self.reduce = reduce
|
self.reduce = reduce
|
||||||
@@ -342,9 +340,7 @@ class _MapReduceExtractor(Generic[ContextT]):
|
|||||||
config: RunnableConfig,
|
config: RunnableConfig,
|
||||||
) -> dict[str, list[ExtractionResult]]:
|
) -> dict[str, list[ExtractionResult]]:
|
||||||
prompt = await self._aget_map_prompt(state, runtime)
|
prompt = await self._aget_map_prompt(state, runtime)
|
||||||
response = cast(
|
response = cast("AIMessage", await self.model.ainvoke(prompt, config=config))
|
||||||
"AIMessage", await self.model.ainvoke(prompt, config=config)
|
|
||||||
)
|
|
||||||
result = response if self.response_format else response.text()
|
result = response if self.response_format else response.text()
|
||||||
extraction_result: ExtractionResult = {
|
extraction_result: ExtractionResult = {
|
||||||
"indexes": state["indexes"],
|
"indexes": state["indexes"],
|
||||||
@@ -375,9 +371,7 @@ class _MapReduceExtractor(Generic[ContextT]):
|
|||||||
config: RunnableConfig,
|
config: RunnableConfig,
|
||||||
) -> MapReduceNodeUpdate:
|
) -> MapReduceNodeUpdate:
|
||||||
prompt = await self._aget_reduce_prompt(state, runtime)
|
prompt = await self._aget_reduce_prompt(state, runtime)
|
||||||
response = cast(
|
response = cast("AIMessage", await self.model.ainvoke(prompt, config=config))
|
||||||
"AIMessage", await self.model.ainvoke(prompt, config=config)
|
|
||||||
)
|
|
||||||
result = response if self.response_format else response.text()
|
result = response if self.response_format else response.text()
|
||||||
return {"result": result}
|
return {"result": result}
|
||||||
|
|
||||||
@@ -411,7 +405,7 @@ class _MapReduceExtractor(Generic[ContextT]):
|
|||||||
# Add-conditional edges doesn't explicitly type Send
|
# Add-conditional edges doesn't explicitly type Send
|
||||||
builder.add_conditional_edges(
|
builder.add_conditional_edges(
|
||||||
"continue_to_map",
|
"continue_to_map",
|
||||||
self.continue_to_map, # type: ignore[arg-type]
|
self.continue_to_map,
|
||||||
["map_process"],
|
["map_process"],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -422,7 +416,7 @@ class _MapReduceExtractor(Generic[ContextT]):
|
|||||||
builder.add_edge("map_process", "reduce_process")
|
builder.add_edge("map_process", "reduce_process")
|
||||||
builder.add_edge("reduce_process", END)
|
builder.add_edge("reduce_process", END)
|
||||||
else:
|
else:
|
||||||
reduce_node = cast("StateNode", self.reduce)
|
reduce_node = self.reduce
|
||||||
# The type is ignored here. Requires parameterizing with generics.
|
# The type is ignored here. Requires parameterizing with generics.
|
||||||
builder.add_node("reduce_process", reduce_node) # type: ignore[arg-type]
|
builder.add_node("reduce_process", reduce_node) # type: ignore[arg-type]
|
||||||
builder.add_edge("map_process", "reduce_process")
|
builder.add_edge("map_process", "reduce_process")
|
||||||
@@ -450,7 +444,7 @@ def create_map_reduce_chain(
|
|||||||
StateNode,
|
StateNode,
|
||||||
] = "default_reducer",
|
] = "default_reducer",
|
||||||
context_schema: type[ContextT] | None = None,
|
context_schema: type[ContextT] | None = None,
|
||||||
response_format: Optional[type[BaseModel]] = None,
|
response_format: type[BaseModel] | None = None,
|
||||||
) -> StateGraph[MapReduceState, ContextT, InputSchema, OutputSchema]:
|
) -> StateGraph[MapReduceState, ContextT, InputSchema, OutputSchema]:
|
||||||
"""Create a map-reduce document extraction chain.
|
"""Create a map-reduce document extraction chain.
|
||||||
|
|
||||||
|
@@ -5,9 +5,7 @@ from __future__ import annotations
|
|||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
|
||||||
Generic,
|
Generic,
|
||||||
Optional,
|
|
||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
@@ -24,6 +22,8 @@ from langchain._internal._utils import RunnableCallable
|
|||||||
from langchain.chat_models import init_chat_model
|
from langchain.chat_models import init_chat_model
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
from langchain_core.language_models.chat_models import BaseChatModel
|
from langchain_core.language_models.chat_models import BaseChatModel
|
||||||
|
|
||||||
# Used for type checking, but IDEs may not recognize it inside the cast.
|
# Used for type checking, but IDEs may not recognize it inside the cast.
|
||||||
@@ -149,7 +149,7 @@ class _Extractor(Generic[ContextT]):
|
|||||||
],
|
],
|
||||||
] = None,
|
] = None,
|
||||||
context_schema: type[ContextT] | None = None,
|
context_schema: type[ContextT] | None = None,
|
||||||
response_format: Optional[type[BaseModel]] = None,
|
response_format: type[BaseModel] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the Extractor.
|
"""Initialize the Extractor.
|
||||||
|
|
||||||
@@ -173,9 +173,7 @@ class _Extractor(Generic[ContextT]):
|
|||||||
if isinstance(model, str):
|
if isinstance(model, str):
|
||||||
model = init_chat_model(model)
|
model = init_chat_model(model)
|
||||||
|
|
||||||
self.model = (
|
self.model = model.with_structured_output(response_format) if response_format else model
|
||||||
model.with_structured_output(response_format) if response_format else model
|
|
||||||
)
|
|
||||||
self.initial_prompt = prompt
|
self.initial_prompt = prompt
|
||||||
self.refine_prompt = refine_prompt
|
self.refine_prompt = refine_prompt
|
||||||
self.context_schema = context_schema
|
self.context_schema = context_schema
|
||||||
@@ -188,9 +186,7 @@ class _Extractor(Generic[ContextT]):
|
|||||||
|
|
||||||
# Choose default prompt based on structured output format
|
# Choose default prompt based on structured output format
|
||||||
default_prompt = (
|
default_prompt = (
|
||||||
DEFAULT_STRUCTURED_INIT_PROMPT
|
DEFAULT_STRUCTURED_INIT_PROMPT if self.response_format else DEFAULT_INIT_PROMPT
|
||||||
if self.response_format
|
|
||||||
else DEFAULT_INIT_PROMPT
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return resolve_prompt(
|
return resolve_prompt(
|
||||||
@@ -209,9 +205,7 @@ class _Extractor(Generic[ContextT]):
|
|||||||
|
|
||||||
# Choose default prompt based on structured output format
|
# Choose default prompt based on structured output format
|
||||||
default_prompt = (
|
default_prompt = (
|
||||||
DEFAULT_STRUCTURED_INIT_PROMPT
|
DEFAULT_STRUCTURED_INIT_PROMPT if self.response_format else DEFAULT_INIT_PROMPT
|
||||||
if self.response_format
|
|
||||||
else DEFAULT_INIT_PROMPT
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return await aresolve_prompt(
|
return await aresolve_prompt(
|
||||||
@@ -246,9 +240,7 @@ class _Extractor(Generic[ContextT]):
|
|||||||
|
|
||||||
# Choose default prompt based on structured output format
|
# Choose default prompt based on structured output format
|
||||||
default_prompt = (
|
default_prompt = (
|
||||||
DEFAULT_STRUCTURED_REFINE_PROMPT
|
DEFAULT_STRUCTURED_REFINE_PROMPT if self.response_format else DEFAULT_REFINE_PROMPT
|
||||||
if self.response_format
|
|
||||||
else DEFAULT_REFINE_PROMPT
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return resolve_prompt(
|
return resolve_prompt(
|
||||||
@@ -283,9 +275,7 @@ class _Extractor(Generic[ContextT]):
|
|||||||
|
|
||||||
# Choose default prompt based on structured output format
|
# Choose default prompt based on structured output format
|
||||||
default_prompt = (
|
default_prompt = (
|
||||||
DEFAULT_STRUCTURED_REFINE_PROMPT
|
DEFAULT_STRUCTURED_REFINE_PROMPT if self.response_format else DEFAULT_REFINE_PROMPT
|
||||||
if self.response_format
|
|
||||||
else DEFAULT_REFINE_PROMPT
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return await aresolve_prompt(
|
return await aresolve_prompt(
|
||||||
@@ -340,16 +330,12 @@ class _Extractor(Generic[ContextT]):
|
|||||||
if "result" not in state or state["result"] == "":
|
if "result" not in state or state["result"] == "":
|
||||||
# Initial processing
|
# Initial processing
|
||||||
prompt = await self._aget_initial_prompt(state, runtime)
|
prompt = await self._aget_initial_prompt(state, runtime)
|
||||||
response = cast(
|
response = cast("AIMessage", await self.model.ainvoke(prompt, config=config))
|
||||||
"AIMessage", await self.model.ainvoke(prompt, config=config)
|
|
||||||
)
|
|
||||||
result = response if self.response_format else response.text()
|
result = response if self.response_format else response.text()
|
||||||
return {"result": result}
|
return {"result": result}
|
||||||
# Refinement
|
# Refinement
|
||||||
prompt = await self._aget_refine_prompt(state, runtime)
|
prompt = await self._aget_refine_prompt(state, runtime)
|
||||||
response = cast(
|
response = cast("AIMessage", await self.model.ainvoke(prompt, config=config))
|
||||||
"AIMessage", await self.model.ainvoke(prompt, config=config)
|
|
||||||
)
|
|
||||||
result = response if self.response_format else response.text()
|
result = response if self.response_format else response.text()
|
||||||
return {"result": result}
|
return {"result": result}
|
||||||
|
|
||||||
@@ -388,7 +374,7 @@ def create_stuff_documents_chain(
|
|||||||
Callable[[ExtractionState, Runtime[ContextT]], list[MessageLikeRepresentation]],
|
Callable[[ExtractionState, Runtime[ContextT]], list[MessageLikeRepresentation]],
|
||||||
] = None,
|
] = None,
|
||||||
context_schema: type[ContextT] | None = None,
|
context_schema: type[ContextT] | None = None,
|
||||||
response_format: Optional[type[BaseModel]] = None,
|
response_format: type[BaseModel] | None = None,
|
||||||
) -> StateGraph[ExtractionState, ContextT, InputSchema, OutputSchema]:
|
) -> StateGraph[ExtractionState, ContextT, InputSchema, OutputSchema]:
|
||||||
"""Create a stuff documents chain for processing documents.
|
"""Create a stuff documents chain for processing documents.
|
||||||
|
|
||||||
|
@@ -5,9 +5,8 @@ from importlib import util
|
|||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
TypeAlias,
|
||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
overload,
|
overload,
|
||||||
@@ -16,10 +15,10 @@ from typing import (
|
|||||||
from langchain_core.language_models import BaseChatModel, LanguageModelInput
|
from langchain_core.language_models import BaseChatModel, LanguageModelInput
|
||||||
from langchain_core.messages import AnyMessage, BaseMessage
|
from langchain_core.messages import AnyMessage, BaseMessage
|
||||||
from langchain_core.runnables import Runnable, RunnableConfig, ensure_config
|
from langchain_core.runnables import Runnable, RunnableConfig, ensure_config
|
||||||
from typing_extensions import TypeAlias, override
|
from typing_extensions import override
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from collections.abc import AsyncIterator, Iterator, Sequence
|
from collections.abc import AsyncIterator, Callable, Iterator, Sequence
|
||||||
|
|
||||||
from langchain_core.runnables.schema import StreamEvent
|
from langchain_core.runnables.schema import StreamEvent
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
@@ -31,9 +30,9 @@ if TYPE_CHECKING:
|
|||||||
def init_chat_model(
|
def init_chat_model(
|
||||||
model: str,
|
model: str,
|
||||||
*,
|
*,
|
||||||
model_provider: Optional[str] = None,
|
model_provider: str | None = None,
|
||||||
configurable_fields: Literal[None] = None,
|
configurable_fields: Literal[None] = None,
|
||||||
config_prefix: Optional[str] = None,
|
config_prefix: str | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> BaseChatModel: ...
|
) -> BaseChatModel: ...
|
||||||
|
|
||||||
@@ -42,20 +41,20 @@ def init_chat_model(
|
|||||||
def init_chat_model(
|
def init_chat_model(
|
||||||
model: Literal[None] = None,
|
model: Literal[None] = None,
|
||||||
*,
|
*,
|
||||||
model_provider: Optional[str] = None,
|
model_provider: str | None = None,
|
||||||
configurable_fields: Literal[None] = None,
|
configurable_fields: Literal[None] = None,
|
||||||
config_prefix: Optional[str] = None,
|
config_prefix: str | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> _ConfigurableModel: ...
|
) -> _ConfigurableModel: ...
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def init_chat_model(
|
def init_chat_model(
|
||||||
model: Optional[str] = None,
|
model: str | None = None,
|
||||||
*,
|
*,
|
||||||
model_provider: Optional[str] = None,
|
model_provider: str | None = None,
|
||||||
configurable_fields: Union[Literal["any"], list[str], tuple[str, ...]] = ...,
|
configurable_fields: Union[Literal["any"], list[str], tuple[str, ...]] = ...,
|
||||||
config_prefix: Optional[str] = None,
|
config_prefix: str | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> _ConfigurableModel: ...
|
) -> _ConfigurableModel: ...
|
||||||
|
|
||||||
@@ -64,13 +63,11 @@ def init_chat_model(
|
|||||||
# name to the supported list in the docstring below. Do *not* change the order of the
|
# name to the supported list in the docstring below. Do *not* change the order of the
|
||||||
# existing providers.
|
# existing providers.
|
||||||
def init_chat_model(
|
def init_chat_model(
|
||||||
model: Optional[str] = None,
|
model: str | None = None,
|
||||||
*,
|
*,
|
||||||
model_provider: Optional[str] = None,
|
model_provider: str | None = None,
|
||||||
configurable_fields: Optional[
|
configurable_fields: Union[Literal["any"], list[str], tuple[str, ...]] | None = None,
|
||||||
Union[Literal["any"], list[str], tuple[str, ...]]
|
config_prefix: str | None = None,
|
||||||
] = None,
|
|
||||||
config_prefix: Optional[str] = None,
|
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Union[BaseChatModel, _ConfigurableModel]:
|
) -> Union[BaseChatModel, _ConfigurableModel]:
|
||||||
"""Initialize a ChatModel from the model name and provider.
|
"""Initialize a ChatModel from the model name and provider.
|
||||||
@@ -326,7 +323,7 @@ def init_chat_model(
|
|||||||
def _init_chat_model_helper(
|
def _init_chat_model_helper(
|
||||||
model: str,
|
model: str,
|
||||||
*,
|
*,
|
||||||
model_provider: Optional[str] = None,
|
model_provider: str | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> BaseChatModel:
|
) -> BaseChatModel:
|
||||||
model, model_provider = _parse_model(model, model_provider)
|
model, model_provider = _parse_model(model, model_provider)
|
||||||
@@ -446,9 +443,7 @@ def _init_chat_model_helper(
|
|||||||
|
|
||||||
return ChatPerplexity(model=model, **kwargs)
|
return ChatPerplexity(model=model, **kwargs)
|
||||||
supported = ", ".join(_SUPPORTED_PROVIDERS)
|
supported = ", ".join(_SUPPORTED_PROVIDERS)
|
||||||
msg = (
|
msg = f"Unsupported {model_provider=}.\n\nSupported model providers are: {supported}"
|
||||||
f"Unsupported {model_provider=}.\n\nSupported model providers are: {supported}"
|
|
||||||
)
|
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
|
||||||
@@ -476,7 +471,7 @@ _SUPPORTED_PROVIDERS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _attempt_infer_model_provider(model_name: str) -> Optional[str]:
|
def _attempt_infer_model_provider(model_name: str) -> str | None:
|
||||||
if any(model_name.startswith(pre) for pre in ("gpt-3", "gpt-4", "o1", "o3")):
|
if any(model_name.startswith(pre) for pre in ("gpt-3", "gpt-4", "o1", "o3")):
|
||||||
return "openai"
|
return "openai"
|
||||||
if model_name.startswith("claude"):
|
if model_name.startswith("claude"):
|
||||||
@@ -500,31 +495,24 @@ def _attempt_infer_model_provider(model_name: str) -> Optional[str]:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _parse_model(model: str, model_provider: Optional[str]) -> tuple[str, str]:
|
def _parse_model(model: str, model_provider: str | None) -> tuple[str, str]:
|
||||||
if (
|
if not model_provider and ":" in model and model.split(":")[0] in _SUPPORTED_PROVIDERS:
|
||||||
not model_provider
|
|
||||||
and ":" in model
|
|
||||||
and model.split(":")[0] in _SUPPORTED_PROVIDERS
|
|
||||||
):
|
|
||||||
model_provider = model.split(":")[0]
|
model_provider = model.split(":")[0]
|
||||||
model = ":".join(model.split(":")[1:])
|
model = ":".join(model.split(":")[1:])
|
||||||
model_provider = model_provider or _attempt_infer_model_provider(model)
|
model_provider = model_provider or _attempt_infer_model_provider(model)
|
||||||
if not model_provider:
|
if not model_provider:
|
||||||
msg = (
|
msg = (
|
||||||
f"Unable to infer model provider for {model=}, please specify "
|
f"Unable to infer model provider for {model=}, please specify model_provider directly."
|
||||||
f"model_provider directly."
|
|
||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
model_provider = model_provider.replace("-", "_").lower()
|
model_provider = model_provider.replace("-", "_").lower()
|
||||||
return model, model_provider
|
return model, model_provider
|
||||||
|
|
||||||
|
|
||||||
def _check_pkg(pkg: str, *, pkg_kebab: Optional[str] = None) -> None:
|
def _check_pkg(pkg: str, *, pkg_kebab: str | None = None) -> None:
|
||||||
if not util.find_spec(pkg):
|
if not util.find_spec(pkg):
|
||||||
pkg_kebab = pkg_kebab if pkg_kebab is not None else pkg.replace("_", "-")
|
pkg_kebab = pkg_kebab if pkg_kebab is not None else pkg.replace("_", "-")
|
||||||
msg = (
|
msg = f"Unable to import {pkg}. Please install with `pip install -U {pkg_kebab}`"
|
||||||
f"Unable to import {pkg}. Please install with `pip install -U {pkg_kebab}`"
|
|
||||||
)
|
|
||||||
raise ImportError(msg)
|
raise ImportError(msg)
|
||||||
|
|
||||||
|
|
||||||
@@ -539,16 +527,14 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
default_config: Optional[dict] = None,
|
default_config: dict | None = None,
|
||||||
configurable_fields: Union[Literal["any"], list[str], tuple[str, ...]] = "any",
|
configurable_fields: Union[Literal["any"], list[str], tuple[str, ...]] = "any",
|
||||||
config_prefix: str = "",
|
config_prefix: str = "",
|
||||||
queued_declarative_operations: Sequence[tuple[str, tuple, dict]] = (),
|
queued_declarative_operations: Sequence[tuple[str, tuple, dict]] = (),
|
||||||
) -> None:
|
) -> None:
|
||||||
self._default_config: dict = default_config or {}
|
self._default_config: dict = default_config or {}
|
||||||
self._configurable_fields: Union[Literal["any"], list[str]] = (
|
self._configurable_fields: Union[Literal["any"], list[str]] = (
|
||||||
configurable_fields
|
configurable_fields if configurable_fields == "any" else list(configurable_fields)
|
||||||
if configurable_fields == "any"
|
|
||||||
else list(configurable_fields)
|
|
||||||
)
|
)
|
||||||
self._config_prefix = (
|
self._config_prefix = (
|
||||||
config_prefix + "_"
|
config_prefix + "_"
|
||||||
@@ -589,14 +575,14 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|||||||
msg += "."
|
msg += "."
|
||||||
raise AttributeError(msg)
|
raise AttributeError(msg)
|
||||||
|
|
||||||
def _model(self, config: Optional[RunnableConfig] = None) -> Runnable:
|
def _model(self, config: RunnableConfig | None = None) -> Runnable:
|
||||||
params = {**self._default_config, **self._model_params(config)}
|
params = {**self._default_config, **self._model_params(config)}
|
||||||
model = _init_chat_model_helper(**params)
|
model = _init_chat_model_helper(**params)
|
||||||
for name, args, kwargs in self._queued_declarative_operations:
|
for name, args, kwargs in self._queued_declarative_operations:
|
||||||
model = getattr(model, name)(*args, **kwargs)
|
model = getattr(model, name)(*args, **kwargs)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def _model_params(self, config: Optional[RunnableConfig]) -> dict:
|
def _model_params(self, config: RunnableConfig | None) -> dict:
|
||||||
config = ensure_config(config)
|
config = ensure_config(config)
|
||||||
model_params = {
|
model_params = {
|
||||||
_remove_prefix(k, self._config_prefix): v
|
_remove_prefix(k, self._config_prefix): v
|
||||||
@@ -604,14 +590,12 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|||||||
if k.startswith(self._config_prefix)
|
if k.startswith(self._config_prefix)
|
||||||
}
|
}
|
||||||
if self._configurable_fields != "any":
|
if self._configurable_fields != "any":
|
||||||
model_params = {
|
model_params = {k: v for k, v in model_params.items() if k in self._configurable_fields}
|
||||||
k: v for k, v in model_params.items() if k in self._configurable_fields
|
|
||||||
}
|
|
||||||
return model_params
|
return model_params
|
||||||
|
|
||||||
def with_config(
|
def with_config(
|
||||||
self,
|
self,
|
||||||
config: Optional[RunnableConfig] = None,
|
config: RunnableConfig | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> _ConfigurableModel:
|
) -> _ConfigurableModel:
|
||||||
"""Bind config to a Runnable, returning a new Runnable."""
|
"""Bind config to a Runnable, returning a new Runnable."""
|
||||||
@@ -662,7 +646,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|||||||
def invoke(
|
def invoke(
|
||||||
self,
|
self,
|
||||||
input: LanguageModelInput,
|
input: LanguageModelInput,
|
||||||
config: Optional[RunnableConfig] = None,
|
config: RunnableConfig | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
return self._model(config).invoke(input, config=config, **kwargs)
|
return self._model(config).invoke(input, config=config, **kwargs)
|
||||||
@@ -671,7 +655,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|||||||
async def ainvoke(
|
async def ainvoke(
|
||||||
self,
|
self,
|
||||||
input: LanguageModelInput,
|
input: LanguageModelInput,
|
||||||
config: Optional[RunnableConfig] = None,
|
config: RunnableConfig | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
return await self._model(config).ainvoke(input, config=config, **kwargs)
|
return await self._model(config).ainvoke(input, config=config, **kwargs)
|
||||||
@@ -680,8 +664,8 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|||||||
def stream(
|
def stream(
|
||||||
self,
|
self,
|
||||||
input: LanguageModelInput,
|
input: LanguageModelInput,
|
||||||
config: Optional[RunnableConfig] = None,
|
config: RunnableConfig | None = None,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Any | None,
|
||||||
) -> Iterator[Any]:
|
) -> Iterator[Any]:
|
||||||
yield from self._model(config).stream(input, config=config, **kwargs)
|
yield from self._model(config).stream(input, config=config, **kwargs)
|
||||||
|
|
||||||
@@ -689,8 +673,8 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|||||||
async def astream(
|
async def astream(
|
||||||
self,
|
self,
|
||||||
input: LanguageModelInput,
|
input: LanguageModelInput,
|
||||||
config: Optional[RunnableConfig] = None,
|
config: RunnableConfig | None = None,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Any | None,
|
||||||
) -> AsyncIterator[Any]:
|
) -> AsyncIterator[Any]:
|
||||||
async for x in self._model(config).astream(input, config=config, **kwargs):
|
async for x in self._model(config).astream(input, config=config, **kwargs):
|
||||||
yield x
|
yield x
|
||||||
@@ -698,10 +682,10 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|||||||
def batch(
|
def batch(
|
||||||
self,
|
self,
|
||||||
inputs: list[LanguageModelInput],
|
inputs: list[LanguageModelInput],
|
||||||
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
|
config: Union[RunnableConfig, list[RunnableConfig]] | None = None,
|
||||||
*,
|
*,
|
||||||
return_exceptions: bool = False,
|
return_exceptions: bool = False,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Any | None,
|
||||||
) -> list[Any]:
|
) -> list[Any]:
|
||||||
config = config or None
|
config = config or None
|
||||||
# If <= 1 config use the underlying models batch implementation.
|
# If <= 1 config use the underlying models batch implementation.
|
||||||
@@ -726,10 +710,10 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|||||||
async def abatch(
|
async def abatch(
|
||||||
self,
|
self,
|
||||||
inputs: list[LanguageModelInput],
|
inputs: list[LanguageModelInput],
|
||||||
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
|
config: Union[RunnableConfig, list[RunnableConfig]] | None = None,
|
||||||
*,
|
*,
|
||||||
return_exceptions: bool = False,
|
return_exceptions: bool = False,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Any | None,
|
||||||
) -> list[Any]:
|
) -> list[Any]:
|
||||||
config = config or None
|
config = config or None
|
||||||
# If <= 1 config use the underlying models batch implementation.
|
# If <= 1 config use the underlying models batch implementation.
|
||||||
@@ -754,7 +738,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|||||||
def batch_as_completed(
|
def batch_as_completed(
|
||||||
self,
|
self,
|
||||||
inputs: Sequence[LanguageModelInput],
|
inputs: Sequence[LanguageModelInput],
|
||||||
config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None,
|
config: Union[RunnableConfig, Sequence[RunnableConfig]] | None = None,
|
||||||
*,
|
*,
|
||||||
return_exceptions: bool = False,
|
return_exceptions: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
@@ -783,7 +767,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|||||||
async def abatch_as_completed(
|
async def abatch_as_completed(
|
||||||
self,
|
self,
|
||||||
inputs: Sequence[LanguageModelInput],
|
inputs: Sequence[LanguageModelInput],
|
||||||
config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None,
|
config: Union[RunnableConfig, Sequence[RunnableConfig]] | None = None,
|
||||||
*,
|
*,
|
||||||
return_exceptions: bool = False,
|
return_exceptions: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
@@ -817,8 +801,8 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|||||||
def transform(
|
def transform(
|
||||||
self,
|
self,
|
||||||
input: Iterator[LanguageModelInput],
|
input: Iterator[LanguageModelInput],
|
||||||
config: Optional[RunnableConfig] = None,
|
config: RunnableConfig | None = None,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Any | None,
|
||||||
) -> Iterator[Any]:
|
) -> Iterator[Any]:
|
||||||
yield from self._model(config).transform(input, config=config, **kwargs)
|
yield from self._model(config).transform(input, config=config, **kwargs)
|
||||||
|
|
||||||
@@ -826,8 +810,8 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|||||||
async def atransform(
|
async def atransform(
|
||||||
self,
|
self,
|
||||||
input: AsyncIterator[LanguageModelInput],
|
input: AsyncIterator[LanguageModelInput],
|
||||||
config: Optional[RunnableConfig] = None,
|
config: RunnableConfig | None = None,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Any | None,
|
||||||
) -> AsyncIterator[Any]:
|
) -> AsyncIterator[Any]:
|
||||||
async for x in self._model(config).atransform(input, config=config, **kwargs):
|
async for x in self._model(config).atransform(input, config=config, **kwargs):
|
||||||
yield x
|
yield x
|
||||||
@@ -836,16 +820,16 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|||||||
def astream_log(
|
def astream_log(
|
||||||
self,
|
self,
|
||||||
input: Any,
|
input: Any,
|
||||||
config: Optional[RunnableConfig] = None,
|
config: RunnableConfig | None = None,
|
||||||
*,
|
*,
|
||||||
diff: Literal[True] = True,
|
diff: Literal[True] = True,
|
||||||
with_streamed_output_list: bool = True,
|
with_streamed_output_list: bool = True,
|
||||||
include_names: Optional[Sequence[str]] = None,
|
include_names: Sequence[str] | None = None,
|
||||||
include_types: Optional[Sequence[str]] = None,
|
include_types: Sequence[str] | None = None,
|
||||||
include_tags: Optional[Sequence[str]] = None,
|
include_tags: Sequence[str] | None = None,
|
||||||
exclude_names: Optional[Sequence[str]] = None,
|
exclude_names: Sequence[str] | None = None,
|
||||||
exclude_types: Optional[Sequence[str]] = None,
|
exclude_types: Sequence[str] | None = None,
|
||||||
exclude_tags: Optional[Sequence[str]] = None,
|
exclude_tags: Sequence[str] | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AsyncIterator[RunLogPatch]: ...
|
) -> AsyncIterator[RunLogPatch]: ...
|
||||||
|
|
||||||
@@ -853,16 +837,16 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|||||||
def astream_log(
|
def astream_log(
|
||||||
self,
|
self,
|
||||||
input: Any,
|
input: Any,
|
||||||
config: Optional[RunnableConfig] = None,
|
config: RunnableConfig | None = None,
|
||||||
*,
|
*,
|
||||||
diff: Literal[False],
|
diff: Literal[False],
|
||||||
with_streamed_output_list: bool = True,
|
with_streamed_output_list: bool = True,
|
||||||
include_names: Optional[Sequence[str]] = None,
|
include_names: Sequence[str] | None = None,
|
||||||
include_types: Optional[Sequence[str]] = None,
|
include_types: Sequence[str] | None = None,
|
||||||
include_tags: Optional[Sequence[str]] = None,
|
include_tags: Sequence[str] | None = None,
|
||||||
exclude_names: Optional[Sequence[str]] = None,
|
exclude_names: Sequence[str] | None = None,
|
||||||
exclude_types: Optional[Sequence[str]] = None,
|
exclude_types: Sequence[str] | None = None,
|
||||||
exclude_tags: Optional[Sequence[str]] = None,
|
exclude_tags: Sequence[str] | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AsyncIterator[RunLog]: ...
|
) -> AsyncIterator[RunLog]: ...
|
||||||
|
|
||||||
@@ -870,16 +854,16 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|||||||
async def astream_log(
|
async def astream_log(
|
||||||
self,
|
self,
|
||||||
input: Any,
|
input: Any,
|
||||||
config: Optional[RunnableConfig] = None,
|
config: RunnableConfig | None = None,
|
||||||
*,
|
*,
|
||||||
diff: bool = True,
|
diff: bool = True,
|
||||||
with_streamed_output_list: bool = True,
|
with_streamed_output_list: bool = True,
|
||||||
include_names: Optional[Sequence[str]] = None,
|
include_names: Sequence[str] | None = None,
|
||||||
include_types: Optional[Sequence[str]] = None,
|
include_types: Sequence[str] | None = None,
|
||||||
include_tags: Optional[Sequence[str]] = None,
|
include_tags: Sequence[str] | None = None,
|
||||||
exclude_names: Optional[Sequence[str]] = None,
|
exclude_names: Sequence[str] | None = None,
|
||||||
exclude_types: Optional[Sequence[str]] = None,
|
exclude_types: Sequence[str] | None = None,
|
||||||
exclude_tags: Optional[Sequence[str]] = None,
|
exclude_tags: Sequence[str] | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Union[AsyncIterator[RunLogPatch], AsyncIterator[RunLog]]:
|
) -> Union[AsyncIterator[RunLogPatch], AsyncIterator[RunLog]]:
|
||||||
async for x in self._model(config).astream_log( # type: ignore[call-overload, misc]
|
async for x in self._model(config).astream_log( # type: ignore[call-overload, misc]
|
||||||
@@ -901,15 +885,15 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|||||||
async def astream_events(
|
async def astream_events(
|
||||||
self,
|
self,
|
||||||
input: Any,
|
input: Any,
|
||||||
config: Optional[RunnableConfig] = None,
|
config: RunnableConfig | None = None,
|
||||||
*,
|
*,
|
||||||
version: Literal["v1", "v2"] = "v2",
|
version: Literal["v1", "v2"] = "v2",
|
||||||
include_names: Optional[Sequence[str]] = None,
|
include_names: Sequence[str] | None = None,
|
||||||
include_types: Optional[Sequence[str]] = None,
|
include_types: Sequence[str] | None = None,
|
||||||
include_tags: Optional[Sequence[str]] = None,
|
include_tags: Sequence[str] | None = None,
|
||||||
exclude_names: Optional[Sequence[str]] = None,
|
exclude_names: Sequence[str] | None = None,
|
||||||
exclude_types: Optional[Sequence[str]] = None,
|
exclude_types: Sequence[str] | None = None,
|
||||||
exclude_tags: Optional[Sequence[str]] = None,
|
exclude_tags: Sequence[str] | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AsyncIterator[StreamEvent]:
|
) -> AsyncIterator[StreamEvent]:
|
||||||
async for x in self._model(config).astream_events(
|
async for x in self._model(config).astream_events(
|
||||||
|
@@ -1,6 +1,6 @@
|
|||||||
import functools
|
import functools
|
||||||
from importlib import util
|
from importlib import util
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Union
|
||||||
|
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.runnables import Runnable
|
from langchain_core.runnables import Runnable
|
||||||
@@ -19,9 +19,7 @@ _SUPPORTED_PROVIDERS = {
|
|||||||
|
|
||||||
def _get_provider_list() -> str:
|
def _get_provider_list() -> str:
|
||||||
"""Get formatted list of providers and their packages."""
|
"""Get formatted list of providers and their packages."""
|
||||||
return "\n".join(
|
return "\n".join(f" - {p}: {pkg.replace('_', '-')}" for p, pkg in _SUPPORTED_PROVIDERS.items())
|
||||||
f" - {p}: {pkg.replace('_', '-')}" for p, pkg in _SUPPORTED_PROVIDERS.items()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_model_string(model_name: str) -> tuple[str, str]:
|
def _parse_model_string(model_name: str) -> tuple[str, str]:
|
||||||
@@ -82,7 +80,7 @@ def _parse_model_string(model_name: str) -> tuple[str, str]:
|
|||||||
def _infer_model_and_provider(
|
def _infer_model_and_provider(
|
||||||
model: str,
|
model: str,
|
||||||
*,
|
*,
|
||||||
provider: Optional[str] = None,
|
provider: str | None = None,
|
||||||
) -> tuple[str, str]:
|
) -> tuple[str, str]:
|
||||||
if not model.strip():
|
if not model.strip():
|
||||||
msg = "Model name cannot be empty"
|
msg = "Model name cannot be empty"
|
||||||
@@ -117,17 +115,14 @@ def _infer_model_and_provider(
|
|||||||
def _check_pkg(pkg: str) -> None:
|
def _check_pkg(pkg: str) -> None:
|
||||||
"""Check if a package is installed."""
|
"""Check if a package is installed."""
|
||||||
if not util.find_spec(pkg):
|
if not util.find_spec(pkg):
|
||||||
msg = (
|
msg = f"Could not import {pkg} python package. Please install it with `pip install {pkg}`"
|
||||||
f"Could not import {pkg} python package. "
|
|
||||||
f"Please install it with `pip install {pkg}`"
|
|
||||||
)
|
|
||||||
raise ImportError(msg)
|
raise ImportError(msg)
|
||||||
|
|
||||||
|
|
||||||
def init_embeddings(
|
def init_embeddings(
|
||||||
model: str,
|
model: str,
|
||||||
*,
|
*,
|
||||||
provider: Optional[str] = None,
|
provider: str | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Union[Embeddings, Runnable[Any, list[float]]]:
|
) -> Union[Embeddings, Runnable[Any, list[float]]]:
|
||||||
"""Initialize an embeddings model from a model name and optional provider.
|
"""Initialize an embeddings model from a model name and optional provider.
|
||||||
@@ -182,9 +177,7 @@ def init_embeddings(
|
|||||||
"""
|
"""
|
||||||
if not model:
|
if not model:
|
||||||
providers = _SUPPORTED_PROVIDERS.keys()
|
providers = _SUPPORTED_PROVIDERS.keys()
|
||||||
msg = (
|
msg = f"Must specify model name. Supported providers are: {', '.join(providers)}"
|
||||||
f"Must specify model name. Supported providers are: {', '.join(providers)}"
|
|
||||||
)
|
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
provider, model_name = _infer_model_and_provider(model, provider=provider)
|
provider, model_name = _infer_model_and_provider(model, provider=provider)
|
||||||
|
@@ -13,7 +13,7 @@ import hashlib
|
|||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
import warnings
|
import warnings
|
||||||
from typing import TYPE_CHECKING, Callable, Literal, Optional, Union, cast
|
from typing import TYPE_CHECKING, Literal, Union, cast
|
||||||
|
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.utils.iter import batch_iterate
|
from langchain_core.utils.iter import batch_iterate
|
||||||
@@ -21,7 +21,7 @@ from langchain_core.utils.iter import batch_iterate
|
|||||||
from langchain.storage.encoder_backed import EncoderBackedStore
|
from langchain.storage.encoder_backed import EncoderBackedStore
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from collections.abc import Sequence
|
from collections.abc import Callable, Sequence
|
||||||
|
|
||||||
from langchain_core.stores import BaseStore, ByteStore
|
from langchain_core.stores import BaseStore, ByteStore
|
||||||
|
|
||||||
@@ -147,8 +147,8 @@ class CacheBackedEmbeddings(Embeddings):
|
|||||||
underlying_embeddings: Embeddings,
|
underlying_embeddings: Embeddings,
|
||||||
document_embedding_store: BaseStore[str, list[float]],
|
document_embedding_store: BaseStore[str, list[float]],
|
||||||
*,
|
*,
|
||||||
batch_size: Optional[int] = None,
|
batch_size: int | None = None,
|
||||||
query_embedding_store: Optional[BaseStore[str, list[float]]] = None,
|
query_embedding_store: BaseStore[str, list[float]] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the embedder.
|
"""Initialize the embedder.
|
||||||
|
|
||||||
@@ -181,17 +181,15 @@ class CacheBackedEmbeddings(Embeddings):
|
|||||||
vectors: list[Union[list[float], None]] = self.document_embedding_store.mget(
|
vectors: list[Union[list[float], None]] = self.document_embedding_store.mget(
|
||||||
texts,
|
texts,
|
||||||
)
|
)
|
||||||
all_missing_indices: list[int] = [
|
all_missing_indices: list[int] = [i for i, vector in enumerate(vectors) if vector is None]
|
||||||
i for i, vector in enumerate(vectors) if vector is None
|
|
||||||
]
|
|
||||||
|
|
||||||
for missing_indices in batch_iterate(self.batch_size, all_missing_indices):
|
for missing_indices in batch_iterate(self.batch_size, all_missing_indices):
|
||||||
missing_texts = [texts[i] for i in missing_indices]
|
missing_texts = [texts[i] for i in missing_indices]
|
||||||
missing_vectors = self.underlying_embeddings.embed_documents(missing_texts)
|
missing_vectors = self.underlying_embeddings.embed_documents(missing_texts)
|
||||||
self.document_embedding_store.mset(
|
self.document_embedding_store.mset(
|
||||||
list(zip(missing_texts, missing_vectors)),
|
list(zip(missing_texts, missing_vectors, strict=False)),
|
||||||
)
|
)
|
||||||
for index, updated_vector in zip(missing_indices, missing_vectors):
|
for index, updated_vector in zip(missing_indices, missing_vectors, strict=False):
|
||||||
vectors[index] = updated_vector
|
vectors[index] = updated_vector
|
||||||
|
|
||||||
return cast(
|
return cast(
|
||||||
@@ -212,12 +210,8 @@ class CacheBackedEmbeddings(Embeddings):
|
|||||||
Returns:
|
Returns:
|
||||||
A list of embeddings for the given texts.
|
A list of embeddings for the given texts.
|
||||||
"""
|
"""
|
||||||
vectors: list[
|
vectors: list[Union[list[float], None]] = await self.document_embedding_store.amget(texts)
|
||||||
Union[list[float], None]
|
all_missing_indices: list[int] = [i for i, vector in enumerate(vectors) if vector is None]
|
||||||
] = await self.document_embedding_store.amget(texts)
|
|
||||||
all_missing_indices: list[int] = [
|
|
||||||
i for i, vector in enumerate(vectors) if vector is None
|
|
||||||
]
|
|
||||||
|
|
||||||
# batch_iterate supports None batch_size which returns all elements at once
|
# batch_iterate supports None batch_size which returns all elements at once
|
||||||
# as a single batch.
|
# as a single batch.
|
||||||
@@ -227,9 +221,9 @@ class CacheBackedEmbeddings(Embeddings):
|
|||||||
missing_texts,
|
missing_texts,
|
||||||
)
|
)
|
||||||
await self.document_embedding_store.amset(
|
await self.document_embedding_store.amset(
|
||||||
list(zip(missing_texts, missing_vectors)),
|
list(zip(missing_texts, missing_vectors, strict=False)),
|
||||||
)
|
)
|
||||||
for index, updated_vector in zip(missing_indices, missing_vectors):
|
for index, updated_vector in zip(missing_indices, missing_vectors, strict=False):
|
||||||
vectors[index] = updated_vector
|
vectors[index] = updated_vector
|
||||||
|
|
||||||
return cast(
|
return cast(
|
||||||
@@ -290,7 +284,7 @@ class CacheBackedEmbeddings(Embeddings):
|
|||||||
document_embedding_cache: ByteStore,
|
document_embedding_cache: ByteStore,
|
||||||
*,
|
*,
|
||||||
namespace: str = "",
|
namespace: str = "",
|
||||||
batch_size: Optional[int] = None,
|
batch_size: int | None = None,
|
||||||
query_embedding_cache: Union[bool, ByteStore] = False,
|
query_embedding_cache: Union[bool, ByteStore] = False,
|
||||||
key_encoder: Union[
|
key_encoder: Union[
|
||||||
Callable[[str], str],
|
Callable[[str], str],
|
||||||
|
@@ -1,8 +1,6 @@
|
|||||||
from collections.abc import AsyncIterator, Iterator, Sequence
|
from collections.abc import AsyncIterator, Callable, Iterator, Sequence
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
|
||||||
Optional,
|
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
@@ -62,37 +60,29 @@ class EncoderBackedStore(BaseStore[K, V]):
|
|||||||
self.value_serializer = value_serializer
|
self.value_serializer = value_serializer
|
||||||
self.value_deserializer = value_deserializer
|
self.value_deserializer = value_deserializer
|
||||||
|
|
||||||
def mget(self, keys: Sequence[K]) -> list[Optional[V]]:
|
def mget(self, keys: Sequence[K]) -> list[V | None]:
|
||||||
"""Get the values associated with the given keys."""
|
"""Get the values associated with the given keys."""
|
||||||
encoded_keys: list[str] = [self.key_encoder(key) for key in keys]
|
encoded_keys: list[str] = [self.key_encoder(key) for key in keys]
|
||||||
values = self.store.mget(encoded_keys)
|
values = self.store.mget(encoded_keys)
|
||||||
return [
|
return [self.value_deserializer(value) if value is not None else value for value in values]
|
||||||
self.value_deserializer(value) if value is not None else value
|
|
||||||
for value in values
|
|
||||||
]
|
|
||||||
|
|
||||||
async def amget(self, keys: Sequence[K]) -> list[Optional[V]]:
|
async def amget(self, keys: Sequence[K]) -> list[V | None]:
|
||||||
"""Get the values associated with the given keys."""
|
"""Get the values associated with the given keys."""
|
||||||
encoded_keys: list[str] = [self.key_encoder(key) for key in keys]
|
encoded_keys: list[str] = [self.key_encoder(key) for key in keys]
|
||||||
values = await self.store.amget(encoded_keys)
|
values = await self.store.amget(encoded_keys)
|
||||||
return [
|
return [self.value_deserializer(value) if value is not None else value for value in values]
|
||||||
self.value_deserializer(value) if value is not None else value
|
|
||||||
for value in values
|
|
||||||
]
|
|
||||||
|
|
||||||
def mset(self, key_value_pairs: Sequence[tuple[K, V]]) -> None:
|
def mset(self, key_value_pairs: Sequence[tuple[K, V]]) -> None:
|
||||||
"""Set the values for the given keys."""
|
"""Set the values for the given keys."""
|
||||||
encoded_pairs = [
|
encoded_pairs = [
|
||||||
(self.key_encoder(key), self.value_serializer(value))
|
(self.key_encoder(key), self.value_serializer(value)) for key, value in key_value_pairs
|
||||||
for key, value in key_value_pairs
|
|
||||||
]
|
]
|
||||||
self.store.mset(encoded_pairs)
|
self.store.mset(encoded_pairs)
|
||||||
|
|
||||||
async def amset(self, key_value_pairs: Sequence[tuple[K, V]]) -> None:
|
async def amset(self, key_value_pairs: Sequence[tuple[K, V]]) -> None:
|
||||||
"""Set the values for the given keys."""
|
"""Set the values for the given keys."""
|
||||||
encoded_pairs = [
|
encoded_pairs = [
|
||||||
(self.key_encoder(key), self.value_serializer(value))
|
(self.key_encoder(key), self.value_serializer(value)) for key, value in key_value_pairs
|
||||||
for key, value in key_value_pairs
|
|
||||||
]
|
]
|
||||||
await self.store.amset(encoded_pairs)
|
await self.store.amset(encoded_pairs)
|
||||||
|
|
||||||
@@ -109,7 +99,7 @@ class EncoderBackedStore(BaseStore[K, V]):
|
|||||||
def yield_keys(
|
def yield_keys(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
prefix: Optional[str] = None,
|
prefix: str | None = None,
|
||||||
) -> Union[Iterator[K], Iterator[str]]:
|
) -> Union[Iterator[K], Iterator[str]]:
|
||||||
"""Get an iterator over keys that match the given prefix."""
|
"""Get an iterator over keys that match the given prefix."""
|
||||||
# For the time being this does not return K, but str
|
# For the time being this does not return K, but str
|
||||||
@@ -119,7 +109,7 @@ class EncoderBackedStore(BaseStore[K, V]):
|
|||||||
async def ayield_keys(
|
async def ayield_keys(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
prefix: Optional[str] = None,
|
prefix: str | None = None,
|
||||||
) -> Union[AsyncIterator[K], AsyncIterator[str]]:
|
) -> Union[AsyncIterator[K], AsyncIterator[str]]:
|
||||||
"""Get an iterator over keys that match the given prefix."""
|
"""Get an iterator over keys that match the given prefix."""
|
||||||
# For the time being this does not return K, but str
|
# For the time being this does not return K, but str
|
||||||
|
@@ -5,7 +5,7 @@ build-backend = "pdm.backend"
|
|||||||
[project]
|
[project]
|
||||||
authors = []
|
authors = []
|
||||||
license = { text = "MIT" }
|
license = { text = "MIT" }
|
||||||
requires-python = ">=3.9, <4.0"
|
requires-python = ">=3.10"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"langchain-core<1.0.0,>=0.3.66",
|
"langchain-core<1.0.0,>=0.3.66",
|
||||||
"langchain-text-splitters<1.0.0,>=0.3.8",
|
"langchain-text-splitters<1.0.0,>=0.3.8",
|
||||||
@@ -46,23 +46,22 @@ repository = "https://github.com/langchain-ai/langchain"
|
|||||||
[dependency-groups]
|
[dependency-groups]
|
||||||
test = [
|
test = [
|
||||||
"pytest<9,>=8",
|
"pytest<9,>=8",
|
||||||
"pytest-cov<5.0.0,>=4.0.0",
|
"pytest-cov>=4.0.0",
|
||||||
"pytest-watcher<1.0.0,>=0.2.6",
|
"pytest-watcher>=0.2.6",
|
||||||
"pytest-asyncio<1.0.0,>=0.23.2",
|
"pytest-asyncio>=0.23.2",
|
||||||
"pytest-socket<1.0.0,>=0.6.0",
|
"pytest-socket>=0.6.0",
|
||||||
"syrupy<5.0.0,>=4.0.2",
|
"syrupy>=4.0.2",
|
||||||
"pytest-xdist<4.0.0,>=3.6.1",
|
"pytest-xdist>=3.6.1",
|
||||||
"blockbuster<1.6,>=1.5.18",
|
"blockbuster>=1.5.18",
|
||||||
"langchain-tests",
|
"langchain-tests",
|
||||||
"langchain-core",
|
|
||||||
"langchain-text-splitters",
|
"langchain-text-splitters",
|
||||||
"langchain-openai",
|
"langchain-openai",
|
||||||
"toml>=0.10.2",
|
"toml>=0.10.2",
|
||||||
]
|
]
|
||||||
codespell = ["codespell<3.0.0,>=2.2.0"]
|
codespell = ["codespell<3.0.0,>=2.2.0"]
|
||||||
lint = [
|
lint = [
|
||||||
"ruff<0.13,>=0.12.2",
|
"ruff>=0.12.2",
|
||||||
"mypy<1.16,>=1.15",
|
"mypy>=1.15",
|
||||||
]
|
]
|
||||||
typing = [
|
typing = [
|
||||||
"types-toml>=0.10.8.20240310",
|
"types-toml>=0.10.8.20240310",
|
||||||
@@ -70,11 +69,10 @@ typing = [
|
|||||||
|
|
||||||
test_integration = [
|
test_integration = [
|
||||||
"vcrpy>=7.0",
|
"vcrpy>=7.0",
|
||||||
"urllib3<2; python_version < \"3.10\"",
|
"wrapt>=1.15.0",
|
||||||
"wrapt<2.0.0,>=1.15.0",
|
"python-dotenv>=1.0.0",
|
||||||
"python-dotenv<2.0.0,>=1.0.0",
|
"cassio>=0.1.0",
|
||||||
"cassio<1.0.0,>=0.1.0",
|
"langchainhub>=0.1.16",
|
||||||
"langchainhub<1.0.0,>=0.1.16",
|
|
||||||
"langchain-core",
|
"langchain-core",
|
||||||
"langchain-text-splitters",
|
"langchain-text-splitters",
|
||||||
]
|
]
|
||||||
@@ -86,8 +84,9 @@ langchain-text-splitters = { path = "../text-splitters", editable = true }
|
|||||||
langchain-openai = { path = "../partners/openai", editable = true }
|
langchain-openai = { path = "../partners/openai", editable = true }
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
target-version = "py39"
|
target-version = "py310"
|
||||||
exclude = ["tests/integration_tests/examples/non-utf8-encoding.py"]
|
exclude = ["tests/integration_tests/examples/non-utf8-encoding.py"]
|
||||||
|
line-length = 100
|
||||||
|
|
||||||
[tool.mypy]
|
[tool.mypy]
|
||||||
strict = "True"
|
strict = "True"
|
||||||
|
@@ -31,9 +31,7 @@ async def test_init_chat_model_chain() -> None:
|
|||||||
chain = prompt | model_with_config
|
chain = prompt | model_with_config
|
||||||
output = chain.invoke({"input": "bar"})
|
output = chain.invoke({"input": "bar"})
|
||||||
assert isinstance(output, AIMessage)
|
assert isinstance(output, AIMessage)
|
||||||
events = [
|
events = [event async for event in chain.astream_events({"input": "bar"}, version="v2")]
|
||||||
event async for event in chain.astream_events({"input": "bar"}, version="v2")
|
|
||||||
]
|
|
||||||
assert events
|
assert events
|
||||||
|
|
||||||
|
|
||||||
|
@@ -1,5 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -38,7 +38,7 @@ def test_all_imports() -> None:
|
|||||||
("mixtral-8x7b-32768", "groq"),
|
("mixtral-8x7b-32768", "groq"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_init_chat_model(model_name: str, model_provider: Optional[str]) -> None:
|
def test_init_chat_model(model_name: str, model_provider: str | None) -> None:
|
||||||
llm1: BaseChatModel = init_chat_model(
|
llm1: BaseChatModel = init_chat_model(
|
||||||
model_name,
|
model_name,
|
||||||
model_provider=model_provider,
|
model_provider=model_provider,
|
||||||
@@ -222,7 +222,7 @@ def test_configurable_with_default() -> None:
|
|||||||
config={"configurable": {"my_model_model": "claude-3-sonnet-20240229"}}
|
config={"configurable": {"my_model_model": "claude-3-sonnet-20240229"}}
|
||||||
)
|
)
|
||||||
|
|
||||||
""" # noqa: E501
|
"""
|
||||||
model = init_chat_model("gpt-4o", configurable_fields="any", config_prefix="bar")
|
model = init_chat_model("gpt-4o", configurable_fields="any", config_prefix="bar")
|
||||||
for method in (
|
for method in (
|
||||||
"invoke",
|
"invoke",
|
||||||
|
@@ -53,9 +53,7 @@ def pytest_addoption(parser: pytest.Parser) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def pytest_collection_modifyitems(
|
def pytest_collection_modifyitems(config: pytest.Config, items: Sequence[pytest.Function]) -> None:
|
||||||
config: pytest.Config, items: Sequence[pytest.Function]
|
|
||||||
) -> None:
|
|
||||||
"""Add implementations for handling custom markers.
|
"""Add implementations for handling custom markers.
|
||||||
|
|
||||||
At the moment, this adds support for a custom `requires` marker.
|
At the moment, this adds support for a custom `requires` marker.
|
||||||
|
@@ -113,9 +113,7 @@ async def test_aembed_documents(cache_embeddings: CacheBackedEmbeddings) -> None
|
|||||||
vectors = await cache_embeddings.aembed_documents(texts)
|
vectors = await cache_embeddings.aembed_documents(texts)
|
||||||
expected_vectors: list[list[float]] = [[1, 2.0], [2.0, 3.0], [1.0, 2.0], [3.0, 4.0]]
|
expected_vectors: list[list[float]] = [[1, 2.0], [2.0, 3.0], [1.0, 2.0], [3.0, 4.0]]
|
||||||
assert vectors == expected_vectors
|
assert vectors == expected_vectors
|
||||||
keys = [
|
keys = [key async for key in cache_embeddings.document_embedding_store.ayield_keys()]
|
||||||
key async for key in cache_embeddings.document_embedding_store.ayield_keys()
|
|
||||||
]
|
|
||||||
assert len(keys) == 4
|
assert len(keys) == 4
|
||||||
# UUID is expected to be the same for the same text
|
# UUID is expected to be the same for the same text
|
||||||
assert keys[0] == "test_namespace812b86c1-8ebf-5483-95c6-c95cf2b52d12"
|
assert keys[0] == "test_namespace812b86c1-8ebf-5483-95c6-c95cf2b52d12"
|
||||||
@@ -128,10 +126,7 @@ async def test_aembed_documents_batch(
|
|||||||
texts = ["1", "22", "a", "333", "RAISE_EXCEPTION"]
|
texts = ["1", "22", "a", "333", "RAISE_EXCEPTION"]
|
||||||
with contextlib.suppress(ValueError):
|
with contextlib.suppress(ValueError):
|
||||||
await cache_embeddings_batch.aembed_documents(texts)
|
await cache_embeddings_batch.aembed_documents(texts)
|
||||||
keys = [
|
keys = [key async for key in cache_embeddings_batch.document_embedding_store.ayield_keys()]
|
||||||
key
|
|
||||||
async for key in cache_embeddings_batch.document_embedding_store.ayield_keys()
|
|
||||||
]
|
|
||||||
# only the first batch of three embeddings should exist
|
# only the first batch of three embeddings should exist
|
||||||
assert len(keys) == 3
|
assert len(keys) == 3
|
||||||
# UUID is expected to be the same for the same text
|
# UUID is expected to be the same for the same text
|
||||||
|
@@ -13,9 +13,7 @@ def test_import_all() -> None:
|
|||||||
library_code = PKG_ROOT / "langchain"
|
library_code = PKG_ROOT / "langchain"
|
||||||
for path in library_code.rglob("*.py"):
|
for path in library_code.rglob("*.py"):
|
||||||
# Calculate the relative path to the module
|
# Calculate the relative path to the module
|
||||||
module_name = (
|
module_name = path.relative_to(PKG_ROOT).with_suffix("").as_posix().replace("/", ".")
|
||||||
path.relative_to(PKG_ROOT).with_suffix("").as_posix().replace("/", ".")
|
|
||||||
)
|
|
||||||
if module_name.endswith("__init__"):
|
if module_name.endswith("__init__"):
|
||||||
# Without init
|
# Without init
|
||||||
module_name = module_name.rsplit(".", 1)[0]
|
module_name = module_name.rsplit(".", 1)[0]
|
||||||
@@ -39,9 +37,7 @@ def test_import_all_using_dir() -> None:
|
|||||||
library_code = PKG_ROOT / "langchain"
|
library_code = PKG_ROOT / "langchain"
|
||||||
for path in library_code.rglob("*.py"):
|
for path in library_code.rglob("*.py"):
|
||||||
# Calculate the relative path to the module
|
# Calculate the relative path to the module
|
||||||
module_name = (
|
module_name = path.relative_to(PKG_ROOT).with_suffix("").as_posix().replace("/", ".")
|
||||||
path.relative_to(PKG_ROOT).with_suffix("").as_posix().replace("/", ".")
|
|
||||||
)
|
|
||||||
if module_name.endswith("__init__"):
|
if module_name.endswith("__init__"):
|
||||||
# Without init
|
# Without init
|
||||||
module_name = module_name.rsplit(".", 1)[0]
|
module_name = module_name.rsplit(".", 1)[0]
|
||||||
|
2415
libs/langchain_v1/uv.lock
generated
2415
libs/langchain_v1/uv.lock
generated
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user