langchain: Add ruff rules A (#31888)

See https://docs.astral.sh/ruff/rules/#flake8-builtins-a

---------

Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
Christophe Bornet 2025-07-07 16:20:27 +02:00 committed by GitHub
parent 49c316667d
commit f06380516f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 147 additions and 102 deletions

View File

@ -41,7 +41,7 @@ from langchain_core.runnables.utils import AddableDict
from langchain_core.tools import BaseTool
from langchain_core.utils.input import get_color_mapping
from pydantic import BaseModel, ConfigDict, model_validator
from typing_extensions import Self
from typing_extensions import Self, override
from langchain._api.deprecation import AGENT_DEPRECATION_WARNING
from langchain.agents.agent_iterator import AgentExecutorIterator
@ -1749,6 +1749,7 @@ class AgentExecutor(Chain):
else:
return intermediate_steps
@override
def stream(
self,
input: Union[dict[str, Any], Any],
@ -1779,6 +1780,7 @@ class AgentExecutor(Chain):
)
yield from iterator
@override
async def astream(
self,
input: Union[dict[str, Any], Any],

View File

@ -20,7 +20,7 @@ from langchain_core.runnables import RunnableConfig, RunnableSerializable, ensur
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool
from pydantic import BaseModel, Field, model_validator
from typing_extensions import Self
from typing_extensions import Self, override
if TYPE_CHECKING:
import openai
@ -272,6 +272,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
)
return cls(assistant_id=assistant.id, client=client, **kwargs)
@override
def invoke(
self, input: dict, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> OutputType:
@ -399,6 +400,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
)
return cls(assistant_id=assistant.id, async_client=async_client, **kwargs)
@override
async def ainvoke(
self, input: dict, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> OutputType:
@ -515,10 +517,10 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
}
return submit_tool_outputs
def _create_run(self, input: dict) -> Any:
def _create_run(self, input_dict: dict) -> Any:
params = {
k: v
for k, v in input.items()
for k, v in input_dict.items()
if k
in (
"instructions",
@ -534,15 +536,15 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
)
}
return self.client.beta.threads.runs.create(
input["thread_id"],
input_dict["thread_id"],
assistant_id=self.assistant_id,
**params,
)
def _create_thread_and_run(self, input: dict, thread: dict) -> Any:
def _create_thread_and_run(self, input_dict: dict, thread: dict) -> Any:
params = {
k: v
for k, v in input.items()
for k, v in input_dict.items()
if k
in (
"instructions",
@ -673,10 +675,10 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
}
return submit_tool_outputs
async def _acreate_run(self, input: dict) -> Any:
async def _acreate_run(self, input_dict: dict) -> Any:
params = {
k: v
for k, v in input.items()
for k, v in input_dict.items()
if k
in (
"instructions",
@ -692,15 +694,15 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
)
}
return await self.async_client.beta.threads.runs.create(
input["thread_id"],
input_dict["thread_id"],
assistant_id=self.assistant_id,
**params,
)
async def _acreate_thread_and_run(self, input: dict, thread: dict) -> Any:
async def _acreate_thread_and_run(self, input_dict: dict, thread: dict) -> Any:
params = {
k: v
for k, v in input.items()
for k, v in input_dict.items()
if k
in (
"instructions",

View File

@ -35,6 +35,7 @@ from pydantic import (
field_validator,
model_validator,
)
from typing_extensions import override
from langchain.schema import RUN_KEY
@ -118,6 +119,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
# This is correct, but pydantic typings/mypy don't think so.
return create_model("ChainOutput", **{k: (Any, None) for k in self.output_keys})
@override
def invoke(
self,
input: dict[str, Any],
@ -171,6 +173,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
return final_outputs
@override
async def ainvoke(
self,
input: dict[str, Any],

View File

@ -88,9 +88,11 @@ class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
if fix_invalid:
def ast_parse(raw_filter: str) -> Optional[FilterDirective]:
filter = cast(Optional[FilterDirective], get_parser().parse(raw_filter))
filter_directive = cast(
Optional[FilterDirective], get_parser().parse(raw_filter)
)
fixed = fix_filter_directive(
filter,
filter_directive,
allowed_comparators=allowed_comparators,
allowed_operators=allowed_operators,
allowed_attributes=allowed_attributes,
@ -107,7 +109,7 @@ class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
def fix_filter_directive(
filter: Optional[FilterDirective],
filter: Optional[FilterDirective], # noqa: A002
*,
allowed_comparators: Optional[Sequence[Comparator]] = None,
allowed_operators: Optional[Sequence[Operator]] = None,

View File

@ -28,7 +28,7 @@ from langchain_core.runnables.schema import StreamEvent
from langchain_core.tools import BaseTool
from langchain_core.tracers import RunLog, RunLogPatch
from pydantic import BaseModel
from typing_extensions import TypeAlias
from typing_extensions import TypeAlias, override
__all__ = [
"init_chat_model",
@ -673,6 +673,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
list[AnyMessage],
]
@override
def invoke(
self,
input: LanguageModelInput,
@ -681,6 +682,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
) -> Any:
return self._model(config).invoke(input, config=config, **kwargs)
@override
async def ainvoke(
self,
input: LanguageModelInput,
@ -689,6 +691,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
) -> Any:
return await self._model(config).ainvoke(input, config=config, **kwargs)
@override
def stream(
self,
input: LanguageModelInput,
@ -697,6 +700,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
) -> Iterator[Any]:
yield from self._model(config).stream(input, config=config, **kwargs)
@override
async def astream(
self,
input: LanguageModelInput,
@ -802,6 +806,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
):
yield x
@override
def transform(
self,
input: Iterator[LanguageModelInput],
@ -810,6 +815,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
) -> Iterator[Any]:
yield from self._model(config).transform(input, config=config, **kwargs)
@override
async def atransform(
self,
input: AsyncIterator[LanguageModelInput],
@ -853,6 +859,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
**kwargs: Any,
) -> AsyncIterator[RunLog]: ...
@override
async def astream_log(
self,
input: Any,
@ -883,6 +890,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
):
yield x
@override
async def astream_events(
self,
input: Any,

View File

@ -27,6 +27,7 @@ from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.tools import BaseTool
from pydantic import ConfigDict, Field
from typing_extensions import override
from langchain.chains.llm import LLMChain
from langchain.evaluation.agents.trajectory_eval_prompt import (
@ -326,6 +327,7 @@ The following is the expected answer. Use this to measure correctness:
)
return cast(dict, self.output_parser.parse(raw_output))
@override
def _evaluate_agent_trajectory(
self,
*,
@ -368,6 +370,7 @@ The following is the expected answer. Use this to measure correctness:
return_only_outputs=True,
)
@override
async def _aevaluate_agent_trajectory(
self,
*,

View File

@ -11,6 +11,7 @@ from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.prompts.prompt import PromptTemplate
from pydantic import ConfigDict, Field
from typing_extensions import override
from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple
from langchain.chains.llm import LLMChain
@ -278,7 +279,7 @@ Performance may be significantly worse with other models."
self,
prediction: str,
prediction_b: str,
input: Optional[str],
input_: Optional[str],
reference: Optional[str],
) -> dict:
"""Prepare the input for the chain.
@ -286,21 +287,21 @@ Performance may be significantly worse with other models."
Args:
prediction (str): The output string from the first model.
prediction_b (str): The output string from the second model.
input (str, optional): The input or task string.
input_ (str, optional): The input or task string.
reference (str, optional): The reference string, if any.
Returns:
dict: The prepared input for the chain.
"""
input_ = {
input_dict = {
"prediction": prediction,
"prediction_b": prediction_b,
"input": input,
"input": input_,
}
if self.requires_reference:
input_["reference"] = reference
return input_
input_dict["reference"] = reference
return input_dict
def _prepare_output(self, result: dict) -> dict:
"""Prepare the output."""
@ -309,6 +310,7 @@ Performance may be significantly worse with other models."
parsed[RUN_KEY] = result[RUN_KEY]
return parsed
@override
def _evaluate_string_pairs(
self,
*,
@ -351,6 +353,7 @@ Performance may be significantly worse with other models."
)
return self._prepare_output(result)
@override
async def _aevaluate_string_pairs(
self,
*,

View File

@ -10,6 +10,7 @@ from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.prompts import BasePromptTemplate
from pydantic import ConfigDict, Field
from typing_extensions import override
from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple
from langchain.chains.llm import LLMChain
@ -383,16 +384,16 @@ class CriteriaEvalChain(StringEvaluator, LLMEvalChain, LLMChain):
self,
prediction: str,
reference: Optional[str],
input: Optional[str],
input_: Optional[str],
) -> dict:
"""Get the evaluation input."""
input_ = {
"input": input,
input_dict = {
"input": input_,
"output": prediction,
}
if self.requires_reference:
input_["reference"] = reference
return input_
input_dict["reference"] = reference
return input_dict
def _prepare_output(self, result: dict) -> dict:
"""Prepare the output."""
@ -401,6 +402,7 @@ class CriteriaEvalChain(StringEvaluator, LLMEvalChain, LLMChain):
parsed[RUN_KEY] = result[RUN_KEY]
return parsed
@override
def _evaluate_strings(
self,
*,
@ -456,6 +458,7 @@ class CriteriaEvalChain(StringEvaluator, LLMEvalChain, LLMChain):
)
return self._prepare_output(result)
@override
async def _aevaluate_strings(
self,
*,

View File

@ -5,6 +5,7 @@ from operator import eq
from typing import Any, Callable, Optional, Union, cast
from langchain_core.utils.json import parse_json_markdown
from typing_extensions import override
from langchain.evaluation.schema import StringEvaluator
@ -49,6 +50,7 @@ class JsonValidityEvaluator(StringEvaluator):
def evaluation_name(self) -> str:
return "json_validity"
@override
def _evaluate_strings(
self,
prediction: str,
@ -132,6 +134,7 @@ class JsonEqualityEvaluator(StringEvaluator):
return parse_json_markdown(string)
return string
@override
def _evaluate_strings(
self,
prediction: str,

View File

@ -2,6 +2,7 @@ import json
from typing import Any, Callable, Optional, Union
from langchain_core.utils.json import parse_json_markdown
from typing_extensions import override
from langchain.evaluation.schema import StringEvaluator
@ -82,6 +83,7 @@ class JsonEditDistanceEvaluator(StringEvaluator):
return parse_json_markdown(node)
return node
@override
def _evaluate_strings(
self,
prediction: str,

View File

@ -1,6 +1,7 @@
from typing import Any, Union
from langchain_core.utils.json import parse_json_markdown
from typing_extensions import override
from langchain.evaluation.schema import StringEvaluator
@ -85,6 +86,7 @@ class JsonSchemaEvaluator(StringEvaluator):
except ValidationError as e:
return {"score": False, "reasoning": repr(e)}
@override
def _evaluate_strings(
self,
prediction: Union[str, Any],

View File

@ -11,6 +11,7 @@ from langchain_core.callbacks import Callbacks
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import PromptTemplate
from pydantic import ConfigDict
from typing_extensions import override
from langchain.chains.llm import LLMChain
from langchain.evaluation.qa.eval_prompt import CONTEXT_PROMPT, COT_PROMPT, PROMPT
@ -154,6 +155,7 @@ class QAEvalChain(LLMChain, StringEvaluator, LLMEvalChain):
parsed_result[RUN_KEY] = result[RUN_KEY]
return parsed_result
@override
def _evaluate_strings(
self,
*,
@ -189,6 +191,7 @@ class QAEvalChain(LLMChain, StringEvaluator, LLMEvalChain):
)
return self._prepare_output(result)
@override
async def _aevaluate_strings(
self,
*,
@ -296,6 +299,7 @@ class ContextQAEvalChain(LLMChain, StringEvaluator, LLMEvalChain):
parsed_result[RUN_KEY] = result[RUN_KEY]
return parsed_result
@override
def _evaluate_strings(
self,
*,
@ -317,6 +321,7 @@ class ContextQAEvalChain(LLMChain, StringEvaluator, LLMEvalChain):
)
return self._prepare_output(result)
@override
async def _aevaluate_strings(
self,
*,

View File

@ -109,13 +109,13 @@ class _EvalArgsMixin:
def _check_evaluation_args(
self,
reference: Optional[str] = None,
input: Optional[str] = None,
input_: Optional[str] = None,
) -> None:
"""Check if the evaluation arguments are valid.
Args:
reference (Optional[str], optional): The reference label.
input (Optional[str], optional): The input string.
input_ (Optional[str], optional): The input string.
Raises:
ValueError: If the evaluator requires an input string but none is provided,
or if the evaluator requires a reference label but none is provided.
@ -152,7 +152,7 @@ class StringEvaluator(_EvalArgsMixin, ABC):
*,
prediction: Union[str, Any],
reference: Optional[Union[str, Any]] = None,
input: Optional[Union[str, Any]] = None,
input: Optional[Union[str, Any]] = None, # noqa: A002
**kwargs: Any,
) -> dict:
"""Evaluate Chain or LLM output, based on optional input and label.
@ -175,7 +175,7 @@ class StringEvaluator(_EvalArgsMixin, ABC):
*,
prediction: Union[str, Any],
reference: Optional[Union[str, Any]] = None,
input: Optional[Union[str, Any]] = None,
input: Optional[Union[str, Any]] = None, # noqa: A002
**kwargs: Any,
) -> dict:
"""Asynchronously evaluate Chain or LLM output, based on optional input and label.
@ -206,7 +206,7 @@ class StringEvaluator(_EvalArgsMixin, ABC):
*,
prediction: str,
reference: Optional[str] = None,
input: Optional[str] = None,
input: Optional[str] = None, # noqa: A002
**kwargs: Any,
) -> dict:
"""Evaluate Chain or LLM output, based on optional input and label.
@ -219,7 +219,7 @@ class StringEvaluator(_EvalArgsMixin, ABC):
Returns:
dict: The evaluation results containing the score or value.
""" # noqa: E501
self._check_evaluation_args(reference=reference, input=input)
self._check_evaluation_args(reference=reference, input_=input)
return self._evaluate_strings(
prediction=prediction, reference=reference, input=input, **kwargs
)
@ -229,7 +229,7 @@ class StringEvaluator(_EvalArgsMixin, ABC):
*,
prediction: str,
reference: Optional[str] = None,
input: Optional[str] = None,
input: Optional[str] = None, # noqa: A002
**kwargs: Any,
) -> dict:
"""Asynchronously evaluate Chain or LLM output, based on optional input and label.
@ -242,7 +242,7 @@ class StringEvaluator(_EvalArgsMixin, ABC):
Returns:
dict: The evaluation results containing the score or value.
""" # noqa: E501
self._check_evaluation_args(reference=reference, input=input)
self._check_evaluation_args(reference=reference, input_=input)
return await self._aevaluate_strings(
prediction=prediction, reference=reference, input=input, **kwargs
)
@ -258,7 +258,7 @@ class PairwiseStringEvaluator(_EvalArgsMixin, ABC):
prediction: str,
prediction_b: str,
reference: Optional[str] = None,
input: Optional[str] = None,
input: Optional[str] = None, # noqa: A002
**kwargs: Any,
) -> dict:
"""Evaluate the output string pairs.
@ -279,7 +279,7 @@ class PairwiseStringEvaluator(_EvalArgsMixin, ABC):
prediction: str,
prediction_b: str,
reference: Optional[str] = None,
input: Optional[str] = None,
input: Optional[str] = None, # noqa: A002
**kwargs: Any,
) -> dict:
"""Asynchronously evaluate the output string pairs.
@ -309,7 +309,7 @@ class PairwiseStringEvaluator(_EvalArgsMixin, ABC):
prediction: str,
prediction_b: str,
reference: Optional[str] = None,
input: Optional[str] = None,
input: Optional[str] = None, # noqa: A002
**kwargs: Any,
) -> dict:
"""Evaluate the output string pairs.
@ -323,7 +323,7 @@ class PairwiseStringEvaluator(_EvalArgsMixin, ABC):
Returns:
dict: A dictionary containing the preference, scores, and/or other information.
""" # noqa: E501
self._check_evaluation_args(reference=reference, input=input)
self._check_evaluation_args(reference=reference, input_=input)
return self._evaluate_string_pairs(
prediction=prediction,
prediction_b=prediction_b,
@ -338,7 +338,7 @@ class PairwiseStringEvaluator(_EvalArgsMixin, ABC):
prediction: str,
prediction_b: str,
reference: Optional[str] = None,
input: Optional[str] = None,
input: Optional[str] = None, # noqa: A002
**kwargs: Any,
) -> dict:
"""Asynchronously evaluate the output string pairs.
@ -352,7 +352,7 @@ class PairwiseStringEvaluator(_EvalArgsMixin, ABC):
Returns:
dict: A dictionary containing the preference, scores, and/or other information.
""" # noqa: E501
self._check_evaluation_args(reference=reference, input=input)
self._check_evaluation_args(reference=reference, input_=input)
return await self._aevaluate_string_pairs(
prediction=prediction,
prediction_b=prediction_b,
@ -376,7 +376,7 @@ class AgentTrajectoryEvaluator(_EvalArgsMixin, ABC):
*,
prediction: str,
agent_trajectory: Sequence[tuple[AgentAction, str]],
input: str,
input: str, # noqa: A002
reference: Optional[str] = None,
**kwargs: Any,
) -> dict:
@ -398,7 +398,7 @@ class AgentTrajectoryEvaluator(_EvalArgsMixin, ABC):
*,
prediction: str,
agent_trajectory: Sequence[tuple[AgentAction, str]],
input: str,
input: str, # noqa: A002
reference: Optional[str] = None,
**kwargs: Any,
) -> dict:
@ -429,7 +429,7 @@ class AgentTrajectoryEvaluator(_EvalArgsMixin, ABC):
*,
prediction: str,
agent_trajectory: Sequence[tuple[AgentAction, str]],
input: str,
input: str, # noqa: A002
reference: Optional[str] = None,
**kwargs: Any,
) -> dict:
@ -445,7 +445,7 @@ class AgentTrajectoryEvaluator(_EvalArgsMixin, ABC):
Returns:
dict: The evaluation result.
"""
self._check_evaluation_args(reference=reference, input=input)
self._check_evaluation_args(reference=reference, input_=input)
return self._evaluate_agent_trajectory(
prediction=prediction,
input=input,
@ -459,7 +459,7 @@ class AgentTrajectoryEvaluator(_EvalArgsMixin, ABC):
*,
prediction: str,
agent_trajectory: Sequence[tuple[AgentAction, str]],
input: str,
input: str, # noqa: A002
reference: Optional[str] = None,
**kwargs: Any,
) -> dict:
@ -475,7 +475,7 @@ class AgentTrajectoryEvaluator(_EvalArgsMixin, ABC):
Returns:
dict: The evaluation result.
"""
self._check_evaluation_args(reference=reference, input=input)
self._check_evaluation_args(reference=reference, input_=input)
return await self._aevaluate_agent_trajectory(
prediction=prediction,
input=input,

View File

@ -11,6 +11,7 @@ from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.prompts.prompt import PromptTemplate
from pydantic import ConfigDict, Field
from typing_extensions import override
from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple
from langchain.chains.llm import LLMChain
@ -292,7 +293,7 @@ Performance may be significantly worse with other models."
def _prepare_input(
self,
prediction: str,
input: Optional[str],
input_: Optional[str],
reference: Optional[str],
) -> dict:
"""Prepare the input for the chain.
@ -300,20 +301,20 @@ Performance may be significantly worse with other models."
Args:
prediction (str): The output string from the first model.
prediction_b (str): The output string from the second model.
input (str, optional): The input or task string.
input_ (str, optional): The input or task string.
reference (str, optional): The reference string, if any.
Returns:
dict: The prepared input for the chain.
"""
input_ = {
input_dict = {
"prediction": prediction,
"input": input,
"input": input_,
}
if self.requires_reference:
input_["reference"] = reference
return input_
input_dict["reference"] = reference
return input_dict
def _prepare_output(self, result: dict) -> dict:
"""Prepare the output."""
@ -324,6 +325,7 @@ Performance may be significantly worse with other models."
parsed["score"] = parsed["score"] / self.normalize_by
return parsed
@override
def _evaluate_strings(
self,
*,
@ -361,7 +363,8 @@ Performance may be significantly worse with other models."
)
return self._prepare_output(result)
async def _aevaluate_string_pairs(
@override
async def _aevaluate_strings(
self,
*,
prediction: str,

View File

@ -10,6 +10,7 @@ from langchain_core.callbacks.manager import (
)
from langchain_core.utils import pre_init
from pydantic import Field
from typing_extensions import override
from langchain.chains.base import Chain
from langchain.evaluation.schema import PairwiseStringEvaluator, StringEvaluator
@ -260,6 +261,7 @@ class StringDistanceEvalChain(StringEvaluator, _RapidFuzzChainMixin):
"""
return {"score": self.compute_metric(inputs["reference"], inputs["prediction"])}
@override
def _evaluate_strings(
self,
*,
@ -295,6 +297,7 @@ class StringDistanceEvalChain(StringEvaluator, _RapidFuzzChainMixin):
return self._prepare_output(result)
@override
async def _aevaluate_strings(
self,
*,

View File

@ -40,7 +40,7 @@ def _get_client(
def push(
repo_full_name: str,
object: Any,
object: Any, # noqa: A002
*,
api_url: Optional[str] = None,
api_key: Optional[str] = None,

View File

@ -167,7 +167,7 @@ class ConversationVectorStoreTokenBufferMemory(ConversationTokenBufferMemory):
self._pop_and_store_interaction(buffer)
def _pop_and_store_interaction(self, buffer: list[BaseMessage]) -> None:
input = buffer.pop(0)
input_ = buffer.pop(0)
output = buffer.pop(0)
timestamp = self._timestamps.pop(0).strftime(TIMESTAMP_FORMAT)
# Split AI output into smaller chunks to avoid creating documents
@ -175,7 +175,7 @@ class ConversationVectorStoreTokenBufferMemory(ConversationTokenBufferMemory):
ai_chunks = self._split_long_ai_text(str(output.content))
for index, chunk in enumerate(ai_chunks):
self.memory_retriever.save_context(
{"Human": f"<{timestamp}/00> {str(input.content)}"},
{"Human": f"<{timestamp}/00> {str(input_.content)}"},
{"AI": f"<{timestamp}/{index:02}> {chunk}"},
)

View File

@ -28,6 +28,7 @@ from langchain_core.runnables.utils import (
get_unique_config_specs,
)
from pydantic import model_validator
from typing_extensions import override
T = TypeVar("T")
H = TypeVar("H", bound=Hashable)
@ -86,6 +87,7 @@ class EnsembleRetriever(BaseRetriever):
values["weights"] = [1 / n_retrievers] * n_retrievers
return values
@override
def invoke(
self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> list[Document]:
@ -119,6 +121,7 @@ class EnsembleRetriever(BaseRetriever):
)
return result
@override
async def ainvoke(
self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> list[Document]:

View File

@ -88,23 +88,23 @@ class LocalFileStore(ByteStore):
return Path(full_path)
def _mkdir_for_store(self, dir: Path) -> None:
def _mkdir_for_store(self, dir_path: Path) -> None:
"""Makes a store directory path (including parents) with specified permissions
This is needed because `Path.mkdir()` is restricted by the current `umask`,
whereas the explicit `os.chmod()` used here is not.
Args:
dir: (Path) The store directory to make
dir_path: (Path) The store directory to make
Returns:
None
"""
if not dir.exists():
self._mkdir_for_store(dir.parent)
dir.mkdir(exist_ok=True)
if not dir_path.exists():
self._mkdir_for_store(dir_path.parent)
dir_path.mkdir(exist_ok=True)
if self.chmod_dir is not None:
os.chmod(dir, self.chmod_dir)
os.chmod(dir_path, self.chmod_dir)
def mget(self, keys: Sequence[str]) -> list[Optional[bytes]]:
"""Get the values associated with the given keys.

View File

@ -143,7 +143,7 @@ ignore-regex = ".*(Stati Uniti|Tense=Pres).*"
ignore-words-list = "momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogyny,unsecure,damon,crate,aadd,symbl,precesses,accademia,nin"
[tool.ruff.lint]
select = ["E", "F", "I", "EM", "PGH003", "PIE", "T201", "D", "UP", "S", "W"]
select = ["A", "E", "F", "I", "EM", "PGH003", "PIE", "T201", "D", "UP", "S", "W"]
pydocstyle.convention = "google"
pyupgrade.keep-runtime-typing = true

View File

@ -33,10 +33,8 @@ print("Hello fifty shades of gray mans!"[::-1]) # noqa: T201
)
def _test_convo_output(
input: str, expected_tool: str, expected_tool_input: str
) -> None:
result = ConvoOutputParser().parse(input.strip())
def _test_convo_output(text: str, expected_tool: str, expected_tool_input: str) -> None:
result = ConvoOutputParser().parse(text.strip())
assert isinstance(result, AgentAction)
assert result.tool == expected_tool
assert result.tool_input == expected_tool_input

View File

@ -25,7 +25,7 @@ def test_parse_invalid_grammar(x: str) -> None:
def test_parse_comparison() -> None:
comp = 'gte("foo", 2)'
expected = Comparison(comparator=Comparator.GTE, attribute="foo", value=2)
for input in (
for text in (
comp,
comp.replace('"', "'"),
comp.replace(" ", ""),
@ -34,7 +34,7 @@ def test_parse_comparison() -> None:
comp.replace(",", ", "),
comp.replace("2", "2.0"),
):
actual = DEFAULT_PARSER.parse(input)
actual = DEFAULT_PARSER.parse(text)
assert expected == actual
@ -43,7 +43,7 @@ def test_parse_operation() -> None:
eq = Comparison(comparator=Comparator.EQ, attribute="foo", value="bar")
lt = Comparison(comparator=Comparator.LT, attribute="baz", value=1995.25)
expected = Operation(operator=Operator.AND, arguments=[eq, lt])
for input in (
for text in (
op,
op.replace('"', "'"),
op.replace(" ", ""),
@ -52,7 +52,7 @@ def test_parse_operation() -> None:
op.replace(",", ", "),
op.replace("25", "250"),
):
actual = DEFAULT_PARSER.parse(input)
actual = DEFAULT_PARSER.parse(text)
assert expected == actual

View File

@ -152,7 +152,7 @@ def test_output_fixing_parser_output_type(
@pytest.mark.parametrize(
"input,base_parser,retry_chain,expected",
"completion,base_parser,retry_chain,expected",
[
(
"2024/07/08",
@ -171,7 +171,7 @@ def test_output_fixing_parser_output_type(
],
)
def test_output_fixing_parser_parse_with_retry_chain(
input: str,
completion: str,
base_parser: BaseOutputParser[T],
retry_chain: Runnable[dict[str, Any], str],
expected: T,
@ -185,11 +185,11 @@ def test_output_fixing_parser_parse_with_retry_chain(
retry_chain=retry_chain,
legacy=False,
)
assert parser.parse(input) == expected
assert parser.parse(completion) == expected
@pytest.mark.parametrize(
"input,base_parser,retry_chain,expected",
"completion,base_parser,retry_chain,expected",
[
(
"2024/07/08",
@ -208,7 +208,7 @@ def test_output_fixing_parser_parse_with_retry_chain(
],
)
async def test_output_fixing_parser_aparse_with_retry_chain(
input: str,
completion: str,
base_parser: BaseOutputParser[T],
retry_chain: Runnable[dict[str, Any], str],
expected: T,
@ -221,7 +221,7 @@ async def test_output_fixing_parser_aparse_with_retry_chain(
retry_chain=retry_chain,
legacy=False,
)
assert (await parser.aparse(input)) == expected
assert (await parser.aparse(completion)) == expected
def _extract_exception(

View File

@ -202,7 +202,7 @@ def test_retry_with_error_output_parser_parse_is_not_implemented() -> None:
@pytest.mark.parametrize(
"input,prompt,base_parser,retry_chain,expected",
"completion,prompt,base_parser,retry_chain,expected",
[
(
"2024/07/08",
@ -215,7 +215,7 @@ def test_retry_with_error_output_parser_parse_is_not_implemented() -> None:
],
)
def test_retry_output_parser_parse_with_prompt_with_retry_chain(
input: str,
completion: str,
prompt: PromptValue,
base_parser: BaseOutputParser[T],
retry_chain: Runnable[dict[str, Any], str],
@ -226,11 +226,11 @@ def test_retry_output_parser_parse_with_prompt_with_retry_chain(
retry_chain=retry_chain,
legacy=False,
)
assert parser.parse_with_prompt(input, prompt) == expected
assert parser.parse_with_prompt(completion, prompt) == expected
@pytest.mark.parametrize(
"input,prompt,base_parser,retry_chain,expected",
"completion,prompt,base_parser,retry_chain,expected",
[
(
"2024/07/08",
@ -243,7 +243,7 @@ def test_retry_output_parser_parse_with_prompt_with_retry_chain(
],
)
async def test_retry_output_parser_aparse_with_prompt_with_retry_chain(
input: str,
completion: str,
prompt: PromptValue,
base_parser: BaseOutputParser[T],
retry_chain: Runnable[dict[str, Any], str],
@ -255,11 +255,11 @@ async def test_retry_output_parser_aparse_with_prompt_with_retry_chain(
retry_chain=retry_chain,
legacy=False,
)
assert (await parser.aparse_with_prompt(input, prompt)) == expected
assert (await parser.aparse_with_prompt(completion, prompt)) == expected
@pytest.mark.parametrize(
"input,prompt,base_parser,retry_chain,expected",
"completion,prompt,base_parser,retry_chain,expected",
[
(
"2024/07/08",
@ -272,7 +272,7 @@ async def test_retry_output_parser_aparse_with_prompt_with_retry_chain(
],
)
def test_retry_with_error_output_parser_parse_with_prompt_with_retry_chain(
input: str,
completion: str,
prompt: PromptValue,
base_parser: BaseOutputParser[T],
retry_chain: Runnable[dict[str, Any], str],
@ -284,11 +284,11 @@ def test_retry_with_error_output_parser_parse_with_prompt_with_retry_chain(
retry_chain=retry_chain,
legacy=False,
)
assert parser.parse_with_prompt(input, prompt) == expected
assert parser.parse_with_prompt(completion, prompt) == expected
@pytest.mark.parametrize(
"input,prompt,base_parser,retry_chain,expected",
"completion,prompt,base_parser,retry_chain,expected",
[
(
"2024/07/08",
@ -301,7 +301,7 @@ def test_retry_with_error_output_parser_parse_with_prompt_with_retry_chain(
],
)
async def test_retry_with_error_output_parser_aparse_with_prompt_with_retry_chain(
input: str,
completion: str,
prompt: PromptValue,
base_parser: BaseOutputParser[T],
retry_chain: Runnable[dict[str, Any], str],
@ -312,7 +312,7 @@ async def test_retry_with_error_output_parser_aparse_with_prompt_with_retry_chai
retry_chain=retry_chain,
legacy=False,
)
assert (await parser.aparse_with_prompt(input, prompt)) == expected
assert (await parser.aparse_with_prompt(completion, prompt)) == expected
def _extract_exception(

View File

@ -56,16 +56,16 @@ def test_debug_is_settable_directly() -> None:
def test_debug_is_settable_via_setter() -> None:
from langchain_core.callbacks.manager import _get_debug
from langchain import globals
from langchain import globals as langchain_globals
previous_value = globals._debug
previous_value = langchain_globals._debug
previous_fn_reading = _get_debug()
assert previous_value == previous_fn_reading
# Flip the value of the flag.
set_debug(not previous_value)
new_value = globals._debug
new_value = langchain_globals._debug
new_fn_reading = _get_debug()
try:
@ -115,17 +115,17 @@ def test_verbose_is_settable_directly() -> None:
def test_verbose_is_settable_via_setter() -> None:
from langchain import globals
from langchain import globals as langchain_globals
from langchain.chains.base import _get_verbosity
previous_value = globals._verbose
previous_value = langchain_globals._verbose
previous_fn_reading = _get_verbosity()
assert previous_value == previous_fn_reading
# Flip the value of the flag.
set_verbose(not previous_value)
new_value = globals._verbose
new_value = langchain_globals._verbose
new_fn_reading = _get_verbosity()
try:

View File

@ -26,9 +26,9 @@ def test_import_all() -> None:
mod = importlib.import_module(module_name)
all = getattr(mod, "__all__", [])
all_attrs = getattr(mod, "__all__", [])
for name in all:
for name in all_attrs:
# Attempt to import the name from the module
try:
obj = getattr(mod, name)
@ -65,9 +65,9 @@ def test_import_all_using_dir() -> None:
except ModuleNotFoundError as e:
msg = f"Could not import {module_name}"
raise ModuleNotFoundError(msg) from e
all = dir(mod)
attributes = dir(mod)
for name in all:
for name in attributes:
if name.strip().startswith("_"):
continue
# Attempt to import the name from the module

View File

@ -2981,7 +2981,7 @@ requires-dist = [
[package.metadata.requires-dev]
codespell = [{ name = "codespell", specifier = ">=2.2.0,<3.0.0" }]
lint = [{ name = "ruff", specifier = ">=0.9.2,<1.0.0" }]
lint = [{ name = "ruff", specifier = ">=0.12.2,<0.13" }]
test = [{ name = "langchain-core", editable = "../core" }]
test-integration = []
typing = [
@ -3007,7 +3007,7 @@ dev = [
]
lint = [
{ name = "langchain-core", editable = "../core" },
{ name = "ruff", specifier = ">=0.9.2,<1.0.0" },
{ name = "ruff", specifier = ">=0.12.2,<0.13" },
]
test = [
{ name = "freezegun", specifier = ">=1.2.2,<2.0.0" },