mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-25 08:03:39 +00:00
core: Add ruff rules PYI (#29335)
See https://docs.astral.sh/ruff/rules/#flake8-pyi-pyi --------- Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
parent
d8e3b7667f
commit
6650b94627
@ -3,7 +3,9 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||||
|
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
@ -879,9 +881,6 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T", bound="BaseCallbackManager")
|
|
||||||
|
|
||||||
|
|
||||||
class BaseCallbackManager(CallbackManagerMixin):
|
class BaseCallbackManager(CallbackManagerMixin):
|
||||||
"""Base callback manager for LangChain."""
|
"""Base callback manager for LangChain."""
|
||||||
|
|
||||||
@ -920,7 +919,7 @@ class BaseCallbackManager(CallbackManagerMixin):
|
|||||||
self.metadata = metadata or {}
|
self.metadata = metadata or {}
|
||||||
self.inheritable_metadata = inheritable_metadata or {}
|
self.inheritable_metadata = inheritable_metadata or {}
|
||||||
|
|
||||||
def copy(self: T) -> T:
|
def copy(self) -> Self:
|
||||||
"""Copy the callback manager."""
|
"""Copy the callback manager."""
|
||||||
return self.__class__(
|
return self.__class__(
|
||||||
handlers=self.handlers.copy(),
|
handlers=self.handlers.copy(),
|
||||||
@ -932,7 +931,7 @@ class BaseCallbackManager(CallbackManagerMixin):
|
|||||||
inheritable_metadata=self.inheritable_metadata.copy(),
|
inheritable_metadata=self.inheritable_metadata.copy(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def merge(self: T, other: BaseCallbackManager) -> T:
|
def merge(self, other: BaseCallbackManager) -> Self:
|
||||||
"""Merge the callback manager with another callback manager.
|
"""Merge the callback manager with another callback manager.
|
||||||
|
|
||||||
May be overwritten in subclasses. Primarily used internally
|
May be overwritten in subclasses. Primarily used internally
|
||||||
|
@ -22,6 +22,7 @@ from typing import (
|
|||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from langsmith.run_helpers import get_tracing_context
|
from langsmith.run_helpers import get_tracing_context
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
from langchain_core.callbacks.base import (
|
from langchain_core.callbacks.base import (
|
||||||
BaseCallbackHandler,
|
BaseCallbackHandler,
|
||||||
@ -444,9 +445,6 @@ async def ahandle_event(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
BRM = TypeVar("BRM", bound="BaseRunManager")
|
|
||||||
|
|
||||||
|
|
||||||
class BaseRunManager(RunManagerMixin):
|
class BaseRunManager(RunManagerMixin):
|
||||||
"""Base class for run manager (a bound callback manager)."""
|
"""Base class for run manager (a bound callback manager)."""
|
||||||
|
|
||||||
@ -489,7 +487,7 @@ class BaseRunManager(RunManagerMixin):
|
|||||||
self.inheritable_metadata = inheritable_metadata or {}
|
self.inheritable_metadata = inheritable_metadata or {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_noop_manager(cls: type[BRM]) -> BRM:
|
def get_noop_manager(cls) -> Self:
|
||||||
"""Return a manager that doesn't perform any operations.
|
"""Return a manager that doesn't perform any operations.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -1258,7 +1258,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
Union[typing.Dict[str, Any], type, Callable, BaseTool] # noqa: UP006
|
Union[typing.Dict[str, Any], type, Callable, BaseTool] # noqa: UP006
|
||||||
],
|
],
|
||||||
*,
|
*,
|
||||||
tool_choice: Optional[Union[str, Literal["any"]]] = None,
|
tool_choice: Optional[Union[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||||
"""Bind tools to the model.
|
"""Bind tools to the model.
|
||||||
|
@ -22,7 +22,7 @@ from pydantic import (
|
|||||||
SkipValidation,
|
SkipValidation,
|
||||||
model_validator,
|
model_validator,
|
||||||
)
|
)
|
||||||
from typing_extensions import override
|
from typing_extensions import Self, override
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.load import Serializable
|
from langchain_core.load import Serializable
|
||||||
@ -304,12 +304,12 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_template(
|
def from_template(
|
||||||
cls: type[MessagePromptTemplateT],
|
cls,
|
||||||
template: str,
|
template: str,
|
||||||
template_format: PromptTemplateFormat = "f-string",
|
template_format: PromptTemplateFormat = "f-string",
|
||||||
partial_variables: Optional[dict[str, Any]] = None,
|
partial_variables: Optional[dict[str, Any]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> MessagePromptTemplateT:
|
) -> Self:
|
||||||
"""Create a class from a string template.
|
"""Create a class from a string template.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -335,11 +335,11 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_template_file(
|
def from_template_file(
|
||||||
cls: type[MessagePromptTemplateT],
|
cls,
|
||||||
template_file: Union[str, Path],
|
template_file: Union[str, Path],
|
||||||
input_variables: list[str],
|
input_variables: list[str],
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> MessagePromptTemplateT:
|
) -> Self:
|
||||||
"""Create a class from a template file.
|
"""Create a class from a template file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -456,11 +456,6 @@ class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_StringImageMessagePromptTemplateT = TypeVar(
|
|
||||||
"_StringImageMessagePromptTemplateT", bound="_StringImageMessagePromptTemplate"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class _TextTemplateParam(TypedDict, total=False):
|
class _TextTemplateParam(TypedDict, total=False):
|
||||||
text: Union[str, dict]
|
text: Union[str, dict]
|
||||||
|
|
||||||
@ -483,13 +478,13 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_template(
|
def from_template(
|
||||||
cls: type[_StringImageMessagePromptTemplateT],
|
cls: type[Self],
|
||||||
template: Union[str, list[Union[str, _TextTemplateParam, _ImageTemplateParam]]],
|
template: Union[str, list[Union[str, _TextTemplateParam, _ImageTemplateParam]]],
|
||||||
template_format: PromptTemplateFormat = "f-string",
|
template_format: PromptTemplateFormat = "f-string",
|
||||||
*,
|
*,
|
||||||
partial_variables: Optional[dict[str, Any]] = None,
|
partial_variables: Optional[dict[str, Any]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> _StringImageMessagePromptTemplateT:
|
) -> Self:
|
||||||
"""Create a class from a string template.
|
"""Create a class from a string template.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -576,11 +571,11 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_template_file(
|
def from_template_file(
|
||||||
cls: type[_StringImageMessagePromptTemplateT],
|
cls: type[Self],
|
||||||
template_file: Union[str, Path],
|
template_file: Union[str, Path],
|
||||||
input_variables: list[str],
|
input_variables: list[str],
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> _StringImageMessagePromptTemplateT:
|
) -> Self:
|
||||||
"""Create a class from a template file.
|
"""Create a class from a template file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -4200,7 +4200,7 @@ class RunnableGenerator(Runnable[Input, Output]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def __eq__(self, other: Any) -> bool:
|
def __eq__(self, other: object) -> bool:
|
||||||
if isinstance(other, RunnableGenerator):
|
if isinstance(other, RunnableGenerator):
|
||||||
if hasattr(self, "_transform") and hasattr(other, "_transform"):
|
if hasattr(self, "_transform") and hasattr(other, "_transform"):
|
||||||
return self._transform == other._transform
|
return self._transform == other._transform
|
||||||
@ -4582,7 +4582,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
return graph
|
return graph
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def __eq__(self, other: Any) -> bool:
|
def __eq__(self, other: object) -> bool:
|
||||||
if isinstance(other, RunnableLambda):
|
if isinstance(other, RunnableLambda):
|
||||||
if hasattr(self, "func") and hasattr(other, "func"):
|
if hasattr(self, "func") and hasattr(other, "func"):
|
||||||
return self.func == other.func
|
return self.func == other.func
|
||||||
@ -5880,22 +5880,24 @@ class RunnableBinding(RunnableBindingBase[Input, Output]):
|
|||||||
|
|
||||||
|
|
||||||
class _RunnableCallableSync(Protocol[Input, Output]):
|
class _RunnableCallableSync(Protocol[Input, Output]):
|
||||||
def __call__(self, __in: Input, *, config: RunnableConfig) -> Output: ...
|
def __call__(self, _in: Input, /, *, config: RunnableConfig) -> Output: ...
|
||||||
|
|
||||||
|
|
||||||
class _RunnableCallableAsync(Protocol[Input, Output]):
|
class _RunnableCallableAsync(Protocol[Input, Output]):
|
||||||
def __call__(self, __in: Input, *, config: RunnableConfig) -> Awaitable[Output]: ...
|
def __call__(
|
||||||
|
self, _in: Input, /, *, config: RunnableConfig
|
||||||
|
) -> Awaitable[Output]: ...
|
||||||
|
|
||||||
|
|
||||||
class _RunnableCallableIterator(Protocol[Input, Output]):
|
class _RunnableCallableIterator(Protocol[Input, Output]):
|
||||||
def __call__(
|
def __call__(
|
||||||
self, __in: Iterator[Input], *, config: RunnableConfig
|
self, _in: Iterator[Input], /, *, config: RunnableConfig
|
||||||
) -> Iterator[Output]: ...
|
) -> Iterator[Output]: ...
|
||||||
|
|
||||||
|
|
||||||
class _RunnableCallableAsyncIterator(Protocol[Input, Output]):
|
class _RunnableCallableAsyncIterator(Protocol[Input, Output]):
|
||||||
def __call__(
|
def __call__(
|
||||||
self, __in: AsyncIterator[Input], *, config: RunnableConfig
|
self, _in: AsyncIterator[Input], /, *, config: RunnableConfig
|
||||||
) -> AsyncIterator[Output]: ...
|
) -> AsyncIterator[Output]: ...
|
||||||
|
|
||||||
|
|
||||||
|
@ -515,7 +515,7 @@ _T_contra = TypeVar("_T_contra", contravariant=True)
|
|||||||
class SupportsAdd(Protocol[_T_contra, _T_co]):
|
class SupportsAdd(Protocol[_T_contra, _T_co]):
|
||||||
"""Protocol for objects that support addition."""
|
"""Protocol for objects that support addition."""
|
||||||
|
|
||||||
def __add__(self, __x: _T_contra) -> _T_co:
|
def __add__(self, x: _T_contra, /) -> _T_co:
|
||||||
"""Add the object to another object."""
|
"""Add the object to another object."""
|
||||||
|
|
||||||
|
|
||||||
|
@ -88,7 +88,12 @@ class NoLock:
|
|||||||
async def __aenter__(self) -> None:
|
async def __aenter__(self) -> None:
|
||||||
"""Do nothing."""
|
"""Do nothing."""
|
||||||
|
|
||||||
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
|
async def __aexit__(
|
||||||
|
self,
|
||||||
|
exc_type: Optional[type[BaseException]],
|
||||||
|
exc_val: Optional[BaseException],
|
||||||
|
exc_tb: Optional[TracebackType],
|
||||||
|
) -> bool:
|
||||||
"""Exception not handled."""
|
"""Exception not handled."""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -237,7 +242,12 @@ class Tee(Generic[T]):
|
|||||||
"""Return the tee instance."""
|
"""Return the tee instance."""
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
|
async def __aexit__(
|
||||||
|
self,
|
||||||
|
exc_type: Optional[type[BaseException]],
|
||||||
|
exc_val: Optional[BaseException],
|
||||||
|
exc_tb: Optional[TracebackType],
|
||||||
|
) -> bool:
|
||||||
"""Close all child iterators."""
|
"""Close all child iterators."""
|
||||||
await self.aclose()
|
await self.aclose()
|
||||||
return False
|
return False
|
||||||
|
@ -4,6 +4,7 @@ from collections import deque
|
|||||||
from collections.abc import Generator, Iterable, Iterator
|
from collections.abc import Generator, Iterable, Iterator
|
||||||
from contextlib import AbstractContextManager
|
from contextlib import AbstractContextManager
|
||||||
from itertools import islice
|
from itertools import islice
|
||||||
|
from types import TracebackType
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Generic,
|
Generic,
|
||||||
@ -24,7 +25,12 @@ class NoLock:
|
|||||||
def __enter__(self) -> None:
|
def __enter__(self) -> None:
|
||||||
"""Do nothing."""
|
"""Do nothing."""
|
||||||
|
|
||||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Literal[False]:
|
def __exit__(
|
||||||
|
self,
|
||||||
|
exc_type: Optional[type[BaseException]],
|
||||||
|
exc_val: Optional[BaseException],
|
||||||
|
exc_tb: Optional[TracebackType],
|
||||||
|
) -> Literal[False]:
|
||||||
"""Exception not handled."""
|
"""Exception not handled."""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -173,7 +179,12 @@ class Tee(Generic[T]):
|
|||||||
"""Return Tee instance."""
|
"""Return Tee instance."""
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Literal[False]:
|
def __exit__(
|
||||||
|
self,
|
||||||
|
exc_type: Optional[type[BaseException]],
|
||||||
|
exc_val: Optional[BaseException],
|
||||||
|
exc_tb: Optional[TracebackType],
|
||||||
|
) -> Literal[False]:
|
||||||
"""Close all child iterators."""
|
"""Close all child iterators."""
|
||||||
self.close()
|
self.close()
|
||||||
return False
|
return False
|
||||||
|
@ -377,12 +377,7 @@ if IS_PYDANTIC_V2:
|
|||||||
def get_fields(model: BaseModelV1) -> dict[str, FieldInfoV1]: ...
|
def get_fields(model: BaseModelV1) -> dict[str, FieldInfoV1]: ...
|
||||||
|
|
||||||
def get_fields(
|
def get_fields(
|
||||||
model: Union[
|
model: Union[type[Union[BaseModelV2, BaseModelV1]], BaseModelV2, BaseModelV1],
|
||||||
BaseModelV2,
|
|
||||||
BaseModelV1,
|
|
||||||
type[BaseModelV2],
|
|
||||||
type[BaseModelV1],
|
|
||||||
],
|
|
||||||
) -> Union[dict[str, FieldInfoV2], dict[str, FieldInfoV1]]:
|
) -> Union[dict[str, FieldInfoV2], dict[str, FieldInfoV1]]:
|
||||||
"""Get the field names of a Pydantic model."""
|
"""Get the field names of a Pydantic model."""
|
||||||
if hasattr(model, "model_fields"):
|
if hasattr(model, "model_fields"):
|
||||||
@ -491,19 +486,21 @@ def _create_root_model_cached(
|
|||||||
|
|
||||||
@lru_cache(maxsize=256)
|
@lru_cache(maxsize=256)
|
||||||
def _create_model_cached(
|
def _create_model_cached(
|
||||||
__model_name: str,
|
model_name: str,
|
||||||
|
/,
|
||||||
**field_definitions: Any,
|
**field_definitions: Any,
|
||||||
) -> type[BaseModel]:
|
) -> type[BaseModel]:
|
||||||
return _create_model_base(
|
return _create_model_base(
|
||||||
__model_name,
|
model_name,
|
||||||
__config__=_SchemaConfig,
|
__config__=_SchemaConfig,
|
||||||
**_remap_field_definitions(field_definitions),
|
**_remap_field_definitions(field_definitions),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_model(
|
def create_model(
|
||||||
__model_name: str,
|
model_name: str,
|
||||||
__module_name: Optional[str] = None,
|
module_name: Optional[str] = None,
|
||||||
|
/,
|
||||||
**field_definitions: Any,
|
**field_definitions: Any,
|
||||||
) -> type[BaseModel]:
|
) -> type[BaseModel]:
|
||||||
"""Create a pydantic model with the given field definitions.
|
"""Create a pydantic model with the given field definitions.
|
||||||
@ -511,8 +508,8 @@ def create_model(
|
|||||||
Please use create_model_v2 instead of this function.
|
Please use create_model_v2 instead of this function.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
__model_name: The name of the model.
|
model_name: The name of the model.
|
||||||
__module_name: The name of the module where the model is defined.
|
module_name: The name of the module where the model is defined.
|
||||||
This is used by Pydantic to resolve any forward references.
|
This is used by Pydantic to resolve any forward references.
|
||||||
**field_definitions: The field definitions for the model.
|
**field_definitions: The field definitions for the model.
|
||||||
|
|
||||||
@ -524,8 +521,8 @@ def create_model(
|
|||||||
kwargs["root"] = field_definitions.pop("__root__")
|
kwargs["root"] = field_definitions.pop("__root__")
|
||||||
|
|
||||||
return create_model_v2(
|
return create_model_v2(
|
||||||
__model_name,
|
model_name,
|
||||||
module_name=__module_name,
|
module_name=module_name,
|
||||||
field_definitions=field_definitions,
|
field_definitions=field_definitions,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
@ -36,6 +36,7 @@ from typing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from pydantic import ConfigDict, Field, model_validator
|
from pydantic import ConfigDict, Field, model_validator
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.retrievers import BaseRetriever, LangSmithRetrieverParams
|
from langchain_core.retrievers import BaseRetriever, LangSmithRetrieverParams
|
||||||
@ -818,11 +819,11 @@ class VectorStore(ABC):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_documents(
|
def from_documents(
|
||||||
cls: type[VST],
|
cls,
|
||||||
documents: list[Document],
|
documents: list[Document],
|
||||||
embedding: Embeddings,
|
embedding: Embeddings,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> VST:
|
) -> Self:
|
||||||
"""Return VectorStore initialized from documents and embeddings.
|
"""Return VectorStore initialized from documents and embeddings.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -848,11 +849,11 @@ class VectorStore(ABC):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def afrom_documents(
|
async def afrom_documents(
|
||||||
cls: type[VST],
|
cls,
|
||||||
documents: list[Document],
|
documents: list[Document],
|
||||||
embedding: Embeddings,
|
embedding: Embeddings,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> VST:
|
) -> Self:
|
||||||
"""Async return VectorStore initialized from documents and embeddings.
|
"""Async return VectorStore initialized from documents and embeddings.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -903,14 +904,14 @@ class VectorStore(ABC):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def afrom_texts(
|
async def afrom_texts(
|
||||||
cls: type[VST],
|
cls,
|
||||||
texts: list[str],
|
texts: list[str],
|
||||||
embedding: Embeddings,
|
embedding: Embeddings,
|
||||||
metadatas: Optional[list[dict]] = None,
|
metadatas: Optional[list[dict]] = None,
|
||||||
*,
|
*,
|
||||||
ids: Optional[list[str]] = None,
|
ids: Optional[list[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> VST:
|
) -> Self:
|
||||||
"""Async return VectorStore initialized from texts and embeddings.
|
"""Async return VectorStore initialized from texts and embeddings.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -103,7 +103,6 @@ ignore = [
|
|||||||
"FBT002",
|
"FBT002",
|
||||||
"PGH003",
|
"PGH003",
|
||||||
"PLR",
|
"PLR",
|
||||||
"PYI",
|
|
||||||
"RUF",
|
"RUF",
|
||||||
"SLF",
|
"SLF",
|
||||||
]
|
]
|
||||||
|
@ -5294,22 +5294,22 @@ async def test_ainvoke_on_returned_runnable() -> None:
|
|||||||
be runthroughaasync path (issue #13407).
|
be runthroughaasync path (issue #13407).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def idchain_sync(__input: dict) -> bool:
|
def idchain_sync(_input: dict, /) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def idchain_async(__input: dict) -> bool:
|
async def idchain_async(_input: dict, /) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
idchain = RunnableLambda(func=idchain_sync, afunc=idchain_async)
|
idchain = RunnableLambda(func=idchain_sync, afunc=idchain_async)
|
||||||
|
|
||||||
def func(__input: dict) -> Runnable:
|
def func(_input: dict, /) -> Runnable:
|
||||||
return idchain
|
return idchain
|
||||||
|
|
||||||
assert await RunnableLambda(func).ainvoke({})
|
assert await RunnableLambda(func).ainvoke({})
|
||||||
|
|
||||||
|
|
||||||
def test_invoke_stream_passthrough_assign_trace() -> None:
|
def test_invoke_stream_passthrough_assign_trace() -> None:
|
||||||
def idchain_sync(__input: dict) -> bool:
|
def idchain_sync(_input: dict, /) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
chain = RunnablePassthrough.assign(urls=idchain_sync)
|
chain = RunnablePassthrough.assign(urls=idchain_sync)
|
||||||
@ -5329,7 +5329,7 @@ def test_invoke_stream_passthrough_assign_trace() -> None:
|
|||||||
|
|
||||||
|
|
||||||
async def test_ainvoke_astream_passthrough_assign_trace() -> None:
|
async def test_ainvoke_astream_passthrough_assign_trace() -> None:
|
||||||
def idchain_sync(__input: dict) -> bool:
|
def idchain_sync(_input: dict, /) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
chain = RunnablePassthrough.assign(urls=idchain_sync)
|
chain = RunnablePassthrough.assign(urls=idchain_sync)
|
||||||
|
@ -7,7 +7,7 @@ from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage
|
|||||||
class AnyStr(str):
|
class AnyStr(str):
|
||||||
__slots__ = ()
|
__slots__ = ()
|
||||||
|
|
||||||
def __eq__(self, other: Any) -> bool:
|
def __eq__(self, other: object) -> bool:
|
||||||
return isinstance(other, str)
|
return isinstance(other, str)
|
||||||
|
|
||||||
|
|
||||||
|
@ -2369,7 +2369,7 @@ def test_tool_return_output_mixin() -> None:
|
|||||||
def __init__(self, x: int) -> None:
|
def __init__(self, x: int) -> None:
|
||||||
self.x = x
|
self.x = x
|
||||||
|
|
||||||
def __eq__(self, other: Any) -> bool:
|
def __eq__(self, other: object) -> bool:
|
||||||
return isinstance(other, self.__class__) and self.x == other.x
|
return isinstance(other, self.__class__) and self.x == other.x
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
|
@ -994,12 +994,12 @@ def test__convert_typed_dict_to_openai_function_fail(typed_dict: type) -> None:
|
|||||||
)
|
)
|
||||||
def test_convert_union_type_py_39() -> None:
|
def test_convert_union_type_py_39() -> None:
|
||||||
@tool
|
@tool
|
||||||
def magic_function(input: int | float) -> str: # noqa: FA102
|
def magic_function(input: int | str) -> str: # noqa: FA102
|
||||||
"""Compute a magic function."""
|
"""Compute a magic function."""
|
||||||
|
|
||||||
result = convert_to_openai_function(magic_function)
|
result = convert_to_openai_function(magic_function)
|
||||||
assert result["parameters"]["properties"]["input"] == {
|
assert result["parameters"]["properties"]["input"] == {
|
||||||
"anyOf": [{"type": "integer"}, {"type": "number"}]
|
"anyOf": [{"type": "integer"}, {"type": "string"}]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user