mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-28 09:28:48 +00:00
core: Add various ruff rules (#26836)
Adds - ASYNC - COM - DJ - EXE - FLY - FURB - ICN - INT - LOG - NPY - PD - Q - RSE - SLOT - T10 - TID - YTT Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
5c826faece
commit
16f5fdb38b
@ -53,7 +53,7 @@ class FakeEmbeddings(Embeddings, BaseModel):
|
|||||||
def _get_embedding(self) -> list[float]:
|
def _get_embedding(self) -> list[float]:
|
||||||
import numpy as np # type: ignore[import-not-found, import-untyped]
|
import numpy as np # type: ignore[import-not-found, import-untyped]
|
||||||
|
|
||||||
return list(np.random.normal(size=self.size))
|
return list(np.random.default_rng().normal(size=self.size))
|
||||||
|
|
||||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||||
return [self._get_embedding() for _ in texts]
|
return [self._get_embedding() for _ in texts]
|
||||||
@ -109,8 +109,8 @@ class DeterministicFakeEmbedding(Embeddings, BaseModel):
|
|||||||
import numpy as np # type: ignore[import-not-found, import-untyped]
|
import numpy as np # type: ignore[import-not-found, import-untyped]
|
||||||
|
|
||||||
# set the seed for the random generator
|
# set the seed for the random generator
|
||||||
np.random.seed(seed)
|
rng = np.random.default_rng(seed)
|
||||||
return list(np.random.normal(size=self.size))
|
return list(rng.normal(size=self.size))
|
||||||
|
|
||||||
def _get_seed(self, text: str) -> int:
|
def _get_seed(self, text: str) -> int:
|
||||||
"""Get a seed for the random generator, using the hash of the text."""
|
"""Get a seed for the random generator, using the hash of the text."""
|
||||||
|
@ -237,7 +237,7 @@ class BaseLanguageModel(
|
|||||||
"""Not implemented on this class."""
|
"""Not implemented on this class."""
|
||||||
# Implement this on child class if there is a way of steering the model to
|
# Implement this on child class if there is a way of steering the model to
|
||||||
# generate responses that match a given schema.
|
# generate responses that match a given schema.
|
||||||
raise NotImplementedError()
|
raise NotImplementedError
|
||||||
|
|
||||||
@deprecated("0.1.7", alternative="invoke", removal="1.0")
|
@deprecated("0.1.7", alternative="invoke", removal="1.0")
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -977,7 +977,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Iterator[ChatGenerationChunk]:
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError
|
||||||
|
|
||||||
async def _astream(
|
async def _astream(
|
||||||
self,
|
self,
|
||||||
@ -1112,7 +1112,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
],
|
],
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError
|
||||||
|
|
||||||
def with_structured_output(
|
def with_structured_output(
|
||||||
self,
|
self,
|
||||||
|
@ -698,7 +698,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
Returns:
|
Returns:
|
||||||
An iterator of GenerationChunks.
|
An iterator of GenerationChunks.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError
|
||||||
|
|
||||||
async def _astream(
|
async def _astream(
|
||||||
self,
|
self,
|
||||||
|
@ -151,7 +151,7 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
|
|||||||
Returns:
|
Returns:
|
||||||
The parsed JSON object.
|
The parsed JSON object.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):
|
class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):
|
||||||
|
@ -207,7 +207,7 @@ class JsonOutputToolsParser(BaseCumulativeTransformOutputParser[Any]):
|
|||||||
Returns:
|
Returns:
|
||||||
The parsed tool calls.
|
The parsed tool calls.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class JsonOutputKeyToolsParser(JsonOutputToolsParser):
|
class JsonOutputKeyToolsParser(JsonOutputToolsParser):
|
||||||
|
@ -106,7 +106,7 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
|
|||||||
Returns:
|
Returns:
|
||||||
The diff between the previous and current parsed output.
|
The diff between the previous and current parsed output.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError
|
||||||
|
|
||||||
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]:
|
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]:
|
||||||
prev_parsed = None
|
prev_parsed = None
|
||||||
|
@ -1336,7 +1336,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
|||||||
Args:
|
Args:
|
||||||
file_path: path to file.
|
file_path: path to file.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError
|
||||||
|
|
||||||
def pretty_repr(self, html: bool = False) -> str:
|
def pretty_repr(self, html: bool = False) -> str:
|
||||||
"""Human-readable representation.
|
"""Human-readable representation.
|
||||||
|
@ -464,4 +464,4 @@ class FewShotChatMessagePromptTemplate(
|
|||||||
Returns:
|
Returns:
|
||||||
A pretty representation of the prompt template.
|
A pretty representation of the prompt template.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError
|
||||||
|
@ -132,4 +132,4 @@ class ImagePromptTemplate(BasePromptTemplate[ImageURL]):
|
|||||||
Returns:
|
Returns:
|
||||||
A pretty representation of the prompt.
|
A pretty representation of the prompt.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
import re
|
import re
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
@ -289,9 +290,14 @@ async def _render_mermaid_using_pyppeteer(
|
|||||||
img_bytes = await page.screenshot({"fullPage": False})
|
img_bytes = await page.screenshot({"fullPage": False})
|
||||||
await browser.close()
|
await browser.close()
|
||||||
|
|
||||||
|
def write_to_file(path: str, bytes: bytes) -> None:
|
||||||
|
with open(path, "wb") as file:
|
||||||
|
file.write(bytes)
|
||||||
|
|
||||||
if output_file_path is not None:
|
if output_file_path is not None:
|
||||||
with open(output_file_path, "wb") as file:
|
await asyncio.get_event_loop().run_in_executor(
|
||||||
file.write(img_bytes)
|
None, write_to_file, output_file_path, img_bytes
|
||||||
|
)
|
||||||
|
|
||||||
return img_bytes
|
return img_bytes
|
||||||
|
|
||||||
|
@ -453,7 +453,7 @@ def secret_from_env(
|
|||||||
return SecretStr(os.environ[key])
|
return SecretStr(os.environ[key])
|
||||||
if isinstance(default, str):
|
if isinstance(default, str):
|
||||||
return SecretStr(default)
|
return SecretStr(default)
|
||||||
elif isinstance(default, type(None)):
|
elif default is None:
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
if error_message:
|
if error_message:
|
||||||
|
@ -44,8 +44,42 @@ python = ">=3.12.4"
|
|||||||
[tool.poetry.extras]
|
[tool.poetry.extras]
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
select = [ "B", "C4", "E", "EM", "F", "I", "N", "PIE", "SIM", "T201", "UP", "W",]
|
select = [
|
||||||
ignore = [ "UP007", "W293",]
|
"ASYNC",
|
||||||
|
"B",
|
||||||
|
"C4",
|
||||||
|
"COM",
|
||||||
|
"DJ",
|
||||||
|
"E",
|
||||||
|
"EM",
|
||||||
|
"EXE",
|
||||||
|
"F",
|
||||||
|
"FLY",
|
||||||
|
"FURB",
|
||||||
|
"I",
|
||||||
|
"ICN",
|
||||||
|
"INT",
|
||||||
|
"LOG",
|
||||||
|
"N",
|
||||||
|
"NPY",
|
||||||
|
"PD",
|
||||||
|
"PIE",
|
||||||
|
"Q",
|
||||||
|
"RSE",
|
||||||
|
"SIM",
|
||||||
|
"SLOT",
|
||||||
|
"T10",
|
||||||
|
"T201",
|
||||||
|
"TID",
|
||||||
|
"UP",
|
||||||
|
"W",
|
||||||
|
"YTT"
|
||||||
|
]
|
||||||
|
ignore = [
|
||||||
|
"COM812", # Messes with the formatter
|
||||||
|
"UP007", # Incompatible with pydantic + Python 3.9
|
||||||
|
"W293", #
|
||||||
|
]
|
||||||
|
|
||||||
[tool.coverage.run]
|
[tool.coverage.run]
|
||||||
omit = [ "tests/*",]
|
omit = [ "tests/*",]
|
||||||
|
@ -17,7 +17,7 @@ def test_add_message_implementation_only() -> None:
|
|||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
"""Clear the store."""
|
"""Clear the store."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError
|
||||||
|
|
||||||
store: list[BaseMessage] = []
|
store: list[BaseMessage] = []
|
||||||
chat_history = SampleChatHistory(store=store)
|
chat_history = SampleChatHistory(store=store)
|
||||||
@ -50,7 +50,7 @@ def test_bulk_message_implementation_only() -> None:
|
|||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
"""Clear the store."""
|
"""Clear the store."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError
|
||||||
|
|
||||||
chat_history = BulkAddHistory(store=store)
|
chat_history = BulkAddHistory(store=store)
|
||||||
chat_history.add_message(HumanMessage(content="Hello"))
|
chat_history.add_message(HumanMessage(content="Hello"))
|
||||||
|
@ -165,7 +165,7 @@ async def test_astream_implementation_fallback_to_stream() -> None:
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
"""Top Level call"""
|
"""Top Level call"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError
|
||||||
|
|
||||||
def _stream(
|
def _stream(
|
||||||
self,
|
self,
|
||||||
@ -210,7 +210,7 @@ async def test_astream_implementation_uses_astream() -> None:
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
"""Top Level call"""
|
"""Top Level call"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError
|
||||||
|
|
||||||
async def _astream( # type: ignore
|
async def _astream( # type: ignore
|
||||||
self,
|
self,
|
||||||
|
@ -161,7 +161,7 @@ async def test_astream_implementation_fallback_to_stream() -> None:
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
"""Top Level call"""
|
"""Top Level call"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError
|
||||||
|
|
||||||
def _stream(
|
def _stream(
|
||||||
self,
|
self,
|
||||||
@ -198,7 +198,7 @@ async def test_astream_implementation_uses_astream() -> None:
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
"""Top Level call"""
|
"""Top Level call"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError
|
||||||
|
|
||||||
async def _astream(
|
async def _astream(
|
||||||
self,
|
self,
|
||||||
|
@ -59,7 +59,7 @@ def test_base_transform_output_parser() -> None:
|
|||||||
|
|
||||||
def parse(self, text: str) -> str:
|
def parse(self, text: str) -> str:
|
||||||
"""Parse a single string into a specific format."""
|
"""Parse a single string into a specific format."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError
|
||||||
|
|
||||||
def parse_result(
|
def parse_result(
|
||||||
self, result: list[Generation], *, partial: bool = False
|
self, result: list[Generation], *, partial: bool = False
|
||||||
|
@ -61,13 +61,13 @@ def chain() -> Runnable:
|
|||||||
|
|
||||||
|
|
||||||
def _raise_error(inputs: dict) -> str:
|
def _raise_error(inputs: dict) -> str:
|
||||||
raise ValueError()
|
raise ValueError
|
||||||
|
|
||||||
|
|
||||||
def _dont_raise_error(inputs: dict) -> str:
|
def _dont_raise_error(inputs: dict) -> str:
|
||||||
if "exception" in inputs:
|
if "exception" in inputs:
|
||||||
return "bar"
|
return "bar"
|
||||||
raise ValueError()
|
raise ValueError
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
@ -99,11 +99,11 @@ def _runnable(inputs: dict) -> str:
|
|||||||
if inputs["text"] == "foo":
|
if inputs["text"] == "foo":
|
||||||
return "first"
|
return "first"
|
||||||
if "exception" not in inputs:
|
if "exception" not in inputs:
|
||||||
raise ValueError()
|
raise ValueError
|
||||||
if inputs["text"] == "bar":
|
if inputs["text"] == "bar":
|
||||||
return "second"
|
return "second"
|
||||||
if isinstance(inputs["exception"], ValueError):
|
if isinstance(inputs["exception"], ValueError):
|
||||||
raise RuntimeError()
|
raise RuntimeError
|
||||||
return "third"
|
return "third"
|
||||||
|
|
||||||
|
|
||||||
@ -251,13 +251,13 @@ def _generate(input: Iterator) -> Iterator[str]:
|
|||||||
|
|
||||||
|
|
||||||
def _generate_immediate_error(input: Iterator) -> Iterator[str]:
|
def _generate_immediate_error(input: Iterator) -> Iterator[str]:
|
||||||
raise ValueError()
|
raise ValueError
|
||||||
yield ""
|
yield ""
|
||||||
|
|
||||||
|
|
||||||
def _generate_delayed_error(input: Iterator) -> Iterator[str]:
|
def _generate_delayed_error(input: Iterator) -> Iterator[str]:
|
||||||
yield ""
|
yield ""
|
||||||
raise ValueError()
|
raise ValueError
|
||||||
|
|
||||||
|
|
||||||
def test_fallbacks_stream() -> None:
|
def test_fallbacks_stream() -> None:
|
||||||
@ -279,13 +279,13 @@ async def _agenerate(input: AsyncIterator) -> AsyncIterator[str]:
|
|||||||
|
|
||||||
|
|
||||||
async def _agenerate_immediate_error(input: AsyncIterator) -> AsyncIterator[str]:
|
async def _agenerate_immediate_error(input: AsyncIterator) -> AsyncIterator[str]:
|
||||||
raise ValueError()
|
raise ValueError
|
||||||
yield ""
|
yield ""
|
||||||
|
|
||||||
|
|
||||||
async def _agenerate_delayed_error(input: AsyncIterator) -> AsyncIterator[str]:
|
async def _agenerate_delayed_error(input: AsyncIterator) -> AsyncIterator[str]:
|
||||||
yield ""
|
yield ""
|
||||||
raise ValueError()
|
raise ValueError
|
||||||
|
|
||||||
|
|
||||||
async def test_fallbacks_astream() -> None:
|
async def test_fallbacks_astream() -> None:
|
||||||
|
@ -356,7 +356,7 @@ def test_runnable_get_graph_with_invalid_input_type() -> None:
|
|||||||
@property
|
@property
|
||||||
@override
|
@override
|
||||||
def InputType(self) -> type:
|
def InputType(self) -> type:
|
||||||
raise TypeError()
|
raise TypeError
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def invoke(
|
def invoke(
|
||||||
@ -381,7 +381,7 @@ def test_runnable_get_graph_with_invalid_output_type() -> None:
|
|||||||
@property
|
@property
|
||||||
@override
|
@override
|
||||||
def OutputType(self) -> type:
|
def OutputType(self) -> type:
|
||||||
raise TypeError()
|
raise TypeError
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def invoke(
|
def invoke(
|
||||||
|
@ -653,7 +653,7 @@ def test_with_types_with_type_generics() -> None:
|
|||||||
|
|
||||||
def foo(x: int) -> None:
|
def foo(x: int) -> None:
|
||||||
"""Add one to the input."""
|
"""Add one to the input."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError
|
||||||
|
|
||||||
# Try specifying some
|
# Try specifying some
|
||||||
RunnableLambda(foo).with_types(
|
RunnableLambda(foo).with_types(
|
||||||
@ -3980,7 +3980,7 @@ def test_seq_batch_return_exceptions(mocker: MockerFixture) -> None:
|
|||||||
def invoke(
|
def invoke(
|
||||||
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
|
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||||
) -> Any:
|
) -> Any:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError
|
||||||
|
|
||||||
def _batch(
|
def _batch(
|
||||||
self,
|
self,
|
||||||
@ -4101,7 +4101,7 @@ async def test_seq_abatch_return_exceptions(mocker: MockerFixture) -> None:
|
|||||||
def invoke(
|
def invoke(
|
||||||
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
|
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||||
) -> Any:
|
) -> Any:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError
|
||||||
|
|
||||||
async def _abatch(
|
async def _abatch(
|
||||||
self,
|
self,
|
||||||
@ -5352,7 +5352,7 @@ async def test_listeners_async() -> None:
|
|||||||
assert value2 in shared_state.values(), "Value not found in the dictionary."
|
assert value2 in shared_state.values(), "Value not found in the dictionary."
|
||||||
|
|
||||||
|
|
||||||
async def test_closing_iterator_doesnt_raise_error() -> None:
|
def test_closing_iterator_doesnt_raise_error() -> None:
|
||||||
"""Test that closing an iterator calls on_chain_end rather than on_chain_error."""
|
"""Test that closing an iterator calls on_chain_end rather than on_chain_error."""
|
||||||
import time
|
import time
|
||||||
|
|
||||||
@ -5361,9 +5361,10 @@ async def test_closing_iterator_doesnt_raise_error() -> None:
|
|||||||
from langchain_core.output_parsers import StrOutputParser
|
from langchain_core.output_parsers import StrOutputParser
|
||||||
|
|
||||||
on_chain_error_triggered = False
|
on_chain_error_triggered = False
|
||||||
|
on_chain_end_triggered = False
|
||||||
|
|
||||||
class MyHandler(BaseCallbackHandler):
|
class MyHandler(BaseCallbackHandler):
|
||||||
async def on_chain_error(
|
def on_chain_error(
|
||||||
self,
|
self,
|
||||||
error: BaseException,
|
error: BaseException,
|
||||||
*,
|
*,
|
||||||
@ -5376,6 +5377,17 @@ async def test_closing_iterator_doesnt_raise_error() -> None:
|
|||||||
nonlocal on_chain_error_triggered
|
nonlocal on_chain_error_triggered
|
||||||
on_chain_error_triggered = True
|
on_chain_error_triggered = True
|
||||||
|
|
||||||
|
def on_chain_end(
|
||||||
|
self,
|
||||||
|
outputs: dict[str, Any],
|
||||||
|
*,
|
||||||
|
run_id: UUID,
|
||||||
|
parent_run_id: Optional[UUID] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
nonlocal on_chain_end_triggered
|
||||||
|
on_chain_end_triggered = True
|
||||||
|
|
||||||
llm = GenericFakeChatModel(messages=iter(["hi there"]))
|
llm = GenericFakeChatModel(messages=iter(["hi there"]))
|
||||||
chain = llm | StrOutputParser()
|
chain = llm | StrOutputParser()
|
||||||
chain_ = chain.with_config({"callbacks": [MyHandler()]})
|
chain_ = chain.with_config({"callbacks": [MyHandler()]})
|
||||||
@ -5386,6 +5398,7 @@ async def test_closing_iterator_doesnt_raise_error() -> None:
|
|||||||
# Wait for a bit to make sure that the callback is called.
|
# Wait for a bit to make sure that the callback is called.
|
||||||
time.sleep(0.05)
|
time.sleep(0.05)
|
||||||
assert on_chain_error_triggered is False
|
assert on_chain_error_triggered is False
|
||||||
|
assert on_chain_end_triggered is True
|
||||||
|
|
||||||
|
|
||||||
def test_pydantic_protected_namespaces() -> None:
|
def test_pydantic_protected_namespaces() -> None:
|
||||||
|
@ -2067,7 +2067,7 @@ class StreamingRunnable(Runnable[Input, Output]):
|
|||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> Iterator[Output]:
|
) -> Iterator[Output]:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError
|
||||||
|
|
||||||
async def astream(
|
async def astream(
|
||||||
self,
|
self,
|
||||||
|
@ -19,7 +19,7 @@ from langchain_core.runnables.utils import (
|
|||||||
[
|
[
|
||||||
(lambda x: x * 2, "lambda x: x * 2"),
|
(lambda x: x * 2, "lambda x: x * 2"),
|
||||||
(lambda a, b: a + b, "lambda a, b: a + b"),
|
(lambda a, b: a + b, "lambda a, b: a + b"),
|
||||||
(lambda x: x if x > 0 else 0, "lambda x: x if x > 0 else 0"),
|
(lambda x: x if x > 0 else 0, "lambda x: x if x > 0 else 0"), # noqa: FURB136
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_get_lambda_source(func: Callable, expected_source: str) -> None:
|
def test_get_lambda_source(func: Callable, expected_source: str) -> None:
|
||||||
|
@ -5,6 +5,8 @@ from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage
|
|||||||
|
|
||||||
|
|
||||||
class AnyStr(str):
|
class AnyStr(str):
|
||||||
|
__slots__ = ()
|
||||||
|
|
||||||
def __eq__(self, other: Any) -> bool:
|
def __eq__(self, other: Any) -> bool:
|
||||||
return isinstance(other, str)
|
return isinstance(other, str)
|
||||||
|
|
||||||
|
@ -346,7 +346,7 @@ class TestGetBufferString(unittest.TestCase):
|
|||||||
self.chat_msg,
|
self.chat_msg,
|
||||||
self.tool_calls_msg,
|
self.tool_calls_msg,
|
||||||
]
|
]
|
||||||
expected_output = "\n".join(
|
expected_output = "\n".join( # noqa: FLY002
|
||||||
[
|
[
|
||||||
"Human: human",
|
"Human: human",
|
||||||
"AI: ai",
|
"AI: ai",
|
||||||
|
@ -401,7 +401,7 @@ def test_structured_tool_from_function_docstring() -> None:
|
|||||||
bar: the bar value
|
bar: the bar value
|
||||||
baz: the baz value
|
baz: the baz value
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError
|
||||||
|
|
||||||
structured_tool = StructuredTool.from_function(foo)
|
structured_tool = StructuredTool.from_function(foo)
|
||||||
assert structured_tool.name == "foo"
|
assert structured_tool.name == "foo"
|
||||||
@ -435,7 +435,7 @@ def test_structured_tool_from_function_docstring_complex_args() -> None:
|
|||||||
bar: int
|
bar: int
|
||||||
baz: List[str]
|
baz: List[str]
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError
|
||||||
|
|
||||||
structured_tool = StructuredTool.from_function(foo)
|
structured_tool = StructuredTool.from_function(foo)
|
||||||
assert structured_tool.name == "foo"
|
assert structured_tool.name == "foo"
|
||||||
@ -781,7 +781,7 @@ def test_structured_tool_from_function() -> None:
|
|||||||
bar: the bar value
|
bar: the bar value
|
||||||
baz: the baz value
|
baz: the baz value
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError
|
||||||
|
|
||||||
structured_tool = StructuredTool.from_function(foo)
|
structured_tool = StructuredTool.from_function(foo)
|
||||||
assert structured_tool.name == "foo"
|
assert structured_tool.name == "foo"
|
||||||
@ -854,7 +854,7 @@ def test_validation_error_handling_non_validation_error(
|
|||||||
self,
|
self,
|
||||||
tool_input: Union[str, dict],
|
tool_input: Union[str, dict],
|
||||||
) -> Union[str, dict[str, Any]]:
|
) -> Union[str, dict[str, Any]]:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError
|
||||||
|
|
||||||
def _run(self) -> str:
|
def _run(self) -> str:
|
||||||
return "dummy"
|
return "dummy"
|
||||||
@ -916,7 +916,7 @@ async def test_async_validation_error_handling_non_validation_error(
|
|||||||
self,
|
self,
|
||||||
tool_input: Union[str, dict],
|
tool_input: Union[str, dict],
|
||||||
) -> Union[str, dict[str, Any]]:
|
) -> Union[str, dict[str, Any]]:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError
|
||||||
|
|
||||||
def _run(self) -> str:
|
def _run(self) -> str:
|
||||||
return "dummy"
|
return "dummy"
|
||||||
|
@ -39,7 +39,7 @@ async def test_inmemory_similarity_search() -> None:
|
|||||||
output = await store.asimilarity_search("bar", k=2)
|
output = await store.asimilarity_search("bar", k=2)
|
||||||
assert output == [
|
assert output == [
|
||||||
_any_id_document(page_content="bar"),
|
_any_id_document(page_content="bar"),
|
||||||
_any_id_document(page_content="baz"),
|
_any_id_document(page_content="foo"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -81,7 +81,7 @@ async def test_inmemory_mmr() -> None:
|
|||||||
output = docsearch.max_marginal_relevance_search("foo", k=10, lambda_mult=0.1)
|
output = docsearch.max_marginal_relevance_search("foo", k=10, lambda_mult=0.1)
|
||||||
assert len(output) == len(texts)
|
assert len(output) == len(texts)
|
||||||
assert output[0] == _any_id_document(page_content="foo")
|
assert output[0] == _any_id_document(page_content="foo")
|
||||||
assert output[1] == _any_id_document(page_content="foy")
|
assert output[1] == _any_id_document(page_content="fou")
|
||||||
|
|
||||||
# Check async version
|
# Check async version
|
||||||
output = await docsearch.amax_marginal_relevance_search(
|
output = await docsearch.amax_marginal_relevance_search(
|
||||||
@ -89,7 +89,7 @@ async def test_inmemory_mmr() -> None:
|
|||||||
)
|
)
|
||||||
assert len(output) == len(texts)
|
assert len(output) == len(texts)
|
||||||
assert output[0] == _any_id_document(page_content="foo")
|
assert output[0] == _any_id_document(page_content="foo")
|
||||||
assert output[1] == _any_id_document(page_content="foy")
|
assert output[1] == _any_id_document(page_content="fou")
|
||||||
|
|
||||||
|
|
||||||
async def test_inmemory_dump_load(tmp_path: Path) -> None:
|
async def test_inmemory_dump_load(tmp_path: Path) -> None:
|
||||||
|
@ -63,7 +63,7 @@ class CustomAddTextsVectorstore(VectorStore):
|
|||||||
def similarity_search(
|
def similarity_search(
|
||||||
self, query: str, k: int = 4, **kwargs: Any
|
self, query: str, k: int = 4, **kwargs: Any
|
||||||
) -> list[Document]:
|
) -> list[Document]:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class CustomAddDocumentsVectorstore(VectorStore):
|
class CustomAddDocumentsVectorstore(VectorStore):
|
||||||
@ -107,7 +107,7 @@ class CustomAddDocumentsVectorstore(VectorStore):
|
|||||||
def similarity_search(
|
def similarity_search(
|
||||||
self, query: str, k: int = 4, **kwargs: Any
|
self, query: str, k: int = 4, **kwargs: Any
|
||||||
) -> list[Document]:
|
) -> list[Document]:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
Loading…
Reference in New Issue
Block a user