mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-03 18:24:10 +00:00
core: Add ruff rules ARG (#30732)
See https://docs.astral.sh/ruff/rules/#flake8-unused-arguments-arg
This commit is contained in:
parent
66758599a9
commit
98f0016fc2
@ -124,7 +124,7 @@ def beta(
|
|||||||
_name = _name or obj.__qualname__
|
_name = _name or obj.__qualname__
|
||||||
old_doc = obj.__doc__
|
old_doc = obj.__doc__
|
||||||
|
|
||||||
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T:
|
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: # noqa: ARG001
|
||||||
"""Finalize the annotation of a class."""
|
"""Finalize the annotation of a class."""
|
||||||
# Can't set new_doc on some extension objects.
|
# Can't set new_doc on some extension objects.
|
||||||
with contextlib.suppress(AttributeError):
|
with contextlib.suppress(AttributeError):
|
||||||
@ -190,7 +190,7 @@ def beta(
|
|||||||
if _name == "<lambda>":
|
if _name == "<lambda>":
|
||||||
_name = set_name
|
_name = set_name
|
||||||
|
|
||||||
def finalize(wrapper: Callable[..., Any], new_doc: str) -> Any:
|
def finalize(wrapper: Callable[..., Any], new_doc: str) -> Any: # noqa: ARG001
|
||||||
"""Finalize the property."""
|
"""Finalize the property."""
|
||||||
return _BetaProperty(
|
return _BetaProperty(
|
||||||
fget=obj.fget, fset=obj.fset, fdel=obj.fdel, doc=new_doc
|
fget=obj.fget, fset=obj.fset, fdel=obj.fdel, doc=new_doc
|
||||||
|
@ -204,7 +204,7 @@ def deprecated(
|
|||||||
_name = _name or obj.__qualname__
|
_name = _name or obj.__qualname__
|
||||||
old_doc = obj.__doc__
|
old_doc = obj.__doc__
|
||||||
|
|
||||||
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T:
|
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: # noqa: ARG001
|
||||||
"""Finalize the deprecation of a class."""
|
"""Finalize the deprecation of a class."""
|
||||||
# Can't set new_doc on some extension objects.
|
# Can't set new_doc on some extension objects.
|
||||||
with contextlib.suppress(AttributeError):
|
with contextlib.suppress(AttributeError):
|
||||||
@ -234,7 +234,7 @@ def deprecated(
|
|||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
old_doc = obj.description
|
old_doc = obj.description
|
||||||
|
|
||||||
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T:
|
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: # noqa: ARG001
|
||||||
return cast(
|
return cast(
|
||||||
"T",
|
"T",
|
||||||
FieldInfoV1(
|
FieldInfoV1(
|
||||||
@ -255,7 +255,7 @@ def deprecated(
|
|||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
old_doc = obj.description
|
old_doc = obj.description
|
||||||
|
|
||||||
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T:
|
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: # noqa: ARG001
|
||||||
return cast(
|
return cast(
|
||||||
"T",
|
"T",
|
||||||
FieldInfoV2(
|
FieldInfoV2(
|
||||||
@ -315,7 +315,7 @@ def deprecated(
|
|||||||
if _name == "<lambda>":
|
if _name == "<lambda>":
|
||||||
_name = set_name
|
_name = set_name
|
||||||
|
|
||||||
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T:
|
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: # noqa: ARG001
|
||||||
"""Finalize the property."""
|
"""Finalize the property."""
|
||||||
return cast(
|
return cast(
|
||||||
"T",
|
"T",
|
||||||
|
@ -27,6 +27,8 @@ from abc import ABC, abstractmethod
|
|||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_core.outputs import Generation
|
from langchain_core.outputs import Generation
|
||||||
from langchain_core.runnables import run_in_executor
|
from langchain_core.runnables import run_in_executor
|
||||||
|
|
||||||
@ -194,6 +196,7 @@ class InMemoryCache(BaseCache):
|
|||||||
del self._cache[next(iter(self._cache))]
|
del self._cache[next(iter(self._cache))]
|
||||||
self._cache[(prompt, llm_string)] = return_val
|
self._cache[(prompt, llm_string)] = return_val
|
||||||
|
|
||||||
|
@override
|
||||||
def clear(self, **kwargs: Any) -> None:
|
def clear(self, **kwargs: Any) -> None:
|
||||||
"""Clear cache."""
|
"""Clear cache."""
|
||||||
self._cache = {}
|
self._cache = {}
|
||||||
@ -227,6 +230,7 @@ class InMemoryCache(BaseCache):
|
|||||||
"""
|
"""
|
||||||
self.update(prompt, llm_string, return_val)
|
self.update(prompt, llm_string, return_val)
|
||||||
|
|
||||||
|
@override
|
||||||
async def aclear(self, **kwargs: Any) -> None:
|
async def aclear(self, **kwargs: Any) -> None:
|
||||||
"""Async clear cache."""
|
"""Async clear cache."""
|
||||||
self.clear()
|
self.clear()
|
||||||
|
@ -5,6 +5,8 @@ from __future__ import annotations
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, Optional, TextIO, cast
|
from typing import TYPE_CHECKING, Any, Optional, TextIO, cast
|
||||||
|
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_core.callbacks import BaseCallbackHandler
|
from langchain_core.callbacks import BaseCallbackHandler
|
||||||
from langchain_core.utils.input import print_text
|
from langchain_core.utils.input import print_text
|
||||||
|
|
||||||
@ -38,6 +40,7 @@ class FileCallbackHandler(BaseCallbackHandler):
|
|||||||
"""Destructor to cleanup when done."""
|
"""Destructor to cleanup when done."""
|
||||||
self.file.close()
|
self.file.close()
|
||||||
|
|
||||||
|
@override
|
||||||
def on_chain_start(
|
def on_chain_start(
|
||||||
self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any
|
self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -61,6 +64,7 @@ class FileCallbackHandler(BaseCallbackHandler):
|
|||||||
file=self.file,
|
file=self.file,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@override
|
||||||
def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None:
|
def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None:
|
||||||
"""Print out that we finished a chain.
|
"""Print out that we finished a chain.
|
||||||
|
|
||||||
@ -70,6 +74,7 @@ class FileCallbackHandler(BaseCallbackHandler):
|
|||||||
"""
|
"""
|
||||||
print_text("\n\033[1m> Finished chain.\033[0m", end="\n", file=self.file)
|
print_text("\n\033[1m> Finished chain.\033[0m", end="\n", file=self.file)
|
||||||
|
|
||||||
|
@override
|
||||||
def on_agent_action(
|
def on_agent_action(
|
||||||
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
|
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
|
||||||
) -> Any:
|
) -> Any:
|
||||||
@ -83,6 +88,7 @@ class FileCallbackHandler(BaseCallbackHandler):
|
|||||||
"""
|
"""
|
||||||
print_text(action.log, color=color or self.color, file=self.file)
|
print_text(action.log, color=color or self.color, file=self.file)
|
||||||
|
|
||||||
|
@override
|
||||||
def on_tool_end(
|
def on_tool_end(
|
||||||
self,
|
self,
|
||||||
output: str,
|
output: str,
|
||||||
@ -109,6 +115,7 @@ class FileCallbackHandler(BaseCallbackHandler):
|
|||||||
if llm_prefix is not None:
|
if llm_prefix is not None:
|
||||||
print_text(f"\n{llm_prefix}", file=self.file)
|
print_text(f"\n{llm_prefix}", file=self.file)
|
||||||
|
|
||||||
|
@override
|
||||||
def on_text(
|
def on_text(
|
||||||
self, text: str, color: Optional[str] = None, end: str = "", **kwargs: Any
|
self, text: str, color: Optional[str] = None, end: str = "", **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -123,6 +130,7 @@ class FileCallbackHandler(BaseCallbackHandler):
|
|||||||
"""
|
"""
|
||||||
print_text(text, color=color or self.color, end=end, file=self.file)
|
print_text(text, color=color or self.color, end=end, file=self.file)
|
||||||
|
|
||||||
|
@override
|
||||||
def on_agent_finish(
|
def on_agent_finish(
|
||||||
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
|
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -22,7 +22,7 @@ from typing import (
|
|||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from langsmith.run_helpers import get_tracing_context
|
from langsmith.run_helpers import get_tracing_context
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self, override
|
||||||
|
|
||||||
from langchain_core.callbacks.base import (
|
from langchain_core.callbacks.base import (
|
||||||
BaseCallbackHandler,
|
BaseCallbackHandler,
|
||||||
@ -1401,6 +1401,7 @@ class CallbackManager(BaseCallbackManager):
|
|||||||
inheritable_metadata=self.inheritable_metadata,
|
inheritable_metadata=self.inheritable_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@override
|
||||||
def on_tool_start(
|
def on_tool_start(
|
||||||
self,
|
self,
|
||||||
serialized: Optional[dict[str, Any]],
|
serialized: Optional[dict[str, Any]],
|
||||||
@ -1456,6 +1457,7 @@ class CallbackManager(BaseCallbackManager):
|
|||||||
inheritable_metadata=self.inheritable_metadata,
|
inheritable_metadata=self.inheritable_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@override
|
||||||
def on_retriever_start(
|
def on_retriever_start(
|
||||||
self,
|
self,
|
||||||
serialized: Optional[dict[str, Any]],
|
serialized: Optional[dict[str, Any]],
|
||||||
@ -1927,6 +1929,7 @@ class AsyncCallbackManager(BaseCallbackManager):
|
|||||||
inheritable_metadata=self.inheritable_metadata,
|
inheritable_metadata=self.inheritable_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@override
|
||||||
async def on_tool_start(
|
async def on_tool_start(
|
||||||
self,
|
self,
|
||||||
serialized: Optional[dict[str, Any]],
|
serialized: Optional[dict[str, Any]],
|
||||||
@ -2017,6 +2020,7 @@ class AsyncCallbackManager(BaseCallbackManager):
|
|||||||
metadata=self.metadata,
|
metadata=self.metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@override
|
||||||
async def on_retriever_start(
|
async def on_retriever_start(
|
||||||
self,
|
self,
|
||||||
serialized: Optional[dict[str, Any]],
|
serialized: Optional[dict[str, Any]],
|
||||||
|
@ -4,6 +4,8 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_core.callbacks.base import BaseCallbackHandler
|
from langchain_core.callbacks.base import BaseCallbackHandler
|
||||||
from langchain_core.utils import print_text
|
from langchain_core.utils import print_text
|
||||||
|
|
||||||
@ -22,6 +24,7 @@ class StdOutCallbackHandler(BaseCallbackHandler):
|
|||||||
"""
|
"""
|
||||||
self.color = color
|
self.color = color
|
||||||
|
|
||||||
|
@override
|
||||||
def on_chain_start(
|
def on_chain_start(
|
||||||
self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any
|
self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -41,6 +44,7 @@ class StdOutCallbackHandler(BaseCallbackHandler):
|
|||||||
name = "<unknown>"
|
name = "<unknown>"
|
||||||
print(f"\n\n\033[1m> Entering new {name} chain...\033[0m") # noqa: T201
|
print(f"\n\n\033[1m> Entering new {name} chain...\033[0m") # noqa: T201
|
||||||
|
|
||||||
|
@override
|
||||||
def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None:
|
def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None:
|
||||||
"""Print out that we finished a chain.
|
"""Print out that we finished a chain.
|
||||||
|
|
||||||
@ -50,6 +54,7 @@ class StdOutCallbackHandler(BaseCallbackHandler):
|
|||||||
"""
|
"""
|
||||||
print("\n\033[1m> Finished chain.\033[0m") # noqa: T201
|
print("\n\033[1m> Finished chain.\033[0m") # noqa: T201
|
||||||
|
|
||||||
|
@override
|
||||||
def on_agent_action(
|
def on_agent_action(
|
||||||
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
|
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
|
||||||
) -> Any:
|
) -> Any:
|
||||||
@ -62,6 +67,7 @@ class StdOutCallbackHandler(BaseCallbackHandler):
|
|||||||
"""
|
"""
|
||||||
print_text(action.log, color=color or self.color)
|
print_text(action.log, color=color or self.color)
|
||||||
|
|
||||||
|
@override
|
||||||
def on_tool_end(
|
def on_tool_end(
|
||||||
self,
|
self,
|
||||||
output: Any,
|
output: Any,
|
||||||
@ -87,6 +93,7 @@ class StdOutCallbackHandler(BaseCallbackHandler):
|
|||||||
if llm_prefix is not None:
|
if llm_prefix is not None:
|
||||||
print_text(f"\n{llm_prefix}")
|
print_text(f"\n{llm_prefix}")
|
||||||
|
|
||||||
|
@override
|
||||||
def on_text(
|
def on_text(
|
||||||
self,
|
self,
|
||||||
text: str,
|
text: str,
|
||||||
@ -104,6 +111,7 @@ class StdOutCallbackHandler(BaseCallbackHandler):
|
|||||||
"""
|
"""
|
||||||
print_text(text, color=color or self.color, end=end)
|
print_text(text, color=color or self.color, end=end)
|
||||||
|
|
||||||
|
@override
|
||||||
def on_agent_finish(
|
def on_agent_finish(
|
||||||
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
|
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -5,6 +5,8 @@ from __future__ import annotations
|
|||||||
import sys
|
import sys
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_core.callbacks.base import BaseCallbackHandler
|
from langchain_core.callbacks.base import BaseCallbackHandler
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -41,6 +43,7 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
|
|||||||
**kwargs (Any): Additional keyword arguments.
|
**kwargs (Any): Additional keyword arguments.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@override
|
||||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||||
"""Run on new LLM token. Only available when streaming is enabled.
|
"""Run on new LLM token. Only available when streaming is enabled.
|
||||||
|
|
||||||
|
@ -58,6 +58,7 @@ class UsageMetadataCallbackHandler(BaseCallbackHandler):
|
|||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return str(self.usage_metadata)
|
return str(self.usage_metadata)
|
||||||
|
|
||||||
|
@override
|
||||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||||
"""Collect token usage."""
|
"""Collect token usage."""
|
||||||
# Check for usage_metadata (langchain-core >= 0.2.2)
|
# Check for usage_metadata (langchain-core >= 0.2.2)
|
||||||
|
@ -151,7 +151,7 @@ def _get_source_id_assigner(
|
|||||||
) -> Callable[[Document], Union[str, None]]:
|
) -> Callable[[Document], Union[str, None]]:
|
||||||
"""Get the source id from the document."""
|
"""Get the source id from the document."""
|
||||||
if source_id_key is None:
|
if source_id_key is None:
|
||||||
return lambda doc: None
|
return lambda _doc: None
|
||||||
if isinstance(source_id_key, str):
|
if isinstance(source_id_key, str):
|
||||||
return lambda doc: doc.metadata[source_id_key]
|
return lambda doc: doc.metadata[source_id_key]
|
||||||
if callable(source_id_key):
|
if callable(source_id_key):
|
||||||
|
@ -6,6 +6,7 @@ from collections.abc import Sequence
|
|||||||
from typing import Any, Optional, cast
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_core._api import beta
|
from langchain_core._api import beta
|
||||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||||
@ -29,6 +30,7 @@ class InMemoryDocumentIndex(DocumentIndex):
|
|||||||
store: dict[str, Document] = Field(default_factory=dict)
|
store: dict[str, Document] = Field(default_factory=dict)
|
||||||
top_k: int = 4
|
top_k: int = 4
|
||||||
|
|
||||||
|
@override
|
||||||
def upsert(self, items: Sequence[Document], /, **kwargs: Any) -> UpsertResponse:
|
def upsert(self, items: Sequence[Document], /, **kwargs: Any) -> UpsertResponse:
|
||||||
"""Upsert items into the index."""
|
"""Upsert items into the index."""
|
||||||
ok_ids = []
|
ok_ids = []
|
||||||
@ -47,6 +49,7 @@ class InMemoryDocumentIndex(DocumentIndex):
|
|||||||
|
|
||||||
return UpsertResponse(succeeded=ok_ids, failed=[])
|
return UpsertResponse(succeeded=ok_ids, failed=[])
|
||||||
|
|
||||||
|
@override
|
||||||
def delete(self, ids: Optional[list[str]] = None, **kwargs: Any) -> DeleteResponse:
|
def delete(self, ids: Optional[list[str]] = None, **kwargs: Any) -> DeleteResponse:
|
||||||
"""Delete by ID."""
|
"""Delete by ID."""
|
||||||
if ids is None:
|
if ids is None:
|
||||||
@ -64,10 +67,12 @@ class InMemoryDocumentIndex(DocumentIndex):
|
|||||||
succeeded=ok_ids, num_deleted=len(ok_ids), num_failed=0, failed=[]
|
succeeded=ok_ids, num_deleted=len(ok_ids), num_failed=0, failed=[]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@override
|
||||||
def get(self, ids: Sequence[str], /, **kwargs: Any) -> list[Document]:
|
def get(self, ids: Sequence[str], /, **kwargs: Any) -> list[Document]:
|
||||||
"""Get by ids."""
|
"""Get by ids."""
|
||||||
return [self.store[id_] for id_ in ids if id_ in self.store]
|
return [self.store[id_] for id_ in ids if id_ in self.store]
|
||||||
|
|
||||||
|
@override
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||||
) -> list[Document]:
|
) -> list[Document]:
|
||||||
|
@ -576,7 +576,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
|
|
||||||
# --- Custom methods ---
|
# --- Custom methods ---
|
||||||
|
|
||||||
def _combine_llm_outputs(self, llm_outputs: list[Optional[dict]]) -> dict:
|
def _combine_llm_outputs(self, llm_outputs: list[Optional[dict]]) -> dict: # noqa: ARG002
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def _get_invocation_params(
|
def _get_invocation_params(
|
||||||
@ -1246,6 +1246,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
"""Return type of chat model."""
|
"""Return type of chat model."""
|
||||||
|
|
||||||
|
@override
|
||||||
def dict(self, **kwargs: Any) -> dict:
|
def dict(self, **kwargs: Any) -> dict:
|
||||||
"""Return a dictionary of the LLM."""
|
"""Return a dictionary of the LLM."""
|
||||||
starter_dict = dict(self._identifying_params)
|
starter_dict = dict(self._identifying_params)
|
||||||
|
@ -171,6 +171,7 @@ class FakeListChatModel(SimpleChatModel):
|
|||||||
class FakeChatModel(SimpleChatModel):
|
class FakeChatModel(SimpleChatModel):
|
||||||
"""Fake Chat Model wrapper for testing purposes."""
|
"""Fake Chat Model wrapper for testing purposes."""
|
||||||
|
|
||||||
|
@override
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
messages: list[BaseMessage],
|
messages: list[BaseMessage],
|
||||||
@ -180,6 +181,7 @@ class FakeChatModel(SimpleChatModel):
|
|||||||
) -> str:
|
) -> str:
|
||||||
return "fake response"
|
return "fake response"
|
||||||
|
|
||||||
|
@override
|
||||||
async def _agenerate(
|
async def _agenerate(
|
||||||
self,
|
self,
|
||||||
messages: list[BaseMessage],
|
messages: list[BaseMessage],
|
||||||
@ -224,6 +226,7 @@ class GenericFakeChatModel(BaseChatModel):
|
|||||||
into message chunks.
|
into message chunks.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@override
|
||||||
def _generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
messages: list[BaseMessage],
|
messages: list[BaseMessage],
|
||||||
@ -346,6 +349,7 @@ class ParrotFakeChatModel(BaseChatModel):
|
|||||||
* Chat model should be usable in both sync and async tests
|
* Chat model should be usable in both sync and async tests
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@override
|
||||||
def _generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
messages: list[BaseMessage],
|
messages: list[BaseMessage],
|
||||||
|
@ -1399,6 +1399,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
"""Return type of llm."""
|
"""Return type of llm."""
|
||||||
|
|
||||||
|
@override
|
||||||
def dict(self, **kwargs: Any) -> dict:
|
def dict(self, **kwargs: Any) -> dict:
|
||||||
"""Return a dictionary of the LLM."""
|
"""Return a dictionary of the LLM."""
|
||||||
starter_dict = dict(self._identifying_params)
|
starter_dict = dict(self._identifying_params)
|
||||||
|
@ -59,7 +59,7 @@ class BaseLLMOutputParser(Generic[T], ABC):
|
|||||||
Returns:
|
Returns:
|
||||||
Structured output.
|
Structured output.
|
||||||
"""
|
"""
|
||||||
return await run_in_executor(None, self.parse_result, result)
|
return await run_in_executor(None, self.parse_result, result, partial=partial)
|
||||||
|
|
||||||
|
|
||||||
class BaseGenerationOutputParser(
|
class BaseGenerationOutputParser(
|
||||||
@ -231,6 +231,7 @@ class BaseOutputParser(
|
|||||||
run_type="parser",
|
run_type="parser",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@override
|
||||||
def parse_result(self, result: list[Generation], *, partial: bool = False) -> T:
|
def parse_result(self, result: list[Generation], *, partial: bool = False) -> T:
|
||||||
"""Parse a list of candidate model Generations into a specific format.
|
"""Parse a list of candidate model Generations into a specific format.
|
||||||
|
|
||||||
@ -290,7 +291,11 @@ class BaseOutputParser(
|
|||||||
return await run_in_executor(None, self.parse, text)
|
return await run_in_executor(None, self.parse, text)
|
||||||
|
|
||||||
# TODO: rename 'completion' -> 'text'.
|
# TODO: rename 'completion' -> 'text'.
|
||||||
def parse_with_prompt(self, completion: str, prompt: PromptValue) -> Any:
|
def parse_with_prompt(
|
||||||
|
self,
|
||||||
|
completion: str,
|
||||||
|
prompt: PromptValue, # noqa: ARG002
|
||||||
|
) -> Any:
|
||||||
"""Parse the output of an LLM call with the input prompt for context.
|
"""Parse the output of an LLM call with the input prompt for context.
|
||||||
|
|
||||||
The prompt is largely provided in the event the OutputParser wants
|
The prompt is largely provided in the event the OutputParser wants
|
||||||
|
@ -7,6 +7,7 @@ from typing import Any, Optional, Union
|
|||||||
|
|
||||||
import jsonpatch # type: ignore[import]
|
import jsonpatch # type: ignore[import]
|
||||||
from pydantic import BaseModel, model_validator
|
from pydantic import BaseModel, model_validator
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_core.exceptions import OutputParserException
|
from langchain_core.exceptions import OutputParserException
|
||||||
from langchain_core.output_parsers import (
|
from langchain_core.output_parsers import (
|
||||||
@ -23,6 +24,7 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
|
|||||||
args_only: bool = True
|
args_only: bool = True
|
||||||
"""Whether to only return the arguments to the function call."""
|
"""Whether to only return the arguments to the function call."""
|
||||||
|
|
||||||
|
@override
|
||||||
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
||||||
"""Parse the result of an LLM call to a JSON object.
|
"""Parse the result of an LLM call to a JSON object.
|
||||||
|
|
||||||
@ -251,6 +253,7 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
|
|||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
@override
|
||||||
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
||||||
"""Parse the result of an LLM call to a JSON object.
|
"""Parse the result of an LLM call to a JSON object.
|
||||||
|
|
||||||
@ -287,6 +290,7 @@ class PydanticAttrOutputFunctionsParser(PydanticOutputFunctionsParser):
|
|||||||
attr_name: str
|
attr_name: str
|
||||||
"""The name of the attribute to return."""
|
"""The name of the attribute to return."""
|
||||||
|
|
||||||
|
@override
|
||||||
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
||||||
"""Parse the result of an LLM call to a JSON object.
|
"""Parse the result of an LLM call to a JSON object.
|
||||||
|
|
||||||
|
@ -9,6 +9,8 @@ from typing import (
|
|||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_core.messages import BaseMessage, BaseMessageChunk
|
from langchain_core.messages import BaseMessage, BaseMessageChunk
|
||||||
from langchain_core.output_parsers.base import BaseOutputParser, T
|
from langchain_core.output_parsers.base import BaseOutputParser, T
|
||||||
from langchain_core.outputs import (
|
from langchain_core.outputs import (
|
||||||
@ -48,6 +50,7 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
|
|||||||
None, self.parse_result, [Generation(text=chunk)]
|
None, self.parse_result, [Generation(text=chunk)]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@override
|
||||||
def transform(
|
def transform(
|
||||||
self,
|
self,
|
||||||
input: Iterator[Union[str, BaseMessage]],
|
input: Iterator[Union[str, BaseMessage]],
|
||||||
@ -68,6 +71,7 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
|
|||||||
input, self._transform, config, run_type="parser"
|
input, self._transform, config, run_type="parser"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@override
|
||||||
async def atransform(
|
async def atransform(
|
||||||
self,
|
self,
|
||||||
input: AsyncIterator[Union[str, BaseMessage]],
|
input: AsyncIterator[Union[str, BaseMessage]],
|
||||||
|
@ -127,6 +127,7 @@ class BasePromptTemplate(
|
|||||||
"""Return the output type of the prompt."""
|
"""Return the output type of the prompt."""
|
||||||
return Union[StringPromptValue, ChatPromptValueConcrete]
|
return Union[StringPromptValue, ChatPromptValueConcrete]
|
||||||
|
|
||||||
|
@override
|
||||||
def get_input_schema(
|
def get_input_schema(
|
||||||
self, config: Optional[RunnableConfig] = None
|
self, config: Optional[RunnableConfig] = None
|
||||||
) -> type[BaseModel]:
|
) -> type[BaseModel]:
|
||||||
|
@ -199,7 +199,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
|||||||
len(set(parameters.keys()) - {"self", "query", "run_manager"}) > 0
|
len(set(parameters.keys()) - {"self", "query", "run_manager"}) > 0
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_ls_params(self, **kwargs: Any) -> LangSmithRetrieverParams:
|
def _get_ls_params(self, **_kwargs: Any) -> LangSmithRetrieverParams:
|
||||||
"""Get standard params for tracing."""
|
"""Get standard params for tracing."""
|
||||||
default_retriever_name = self.get_name()
|
default_retriever_name = self.get_name()
|
||||||
if default_retriever_name.startswith("Retriever"):
|
if default_retriever_name.startswith("Retriever"):
|
||||||
|
@ -326,7 +326,8 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
return self.get_input_schema()
|
return self.get_input_schema()
|
||||||
|
|
||||||
def get_input_schema(
|
def get_input_schema(
|
||||||
self, config: Optional[RunnableConfig] = None
|
self,
|
||||||
|
config: Optional[RunnableConfig] = None, # noqa: ARG002
|
||||||
) -> type[BaseModel]:
|
) -> type[BaseModel]:
|
||||||
"""Get a pydantic model that can be used to validate input to the Runnable.
|
"""Get a pydantic model that can be used to validate input to the Runnable.
|
||||||
|
|
||||||
@ -398,7 +399,8 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
return self.get_output_schema()
|
return self.get_output_schema()
|
||||||
|
|
||||||
def get_output_schema(
|
def get_output_schema(
|
||||||
self, config: Optional[RunnableConfig] = None
|
self,
|
||||||
|
config: Optional[RunnableConfig] = None, # noqa: ARG002
|
||||||
) -> type[BaseModel]:
|
) -> type[BaseModel]:
|
||||||
"""Get a pydantic model that can be used to validate output to the Runnable.
|
"""Get a pydantic model that can be used to validate output to the Runnable.
|
||||||
|
|
||||||
@ -4751,11 +4753,6 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
)
|
)
|
||||||
return cast("Output", output)
|
return cast("Output", output)
|
||||||
|
|
||||||
def _config(
|
|
||||||
self, config: Optional[RunnableConfig], callable: Callable[..., Any]
|
|
||||||
) -> RunnableConfig:
|
|
||||||
return ensure_config(config)
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def invoke(
|
def invoke(
|
||||||
self,
|
self,
|
||||||
@ -4780,7 +4777,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
return self._call_with_config(
|
return self._call_with_config(
|
||||||
self._invoke,
|
self._invoke,
|
||||||
input,
|
input,
|
||||||
self._config(config, self.func),
|
ensure_config(config),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
msg = "Cannot invoke a coroutine function synchronously.Use `ainvoke` instead."
|
msg = "Cannot invoke a coroutine function synchronously.Use `ainvoke` instead."
|
||||||
@ -4803,11 +4800,10 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
Returns:
|
Returns:
|
||||||
The output of this Runnable.
|
The output of this Runnable.
|
||||||
"""
|
"""
|
||||||
the_func = self.afunc if hasattr(self, "afunc") else self.func
|
|
||||||
return await self._acall_with_config(
|
return await self._acall_with_config(
|
||||||
self._ainvoke,
|
self._ainvoke,
|
||||||
input,
|
input,
|
||||||
self._config(config, the_func),
|
ensure_config(config),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -4884,7 +4880,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
yield from self._transform_stream_with_config(
|
yield from self._transform_stream_with_config(
|
||||||
input,
|
input,
|
||||||
self._transform,
|
self._transform,
|
||||||
self._config(config, self.func),
|
ensure_config(config),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -5012,7 +5008,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
async for output in self._atransform_stream_with_config(
|
async for output in self._atransform_stream_with_config(
|
||||||
input,
|
input,
|
||||||
self._atransform,
|
self._atransform,
|
||||||
self._config(config, self.afunc if hasattr(self, "afunc") else self.func),
|
ensure_config(config),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
yield output
|
yield output
|
||||||
|
@ -400,6 +400,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
|||||||
def OutputType(self) -> type[Output]:
|
def OutputType(self) -> type[Output]:
|
||||||
return self._history_chain.OutputType
|
return self._history_chain.OutputType
|
||||||
|
|
||||||
|
@override
|
||||||
def get_output_schema(
|
def get_output_schema(
|
||||||
self, config: Optional[RunnableConfig] = None
|
self, config: Optional[RunnableConfig] = None
|
||||||
) -> type[BaseModel]:
|
) -> type[BaseModel]:
|
||||||
@ -432,12 +433,6 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
|||||||
module_name=self.__class__.__module__,
|
module_name=self.__class__.__module__,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _is_not_async(self, *args: Sequence[Any], **kwargs: dict[str, Any]) -> bool:
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def _is_async(self, *args: Sequence[Any], **kwargs: dict[str, Any]) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _get_input_messages(
|
def _get_input_messages(
|
||||||
self, input_val: Union[str, BaseMessage, Sequence[BaseMessage], dict]
|
self, input_val: Union[str, BaseMessage, Sequence[BaseMessage], dict]
|
||||||
) -> list[BaseMessage]:
|
) -> list[BaseMessage]:
|
||||||
|
@ -133,6 +133,7 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
|||||||
self._on_llm_start(llm_run)
|
self._on_llm_start(llm_run)
|
||||||
return llm_run
|
return llm_run
|
||||||
|
|
||||||
|
@override
|
||||||
def on_llm_new_token(
|
def on_llm_new_token(
|
||||||
self,
|
self,
|
||||||
token: str,
|
token: str,
|
||||||
@ -161,11 +162,11 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
|||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
chunk=chunk,
|
chunk=chunk,
|
||||||
parent_run_id=parent_run_id,
|
parent_run_id=parent_run_id,
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
self._on_llm_new_token(llm_run, token, chunk)
|
self._on_llm_new_token(llm_run, token, chunk)
|
||||||
return llm_run
|
return llm_run
|
||||||
|
|
||||||
|
@override
|
||||||
def on_retry(
|
def on_retry(
|
||||||
self,
|
self,
|
||||||
retry_state: RetryCallState,
|
retry_state: RetryCallState,
|
||||||
@ -188,6 +189,7 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
|||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@override
|
||||||
def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> Run:
|
def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> Run:
|
||||||
"""End a trace for an LLM run.
|
"""End a trace for an LLM run.
|
||||||
|
|
||||||
@ -235,6 +237,7 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
|||||||
self._on_llm_error(llm_run)
|
self._on_llm_error(llm_run)
|
||||||
return llm_run
|
return llm_run
|
||||||
|
|
||||||
|
@override
|
||||||
def on_chain_start(
|
def on_chain_start(
|
||||||
self,
|
self,
|
||||||
serialized: dict[str, Any],
|
serialized: dict[str, Any],
|
||||||
@ -279,6 +282,7 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
|||||||
self._on_chain_start(chain_run)
|
self._on_chain_start(chain_run)
|
||||||
return chain_run
|
return chain_run
|
||||||
|
|
||||||
|
@override
|
||||||
def on_chain_end(
|
def on_chain_end(
|
||||||
self,
|
self,
|
||||||
outputs: dict[str, Any],
|
outputs: dict[str, Any],
|
||||||
@ -302,12 +306,12 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
|||||||
outputs=outputs,
|
outputs=outputs,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
self._end_trace(chain_run)
|
self._end_trace(chain_run)
|
||||||
self._on_chain_end(chain_run)
|
self._on_chain_end(chain_run)
|
||||||
return chain_run
|
return chain_run
|
||||||
|
|
||||||
|
@override
|
||||||
def on_chain_error(
|
def on_chain_error(
|
||||||
self,
|
self,
|
||||||
error: BaseException,
|
error: BaseException,
|
||||||
@ -331,7 +335,6 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
|||||||
error=error,
|
error=error,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
self._end_trace(chain_run)
|
self._end_trace(chain_run)
|
||||||
self._on_chain_error(chain_run)
|
self._on_chain_error(chain_run)
|
||||||
@ -381,6 +384,7 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
|||||||
self._on_tool_start(tool_run)
|
self._on_tool_start(tool_run)
|
||||||
return tool_run
|
return tool_run
|
||||||
|
|
||||||
|
@override
|
||||||
def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> Run:
|
def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> Run:
|
||||||
"""End a trace for a tool run.
|
"""End a trace for a tool run.
|
||||||
|
|
||||||
@ -395,12 +399,12 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
|||||||
tool_run = self._complete_tool_run(
|
tool_run = self._complete_tool_run(
|
||||||
output=output,
|
output=output,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
self._end_trace(tool_run)
|
self._end_trace(tool_run)
|
||||||
self._on_tool_end(tool_run)
|
self._on_tool_end(tool_run)
|
||||||
return tool_run
|
return tool_run
|
||||||
|
|
||||||
|
@override
|
||||||
def on_tool_error(
|
def on_tool_error(
|
||||||
self,
|
self,
|
||||||
error: BaseException,
|
error: BaseException,
|
||||||
@ -467,6 +471,7 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
|||||||
self._on_retriever_start(retrieval_run)
|
self._on_retriever_start(retrieval_run)
|
||||||
return retrieval_run
|
return retrieval_run
|
||||||
|
|
||||||
|
@override
|
||||||
def on_retriever_error(
|
def on_retriever_error(
|
||||||
self,
|
self,
|
||||||
error: BaseException,
|
error: BaseException,
|
||||||
@ -487,12 +492,12 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
|||||||
retrieval_run = self._errored_retrieval_run(
|
retrieval_run = self._errored_retrieval_run(
|
||||||
error=error,
|
error=error,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
self._end_trace(retrieval_run)
|
self._end_trace(retrieval_run)
|
||||||
self._on_retriever_error(retrieval_run)
|
self._on_retriever_error(retrieval_run)
|
||||||
return retrieval_run
|
return retrieval_run
|
||||||
|
|
||||||
|
@override
|
||||||
def on_retriever_end(
|
def on_retriever_end(
|
||||||
self, documents: Sequence[Document], *, run_id: UUID, **kwargs: Any
|
self, documents: Sequence[Document], *, run_id: UUID, **kwargs: Any
|
||||||
) -> Run:
|
) -> Run:
|
||||||
@ -509,7 +514,6 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
|||||||
retrieval_run = self._complete_retrieval_run(
|
retrieval_run = self._complete_retrieval_run(
|
||||||
documents=documents,
|
documents=documents,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
self._end_trace(retrieval_run)
|
self._end_trace(retrieval_run)
|
||||||
self._on_retriever_end(retrieval_run)
|
self._on_retriever_end(retrieval_run)
|
||||||
@ -623,7 +627,6 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
|
|||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
chunk=chunk,
|
chunk=chunk,
|
||||||
parent_run_id=parent_run_id,
|
parent_run_id=parent_run_id,
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
await self._on_llm_new_token(llm_run, token, chunk)
|
await self._on_llm_new_token(llm_run, token, chunk)
|
||||||
|
|
||||||
@ -715,7 +718,6 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
|
|||||||
outputs=outputs,
|
outputs=outputs,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
tasks = [self._end_trace(chain_run), self._on_chain_end(chain_run)]
|
tasks = [self._end_trace(chain_run), self._on_chain_end(chain_run)]
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
@ -733,7 +735,6 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
|
|||||||
error=error,
|
error=error,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
tasks = [self._end_trace(chain_run), self._on_chain_error(chain_run)]
|
tasks = [self._end_trace(chain_run), self._on_chain_error(chain_run)]
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
@ -776,7 +777,6 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
|
|||||||
tool_run = self._complete_tool_run(
|
tool_run = self._complete_tool_run(
|
||||||
output=output,
|
output=output,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
tasks = [self._end_trace(tool_run), self._on_tool_end(tool_run)]
|
tasks = [self._end_trace(tool_run), self._on_tool_end(tool_run)]
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
@ -839,7 +839,6 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
|
|||||||
retrieval_run = self._errored_retrieval_run(
|
retrieval_run = self._errored_retrieval_run(
|
||||||
error=error,
|
error=error,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
tasks = [
|
tasks = [
|
||||||
self._end_trace(retrieval_run),
|
self._end_trace(retrieval_run),
|
||||||
@ -860,7 +859,6 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
|
|||||||
retrieval_run = self._complete_retrieval_run(
|
retrieval_run = self._complete_retrieval_run(
|
||||||
documents=documents,
|
documents=documents,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
tasks = [self._end_trace(retrieval_run), self._on_retriever_end(retrieval_run)]
|
tasks = [self._end_trace(retrieval_run), self._on_retriever_end(retrieval_run)]
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
|
@ -41,7 +41,7 @@ run_collector_var: ContextVar[Optional[RunCollectorCallbackHandler]] = ContextVa
|
|||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def tracing_enabled(
|
def tracing_enabled(
|
||||||
session_name: str = "default",
|
session_name: str = "default", # noqa: ARG001
|
||||||
) -> Generator[TracerSessionV1, None, None]:
|
) -> Generator[TracerSessionV1, None, None]:
|
||||||
"""Throw an error because this has been replaced by tracing_v2_enabled."""
|
"""Throw an error because this has been replaced by tracing_v2_enabled."""
|
||||||
msg = (
|
msg = (
|
||||||
|
@ -231,8 +231,7 @@ class _TracerCore(ABC):
|
|||||||
token: str,
|
token: str,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None, # noqa: ARG002
|
||||||
**kwargs: Any,
|
|
||||||
) -> Run:
|
) -> Run:
|
||||||
"""Append token event to LLM run and return the run."""
|
"""Append token event to LLM run and return the run."""
|
||||||
llm_run = self._get_run(run_id, run_type={"llm", "chat_model"})
|
llm_run = self._get_run(run_id, run_type={"llm", "chat_model"})
|
||||||
@ -252,7 +251,6 @@ class _TracerCore(ABC):
|
|||||||
self,
|
self,
|
||||||
retry_state: RetryCallState,
|
retry_state: RetryCallState,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
**kwargs: Any,
|
|
||||||
) -> Run:
|
) -> Run:
|
||||||
llm_run = self._get_run(run_id)
|
llm_run = self._get_run(run_id)
|
||||||
retry_d: dict[str, Any] = {
|
retry_d: dict[str, Any] = {
|
||||||
@ -369,7 +367,6 @@ class _TracerCore(ABC):
|
|||||||
outputs: dict[str, Any],
|
outputs: dict[str, Any],
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
inputs: Optional[dict[str, Any]] = None,
|
inputs: Optional[dict[str, Any]] = None,
|
||||||
**kwargs: Any,
|
|
||||||
) -> Run:
|
) -> Run:
|
||||||
"""Update a chain run with outputs and end time."""
|
"""Update a chain run with outputs and end time."""
|
||||||
chain_run = self._get_run(run_id)
|
chain_run = self._get_run(run_id)
|
||||||
@ -385,7 +382,6 @@ class _TracerCore(ABC):
|
|||||||
error: BaseException,
|
error: BaseException,
|
||||||
inputs: Optional[dict[str, Any]],
|
inputs: Optional[dict[str, Any]],
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
**kwargs: Any,
|
|
||||||
) -> Run:
|
) -> Run:
|
||||||
chain_run = self._get_run(run_id)
|
chain_run = self._get_run(run_id)
|
||||||
chain_run.error = self._get_stacktrace(error)
|
chain_run.error = self._get_stacktrace(error)
|
||||||
@ -439,7 +435,6 @@ class _TracerCore(ABC):
|
|||||||
self,
|
self,
|
||||||
output: dict[str, Any],
|
output: dict[str, Any],
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
**kwargs: Any,
|
|
||||||
) -> Run:
|
) -> Run:
|
||||||
"""Update a tool run with outputs and end time."""
|
"""Update a tool run with outputs and end time."""
|
||||||
tool_run = self._get_run(run_id, run_type="tool")
|
tool_run = self._get_run(run_id, run_type="tool")
|
||||||
@ -452,7 +447,6 @@ class _TracerCore(ABC):
|
|||||||
self,
|
self,
|
||||||
error: BaseException,
|
error: BaseException,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
**kwargs: Any,
|
|
||||||
) -> Run:
|
) -> Run:
|
||||||
"""Update a tool run with error and end time."""
|
"""Update a tool run with error and end time."""
|
||||||
tool_run = self._get_run(run_id, run_type="tool")
|
tool_run = self._get_run(run_id, run_type="tool")
|
||||||
@ -494,7 +488,6 @@ class _TracerCore(ABC):
|
|||||||
self,
|
self,
|
||||||
documents: Sequence[Document],
|
documents: Sequence[Document],
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
**kwargs: Any,
|
|
||||||
) -> Run:
|
) -> Run:
|
||||||
"""Update a retrieval run with outputs and end time."""
|
"""Update a retrieval run with outputs and end time."""
|
||||||
retrieval_run = self._get_run(run_id, run_type="retriever")
|
retrieval_run = self._get_run(run_id, run_type="retriever")
|
||||||
@ -507,7 +500,6 @@ class _TracerCore(ABC):
|
|||||||
self,
|
self,
|
||||||
error: BaseException,
|
error: BaseException,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
**kwargs: Any,
|
|
||||||
) -> Run:
|
) -> Run:
|
||||||
retrieval_run = self._get_run(run_id, run_type="retriever")
|
retrieval_run = self._get_run(run_id, run_type="retriever")
|
||||||
retrieval_run.error = self._get_stacktrace(error)
|
retrieval_run.error = self._get_stacktrace(error)
|
||||||
@ -523,75 +515,75 @@ class _TracerCore(ABC):
|
|||||||
"""Copy the tracer."""
|
"""Copy the tracer."""
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def _end_trace(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
def _end_trace(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
|
||||||
"""End a trace for a run."""
|
"""End a trace for a run."""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _on_run_create(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
def _on_run_create(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
|
||||||
"""Process a run upon creation."""
|
"""Process a run upon creation."""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _on_run_update(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
def _on_run_update(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
|
||||||
"""Process a run upon update."""
|
"""Process a run upon update."""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _on_llm_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
def _on_llm_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
|
||||||
"""Process the LLM Run upon start."""
|
"""Process the LLM Run upon start."""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _on_llm_new_token(
|
def _on_llm_new_token(
|
||||||
self,
|
self,
|
||||||
run: Run,
|
run: Run, # noqa: ARG002
|
||||||
token: str,
|
token: str, # noqa: ARG002
|
||||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]],
|
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]], # noqa: ARG002
|
||||||
) -> Union[None, Coroutine[Any, Any, None]]:
|
) -> Union[None, Coroutine[Any, Any, None]]:
|
||||||
"""Process new LLM token."""
|
"""Process new LLM token."""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _on_llm_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
def _on_llm_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
|
||||||
"""Process the LLM Run."""
|
"""Process the LLM Run."""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _on_llm_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
def _on_llm_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
|
||||||
"""Process the LLM Run upon error."""
|
"""Process the LLM Run upon error."""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _on_chain_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
def _on_chain_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
|
||||||
"""Process the Chain Run upon start."""
|
"""Process the Chain Run upon start."""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _on_chain_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
def _on_chain_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
|
||||||
"""Process the Chain Run."""
|
"""Process the Chain Run."""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _on_chain_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
def _on_chain_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
|
||||||
"""Process the Chain Run upon error."""
|
"""Process the Chain Run upon error."""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _on_tool_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
def _on_tool_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
|
||||||
"""Process the Tool Run upon start."""
|
"""Process the Tool Run upon start."""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _on_tool_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
def _on_tool_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
|
||||||
"""Process the Tool Run."""
|
"""Process the Tool Run."""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _on_tool_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
def _on_tool_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
|
||||||
"""Process the Tool Run upon error."""
|
"""Process the Tool Run upon error."""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _on_chat_model_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
def _on_chat_model_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
|
||||||
"""Process the Chat Model Run upon start."""
|
"""Process the Chat Model Run upon start."""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _on_retriever_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
def _on_retriever_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
|
||||||
"""Process the Retriever Run upon start."""
|
"""Process the Retriever Run upon start."""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _on_retriever_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
def _on_retriever_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
|
||||||
"""Process the Retriever Run."""
|
"""Process the Retriever Run."""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _on_retriever_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
def _on_retriever_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
|
||||||
"""Process the Retriever Run upon error."""
|
"""Process the Retriever Run upon error."""
|
||||||
return None
|
return None
|
||||||
|
@ -15,7 +15,7 @@ from typing import (
|
|||||||
)
|
)
|
||||||
from uuid import UUID, uuid4
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
from typing_extensions import NotRequired, TypedDict
|
from typing_extensions import NotRequired, TypedDict, override
|
||||||
|
|
||||||
from langchain_core.callbacks.base import AsyncCallbackHandler
|
from langchain_core.callbacks.base import AsyncCallbackHandler
|
||||||
from langchain_core.messages import AIMessageChunk, BaseMessage, BaseMessageChunk
|
from langchain_core.messages import AIMessageChunk, BaseMessage, BaseMessageChunk
|
||||||
@ -293,6 +293,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
|
|||||||
self.run_map[run_id] = info
|
self.run_map[run_id] = info
|
||||||
self.parent_map[run_id] = parent_run_id
|
self.parent_map[run_id] = parent_run_id
|
||||||
|
|
||||||
|
@override
|
||||||
async def on_chat_model_start(
|
async def on_chat_model_start(
|
||||||
self,
|
self,
|
||||||
serialized: dict[str, Any],
|
serialized: dict[str, Any],
|
||||||
@ -334,6 +335,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
|
|||||||
run_type,
|
run_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@override
|
||||||
async def on_llm_start(
|
async def on_llm_start(
|
||||||
self,
|
self,
|
||||||
serialized: dict[str, Any],
|
serialized: dict[str, Any],
|
||||||
@ -377,6 +379,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
|
|||||||
run_type,
|
run_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@override
|
||||||
async def on_custom_event(
|
async def on_custom_event(
|
||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
@ -399,6 +402,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
|
|||||||
)
|
)
|
||||||
self._send(event, name)
|
self._send(event, name)
|
||||||
|
|
||||||
|
@override
|
||||||
async def on_llm_new_token(
|
async def on_llm_new_token(
|
||||||
self,
|
self,
|
||||||
token: str,
|
token: str,
|
||||||
@ -450,6 +454,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
|
|||||||
run_info["run_type"],
|
run_info["run_type"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@override
|
||||||
async def on_llm_end(
|
async def on_llm_end(
|
||||||
self, response: LLMResult, *, run_id: UUID, **kwargs: Any
|
self, response: LLMResult, *, run_id: UUID, **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -552,6 +557,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
|
|||||||
run_type_,
|
run_type_,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@override
|
||||||
async def on_chain_end(
|
async def on_chain_end(
|
||||||
self,
|
self,
|
||||||
outputs: dict[str, Any],
|
outputs: dict[str, Any],
|
||||||
@ -586,6 +592,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
|
|||||||
run_type,
|
run_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@override
|
||||||
async def on_tool_start(
|
async def on_tool_start(
|
||||||
self,
|
self,
|
||||||
serialized: dict[str, Any],
|
serialized: dict[str, Any],
|
||||||
@ -627,6 +634,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
|
|||||||
"tool",
|
"tool",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@override
|
||||||
async def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
async def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||||
"""End a trace for a tool run."""
|
"""End a trace for a tool run."""
|
||||||
run_info = self.run_map.pop(run_id)
|
run_info = self.run_map.pop(run_id)
|
||||||
@ -654,6 +662,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
|
|||||||
"tool",
|
"tool",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@override
|
||||||
async def on_retriever_start(
|
async def on_retriever_start(
|
||||||
self,
|
self,
|
||||||
serialized: dict[str, Any],
|
serialized: dict[str, Any],
|
||||||
@ -697,6 +706,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
|
|||||||
run_type,
|
run_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@override
|
||||||
async def on_retriever_end(
|
async def on_retriever_end(
|
||||||
self, documents: Sequence[Document], *, run_id: UUID, **kwargs: Any
|
self, documents: Sequence[Document], *, run_id: UUID, **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -19,6 +19,7 @@ from tenacity import (
|
|||||||
stop_after_attempt,
|
stop_after_attempt,
|
||||||
wait_exponential_jitter,
|
wait_exponential_jitter,
|
||||||
)
|
)
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_core.env import get_runtime_environment
|
from langchain_core.env import get_runtime_environment
|
||||||
from langchain_core.load import dumpd
|
from langchain_core.load import dumpd
|
||||||
@ -252,13 +253,13 @@ class LangChainTracer(BaseTracer):
|
|||||||
run.reference_example_id = self.example_id
|
run.reference_example_id = self.example_id
|
||||||
self._persist_run_single(run)
|
self._persist_run_single(run)
|
||||||
|
|
||||||
|
@override
|
||||||
def _llm_run_with_token_event(
|
def _llm_run_with_token_event(
|
||||||
self,
|
self,
|
||||||
token: str,
|
token: str,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None,
|
||||||
**kwargs: Any,
|
|
||||||
) -> Run:
|
) -> Run:
|
||||||
"""Append token event to LLM run and return the run."""
|
"""Append token event to LLM run and return the run."""
|
||||||
return super()._llm_run_with_token_event(
|
return super()._llm_run_with_token_event(
|
||||||
@ -267,7 +268,6 @@ class LangChainTracer(BaseTracer):
|
|||||||
run_id,
|
run_id,
|
||||||
chunk=None,
|
chunk=None,
|
||||||
parent_run_id=parent_run_id,
|
parent_run_id=parent_run_id,
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _on_chat_model_start(self, run: Run) -> None:
|
def _on_chat_model_start(self, run: Run) -> None:
|
||||||
|
@ -6,7 +6,7 @@ Please use LangChainTracer instead.
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
def get_headers(*args: Any, **kwargs: Any) -> Any:
|
def get_headers(*args: Any, **kwargs: Any) -> Any: # noqa: ARG001
|
||||||
"""Throw an error because this has been replaced by get_headers."""
|
"""Throw an error because this has been replaced by get_headers."""
|
||||||
msg = (
|
msg = (
|
||||||
"get_headers for LangChainTracerV1 is no longer supported. "
|
"get_headers for LangChainTracerV1 is no longer supported. "
|
||||||
@ -15,7 +15,7 @@ def get_headers(*args: Any, **kwargs: Any) -> Any:
|
|||||||
raise RuntimeError(msg)
|
raise RuntimeError(msg)
|
||||||
|
|
||||||
|
|
||||||
def LangChainTracerV1(*args: Any, **kwargs: Any) -> Any: # noqa: N802
|
def LangChainTracerV1(*args: Any, **kwargs: Any) -> Any: # noqa: N802,ARG001
|
||||||
"""Throw an error because this has been replaced by LangChainTracer."""
|
"""Throw an error because this has been replaced by LangChainTracer."""
|
||||||
msg = (
|
msg = (
|
||||||
"LangChainTracerV1 is no longer supported. Please use LangChainTracer instead."
|
"LangChainTracerV1 is no longer supported. Please use LangChainTracer instead."
|
||||||
|
@ -16,8 +16,8 @@ from langchain_core._api.deprecation import deprecated
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
def try_load_from_hub(
|
def try_load_from_hub(
|
||||||
*args: Any,
|
*args: Any, # noqa: ARG001
|
||||||
**kwargs: Any,
|
**kwargs: Any, # noqa: ARG001
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""[DEPRECATED] Try to load from the old Hub."""
|
"""[DEPRECATED] Try to load from the old Hub."""
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
|
@ -65,7 +65,11 @@ def grab_literal(template: str, l_del: str) -> tuple[str, str]:
|
|||||||
return (literal, template)
|
return (literal, template)
|
||||||
|
|
||||||
|
|
||||||
def l_sa_check(template: str, literal: str, is_standalone: bool) -> bool:
|
def l_sa_check(
|
||||||
|
template: str, # noqa: ARG001
|
||||||
|
literal: str,
|
||||||
|
is_standalone: bool,
|
||||||
|
) -> bool:
|
||||||
"""Do a preliminary check to see if a tag could be a standalone.
|
"""Do a preliminary check to see if a tag could be a standalone.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -38,6 +38,7 @@ from pydantic.json_schema import (
|
|||||||
JsonSchemaMode,
|
JsonSchemaMode,
|
||||||
JsonSchemaValue,
|
JsonSchemaValue,
|
||||||
)
|
)
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from pydantic_core import core_schema
|
from pydantic_core import core_schema
|
||||||
@ -233,6 +234,7 @@ class _IgnoreUnserializable(GenerateJsonSchema):
|
|||||||
https://docs.pydantic.dev/latest/concepts/json_schema/#customizing-the-json-schema-generation-process
|
https://docs.pydantic.dev/latest/concepts/json_schema/#customizing-the-json-schema-generation-process
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@override
|
||||||
def handle_invalid_for_json_schema(
|
def handle_invalid_for_json_schema(
|
||||||
self, schema: core_schema.CoreSchema, error_info: str
|
self, schema: core_schema.CoreSchema, error_info: str
|
||||||
) -> JsonSchemaValue:
|
) -> JsonSchemaValue:
|
||||||
|
@ -13,6 +13,7 @@ from typing import Any, Callable, Optional, Union, overload
|
|||||||
from packaging.version import parse
|
from packaging.version import parse
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
from requests import HTTPError, Response
|
from requests import HTTPError, Response
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_core.utils.pydantic import (
|
from langchain_core.utils.pydantic import (
|
||||||
is_pydantic_v1_subclass,
|
is_pydantic_v1_subclass,
|
||||||
@ -91,6 +92,7 @@ def mock_now(dt_value: datetime.datetime) -> Iterator[type]:
|
|||||||
"""Mock datetime.datetime.now() with a fixed datetime."""
|
"""Mock datetime.datetime.now() with a fixed datetime."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@override
|
||||||
def now(cls, tz: Union[datetime.tzinfo, None] = None) -> "MockDateTime":
|
def now(cls, tz: Union[datetime.tzinfo, None] = None) -> "MockDateTime":
|
||||||
# Create a copy of dt_value.
|
# Create a copy of dt_value.
|
||||||
return MockDateTime(
|
return MockDateTime(
|
||||||
|
@ -36,7 +36,7 @@ from typing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from pydantic import ConfigDict, Field, model_validator
|
from pydantic import ConfigDict, Field, model_validator
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self, override
|
||||||
|
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.retrievers import BaseRetriever, LangSmithRetrieverParams
|
from langchain_core.retrievers import BaseRetriever, LangSmithRetrieverParams
|
||||||
@ -1070,6 +1070,7 @@ class VectorStoreRetriever(BaseRetriever):
|
|||||||
|
|
||||||
return ls_params
|
return ls_params
|
||||||
|
|
||||||
|
@override
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
|
||||||
) -> list[Document]:
|
) -> list[Document]:
|
||||||
@ -1090,6 +1091,7 @@ class VectorStoreRetriever(BaseRetriever):
|
|||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
|
@override
|
||||||
async def _aget_relevant_documents(
|
async def _aget_relevant_documents(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
|
@ -285,7 +285,7 @@ class InMemoryVectorStore(VectorStore):
|
|||||||
since="0.2.29",
|
since="0.2.29",
|
||||||
removal="1.0",
|
removal="1.0",
|
||||||
)
|
)
|
||||||
def upsert(self, items: Sequence[Document], /, **kwargs: Any) -> UpsertResponse:
|
def upsert(self, items: Sequence[Document], /, **_kwargs: Any) -> UpsertResponse:
|
||||||
"""[DEPRECATED] Upsert documents into the store.
|
"""[DEPRECATED] Upsert documents into the store.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -319,7 +319,7 @@ class InMemoryVectorStore(VectorStore):
|
|||||||
removal="1.0",
|
removal="1.0",
|
||||||
)
|
)
|
||||||
async def aupsert(
|
async def aupsert(
|
||||||
self, items: Sequence[Document], /, **kwargs: Any
|
self, items: Sequence[Document], /, **_kwargs: Any
|
||||||
) -> UpsertResponse:
|
) -> UpsertResponse:
|
||||||
"""[DEPRECATED] Upsert documents into the store.
|
"""[DEPRECATED] Upsert documents into the store.
|
||||||
|
|
||||||
@ -364,7 +364,6 @@ class InMemoryVectorStore(VectorStore):
|
|||||||
embedding: list[float],
|
embedding: list[float],
|
||||||
k: int = 4,
|
k: int = 4,
|
||||||
filter: Optional[Callable[[Document], bool]] = None,
|
filter: Optional[Callable[[Document], bool]] = None,
|
||||||
**kwargs: Any,
|
|
||||||
) -> list[tuple[Document, float, list[float]]]:
|
) -> list[tuple[Document, float, list[float]]]:
|
||||||
# get all docs with fixed order in list
|
# get all docs with fixed order in list
|
||||||
docs = list(self.store.values())
|
docs = list(self.store.values())
|
||||||
@ -404,7 +403,7 @@ class InMemoryVectorStore(VectorStore):
|
|||||||
embedding: list[float],
|
embedding: list[float],
|
||||||
k: int = 4,
|
k: int = 4,
|
||||||
filter: Optional[Callable[[Document], bool]] = None,
|
filter: Optional[Callable[[Document], bool]] = None,
|
||||||
**kwargs: Any,
|
**_kwargs: Any,
|
||||||
) -> list[tuple[Document, float]]:
|
) -> list[tuple[Document, float]]:
|
||||||
"""Search for the most similar documents to the given embedding.
|
"""Search for the most similar documents to the given embedding.
|
||||||
|
|
||||||
@ -419,7 +418,7 @@ class InMemoryVectorStore(VectorStore):
|
|||||||
return [
|
return [
|
||||||
(doc, similarity)
|
(doc, similarity)
|
||||||
for doc, similarity, _ in self._similarity_search_with_score_by_vector(
|
for doc, similarity, _ in self._similarity_search_with_score_by_vector(
|
||||||
embedding=embedding, k=k, filter=filter, **kwargs
|
embedding=embedding, k=k, filter=filter
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -490,12 +489,14 @@ class InMemoryVectorStore(VectorStore):
|
|||||||
k: int = 4,
|
k: int = 4,
|
||||||
fetch_k: int = 20,
|
fetch_k: int = 20,
|
||||||
lambda_mult: float = 0.5,
|
lambda_mult: float = 0.5,
|
||||||
|
*,
|
||||||
|
filter: Optional[Callable[[Document], bool]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> list[Document]:
|
) -> list[Document]:
|
||||||
prefetch_hits = self._similarity_search_with_score_by_vector(
|
prefetch_hits = self._similarity_search_with_score_by_vector(
|
||||||
embedding=embedding,
|
embedding=embedding,
|
||||||
k=fetch_k,
|
k=fetch_k,
|
||||||
**kwargs,
|
filter=filter,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -98,7 +98,6 @@ ignore = [
|
|||||||
# TODO rules
|
# TODO rules
|
||||||
"A",
|
"A",
|
||||||
"ANN401",
|
"ANN401",
|
||||||
"ARG",
|
|
||||||
"BLE",
|
"BLE",
|
||||||
"ERA",
|
"ERA",
|
||||||
"FBT001",
|
"FBT001",
|
||||||
@ -132,5 +131,6 @@ classmethod-decorators = [ "classmethod", "langchain_core.utils.pydantic.pre_ini
|
|||||||
"tests/unit_tests/prompts/test_chat.py" = [ "E501",]
|
"tests/unit_tests/prompts/test_chat.py" = [ "E501",]
|
||||||
"tests/unit_tests/runnables/test_runnable.py" = [ "E501",]
|
"tests/unit_tests/runnables/test_runnable.py" = [ "E501",]
|
||||||
"tests/unit_tests/runnables/test_graph.py" = [ "E501",]
|
"tests/unit_tests/runnables/test_graph.py" = [ "E501",]
|
||||||
|
"tests/unit_tests/test_tools.py" = [ "ARG",]
|
||||||
"tests/**" = [ "D", "S",]
|
"tests/**" = [ "D", "S",]
|
||||||
"scripts/**" = [ "INP", "S",]
|
"scripts/**" = [ "INP", "S",]
|
||||||
|
@ -10,6 +10,8 @@ from contextlib import asynccontextmanager
|
|||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
AsyncCallbackHandler,
|
AsyncCallbackHandler,
|
||||||
AsyncCallbackManager,
|
AsyncCallbackManager,
|
||||||
@ -45,6 +47,7 @@ async def test_inline_handlers_share_parent_context() -> None:
|
|||||||
"""Initialize the handler."""
|
"""Initialize the handler."""
|
||||||
self.run_inline = run_inline
|
self.run_inline = run_inline
|
||||||
|
|
||||||
|
@override
|
||||||
async def on_llm_start(self, *args: Any, **kwargs: Any) -> None:
|
async def on_llm_start(self, *args: Any, **kwargs: Any) -> None:
|
||||||
"""Update the callstack with the name of the callback."""
|
"""Update the callstack with the name of the callback."""
|
||||||
some_var.set("on_llm_start")
|
some_var.set("on_llm_start")
|
||||||
|
@ -74,6 +74,7 @@ async def test_async_custom_event_implicit_config() -> None:
|
|||||||
# a decorator for async functions
|
# a decorator for async functions
|
||||||
@RunnableLambda # type: ignore[arg-type]
|
@RunnableLambda # type: ignore[arg-type]
|
||||||
async def foo(x: int, config: RunnableConfig) -> int:
|
async def foo(x: int, config: RunnableConfig) -> int:
|
||||||
|
assert "callbacks" in config
|
||||||
await adispatch_custom_event("event1", {"x": x})
|
await adispatch_custom_event("event1", {"x": x})
|
||||||
await adispatch_custom_event("event2", {"x": x})
|
await adispatch_custom_event("event2", {"x": x})
|
||||||
return x
|
return x
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_core.document_loaders.base import BaseBlobParser, BaseLoader
|
from langchain_core.document_loaders.base import BaseBlobParser, BaseLoader
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
@ -15,6 +16,7 @@ def test_base_blob_parser() -> None:
|
|||||||
class MyParser(BaseBlobParser):
|
class MyParser(BaseBlobParser):
|
||||||
"""A simple parser that returns a single document."""
|
"""A simple parser that returns a single document."""
|
||||||
|
|
||||||
|
@override
|
||||||
def lazy_parse(self, blob: Blob) -> Iterator[Document]:
|
def lazy_parse(self, blob: Blob) -> Iterator[Document]:
|
||||||
"""Lazy parsing interface."""
|
"""Lazy parsing interface."""
|
||||||
yield Document(
|
yield Document(
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from typing import Any, Optional, cast
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.embeddings import Embeddings, FakeEmbeddings
|
from langchain_core.embeddings import Embeddings, FakeEmbeddings
|
||||||
from langchain_core.example_selectors import (
|
from langchain_core.example_selectors import (
|
||||||
@ -21,6 +23,7 @@ class DummyVectorStore(VectorStore):
|
|||||||
def embeddings(self) -> Optional[Embeddings]:
|
def embeddings(self) -> Optional[Embeddings]:
|
||||||
return self._embeddings
|
return self._embeddings
|
||||||
|
|
||||||
|
@override
|
||||||
def add_texts(
|
def add_texts(
|
||||||
self,
|
self,
|
||||||
texts: Iterable[str],
|
texts: Iterable[str],
|
||||||
@ -32,6 +35,7 @@ class DummyVectorStore(VectorStore):
|
|||||||
self.metadatas.extend(metadatas)
|
self.metadatas.extend(metadatas)
|
||||||
return ["dummy_id"]
|
return ["dummy_id"]
|
||||||
|
|
||||||
|
@override
|
||||||
def similarity_search(
|
def similarity_search(
|
||||||
self, query: str, k: int = 4, **kwargs: Any
|
self, query: str, k: int = 4, **kwargs: Any
|
||||||
) -> list[Document]:
|
) -> list[Document]:
|
||||||
@ -41,6 +45,7 @@ class DummyVectorStore(VectorStore):
|
|||||||
)
|
)
|
||||||
] * k
|
] * k
|
||||||
|
|
||||||
|
@override
|
||||||
def max_marginal_relevance_search(
|
def max_marginal_relevance_search(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
|
@ -5,6 +5,7 @@ from typing import Any, Optional, Union
|
|||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
|
from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
|
||||||
from langchain_core.messages import BaseMessage
|
from langchain_core.messages import BaseMessage
|
||||||
@ -138,6 +139,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
|||||||
"""Whether to ignore retriever callbacks."""
|
"""Whether to ignore retriever callbacks."""
|
||||||
return self.ignore_retriever_
|
return self.ignore_retriever_
|
||||||
|
|
||||||
|
@override
|
||||||
def on_llm_start(
|
def on_llm_start(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -145,6 +147,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
self.on_llm_start_common()
|
self.on_llm_start_common()
|
||||||
|
|
||||||
|
@override
|
||||||
def on_llm_new_token(
|
def on_llm_new_token(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -152,6 +155,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
self.on_llm_new_token_common()
|
self.on_llm_new_token_common()
|
||||||
|
|
||||||
|
@override
|
||||||
def on_llm_end(
|
def on_llm_end(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -159,6 +163,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
self.on_llm_end_common()
|
self.on_llm_end_common()
|
||||||
|
|
||||||
|
@override
|
||||||
def on_llm_error(
|
def on_llm_error(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -166,6 +171,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
self.on_llm_error_common(*args, **kwargs)
|
self.on_llm_error_common(*args, **kwargs)
|
||||||
|
|
||||||
|
@override
|
||||||
def on_retry(
|
def on_retry(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -173,6 +179,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
self.on_retry_common()
|
self.on_retry_common()
|
||||||
|
|
||||||
|
@override
|
||||||
def on_chain_start(
|
def on_chain_start(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -180,6 +187,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
self.on_chain_start_common()
|
self.on_chain_start_common()
|
||||||
|
|
||||||
|
@override
|
||||||
def on_chain_end(
|
def on_chain_end(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -187,6 +195,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
self.on_chain_end_common()
|
self.on_chain_end_common()
|
||||||
|
|
||||||
|
@override
|
||||||
def on_chain_error(
|
def on_chain_error(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -194,6 +203,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
self.on_chain_error_common()
|
self.on_chain_error_common()
|
||||||
|
|
||||||
|
@override
|
||||||
def on_tool_start(
|
def on_tool_start(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -201,6 +211,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
self.on_tool_start_common()
|
self.on_tool_start_common()
|
||||||
|
|
||||||
|
@override
|
||||||
def on_tool_end(
|
def on_tool_end(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -208,6 +219,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
self.on_tool_end_common()
|
self.on_tool_end_common()
|
||||||
|
|
||||||
|
@override
|
||||||
def on_tool_error(
|
def on_tool_error(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -215,6 +227,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
self.on_tool_error_common()
|
self.on_tool_error_common()
|
||||||
|
|
||||||
|
@override
|
||||||
def on_agent_action(
|
def on_agent_action(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -222,6 +235,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
self.on_agent_action_common()
|
self.on_agent_action_common()
|
||||||
|
|
||||||
|
@override
|
||||||
def on_agent_finish(
|
def on_agent_finish(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -229,6 +243,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
self.on_agent_finish_common()
|
self.on_agent_finish_common()
|
||||||
|
|
||||||
|
@override
|
||||||
def on_text(
|
def on_text(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -236,6 +251,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
self.on_text_common()
|
self.on_text_common()
|
||||||
|
|
||||||
|
@override
|
||||||
def on_retriever_start(
|
def on_retriever_start(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -243,6 +259,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
self.on_retriever_start_common()
|
self.on_retriever_start_common()
|
||||||
|
|
||||||
|
@override
|
||||||
def on_retriever_end(
|
def on_retriever_end(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -250,6 +267,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
self.on_retriever_end_common()
|
self.on_retriever_end_common()
|
||||||
|
|
||||||
|
@override
|
||||||
def on_retriever_error(
|
def on_retriever_error(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -263,6 +281,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
|||||||
|
|
||||||
|
|
||||||
class FakeCallbackHandlerWithChatStart(FakeCallbackHandler):
|
class FakeCallbackHandlerWithChatStart(FakeCallbackHandler):
|
||||||
|
@override
|
||||||
def on_chat_model_start(
|
def on_chat_model_start(
|
||||||
self,
|
self,
|
||||||
serialized: dict[str, Any],
|
serialized: dict[str, Any],
|
||||||
@ -294,6 +313,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
|||||||
"""Whether to ignore agent callbacks."""
|
"""Whether to ignore agent callbacks."""
|
||||||
return self.ignore_agent_
|
return self.ignore_agent_
|
||||||
|
|
||||||
|
@override
|
||||||
async def on_retry(
|
async def on_retry(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -301,6 +321,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
self.on_retry_common()
|
self.on_retry_common()
|
||||||
|
|
||||||
|
@override
|
||||||
async def on_llm_start(
|
async def on_llm_start(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -308,6 +329,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
|||||||
) -> None:
|
) -> None:
|
||||||
self.on_llm_start_common()
|
self.on_llm_start_common()
|
||||||
|
|
||||||
|
@override
|
||||||
async def on_llm_new_token(
|
async def on_llm_new_token(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -315,6 +337,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
|||||||
) -> None:
|
) -> None:
|
||||||
self.on_llm_new_token_common()
|
self.on_llm_new_token_common()
|
||||||
|
|
||||||
|
@override
|
||||||
async def on_llm_end(
|
async def on_llm_end(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -322,6 +345,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
|||||||
) -> None:
|
) -> None:
|
||||||
self.on_llm_end_common()
|
self.on_llm_end_common()
|
||||||
|
|
||||||
|
@override
|
||||||
async def on_llm_error(
|
async def on_llm_error(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -329,6 +353,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
|||||||
) -> None:
|
) -> None:
|
||||||
self.on_llm_error_common(*args, **kwargs)
|
self.on_llm_error_common(*args, **kwargs)
|
||||||
|
|
||||||
|
@override
|
||||||
async def on_chain_start(
|
async def on_chain_start(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -336,6 +361,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
|||||||
) -> None:
|
) -> None:
|
||||||
self.on_chain_start_common()
|
self.on_chain_start_common()
|
||||||
|
|
||||||
|
@override
|
||||||
async def on_chain_end(
|
async def on_chain_end(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -343,6 +369,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
|||||||
) -> None:
|
) -> None:
|
||||||
self.on_chain_end_common()
|
self.on_chain_end_common()
|
||||||
|
|
||||||
|
@override
|
||||||
async def on_chain_error(
|
async def on_chain_error(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -350,6 +377,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
|||||||
) -> None:
|
) -> None:
|
||||||
self.on_chain_error_common()
|
self.on_chain_error_common()
|
||||||
|
|
||||||
|
@override
|
||||||
async def on_tool_start(
|
async def on_tool_start(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -357,6 +385,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
|||||||
) -> None:
|
) -> None:
|
||||||
self.on_tool_start_common()
|
self.on_tool_start_common()
|
||||||
|
|
||||||
|
@override
|
||||||
async def on_tool_end(
|
async def on_tool_end(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -364,6 +393,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
|||||||
) -> None:
|
) -> None:
|
||||||
self.on_tool_end_common()
|
self.on_tool_end_common()
|
||||||
|
|
||||||
|
@override
|
||||||
async def on_tool_error(
|
async def on_tool_error(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -371,6 +401,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
|||||||
) -> None:
|
) -> None:
|
||||||
self.on_tool_error_common()
|
self.on_tool_error_common()
|
||||||
|
|
||||||
|
@override
|
||||||
async def on_agent_action(
|
async def on_agent_action(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -378,6 +409,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
|||||||
) -> None:
|
) -> None:
|
||||||
self.on_agent_action_common()
|
self.on_agent_action_common()
|
||||||
|
|
||||||
|
@override
|
||||||
async def on_agent_finish(
|
async def on_agent_finish(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -385,6 +417,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
|||||||
) -> None:
|
) -> None:
|
||||||
self.on_agent_finish_common()
|
self.on_agent_finish_common()
|
||||||
|
|
||||||
|
@override
|
||||||
async def on_text(
|
async def on_text(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
|
@ -4,6 +4,8 @@ from itertools import cycle
|
|||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_core.callbacks.base import AsyncCallbackHandler
|
from langchain_core.callbacks.base import AsyncCallbackHandler
|
||||||
from langchain_core.language_models import (
|
from langchain_core.language_models import (
|
||||||
FakeListChatModel,
|
FakeListChatModel,
|
||||||
@ -171,6 +173,7 @@ async def test_callback_handlers() -> None:
|
|||||||
# Required to implement since this is an abstract method
|
# Required to implement since this is an abstract method
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@override
|
||||||
async def on_llm_new_token(
|
async def on_llm_new_token(
|
||||||
self,
|
self,
|
||||||
token: str,
|
token: str,
|
||||||
|
@ -5,6 +5,7 @@ from collections.abc import AsyncIterator, Iterator
|
|||||||
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
|
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||||
from langchain_core.language_models import BaseChatModel, FakeListChatModel
|
from langchain_core.language_models import BaseChatModel, FakeListChatModel
|
||||||
@ -138,6 +139,7 @@ async def test_astream_fallback_to_ainvoke() -> None:
|
|||||||
"""Test astream uses appropriate implementation."""
|
"""Test astream uses appropriate implementation."""
|
||||||
|
|
||||||
class ModelWithGenerate(BaseChatModel):
|
class ModelWithGenerate(BaseChatModel):
|
||||||
|
@override
|
||||||
def _generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
messages: list[BaseMessage],
|
messages: list[BaseMessage],
|
||||||
@ -176,6 +178,7 @@ async def test_astream_implementation_fallback_to_stream() -> None:
|
|||||||
"""Top Level call."""
|
"""Top Level call."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@override
|
||||||
def _stream(
|
def _stream(
|
||||||
self,
|
self,
|
||||||
messages: list[BaseMessage],
|
messages: list[BaseMessage],
|
||||||
@ -221,6 +224,7 @@ async def test_astream_implementation_uses_astream() -> None:
|
|||||||
"""Top Level call."""
|
"""Top Level call."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@override
|
||||||
async def _astream( # type: ignore
|
async def _astream( # type: ignore
|
||||||
self,
|
self,
|
||||||
messages: list[BaseMessage],
|
messages: list[BaseMessage],
|
||||||
@ -286,6 +290,7 @@ async def test_async_pass_run_id() -> None:
|
|||||||
|
|
||||||
|
|
||||||
class NoStreamingModel(BaseChatModel):
|
class NoStreamingModel(BaseChatModel):
|
||||||
|
@override
|
||||||
def _generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
messages: list[BaseMessage],
|
messages: list[BaseMessage],
|
||||||
@ -301,6 +306,7 @@ class NoStreamingModel(BaseChatModel):
|
|||||||
|
|
||||||
|
|
||||||
class StreamingModel(NoStreamingModel):
|
class StreamingModel(NoStreamingModel):
|
||||||
|
@override
|
||||||
def _stream(
|
def _stream(
|
||||||
self,
|
self,
|
||||||
messages: list[BaseMessage],
|
messages: list[BaseMessage],
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_core.caches import RETURN_VAL_TYPE, BaseCache
|
from langchain_core.caches import RETURN_VAL_TYPE, BaseCache
|
||||||
from langchain_core.globals import set_llm_cache
|
from langchain_core.globals import set_llm_cache
|
||||||
@ -30,6 +31,7 @@ class InMemoryCache(BaseCache):
|
|||||||
"""Update cache based on prompt and llm_string."""
|
"""Update cache based on prompt and llm_string."""
|
||||||
self._cache[(prompt, llm_string)] = return_val
|
self._cache[(prompt, llm_string)] = return_val
|
||||||
|
|
||||||
|
@override
|
||||||
def clear(self, **kwargs: Any) -> None:
|
def clear(self, **kwargs: Any) -> None:
|
||||||
"""Clear cache."""
|
"""Clear cache."""
|
||||||
self._cache = {}
|
self._cache = {}
|
||||||
|
@ -2,6 +2,7 @@ from collections.abc import AsyncIterator, Iterator
|
|||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
AsyncCallbackManagerForLLMRun,
|
AsyncCallbackManagerForLLMRun,
|
||||||
@ -106,6 +107,7 @@ async def test_error_callback() -> None:
|
|||||||
"""Return type of llm."""
|
"""Return type of llm."""
|
||||||
return "failing-llm"
|
return "failing-llm"
|
||||||
|
|
||||||
|
@override
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@ -136,6 +138,7 @@ async def test_astream_fallback_to_ainvoke() -> None:
|
|||||||
"""Test astream uses appropriate implementation."""
|
"""Test astream uses appropriate implementation."""
|
||||||
|
|
||||||
class ModelWithGenerate(BaseLLM):
|
class ModelWithGenerate(BaseLLM):
|
||||||
|
@override
|
||||||
def _generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
prompts: list[str],
|
prompts: list[str],
|
||||||
@ -172,6 +175,7 @@ async def test_astream_implementation_fallback_to_stream() -> None:
|
|||||||
"""Top Level call."""
|
"""Top Level call."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@override
|
||||||
def _stream(
|
def _stream(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@ -209,6 +213,7 @@ async def test_astream_implementation_uses_astream() -> None:
|
|||||||
"""Top Level call."""
|
"""Top Level call."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@override
|
||||||
async def _astream(
|
async def _astream(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_core.caches import RETURN_VAL_TYPE, BaseCache
|
from langchain_core.caches import RETURN_VAL_TYPE, BaseCache
|
||||||
from langchain_core.globals import set_llm_cache
|
from langchain_core.globals import set_llm_cache
|
||||||
from langchain_core.language_models import FakeListLLM
|
from langchain_core.language_models import FakeListLLM
|
||||||
@ -20,6 +22,7 @@ class InMemoryCache(BaseCache):
|
|||||||
"""Update cache based on prompt and llm_string."""
|
"""Update cache based on prompt and llm_string."""
|
||||||
self._cache[(prompt, llm_string)] = return_val
|
self._cache[(prompt, llm_string)] = return_val
|
||||||
|
|
||||||
|
@override
|
||||||
def clear(self, **kwargs: Any) -> None:
|
def clear(self, **kwargs: Any) -> None:
|
||||||
"""Clear cache."""
|
"""Clear cache."""
|
||||||
self._cache = {}
|
self._cache = {}
|
||||||
@ -74,6 +77,7 @@ class InMemoryCacheBad(BaseCache):
|
|||||||
msg = "This code should not be triggered"
|
msg = "This code should not be triggered"
|
||||||
raise NotImplementedError(msg)
|
raise NotImplementedError(msg)
|
||||||
|
|
||||||
|
@override
|
||||||
def clear(self, **kwargs: Any) -> None:
|
def clear(self, **kwargs: Any) -> None:
|
||||||
"""Clear cache."""
|
"""Clear cache."""
|
||||||
self._cache = {}
|
self._cache = {}
|
||||||
|
@ -6,6 +6,7 @@ from collections.abc import Sequence
|
|||||||
from typing import Any, Callable, Optional, Union
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_core.language_models.fake_chat_models import FakeChatModel
|
from langchain_core.language_models.fake_chat_models import FakeChatModel
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
@ -660,6 +661,7 @@ def test_trim_messages_start_on_with_allow_partial() -> None:
|
|||||||
|
|
||||||
|
|
||||||
class FakeTokenCountingModel(FakeChatModel):
|
class FakeTokenCountingModel(FakeChatModel):
|
||||||
|
@override
|
||||||
def get_num_tokens_from_messages(
|
def get_num_tokens_from_messages(
|
||||||
self,
|
self,
|
||||||
messages: list[BaseMessage],
|
messages: list[BaseMessage],
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
"""Module to test base parser implementations."""
|
"""Module to test base parser implementations."""
|
||||||
|
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_core.exceptions import OutputParserException
|
from langchain_core.exceptions import OutputParserException
|
||||||
from langchain_core.language_models import GenericFakeChatModel
|
from langchain_core.language_models import GenericFakeChatModel
|
||||||
from langchain_core.messages import AIMessage
|
from langchain_core.messages import AIMessage
|
||||||
@ -16,6 +18,7 @@ def test_base_generation_parser() -> None:
|
|||||||
class StrInvertCase(BaseGenerationOutputParser[str]):
|
class StrInvertCase(BaseGenerationOutputParser[str]):
|
||||||
"""An example parser that inverts the case of the characters in the message."""
|
"""An example parser that inverts the case of the characters in the message."""
|
||||||
|
|
||||||
|
@override
|
||||||
def parse_result(
|
def parse_result(
|
||||||
self, result: list[Generation], *, partial: bool = False
|
self, result: list[Generation], *, partial: bool = False
|
||||||
) -> str:
|
) -> str:
|
||||||
@ -59,6 +62,7 @@ def test_base_transform_output_parser() -> None:
|
|||||||
"""Parse a single string into a specific format."""
|
"""Parse a single string into a specific format."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@override
|
||||||
def parse_result(
|
def parse_result(
|
||||||
self, result: list[Generation], *, partial: bool = False
|
self, result: list[Generation], *, partial: bool = False
|
||||||
) -> str:
|
) -> str:
|
||||||
|
@ -5,6 +5,7 @@ from collections.abc import Sequence
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_core.example_selectors import BaseExampleSelector
|
from langchain_core.example_selectors import BaseExampleSelector
|
||||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||||
@ -383,6 +384,7 @@ class AsIsSelector(BaseExampleSelector):
|
|||||||
def add_example(self, example: dict[str, str]) -> Any:
|
def add_example(self, example: dict[str, str]) -> Any:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@override
|
||||||
def select_examples(self, input_variables: dict[str, str]) -> list[dict]:
|
def select_examples(self, input_variables: dict[str, str]) -> list[dict]:
|
||||||
return list(self.examples)
|
return list(self.examples)
|
||||||
|
|
||||||
@ -481,6 +483,7 @@ class AsyncAsIsSelector(BaseExampleSelector):
|
|||||||
def select_examples(self, input_variables: dict[str, str]) -> list[dict]:
|
def select_examples(self, input_variables: dict[str, str]) -> list[dict]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@override
|
||||||
async def aselect_examples(self, input_variables: dict[str, str]) -> list[dict]:
|
async def aselect_examples(self, input_variables: dict[str, str]) -> list[dict]:
|
||||||
return list(self.examples)
|
return list(self.examples)
|
||||||
|
|
||||||
|
@ -14,7 +14,7 @@ from langchain_core.utils.pydantic import is_basemodel_subclass
|
|||||||
|
|
||||||
|
|
||||||
def _fake_runnable(
|
def _fake_runnable(
|
||||||
input: Any, *, schema: Union[dict, type[BaseModel]], value: Any = 42, **_: Any
|
_: Any, *, schema: Union[dict, type[BaseModel]], value: Any = 42, **_kwargs: Any
|
||||||
) -> Union[BaseModel, dict]:
|
) -> Union[BaseModel, dict]:
|
||||||
if isclass(schema) and is_basemodel_subclass(schema):
|
if isclass(schema) and is_basemodel_subclass(schema):
|
||||||
return schema(name="yo", value=value)
|
return schema(name="yo", value=value)
|
||||||
|
@ -2,7 +2,7 @@ from typing import Any, Optional
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import ConfigDict, Field, model_validator
|
from pydantic import ConfigDict, Field, model_validator
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self, override
|
||||||
|
|
||||||
from langchain_core.runnables import (
|
from langchain_core.runnables import (
|
||||||
ConfigurableField,
|
ConfigurableField,
|
||||||
@ -32,6 +32,7 @@ class MyRunnable(RunnableSerializable[str, str]):
|
|||||||
self._my_hidden_property = self.my_property
|
self._my_hidden_property = self.my_property
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@override
|
||||||
def invoke(
|
def invoke(
|
||||||
self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any
|
self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||||
) -> Any:
|
) -> Any:
|
||||||
@ -41,12 +42,15 @@ class MyRunnable(RunnableSerializable[str, str]):
|
|||||||
return self.my_property
|
return self.my_property
|
||||||
|
|
||||||
def my_custom_function_w_config(
|
def my_custom_function_w_config(
|
||||||
self, config: Optional[RunnableConfig] = None
|
self,
|
||||||
|
config: Optional[RunnableConfig] = None, # noqa: ARG002
|
||||||
) -> str:
|
) -> str:
|
||||||
return self.my_property
|
return self.my_property
|
||||||
|
|
||||||
def my_custom_function_w_kw_config(
|
def my_custom_function_w_kw_config(
|
||||||
self, *, config: Optional[RunnableConfig] = None
|
self,
|
||||||
|
*,
|
||||||
|
config: Optional[RunnableConfig] = None, # noqa: ARG002
|
||||||
) -> str:
|
) -> str:
|
||||||
return self.my_property
|
return self.my_property
|
||||||
|
|
||||||
@ -54,6 +58,7 @@ class MyRunnable(RunnableSerializable[str, str]):
|
|||||||
class MyOtherRunnable(RunnableSerializable[str, str]):
|
class MyOtherRunnable(RunnableSerializable[str, str]):
|
||||||
my_other_property: str
|
my_other_property: str
|
||||||
|
|
||||||
|
@override
|
||||||
def invoke(
|
def invoke(
|
||||||
self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any
|
self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||||
) -> Any:
|
) -> Any:
|
||||||
@ -62,7 +67,7 @@ class MyOtherRunnable(RunnableSerializable[str, str]):
|
|||||||
def my_other_custom_function(self) -> str:
|
def my_other_custom_function(self) -> str:
|
||||||
return self.my_other_property
|
return self.my_other_property
|
||||||
|
|
||||||
def my_other_custom_function_w_config(self, config: RunnableConfig) -> str:
|
def my_other_custom_function_w_config(self, config: RunnableConfig) -> str: # noqa: ARG002
|
||||||
return self.my_other_property
|
return self.my_other_property
|
||||||
|
|
||||||
|
|
||||||
|
@ -25,7 +25,7 @@ def seq_naive_rag() -> Runnable:
|
|||||||
"What's your name?",
|
"What's your name?",
|
||||||
]
|
]
|
||||||
|
|
||||||
retriever = RunnableLambda(lambda x: context)
|
retriever = RunnableLambda(lambda _: context)
|
||||||
prompt = PromptTemplate.from_template("{context} {question}")
|
prompt = PromptTemplate.from_template("{context} {question}")
|
||||||
llm = FakeListLLM(responses=["hello"])
|
llm = FakeListLLM(responses=["hello"])
|
||||||
|
|
||||||
@ -53,7 +53,7 @@ def seq_naive_rag_alt() -> Runnable:
|
|||||||
"What's your name?",
|
"What's your name?",
|
||||||
]
|
]
|
||||||
|
|
||||||
retriever = RunnableLambda(lambda x: context)
|
retriever = RunnableLambda(lambda _: context)
|
||||||
prompt = PromptTemplate.from_template("{context} {question}")
|
prompt = PromptTemplate.from_template("{context} {question}")
|
||||||
llm = FakeListLLM(responses=["hello"])
|
llm = FakeListLLM(responses=["hello"])
|
||||||
|
|
||||||
@ -78,7 +78,7 @@ def seq_naive_rag_scoped() -> Runnable:
|
|||||||
"What's your name?",
|
"What's your name?",
|
||||||
]
|
]
|
||||||
|
|
||||||
retriever = RunnableLambda(lambda x: context)
|
retriever = RunnableLambda(lambda _: context)
|
||||||
prompt = PromptTemplate.from_template("{context} {question}")
|
prompt = PromptTemplate.from_template("{context} {question}")
|
||||||
llm = FakeListLLM(responses=["hello"])
|
llm = FakeListLLM(responses=["hello"])
|
||||||
|
|
||||||
|
@ -9,6 +9,7 @@ from typing import (
|
|||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from syrupy import SnapshotAssertion
|
from syrupy import SnapshotAssertion
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||||
from langchain_core.language_models import (
|
from langchain_core.language_models import (
|
||||||
@ -60,7 +61,7 @@ def chain() -> Runnable:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _raise_error(inputs: dict) -> str:
|
def _raise_error(_: dict) -> str:
|
||||||
raise ValueError
|
raise ValueError
|
||||||
|
|
||||||
|
|
||||||
@ -259,17 +260,17 @@ async def test_abatch() -> None:
|
|||||||
_assert_potential_error(actual, expected)
|
_assert_potential_error(actual, expected)
|
||||||
|
|
||||||
|
|
||||||
def _generate(input: Iterator) -> Iterator[str]:
|
def _generate(_: Iterator) -> Iterator[str]:
|
||||||
yield from "foo bar"
|
yield from "foo bar"
|
||||||
|
|
||||||
|
|
||||||
def _generate_immediate_error(input: Iterator) -> Iterator[str]:
|
def _generate_immediate_error(_: Iterator) -> Iterator[str]:
|
||||||
msg = "immmediate error"
|
msg = "immmediate error"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
yield ""
|
yield ""
|
||||||
|
|
||||||
|
|
||||||
def _generate_delayed_error(input: Iterator) -> Iterator[str]:
|
def _generate_delayed_error(_: Iterator) -> Iterator[str]:
|
||||||
yield ""
|
yield ""
|
||||||
msg = "delayed error"
|
msg = "delayed error"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
@ -288,18 +289,18 @@ def test_fallbacks_stream() -> None:
|
|||||||
list(runnable.stream({}))
|
list(runnable.stream({}))
|
||||||
|
|
||||||
|
|
||||||
async def _agenerate(input: AsyncIterator) -> AsyncIterator[str]:
|
async def _agenerate(_: AsyncIterator) -> AsyncIterator[str]:
|
||||||
for c in "foo bar":
|
for c in "foo bar":
|
||||||
yield c
|
yield c
|
||||||
|
|
||||||
|
|
||||||
async def _agenerate_immediate_error(input: AsyncIterator) -> AsyncIterator[str]:
|
async def _agenerate_immediate_error(_: AsyncIterator) -> AsyncIterator[str]:
|
||||||
msg = "immmediate error"
|
msg = "immmediate error"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
yield ""
|
yield ""
|
||||||
|
|
||||||
|
|
||||||
async def _agenerate_delayed_error(input: AsyncIterator) -> AsyncIterator[str]:
|
async def _agenerate_delayed_error(_: AsyncIterator) -> AsyncIterator[str]:
|
||||||
yield ""
|
yield ""
|
||||||
msg = "delayed error"
|
msg = "delayed error"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
@ -323,6 +324,7 @@ async def test_fallbacks_astream() -> None:
|
|||||||
class FakeStructuredOutputModel(BaseChatModel):
|
class FakeStructuredOutputModel(BaseChatModel):
|
||||||
foo: int
|
foo: int
|
||||||
|
|
||||||
|
@override
|
||||||
def _generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
messages: list[BaseMessage],
|
messages: list[BaseMessage],
|
||||||
@ -333,6 +335,7 @@ class FakeStructuredOutputModel(BaseChatModel):
|
|||||||
"""Top Level call."""
|
"""Top Level call."""
|
||||||
return ChatResult(generations=[])
|
return ChatResult(generations=[])
|
||||||
|
|
||||||
|
@override
|
||||||
def bind_tools(
|
def bind_tools(
|
||||||
self,
|
self,
|
||||||
tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
|
tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
|
||||||
@ -340,10 +343,11 @@ class FakeStructuredOutputModel(BaseChatModel):
|
|||||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||||
return self.bind(tools=tools)
|
return self.bind(tools=tools)
|
||||||
|
|
||||||
|
@override
|
||||||
def with_structured_output(
|
def with_structured_output(
|
||||||
self, schema: Union[dict, type[BaseModel]], **kwargs: Any
|
self, schema: Union[dict, type[BaseModel]], **kwargs: Any
|
||||||
) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]:
|
) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]:
|
||||||
return RunnableLambda(lambda x: {"foo": self.foo})
|
return RunnableLambda(lambda _: {"foo": self.foo})
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
@ -353,6 +357,7 @@ class FakeStructuredOutputModel(BaseChatModel):
|
|||||||
class FakeModel(BaseChatModel):
|
class FakeModel(BaseChatModel):
|
||||||
bar: int
|
bar: int
|
||||||
|
|
||||||
|
@override
|
||||||
def _generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
messages: list[BaseMessage],
|
messages: list[BaseMessage],
|
||||||
@ -363,6 +368,7 @@ class FakeModel(BaseChatModel):
|
|||||||
"""Top Level call."""
|
"""Top Level call."""
|
||||||
return ChatResult(generations=[])
|
return ChatResult(generations=[])
|
||||||
|
|
||||||
|
@override
|
||||||
def bind_tools(
|
def bind_tools(
|
||||||
self,
|
self,
|
||||||
tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
|
tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
|
||||||
|
@ -5,6 +5,7 @@ from typing import Any, Callable, Optional, Union
|
|||||||
import pytest
|
import pytest
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
CallbackManagerForLLMRun,
|
CallbackManagerForLLMRun,
|
||||||
@ -39,7 +40,7 @@ def _get_get_session_history(
|
|||||||
chat_history_store = store if store is not None else {}
|
chat_history_store = store if store is not None else {}
|
||||||
|
|
||||||
def get_session_history(
|
def get_session_history(
|
||||||
session_id: str, **kwargs: Any
|
session_id: str, **_kwargs: Any
|
||||||
) -> InMemoryChatMessageHistory:
|
) -> InMemoryChatMessageHistory:
|
||||||
if session_id not in chat_history_store:
|
if session_id not in chat_history_store:
|
||||||
chat_history_store[session_id] = InMemoryChatMessageHistory()
|
chat_history_store[session_id] = InMemoryChatMessageHistory()
|
||||||
@ -253,6 +254,7 @@ async def test_output_message_async() -> None:
|
|||||||
class LengthChatModel(BaseChatModel):
|
class LengthChatModel(BaseChatModel):
|
||||||
"""A fake chat model that returns the length of the messages passed in."""
|
"""A fake chat model that returns the length of the messages passed in."""
|
||||||
|
|
||||||
|
@override
|
||||||
def _generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
messages: list[BaseMessage],
|
messages: list[BaseMessage],
|
||||||
@ -856,7 +858,7 @@ def test_get_output_messages_no_value_error() -> None:
|
|||||||
|
|
||||||
def test_get_output_messages_with_value_error() -> None:
|
def test_get_output_messages_with_value_error() -> None:
|
||||||
illegal_bool_message = False
|
illegal_bool_message = False
|
||||||
runnable = _RunnableLambdaWithRaiseError(lambda messages: illegal_bool_message)
|
runnable = _RunnableLambdaWithRaiseError(lambda _: illegal_bool_message)
|
||||||
store: dict = {}
|
store: dict = {}
|
||||||
get_session_history = _get_get_session_history(store=store)
|
get_session_history = _get_get_session_history(store=store)
|
||||||
with_history = RunnableWithMessageHistory(runnable, get_session_history)
|
with_history = RunnableWithMessageHistory(runnable, get_session_history)
|
||||||
@ -874,7 +876,7 @@ def test_get_output_messages_with_value_error() -> None:
|
|||||||
with_history.bound.invoke([HumanMessage(content="hello")], config)
|
with_history.bound.invoke([HumanMessage(content="hello")], config)
|
||||||
|
|
||||||
illegal_int_message = 123
|
illegal_int_message = 123
|
||||||
runnable = _RunnableLambdaWithRaiseError(lambda messages: illegal_int_message)
|
runnable = _RunnableLambdaWithRaiseError(lambda _: illegal_int_message)
|
||||||
with_history = RunnableWithMessageHistory(runnable, get_session_history)
|
with_history = RunnableWithMessageHistory(runnable, get_session_history)
|
||||||
|
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
|
@ -21,10 +21,11 @@ from packaging import version
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from pytest_mock import MockerFixture
|
from pytest_mock import MockerFixture
|
||||||
from syrupy import SnapshotAssertion
|
from syrupy import SnapshotAssertion
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict, override
|
||||||
|
|
||||||
from langchain_core.callbacks.manager import (
|
from langchain_core.callbacks.manager import (
|
||||||
Callbacks,
|
AsyncCallbackManagerForRetrieverRun,
|
||||||
|
CallbackManagerForRetrieverRun,
|
||||||
atrace_as_chain_group,
|
atrace_as_chain_group,
|
||||||
trace_as_chain_group,
|
trace_as_chain_group,
|
||||||
)
|
)
|
||||||
@ -184,6 +185,7 @@ class FakeTracer(BaseTracer):
|
|||||||
|
|
||||||
|
|
||||||
class FakeRunnable(Runnable[str, int]):
|
class FakeRunnable(Runnable[str, int]):
|
||||||
|
@override
|
||||||
def invoke(
|
def invoke(
|
||||||
self,
|
self,
|
||||||
input: str,
|
input: str,
|
||||||
@ -196,6 +198,7 @@ class FakeRunnable(Runnable[str, int]):
|
|||||||
class FakeRunnableSerializable(RunnableSerializable[str, int]):
|
class FakeRunnableSerializable(RunnableSerializable[str, int]):
|
||||||
hello: str = ""
|
hello: str = ""
|
||||||
|
|
||||||
|
@override
|
||||||
def invoke(
|
def invoke(
|
||||||
self,
|
self,
|
||||||
input: str,
|
input: str,
|
||||||
@ -206,25 +209,15 @@ class FakeRunnableSerializable(RunnableSerializable[str, int]):
|
|||||||
|
|
||||||
|
|
||||||
class FakeRetriever(BaseRetriever):
|
class FakeRetriever(BaseRetriever):
|
||||||
|
@override
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self,
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
callbacks: Callbacks = None,
|
|
||||||
tags: Optional[list[str]] = None,
|
|
||||||
metadata: Optional[dict[str, Any]] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> list[Document]:
|
) -> list[Document]:
|
||||||
return [Document(page_content="foo"), Document(page_content="bar")]
|
return [Document(page_content="foo"), Document(page_content="bar")]
|
||||||
|
|
||||||
|
@override
|
||||||
async def _aget_relevant_documents(
|
async def _aget_relevant_documents(
|
||||||
self,
|
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
callbacks: Callbacks = None,
|
|
||||||
tags: Optional[list[str]] = None,
|
|
||||||
metadata: Optional[dict[str, Any]] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> list[Document]:
|
) -> list[Document]:
|
||||||
return [Document(page_content="foo"), Document(page_content="bar")]
|
return [Document(page_content="foo"), Document(page_content="bar")]
|
||||||
|
|
||||||
@ -506,7 +499,7 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
|
|||||||
|
|
||||||
foo_ = RunnableLambda(foo)
|
foo_ = RunnableLambda(foo)
|
||||||
|
|
||||||
assert foo_.assign(bar=lambda x: "foo").get_output_schema().model_json_schema() == {
|
assert foo_.assign(bar=lambda _: "foo").get_output_schema().model_json_schema() == {
|
||||||
"properties": {"bar": {"title": "Bar"}, "root": {"title": "Root"}},
|
"properties": {"bar": {"title": "Bar"}, "root": {"title": "Root"}},
|
||||||
"required": ["root", "bar"],
|
"required": ["root", "bar"],
|
||||||
"title": "RunnableAssignOutput",
|
"title": "RunnableAssignOutput",
|
||||||
@ -1782,10 +1775,10 @@ def test_with_listener_propagation(mocker: MockerFixture) -> None:
|
|||||||
|
|
||||||
|
|
||||||
@freeze_time("2023-01-01")
|
@freeze_time("2023-01-01")
|
||||||
|
@pytest.mark.usefixtures("deterministic_uuids")
|
||||||
def test_prompt_with_chat_model(
|
def test_prompt_with_chat_model(
|
||||||
mocker: MockerFixture,
|
mocker: MockerFixture,
|
||||||
snapshot: SnapshotAssertion,
|
snapshot: SnapshotAssertion,
|
||||||
deterministic_uuids: MockerFixture,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
prompt = (
|
prompt = (
|
||||||
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||||
@ -1888,10 +1881,10 @@ def test_prompt_with_chat_model(
|
|||||||
|
|
||||||
|
|
||||||
@freeze_time("2023-01-01")
|
@freeze_time("2023-01-01")
|
||||||
|
@pytest.mark.usefixtures("deterministic_uuids")
|
||||||
async def test_prompt_with_chat_model_async(
|
async def test_prompt_with_chat_model_async(
|
||||||
mocker: MockerFixture,
|
mocker: MockerFixture,
|
||||||
snapshot: SnapshotAssertion,
|
snapshot: SnapshotAssertion,
|
||||||
deterministic_uuids: MockerFixture,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
prompt = (
|
prompt = (
|
||||||
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||||
@ -2519,7 +2512,7 @@ async def test_stream_log_retriever() -> None:
|
|||||||
|
|
||||||
@freeze_time("2023-01-01")
|
@freeze_time("2023-01-01")
|
||||||
async def test_stream_log_lists() -> None:
|
async def test_stream_log_lists() -> None:
|
||||||
async def list_producer(input: AsyncIterator[Any]) -> AsyncIterator[AddableDict]:
|
async def list_producer(_: AsyncIterator[Any]) -> AsyncIterator[AddableDict]:
|
||||||
for i in range(4):
|
for i in range(4):
|
||||||
yield AddableDict(alist=[str(i)])
|
yield AddableDict(alist=[str(i)])
|
||||||
|
|
||||||
@ -2631,10 +2624,10 @@ async def test_prompt_with_llm_and_async_lambda(
|
|||||||
|
|
||||||
|
|
||||||
@freeze_time("2023-01-01")
|
@freeze_time("2023-01-01")
|
||||||
|
@pytest.mark.usefixtures("deterministic_uuids")
|
||||||
def test_prompt_with_chat_model_and_parser(
|
def test_prompt_with_chat_model_and_parser(
|
||||||
mocker: MockerFixture,
|
mocker: MockerFixture,
|
||||||
snapshot: SnapshotAssertion,
|
snapshot: SnapshotAssertion,
|
||||||
deterministic_uuids: MockerFixture,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
prompt = (
|
prompt = (
|
||||||
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||||
@ -2672,10 +2665,9 @@ def test_prompt_with_chat_model_and_parser(
|
|||||||
|
|
||||||
|
|
||||||
@freeze_time("2023-01-01")
|
@freeze_time("2023-01-01")
|
||||||
|
@pytest.mark.usefixtures("deterministic_uuids")
|
||||||
def test_combining_sequences(
|
def test_combining_sequences(
|
||||||
mocker: MockerFixture,
|
|
||||||
snapshot: SnapshotAssertion,
|
snapshot: SnapshotAssertion,
|
||||||
deterministic_uuids: MockerFixture,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
prompt = (
|
prompt = (
|
||||||
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||||
@ -3513,7 +3505,7 @@ def test_bind_bind() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_bind_with_lambda() -> None:
|
def test_bind_with_lambda() -> None:
|
||||||
def my_function(*args: Any, **kwargs: Any) -> int:
|
def my_function(_: Any, **kwargs: Any) -> int:
|
||||||
return 3 + kwargs.get("n", 0)
|
return 3 + kwargs.get("n", 0)
|
||||||
|
|
||||||
runnable = RunnableLambda(my_function).bind(n=1)
|
runnable = RunnableLambda(my_function).bind(n=1)
|
||||||
@ -3523,7 +3515,7 @@ def test_bind_with_lambda() -> None:
|
|||||||
|
|
||||||
|
|
||||||
async def test_bind_with_lambda_async() -> None:
|
async def test_bind_with_lambda_async() -> None:
|
||||||
def my_function(*args: Any, **kwargs: Any) -> int:
|
def my_function(_: Any, **kwargs: Any) -> int:
|
||||||
return 3 + kwargs.get("n", 0)
|
return 3 + kwargs.get("n", 0)
|
||||||
|
|
||||||
runnable = RunnableLambda(my_function).bind(n=1)
|
runnable = RunnableLambda(my_function).bind(n=1)
|
||||||
@ -3858,7 +3850,7 @@ def test_each(snapshot: SnapshotAssertion) -> None:
|
|||||||
def test_recursive_lambda() -> None:
|
def test_recursive_lambda() -> None:
|
||||||
def _simple_recursion(x: int) -> Union[int, Runnable]:
|
def _simple_recursion(x: int) -> Union[int, Runnable]:
|
||||||
if x < 10:
|
if x < 10:
|
||||||
return RunnableLambda(lambda *args: _simple_recursion(x + 1))
|
return RunnableLambda(lambda *_: _simple_recursion(x + 1))
|
||||||
return x
|
return x
|
||||||
|
|
||||||
runnable = RunnableLambda(_simple_recursion)
|
runnable = RunnableLambda(_simple_recursion)
|
||||||
@ -4008,7 +4000,7 @@ def test_runnable_lambda_stream() -> None:
|
|||||||
# sleep to better simulate a real stream
|
# sleep to better simulate a real stream
|
||||||
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
|
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
|
||||||
|
|
||||||
output = list(RunnableLambda(lambda x: llm).stream(""))
|
output = list(RunnableLambda(lambda _: llm).stream(""))
|
||||||
assert output == list(llm_res)
|
assert output == list(llm_res)
|
||||||
|
|
||||||
|
|
||||||
@ -4021,7 +4013,7 @@ def test_runnable_lambda_stream_with_callbacks() -> None:
|
|||||||
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
|
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
|
||||||
config: RunnableConfig = {"callbacks": [tracer]}
|
config: RunnableConfig = {"callbacks": [tracer]}
|
||||||
|
|
||||||
assert list(RunnableLambda(lambda x: llm).stream("", config=config)) == list(
|
assert list(RunnableLambda(lambda _: llm).stream("", config=config)) == list(
|
||||||
llm_res
|
llm_res
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -4029,7 +4021,7 @@ def test_runnable_lambda_stream_with_callbacks() -> None:
|
|||||||
assert tracer.runs[0].error is None
|
assert tracer.runs[0].error is None
|
||||||
assert tracer.runs[0].outputs == {"output": llm_res}
|
assert tracer.runs[0].outputs == {"output": llm_res}
|
||||||
|
|
||||||
def raise_value_error(x: int) -> int:
|
def raise_value_error(_: int) -> int:
|
||||||
"""Raise a value error."""
|
"""Raise a value error."""
|
||||||
msg = "x is too large"
|
msg = "x is too large"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
@ -4076,7 +4068,7 @@ async def test_runnable_lambda_astream() -> None:
|
|||||||
_
|
_
|
||||||
async for _ in RunnableLambda(
|
async for _ in RunnableLambda(
|
||||||
func=id,
|
func=id,
|
||||||
afunc=awrapper(lambda x: llm),
|
afunc=awrapper(lambda _: llm),
|
||||||
).astream("")
|
).astream("")
|
||||||
]
|
]
|
||||||
assert output == list(llm_res)
|
assert output == list(llm_res)
|
||||||
@ -4084,7 +4076,7 @@ async def test_runnable_lambda_astream() -> None:
|
|||||||
output = [
|
output = [
|
||||||
chunk
|
chunk
|
||||||
async for chunk in cast(
|
async for chunk in cast(
|
||||||
"AsyncIterator[str]", RunnableLambda(lambda x: llm).astream("")
|
"AsyncIterator[str]", RunnableLambda(lambda _: llm).astream("")
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
assert output == list(llm_res)
|
assert output == list(llm_res)
|
||||||
@ -4100,14 +4092,14 @@ async def test_runnable_lambda_astream_with_callbacks() -> None:
|
|||||||
config: RunnableConfig = {"callbacks": [tracer]}
|
config: RunnableConfig = {"callbacks": [tracer]}
|
||||||
|
|
||||||
assert [
|
assert [
|
||||||
_ async for _ in RunnableLambda(lambda x: llm).astream("", config=config)
|
_ async for _ in RunnableLambda(lambda _: llm).astream("", config=config)
|
||||||
] == list(llm_res)
|
] == list(llm_res)
|
||||||
|
|
||||||
assert len(tracer.runs) == 1
|
assert len(tracer.runs) == 1
|
||||||
assert tracer.runs[0].error is None
|
assert tracer.runs[0].error is None
|
||||||
assert tracer.runs[0].outputs == {"output": llm_res}
|
assert tracer.runs[0].outputs == {"output": llm_res}
|
||||||
|
|
||||||
def raise_value_error(x: int) -> int:
|
def raise_value_error(_: int) -> int:
|
||||||
"""Raise a value error."""
|
"""Raise a value error."""
|
||||||
msg = "x is too large"
|
msg = "x is too large"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
@ -4487,7 +4479,7 @@ def test_runnable_branch_invoke_call_counts(mocker: MockerFixture) -> None:
|
|||||||
|
|
||||||
def test_runnable_branch_invoke() -> None:
|
def test_runnable_branch_invoke() -> None:
|
||||||
# Test with single branch
|
# Test with single branch
|
||||||
def raise_value_error(x: int) -> int:
|
def raise_value_error(_: int) -> int:
|
||||||
"""Raise a value error."""
|
"""Raise a value error."""
|
||||||
msg = "x is too large"
|
msg = "x is too large"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
@ -4552,7 +4544,7 @@ def test_runnable_branch_invoke_callbacks() -> None:
|
|||||||
"""Verify that callbacks are correctly used in invoke."""
|
"""Verify that callbacks are correctly used in invoke."""
|
||||||
tracer = FakeTracer()
|
tracer = FakeTracer()
|
||||||
|
|
||||||
def raise_value_error(x: int) -> int:
|
def raise_value_error(_: int) -> int:
|
||||||
"""Raise a value error."""
|
"""Raise a value error."""
|
||||||
msg = "x is too large"
|
msg = "x is too large"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
@ -4580,7 +4572,7 @@ async def test_runnable_branch_ainvoke_callbacks() -> None:
|
|||||||
"""Verify that callbacks are invoked correctly in ainvoke."""
|
"""Verify that callbacks are invoked correctly in ainvoke."""
|
||||||
tracer = FakeTracer()
|
tracer = FakeTracer()
|
||||||
|
|
||||||
async def raise_value_error(x: int) -> int:
|
async def raise_value_error(_: int) -> int:
|
||||||
"""Raise a value error."""
|
"""Raise a value error."""
|
||||||
msg = "x is too large"
|
msg = "x is too large"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
@ -4755,13 +4747,13 @@ def test_representation_of_runnables() -> None:
|
|||||||
runnable = RunnableLambda(lambda x: x * 2)
|
runnable = RunnableLambda(lambda x: x * 2)
|
||||||
assert repr(runnable) == "RunnableLambda(lambda x: x * 2)"
|
assert repr(runnable) == "RunnableLambda(lambda x: x * 2)"
|
||||||
|
|
||||||
def f(x: int) -> int:
|
def f(_: int) -> int:
|
||||||
"""Return 2."""
|
"""Return 2."""
|
||||||
return 2
|
return 2
|
||||||
|
|
||||||
assert repr(RunnableLambda(func=f)) == "RunnableLambda(f)"
|
assert repr(RunnableLambda(func=f)) == "RunnableLambda(f)"
|
||||||
|
|
||||||
async def af(x: int) -> int:
|
async def af(_: int) -> int:
|
||||||
"""Return 2."""
|
"""Return 2."""
|
||||||
return 2
|
return 2
|
||||||
|
|
||||||
@ -4814,7 +4806,7 @@ async def test_tool_from_runnable() -> None:
|
|||||||
def test_runnable_gen() -> None:
|
def test_runnable_gen() -> None:
|
||||||
"""Test that a generator can be used as a runnable."""
|
"""Test that a generator can be used as a runnable."""
|
||||||
|
|
||||||
def gen(input: Iterator[Any]) -> Iterator[int]:
|
def gen(_: Iterator[Any]) -> Iterator[int]:
|
||||||
yield 1
|
yield 1
|
||||||
yield 2
|
yield 2
|
||||||
yield 3
|
yield 3
|
||||||
@ -4835,7 +4827,7 @@ def test_runnable_gen() -> None:
|
|||||||
async def test_runnable_gen_async() -> None:
|
async def test_runnable_gen_async() -> None:
|
||||||
"""Test that a generator can be used as a runnable."""
|
"""Test that a generator can be used as a runnable."""
|
||||||
|
|
||||||
async def agen(input: AsyncIterator[Any]) -> AsyncIterator[int]:
|
async def agen(_: AsyncIterator[Any]) -> AsyncIterator[int]:
|
||||||
yield 1
|
yield 1
|
||||||
yield 2
|
yield 2
|
||||||
yield 3
|
yield 3
|
||||||
@ -4847,7 +4839,7 @@ async def test_runnable_gen_async() -> None:
|
|||||||
assert await arunnable.abatch([None, None]) == [6, 6]
|
assert await arunnable.abatch([None, None]) == [6, 6]
|
||||||
|
|
||||||
class AsyncGen:
|
class AsyncGen:
|
||||||
async def __call__(self, input: AsyncIterator[Any]) -> AsyncIterator[int]:
|
async def __call__(self, _: AsyncIterator[Any]) -> AsyncIterator[int]:
|
||||||
yield 1
|
yield 1
|
||||||
yield 2
|
yield 2
|
||||||
yield 3
|
yield 3
|
||||||
@ -4870,7 +4862,7 @@ def test_runnable_gen_context_config() -> None:
|
|||||||
"""
|
"""
|
||||||
fake = RunnableLambda(len)
|
fake = RunnableLambda(len)
|
||||||
|
|
||||||
def gen(input: Iterator[Any]) -> Iterator[int]:
|
def gen(_: Iterator[Any]) -> Iterator[int]:
|
||||||
yield fake.invoke("a")
|
yield fake.invoke("a")
|
||||||
yield fake.invoke("aa")
|
yield fake.invoke("aa")
|
||||||
yield fake.invoke("aaa")
|
yield fake.invoke("aaa")
|
||||||
@ -4944,7 +4936,7 @@ async def test_runnable_gen_context_config_async() -> None:
|
|||||||
|
|
||||||
fake = RunnableLambda(len)
|
fake = RunnableLambda(len)
|
||||||
|
|
||||||
async def agen(input: AsyncIterator[Any]) -> AsyncIterator[int]:
|
async def agen(_: AsyncIterator[Any]) -> AsyncIterator[int]:
|
||||||
yield await fake.ainvoke("a")
|
yield await fake.ainvoke("a")
|
||||||
yield await fake.ainvoke("aa")
|
yield await fake.ainvoke("aa")
|
||||||
yield await fake.ainvoke("aaa")
|
yield await fake.ainvoke("aaa")
|
||||||
@ -5441,6 +5433,7 @@ def test_default_transform_with_dicts() -> None:
|
|||||||
"""Test that default transform works with dicts."""
|
"""Test that default transform works with dicts."""
|
||||||
|
|
||||||
class CustomRunnable(RunnableSerializable[Input, Output]):
|
class CustomRunnable(RunnableSerializable[Input, Output]):
|
||||||
|
@override
|
||||||
def invoke(
|
def invoke(
|
||||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||||
) -> Output:
|
) -> Output:
|
||||||
@ -5462,6 +5455,7 @@ async def test_default_atransform_with_dicts() -> None:
|
|||||||
"""Test that default transform works with dicts."""
|
"""Test that default transform works with dicts."""
|
||||||
|
|
||||||
class CustomRunnable(RunnableSerializable[Input, Output]):
|
class CustomRunnable(RunnableSerializable[Input, Output]):
|
||||||
|
@override
|
||||||
def invoke(
|
def invoke(
|
||||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||||
) -> Output:
|
) -> Output:
|
||||||
@ -5581,6 +5575,7 @@ def test_closing_iterator_doesnt_raise_error() -> None:
|
|||||||
on_chain_end_triggered = False
|
on_chain_end_triggered = False
|
||||||
|
|
||||||
class MyHandler(BaseCallbackHandler):
|
class MyHandler(BaseCallbackHandler):
|
||||||
|
@override
|
||||||
def on_chain_error(
|
def on_chain_error(
|
||||||
self,
|
self,
|
||||||
error: BaseException,
|
error: BaseException,
|
||||||
@ -5594,6 +5589,7 @@ def test_closing_iterator_doesnt_raise_error() -> None:
|
|||||||
nonlocal on_chain_error_triggered
|
nonlocal on_chain_error_triggered
|
||||||
on_chain_error_triggered = True
|
on_chain_error_triggered = True
|
||||||
|
|
||||||
|
@override
|
||||||
def on_chain_end(
|
def on_chain_end(
|
||||||
self,
|
self,
|
||||||
outputs: dict[str, Any],
|
outputs: dict[str, Any],
|
||||||
|
@ -8,6 +8,7 @@ from typing import Any, cast
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun, Callbacks
|
from langchain_core.callbacks import CallbackManagerForRetrieverRun, Callbacks
|
||||||
from langchain_core.chat_history import BaseChatMessageHistory
|
from langchain_core.chat_history import BaseChatMessageHistory
|
||||||
@ -82,12 +83,12 @@ def _assert_events_equal_allow_superset_metadata(events: list, expected: list) -
|
|||||||
async def test_event_stream_with_simple_function_tool() -> None:
|
async def test_event_stream_with_simple_function_tool() -> None:
|
||||||
"""Test the event stream with a function and tool."""
|
"""Test the event stream with a function and tool."""
|
||||||
|
|
||||||
def foo(x: int) -> dict:
|
def foo(_: int) -> dict:
|
||||||
"""Foo."""
|
"""Foo."""
|
||||||
return {"x": 5}
|
return {"x": 5}
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def get_docs(x: int) -> list[Document]:
|
def get_docs(x: int) -> list[Document]: # noqa: ARG001
|
||||||
"""Hello Doc."""
|
"""Hello Doc."""
|
||||||
return [Document(page_content="hello")]
|
return [Document(page_content="hello")]
|
||||||
|
|
||||||
@ -434,7 +435,7 @@ async def test_event_stream_with_triple_lambda_test_filtering() -> None:
|
|||||||
|
|
||||||
|
|
||||||
async def test_event_stream_with_lambdas_from_lambda() -> None:
|
async def test_event_stream_with_lambdas_from_lambda() -> None:
|
||||||
as_lambdas = RunnableLambda(lambda x: {"answer": "goodbye"}).with_config(
|
as_lambdas = RunnableLambda(lambda _: {"answer": "goodbye"}).with_config(
|
||||||
{"run_name": "my_lambda"}
|
{"run_name": "my_lambda"}
|
||||||
)
|
)
|
||||||
events = await _collect_events(
|
events = await _collect_events(
|
||||||
@ -1021,7 +1022,7 @@ async def test_event_streaming_with_tools() -> None:
|
|||||||
return "hello"
|
return "hello"
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def with_callbacks(callbacks: Callbacks) -> str:
|
def with_callbacks(callbacks: Callbacks) -> str: # noqa: ARG001
|
||||||
"""A tool that does nothing."""
|
"""A tool that does nothing."""
|
||||||
return "world"
|
return "world"
|
||||||
|
|
||||||
@ -1031,7 +1032,7 @@ async def test_event_streaming_with_tools() -> None:
|
|||||||
return {"x": x, "y": y}
|
return {"x": x, "y": y}
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def with_parameters_and_callbacks(x: int, y: str, callbacks: Callbacks) -> dict:
|
def with_parameters_and_callbacks(x: int, y: str, callbacks: Callbacks) -> dict: # noqa: ARG001
|
||||||
"""A tool that does nothing."""
|
"""A tool that does nothing."""
|
||||||
return {"x": x, "y": y}
|
return {"x": x, "y": y}
|
||||||
|
|
||||||
@ -1180,6 +1181,7 @@ async def test_event_streaming_with_tools() -> None:
|
|||||||
class HardCodedRetriever(BaseRetriever):
|
class HardCodedRetriever(BaseRetriever):
|
||||||
documents: list[Document]
|
documents: list[Document]
|
||||||
|
|
||||||
|
@override
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||||
) -> list[Document]:
|
) -> list[Document]:
|
||||||
@ -1592,10 +1594,10 @@ async def test_chain_ordering() -> None:
|
|||||||
async def test_event_stream_with_retry() -> None:
|
async def test_event_stream_with_retry() -> None:
|
||||||
"""Test the event stream with a tool."""
|
"""Test the event stream with a tool."""
|
||||||
|
|
||||||
def success(inputs: str) -> str:
|
def success(_: str) -> str:
|
||||||
return "success"
|
return "success"
|
||||||
|
|
||||||
def fail(inputs: str) -> None:
|
def fail(_: str) -> None:
|
||||||
"""Simple func."""
|
"""Simple func."""
|
||||||
msg = "fail"
|
msg = "fail"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
@ -17,6 +17,7 @@ from typing import (
|
|||||||
import pytest
|
import pytest
|
||||||
from blockbuster import BlockBuster
|
from blockbuster import BlockBuster
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun, Callbacks
|
from langchain_core.callbacks import CallbackManagerForRetrieverRun, Callbacks
|
||||||
from langchain_core.chat_history import BaseChatMessageHistory
|
from langchain_core.chat_history import BaseChatMessageHistory
|
||||||
@ -97,12 +98,12 @@ async def _collect_events(
|
|||||||
async def test_event_stream_with_simple_function_tool() -> None:
|
async def test_event_stream_with_simple_function_tool() -> None:
|
||||||
"""Test the event stream with a function and tool."""
|
"""Test the event stream with a function and tool."""
|
||||||
|
|
||||||
def foo(x: int) -> dict:
|
def foo(x: int) -> dict: # noqa: ARG001
|
||||||
"""Foo."""
|
"""Foo."""
|
||||||
return {"x": 5}
|
return {"x": 5}
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def get_docs(x: int) -> list[Document]:
|
def get_docs(x: int) -> list[Document]: # noqa: ARG001
|
||||||
"""Hello Doc."""
|
"""Hello Doc."""
|
||||||
return [Document(page_content="hello")]
|
return [Document(page_content="hello")]
|
||||||
|
|
||||||
@ -465,7 +466,7 @@ async def test_event_stream_with_triple_lambda_test_filtering() -> None:
|
|||||||
|
|
||||||
|
|
||||||
async def test_event_stream_with_lambdas_from_lambda() -> None:
|
async def test_event_stream_with_lambdas_from_lambda() -> None:
|
||||||
as_lambdas = RunnableLambda(lambda x: {"answer": "goodbye"}).with_config(
|
as_lambdas = RunnableLambda(lambda _: {"answer": "goodbye"}).with_config(
|
||||||
{"run_name": "my_lambda"}
|
{"run_name": "my_lambda"}
|
||||||
)
|
)
|
||||||
events = await _collect_events(
|
events = await _collect_events(
|
||||||
@ -1043,7 +1044,7 @@ async def test_event_streaming_with_tools() -> None:
|
|||||||
return "hello"
|
return "hello"
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def with_callbacks(callbacks: Callbacks) -> str:
|
def with_callbacks(callbacks: Callbacks) -> str: # noqa: ARG001
|
||||||
"""A tool that does nothing."""
|
"""A tool that does nothing."""
|
||||||
return "world"
|
return "world"
|
||||||
|
|
||||||
@ -1053,7 +1054,7 @@ async def test_event_streaming_with_tools() -> None:
|
|||||||
return {"x": x, "y": y}
|
return {"x": x, "y": y}
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def with_parameters_and_callbacks(x: int, y: str, callbacks: Callbacks) -> dict:
|
def with_parameters_and_callbacks(x: int, y: str, callbacks: Callbacks) -> dict: # noqa: ARG001
|
||||||
"""A tool that does nothing."""
|
"""A tool that does nothing."""
|
||||||
return {"x": x, "y": y}
|
return {"x": x, "y": y}
|
||||||
|
|
||||||
@ -1165,6 +1166,7 @@ async def test_event_streaming_with_tools() -> None:
|
|||||||
class HardCodedRetriever(BaseRetriever):
|
class HardCodedRetriever(BaseRetriever):
|
||||||
documents: list[Document]
|
documents: list[Document]
|
||||||
|
|
||||||
|
@override
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||||
) -> list[Document]:
|
) -> list[Document]:
|
||||||
@ -1553,10 +1555,10 @@ async def test_chain_ordering() -> None:
|
|||||||
async def test_event_stream_with_retry() -> None:
|
async def test_event_stream_with_retry() -> None:
|
||||||
"""Test the event stream with a tool."""
|
"""Test the event stream with a tool."""
|
||||||
|
|
||||||
def success(inputs: str) -> str:
|
def success(_: str) -> str:
|
||||||
return "success"
|
return "success"
|
||||||
|
|
||||||
def fail(inputs: str) -> None:
|
def fail(_: str) -> None:
|
||||||
"""Simple func."""
|
"""Simple func."""
|
||||||
msg = "fail"
|
msg = "fail"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
@ -2069,6 +2071,7 @@ class StreamingRunnable(Runnable[Input, Output]):
|
|||||||
"""Initialize the runnable."""
|
"""Initialize the runnable."""
|
||||||
self.iterable = iterable
|
self.iterable = iterable
|
||||||
|
|
||||||
|
@override
|
||||||
def invoke(
|
def invoke(
|
||||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||||
) -> Output:
|
) -> Output:
|
||||||
@ -2084,6 +2087,7 @@ class StreamingRunnable(Runnable[Input, Output]):
|
|||||||
) -> Iterator[Output]:
|
) -> Iterator[Output]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@override
|
||||||
async def astream(
|
async def astream(
|
||||||
self,
|
self,
|
||||||
input: Input,
|
input: Input,
|
||||||
@ -2323,7 +2327,7 @@ async def test_bad_parent_ids() -> None:
|
|||||||
async def test_runnable_generator() -> None:
|
async def test_runnable_generator() -> None:
|
||||||
"""Test async events from sync lambda."""
|
"""Test async events from sync lambda."""
|
||||||
|
|
||||||
async def generator(inputs: AsyncIterator[str]) -> AsyncIterator[str]:
|
async def generator(_: AsyncIterator[str]) -> AsyncIterator[str]:
|
||||||
yield "1"
|
yield "1"
|
||||||
yield "2"
|
yield "2"
|
||||||
|
|
||||||
|
@ -375,7 +375,7 @@ class TestRunnableSequenceParallelTraceNesting:
|
|||||||
def test_sync(
|
def test_sync(
|
||||||
self, method: Callable[[RunnableLambda, list[BaseCallbackHandler]], int]
|
self, method: Callable[[RunnableLambda, list[BaseCallbackHandler]], int]
|
||||||
) -> None:
|
) -> None:
|
||||||
def other_thing(a: int) -> Generator[int, None, None]: # type: ignore
|
def other_thing(_: int) -> Generator[int, None, None]: # type: ignore
|
||||||
yield 1
|
yield 1
|
||||||
|
|
||||||
parent = self._create_parent(other_thing)
|
parent = self._create_parent(other_thing)
|
||||||
@ -407,7 +407,7 @@ class TestRunnableSequenceParallelTraceNesting:
|
|||||||
[RunnableLambda, list[BaseCallbackHandler]], Coroutine[Any, Any, int]
|
[RunnableLambda, list[BaseCallbackHandler]], Coroutine[Any, Any, int]
|
||||||
],
|
],
|
||||||
) -> None:
|
) -> None:
|
||||||
async def other_thing(a: int) -> AsyncGenerator[int, None]:
|
async def other_thing(_: int) -> AsyncGenerator[int, None]:
|
||||||
yield 1
|
yield 1
|
||||||
|
|
||||||
parent = self._create_parent(other_thing)
|
parent = self._create_parent(other_thing)
|
||||||
|
@ -2301,7 +2301,7 @@ def test_injected_arg_with_complex_type() -> None:
|
|||||||
self.value = "bar"
|
self.value = "bar"
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def injected_tool(x: int, foo: Annotated[Foo, InjectedToolArg]) -> str:
|
def injected_tool(x: int, foo: Annotated[Foo, InjectedToolArg]) -> str: # noqa: ARG001
|
||||||
"""Tool that has an injected tool arg."""
|
"""Tool that has an injected tool arg."""
|
||||||
return foo.value
|
return foo.value
|
||||||
|
|
||||||
@ -2477,7 +2477,7 @@ def test_simple_tool_args_schema_dict() -> None:
|
|||||||
|
|
||||||
def test_empty_string_tool_call_id() -> None:
|
def test_empty_string_tool_call_id() -> None:
|
||||||
@tool
|
@tool
|
||||||
def foo(x: int) -> str:
|
def foo(x: int) -> str: # noqa: ARG001
|
||||||
"""Foo."""
|
"""Foo."""
|
||||||
return "hi"
|
return "hi"
|
||||||
|
|
||||||
@ -2489,7 +2489,7 @@ def test_empty_string_tool_call_id() -> None:
|
|||||||
def test_tool_decorator_description() -> None:
|
def test_tool_decorator_description() -> None:
|
||||||
# test basic tool
|
# test basic tool
|
||||||
@tool
|
@tool
|
||||||
def foo(x: int) -> str:
|
def foo(x: int) -> str: # noqa: ARG001
|
||||||
"""Foo."""
|
"""Foo."""
|
||||||
return "hi"
|
return "hi"
|
||||||
|
|
||||||
@ -2501,7 +2501,7 @@ def test_tool_decorator_description() -> None:
|
|||||||
|
|
||||||
# test basic tool with description
|
# test basic tool with description
|
||||||
@tool(description="description")
|
@tool(description="description")
|
||||||
def foo_description(x: int) -> str:
|
def foo_description(x: int) -> str: # noqa: ARG001
|
||||||
"""Foo."""
|
"""Foo."""
|
||||||
return "hi"
|
return "hi"
|
||||||
|
|
||||||
@ -2520,7 +2520,7 @@ def test_tool_decorator_description() -> None:
|
|||||||
x: int
|
x: int
|
||||||
|
|
||||||
@tool(args_schema=ArgsSchema)
|
@tool(args_schema=ArgsSchema)
|
||||||
def foo_args_schema(x: int) -> str:
|
def foo_args_schema(x: int) -> str: # noqa: ARG001
|
||||||
return "hi"
|
return "hi"
|
||||||
|
|
||||||
assert foo_args_schema.description == "Bar."
|
assert foo_args_schema.description == "Bar."
|
||||||
@ -2532,7 +2532,7 @@ def test_tool_decorator_description() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@tool(description="description", args_schema=ArgsSchema)
|
@tool(description="description", args_schema=ArgsSchema)
|
||||||
def foo_args_schema_description(x: int) -> str:
|
def foo_args_schema_description(x: int) -> str: # noqa: ARG001
|
||||||
return "hi"
|
return "hi"
|
||||||
|
|
||||||
assert foo_args_schema_description.description == "description"
|
assert foo_args_schema_description.description == "description"
|
||||||
@ -2554,11 +2554,11 @@ def test_tool_decorator_description() -> None:
|
|||||||
}
|
}
|
||||||
|
|
||||||
@tool(args_schema=args_json_schema)
|
@tool(args_schema=args_json_schema)
|
||||||
def foo_args_jsons_schema(x: int) -> str:
|
def foo_args_jsons_schema(x: int) -> str: # noqa: ARG001
|
||||||
return "hi"
|
return "hi"
|
||||||
|
|
||||||
@tool(description="description", args_schema=args_json_schema)
|
@tool(description="description", args_schema=args_json_schema)
|
||||||
def foo_args_jsons_schema_with_description(x: int) -> str:
|
def foo_args_jsons_schema_with_description(x: int) -> str: # noqa: ARG001
|
||||||
return "hi"
|
return "hi"
|
||||||
|
|
||||||
assert foo_args_jsons_schema.description == "JSON Schema."
|
assert foo_args_jsons_schema.description == "JSON Schema."
|
||||||
@ -2620,10 +2620,10 @@ def test_title_property_preserved() -> None:
|
|||||||
async def test_tool_ainvoke_does_not_mutate_inputs() -> None:
|
async def test_tool_ainvoke_does_not_mutate_inputs() -> None:
|
||||||
"""Verify that the inputs are not mutated when invoking a tool asynchronously."""
|
"""Verify that the inputs are not mutated when invoking a tool asynchronously."""
|
||||||
|
|
||||||
def sync_no_op(foo: int) -> str:
|
def sync_no_op(foo: int) -> str: # noqa: ARG001
|
||||||
return "good"
|
return "good"
|
||||||
|
|
||||||
async def async_no_op(foo: int) -> str:
|
async def async_no_op(foo: int) -> str: # noqa: ARG001
|
||||||
return "good"
|
return "good"
|
||||||
|
|
||||||
tool = StructuredTool(
|
tool = StructuredTool(
|
||||||
@ -2668,10 +2668,10 @@ async def test_tool_ainvoke_does_not_mutate_inputs() -> None:
|
|||||||
def test_tool_invoke_does_not_mutate_inputs() -> None:
|
def test_tool_invoke_does_not_mutate_inputs() -> None:
|
||||||
"""Verify that the inputs are not mutated when invoking a tool synchronously."""
|
"""Verify that the inputs are not mutated when invoking a tool synchronously."""
|
||||||
|
|
||||||
def sync_no_op(foo: int) -> str:
|
def sync_no_op(foo: int) -> str: # noqa: ARG001
|
||||||
return "good"
|
return "good"
|
||||||
|
|
||||||
async def async_no_op(foo: int) -> str:
|
async def async_no_op(foo: int) -> str: # noqa: ARG001
|
||||||
return "good"
|
return "good"
|
||||||
|
|
||||||
tool = StructuredTool(
|
tool = StructuredTool(
|
||||||
|
@ -121,7 +121,7 @@ def dummy_structured_tool() -> StructuredTool:
|
|||||||
arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'")
|
arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'")
|
||||||
|
|
||||||
return StructuredTool.from_function(
|
return StructuredTool.from_function(
|
||||||
lambda x: None,
|
lambda _: None,
|
||||||
name="dummy_function",
|
name="dummy_function",
|
||||||
description="Dummy function.",
|
description="Dummy function.",
|
||||||
args_schema=Schema,
|
args_schema=Schema,
|
||||||
@ -143,7 +143,7 @@ def dummy_structured_tool_args_schema_dict() -> StructuredTool:
|
|||||||
"required": ["arg1", "arg2"],
|
"required": ["arg1", "arg2"],
|
||||||
}
|
}
|
||||||
return StructuredTool.from_function(
|
return StructuredTool.from_function(
|
||||||
lambda x: None,
|
lambda _: None,
|
||||||
name="dummy_function",
|
name="dummy_function",
|
||||||
description="Dummy function.",
|
description="Dummy function.",
|
||||||
args_schema=args_schema,
|
args_schema=args_schema,
|
||||||
|
@ -10,6 +10,7 @@ import uuid
|
|||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.embeddings import Embeddings, FakeEmbeddings
|
from langchain_core.embeddings import Embeddings, FakeEmbeddings
|
||||||
@ -25,6 +26,7 @@ class CustomAddTextsVectorstore(VectorStore):
|
|||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.store: dict[str, Document] = {}
|
self.store: dict[str, Document] = {}
|
||||||
|
|
||||||
|
@override
|
||||||
def add_texts(
|
def add_texts(
|
||||||
self,
|
self,
|
||||||
texts: Iterable[str],
|
texts: Iterable[str],
|
||||||
@ -51,6 +53,7 @@ class CustomAddTextsVectorstore(VectorStore):
|
|||||||
return [self.store[id] for id in ids if id in self.store]
|
return [self.store[id] for id in ids if id in self.store]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@override
|
||||||
def from_texts( # type: ignore
|
def from_texts( # type: ignore
|
||||||
cls,
|
cls,
|
||||||
texts: list[str],
|
texts: list[str],
|
||||||
@ -74,6 +77,7 @@ class CustomAddDocumentsVectorstore(VectorStore):
|
|||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.store: dict[str, Document] = {}
|
self.store: dict[str, Document] = {}
|
||||||
|
|
||||||
|
@override
|
||||||
def add_documents(
|
def add_documents(
|
||||||
self,
|
self,
|
||||||
documents: list[Document],
|
documents: list[Document],
|
||||||
@ -95,6 +99,7 @@ class CustomAddDocumentsVectorstore(VectorStore):
|
|||||||
return [self.store[id] for id in ids if id in self.store]
|
return [self.store[id] for id in ids if id in self.store]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@override
|
||||||
def from_texts( # type: ignore
|
def from_texts( # type: ignore
|
||||||
cls,
|
cls,
|
||||||
texts: list[str],
|
texts: list[str],
|
||||||
|
Loading…
Reference in New Issue
Block a user