mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-11 15:35:09 +00:00
langchain: Add ruff rules A (#31888)
See https://docs.astral.sh/ruff/rules/#flake8-builtins-a --------- Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
parent
49c316667d
commit
f06380516f
@ -41,7 +41,7 @@ from langchain_core.runnables.utils import AddableDict
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils.input import get_color_mapping
|
||||
from pydantic import BaseModel, ConfigDict, model_validator
|
||||
from typing_extensions import Self
|
||||
from typing_extensions import Self, override
|
||||
|
||||
from langchain._api.deprecation import AGENT_DEPRECATION_WARNING
|
||||
from langchain.agents.agent_iterator import AgentExecutorIterator
|
||||
@ -1749,6 +1749,7 @@ class AgentExecutor(Chain):
|
||||
else:
|
||||
return intermediate_steps
|
||||
|
||||
@override
|
||||
def stream(
|
||||
self,
|
||||
input: Union[dict[str, Any], Any],
|
||||
@ -1779,6 +1780,7 @@ class AgentExecutor(Chain):
|
||||
)
|
||||
yield from iterator
|
||||
|
||||
@override
|
||||
async def astream(
|
||||
self,
|
||||
input: Union[dict[str, Any], Any],
|
||||
|
@ -20,7 +20,7 @@ from langchain_core.runnables import RunnableConfig, RunnableSerializable, ensur
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from typing_extensions import Self
|
||||
from typing_extensions import Self, override
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import openai
|
||||
@ -272,6 +272,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
|
||||
)
|
||||
return cls(assistant_id=assistant.id, client=client, **kwargs)
|
||||
|
||||
@override
|
||||
def invoke(
|
||||
self, input: dict, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> OutputType:
|
||||
@ -399,6 +400,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
|
||||
)
|
||||
return cls(assistant_id=assistant.id, async_client=async_client, **kwargs)
|
||||
|
||||
@override
|
||||
async def ainvoke(
|
||||
self, input: dict, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> OutputType:
|
||||
@ -515,10 +517,10 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
|
||||
}
|
||||
return submit_tool_outputs
|
||||
|
||||
def _create_run(self, input: dict) -> Any:
|
||||
def _create_run(self, input_dict: dict) -> Any:
|
||||
params = {
|
||||
k: v
|
||||
for k, v in input.items()
|
||||
for k, v in input_dict.items()
|
||||
if k
|
||||
in (
|
||||
"instructions",
|
||||
@ -534,15 +536,15 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
|
||||
)
|
||||
}
|
||||
return self.client.beta.threads.runs.create(
|
||||
input["thread_id"],
|
||||
input_dict["thread_id"],
|
||||
assistant_id=self.assistant_id,
|
||||
**params,
|
||||
)
|
||||
|
||||
def _create_thread_and_run(self, input: dict, thread: dict) -> Any:
|
||||
def _create_thread_and_run(self, input_dict: dict, thread: dict) -> Any:
|
||||
params = {
|
||||
k: v
|
||||
for k, v in input.items()
|
||||
for k, v in input_dict.items()
|
||||
if k
|
||||
in (
|
||||
"instructions",
|
||||
@ -673,10 +675,10 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
|
||||
}
|
||||
return submit_tool_outputs
|
||||
|
||||
async def _acreate_run(self, input: dict) -> Any:
|
||||
async def _acreate_run(self, input_dict: dict) -> Any:
|
||||
params = {
|
||||
k: v
|
||||
for k, v in input.items()
|
||||
for k, v in input_dict.items()
|
||||
if k
|
||||
in (
|
||||
"instructions",
|
||||
@ -692,15 +694,15 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
|
||||
)
|
||||
}
|
||||
return await self.async_client.beta.threads.runs.create(
|
||||
input["thread_id"],
|
||||
input_dict["thread_id"],
|
||||
assistant_id=self.assistant_id,
|
||||
**params,
|
||||
)
|
||||
|
||||
async def _acreate_thread_and_run(self, input: dict, thread: dict) -> Any:
|
||||
async def _acreate_thread_and_run(self, input_dict: dict, thread: dict) -> Any:
|
||||
params = {
|
||||
k: v
|
||||
for k, v in input.items()
|
||||
for k, v in input_dict.items()
|
||||
if k
|
||||
in (
|
||||
"instructions",
|
||||
|
@ -35,6 +35,7 @@ from pydantic import (
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain.schema import RUN_KEY
|
||||
|
||||
@ -118,6 +119,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
|
||||
# This is correct, but pydantic typings/mypy don't think so.
|
||||
return create_model("ChainOutput", **{k: (Any, None) for k in self.output_keys})
|
||||
|
||||
@override
|
||||
def invoke(
|
||||
self,
|
||||
input: dict[str, Any],
|
||||
@ -171,6 +173,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
|
||||
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
|
||||
return final_outputs
|
||||
|
||||
@override
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: dict[str, Any],
|
||||
|
@ -88,9 +88,11 @@ class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
|
||||
if fix_invalid:
|
||||
|
||||
def ast_parse(raw_filter: str) -> Optional[FilterDirective]:
|
||||
filter = cast(Optional[FilterDirective], get_parser().parse(raw_filter))
|
||||
filter_directive = cast(
|
||||
Optional[FilterDirective], get_parser().parse(raw_filter)
|
||||
)
|
||||
fixed = fix_filter_directive(
|
||||
filter,
|
||||
filter_directive,
|
||||
allowed_comparators=allowed_comparators,
|
||||
allowed_operators=allowed_operators,
|
||||
allowed_attributes=allowed_attributes,
|
||||
@ -107,7 +109,7 @@ class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
|
||||
|
||||
|
||||
def fix_filter_directive(
|
||||
filter: Optional[FilterDirective],
|
||||
filter: Optional[FilterDirective], # noqa: A002
|
||||
*,
|
||||
allowed_comparators: Optional[Sequence[Comparator]] = None,
|
||||
allowed_operators: Optional[Sequence[Operator]] = None,
|
||||
|
@ -28,7 +28,7 @@ from langchain_core.runnables.schema import StreamEvent
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.tracers import RunLog, RunLogPatch
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import TypeAlias
|
||||
from typing_extensions import TypeAlias, override
|
||||
|
||||
__all__ = [
|
||||
"init_chat_model",
|
||||
@ -673,6 +673,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
list[AnyMessage],
|
||||
]
|
||||
|
||||
@override
|
||||
def invoke(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
@ -681,6 +682,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
) -> Any:
|
||||
return self._model(config).invoke(input, config=config, **kwargs)
|
||||
|
||||
@override
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
@ -689,6 +691,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
) -> Any:
|
||||
return await self._model(config).ainvoke(input, config=config, **kwargs)
|
||||
|
||||
@override
|
||||
def stream(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
@ -697,6 +700,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
) -> Iterator[Any]:
|
||||
yield from self._model(config).stream(input, config=config, **kwargs)
|
||||
|
||||
@override
|
||||
async def astream(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
@ -802,6 +806,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
):
|
||||
yield x
|
||||
|
||||
@override
|
||||
def transform(
|
||||
self,
|
||||
input: Iterator[LanguageModelInput],
|
||||
@ -810,6 +815,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
) -> Iterator[Any]:
|
||||
yield from self._model(config).transform(input, config=config, **kwargs)
|
||||
|
||||
@override
|
||||
async def atransform(
|
||||
self,
|
||||
input: AsyncIterator[LanguageModelInput],
|
||||
@ -853,6 +859,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[RunLog]: ...
|
||||
|
||||
@override
|
||||
async def astream_log(
|
||||
self,
|
||||
input: Any,
|
||||
@ -883,6 +890,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
):
|
||||
yield x
|
||||
|
||||
@override
|
||||
async def astream_events(
|
||||
self,
|
||||
input: Any,
|
||||
|
@ -27,6 +27,7 @@ from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.output_parsers import BaseOutputParser
|
||||
from langchain_core.tools import BaseTool
|
||||
from pydantic import ConfigDict, Field
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.evaluation.agents.trajectory_eval_prompt import (
|
||||
@ -326,6 +327,7 @@ The following is the expected answer. Use this to measure correctness:
|
||||
)
|
||||
return cast(dict, self.output_parser.parse(raw_output))
|
||||
|
||||
@override
|
||||
def _evaluate_agent_trajectory(
|
||||
self,
|
||||
*,
|
||||
@ -368,6 +370,7 @@ The following is the expected answer. Use this to measure correctness:
|
||||
return_only_outputs=True,
|
||||
)
|
||||
|
||||
@override
|
||||
async def _aevaluate_agent_trajectory(
|
||||
self,
|
||||
*,
|
||||
|
@ -11,6 +11,7 @@ from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.output_parsers import BaseOutputParser
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from pydantic import ConfigDict, Field
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple
|
||||
from langchain.chains.llm import LLMChain
|
||||
@ -278,7 +279,7 @@ Performance may be significantly worse with other models."
|
||||
self,
|
||||
prediction: str,
|
||||
prediction_b: str,
|
||||
input: Optional[str],
|
||||
input_: Optional[str],
|
||||
reference: Optional[str],
|
||||
) -> dict:
|
||||
"""Prepare the input for the chain.
|
||||
@ -286,21 +287,21 @@ Performance may be significantly worse with other models."
|
||||
Args:
|
||||
prediction (str): The output string from the first model.
|
||||
prediction_b (str): The output string from the second model.
|
||||
input (str, optional): The input or task string.
|
||||
input_ (str, optional): The input or task string.
|
||||
reference (str, optional): The reference string, if any.
|
||||
|
||||
Returns:
|
||||
dict: The prepared input for the chain.
|
||||
|
||||
"""
|
||||
input_ = {
|
||||
input_dict = {
|
||||
"prediction": prediction,
|
||||
"prediction_b": prediction_b,
|
||||
"input": input,
|
||||
"input": input_,
|
||||
}
|
||||
if self.requires_reference:
|
||||
input_["reference"] = reference
|
||||
return input_
|
||||
input_dict["reference"] = reference
|
||||
return input_dict
|
||||
|
||||
def _prepare_output(self, result: dict) -> dict:
|
||||
"""Prepare the output."""
|
||||
@ -309,6 +310,7 @@ Performance may be significantly worse with other models."
|
||||
parsed[RUN_KEY] = result[RUN_KEY]
|
||||
return parsed
|
||||
|
||||
@override
|
||||
def _evaluate_string_pairs(
|
||||
self,
|
||||
*,
|
||||
@ -351,6 +353,7 @@ Performance may be significantly worse with other models."
|
||||
)
|
||||
return self._prepare_output(result)
|
||||
|
||||
@override
|
||||
async def _aevaluate_string_pairs(
|
||||
self,
|
||||
*,
|
||||
|
@ -10,6 +10,7 @@ from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.output_parsers import BaseOutputParser
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from pydantic import ConfigDict, Field
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple
|
||||
from langchain.chains.llm import LLMChain
|
||||
@ -383,16 +384,16 @@ class CriteriaEvalChain(StringEvaluator, LLMEvalChain, LLMChain):
|
||||
self,
|
||||
prediction: str,
|
||||
reference: Optional[str],
|
||||
input: Optional[str],
|
||||
input_: Optional[str],
|
||||
) -> dict:
|
||||
"""Get the evaluation input."""
|
||||
input_ = {
|
||||
"input": input,
|
||||
input_dict = {
|
||||
"input": input_,
|
||||
"output": prediction,
|
||||
}
|
||||
if self.requires_reference:
|
||||
input_["reference"] = reference
|
||||
return input_
|
||||
input_dict["reference"] = reference
|
||||
return input_dict
|
||||
|
||||
def _prepare_output(self, result: dict) -> dict:
|
||||
"""Prepare the output."""
|
||||
@ -401,6 +402,7 @@ class CriteriaEvalChain(StringEvaluator, LLMEvalChain, LLMChain):
|
||||
parsed[RUN_KEY] = result[RUN_KEY]
|
||||
return parsed
|
||||
|
||||
@override
|
||||
def _evaluate_strings(
|
||||
self,
|
||||
*,
|
||||
@ -456,6 +458,7 @@ class CriteriaEvalChain(StringEvaluator, LLMEvalChain, LLMChain):
|
||||
)
|
||||
return self._prepare_output(result)
|
||||
|
||||
@override
|
||||
async def _aevaluate_strings(
|
||||
self,
|
||||
*,
|
||||
|
@ -5,6 +5,7 @@ from operator import eq
|
||||
from typing import Any, Callable, Optional, Union, cast
|
||||
|
||||
from langchain_core.utils.json import parse_json_markdown
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain.evaluation.schema import StringEvaluator
|
||||
|
||||
@ -49,6 +50,7 @@ class JsonValidityEvaluator(StringEvaluator):
|
||||
def evaluation_name(self) -> str:
|
||||
return "json_validity"
|
||||
|
||||
@override
|
||||
def _evaluate_strings(
|
||||
self,
|
||||
prediction: str,
|
||||
@ -132,6 +134,7 @@ class JsonEqualityEvaluator(StringEvaluator):
|
||||
return parse_json_markdown(string)
|
||||
return string
|
||||
|
||||
@override
|
||||
def _evaluate_strings(
|
||||
self,
|
||||
prediction: str,
|
||||
|
@ -2,6 +2,7 @@ import json
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
from langchain_core.utils.json import parse_json_markdown
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain.evaluation.schema import StringEvaluator
|
||||
|
||||
@ -82,6 +83,7 @@ class JsonEditDistanceEvaluator(StringEvaluator):
|
||||
return parse_json_markdown(node)
|
||||
return node
|
||||
|
||||
@override
|
||||
def _evaluate_strings(
|
||||
self,
|
||||
prediction: str,
|
||||
|
@ -1,6 +1,7 @@
|
||||
from typing import Any, Union
|
||||
|
||||
from langchain_core.utils.json import parse_json_markdown
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain.evaluation.schema import StringEvaluator
|
||||
|
||||
@ -85,6 +86,7 @@ class JsonSchemaEvaluator(StringEvaluator):
|
||||
except ValidationError as e:
|
||||
return {"score": False, "reasoning": repr(e)}
|
||||
|
||||
@override
|
||||
def _evaluate_strings(
|
||||
self,
|
||||
prediction: Union[str, Any],
|
||||
|
@ -11,6 +11,7 @@ from langchain_core.callbacks import Callbacks
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from pydantic import ConfigDict
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.evaluation.qa.eval_prompt import CONTEXT_PROMPT, COT_PROMPT, PROMPT
|
||||
@ -154,6 +155,7 @@ class QAEvalChain(LLMChain, StringEvaluator, LLMEvalChain):
|
||||
parsed_result[RUN_KEY] = result[RUN_KEY]
|
||||
return parsed_result
|
||||
|
||||
@override
|
||||
def _evaluate_strings(
|
||||
self,
|
||||
*,
|
||||
@ -189,6 +191,7 @@ class QAEvalChain(LLMChain, StringEvaluator, LLMEvalChain):
|
||||
)
|
||||
return self._prepare_output(result)
|
||||
|
||||
@override
|
||||
async def _aevaluate_strings(
|
||||
self,
|
||||
*,
|
||||
@ -296,6 +299,7 @@ class ContextQAEvalChain(LLMChain, StringEvaluator, LLMEvalChain):
|
||||
parsed_result[RUN_KEY] = result[RUN_KEY]
|
||||
return parsed_result
|
||||
|
||||
@override
|
||||
def _evaluate_strings(
|
||||
self,
|
||||
*,
|
||||
@ -317,6 +321,7 @@ class ContextQAEvalChain(LLMChain, StringEvaluator, LLMEvalChain):
|
||||
)
|
||||
return self._prepare_output(result)
|
||||
|
||||
@override
|
||||
async def _aevaluate_strings(
|
||||
self,
|
||||
*,
|
||||
|
@ -109,13 +109,13 @@ class _EvalArgsMixin:
|
||||
def _check_evaluation_args(
|
||||
self,
|
||||
reference: Optional[str] = None,
|
||||
input: Optional[str] = None,
|
||||
input_: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Check if the evaluation arguments are valid.
|
||||
|
||||
Args:
|
||||
reference (Optional[str], optional): The reference label.
|
||||
input (Optional[str], optional): The input string.
|
||||
input_ (Optional[str], optional): The input string.
|
||||
Raises:
|
||||
ValueError: If the evaluator requires an input string but none is provided,
|
||||
or if the evaluator requires a reference label but none is provided.
|
||||
@ -152,7 +152,7 @@ class StringEvaluator(_EvalArgsMixin, ABC):
|
||||
*,
|
||||
prediction: Union[str, Any],
|
||||
reference: Optional[Union[str, Any]] = None,
|
||||
input: Optional[Union[str, Any]] = None,
|
||||
input: Optional[Union[str, Any]] = None, # noqa: A002
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
"""Evaluate Chain or LLM output, based on optional input and label.
|
||||
@ -175,7 +175,7 @@ class StringEvaluator(_EvalArgsMixin, ABC):
|
||||
*,
|
||||
prediction: Union[str, Any],
|
||||
reference: Optional[Union[str, Any]] = None,
|
||||
input: Optional[Union[str, Any]] = None,
|
||||
input: Optional[Union[str, Any]] = None, # noqa: A002
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
"""Asynchronously evaluate Chain or LLM output, based on optional input and label.
|
||||
@ -206,7 +206,7 @@ class StringEvaluator(_EvalArgsMixin, ABC):
|
||||
*,
|
||||
prediction: str,
|
||||
reference: Optional[str] = None,
|
||||
input: Optional[str] = None,
|
||||
input: Optional[str] = None, # noqa: A002
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
"""Evaluate Chain or LLM output, based on optional input and label.
|
||||
@ -219,7 +219,7 @@ class StringEvaluator(_EvalArgsMixin, ABC):
|
||||
Returns:
|
||||
dict: The evaluation results containing the score or value.
|
||||
""" # noqa: E501
|
||||
self._check_evaluation_args(reference=reference, input=input)
|
||||
self._check_evaluation_args(reference=reference, input_=input)
|
||||
return self._evaluate_strings(
|
||||
prediction=prediction, reference=reference, input=input, **kwargs
|
||||
)
|
||||
@ -229,7 +229,7 @@ class StringEvaluator(_EvalArgsMixin, ABC):
|
||||
*,
|
||||
prediction: str,
|
||||
reference: Optional[str] = None,
|
||||
input: Optional[str] = None,
|
||||
input: Optional[str] = None, # noqa: A002
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
"""Asynchronously evaluate Chain or LLM output, based on optional input and label.
|
||||
@ -242,7 +242,7 @@ class StringEvaluator(_EvalArgsMixin, ABC):
|
||||
Returns:
|
||||
dict: The evaluation results containing the score or value.
|
||||
""" # noqa: E501
|
||||
self._check_evaluation_args(reference=reference, input=input)
|
||||
self._check_evaluation_args(reference=reference, input_=input)
|
||||
return await self._aevaluate_strings(
|
||||
prediction=prediction, reference=reference, input=input, **kwargs
|
||||
)
|
||||
@ -258,7 +258,7 @@ class PairwiseStringEvaluator(_EvalArgsMixin, ABC):
|
||||
prediction: str,
|
||||
prediction_b: str,
|
||||
reference: Optional[str] = None,
|
||||
input: Optional[str] = None,
|
||||
input: Optional[str] = None, # noqa: A002
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
"""Evaluate the output string pairs.
|
||||
@ -279,7 +279,7 @@ class PairwiseStringEvaluator(_EvalArgsMixin, ABC):
|
||||
prediction: str,
|
||||
prediction_b: str,
|
||||
reference: Optional[str] = None,
|
||||
input: Optional[str] = None,
|
||||
input: Optional[str] = None, # noqa: A002
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
"""Asynchronously evaluate the output string pairs.
|
||||
@ -309,7 +309,7 @@ class PairwiseStringEvaluator(_EvalArgsMixin, ABC):
|
||||
prediction: str,
|
||||
prediction_b: str,
|
||||
reference: Optional[str] = None,
|
||||
input: Optional[str] = None,
|
||||
input: Optional[str] = None, # noqa: A002
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
"""Evaluate the output string pairs.
|
||||
@ -323,7 +323,7 @@ class PairwiseStringEvaluator(_EvalArgsMixin, ABC):
|
||||
Returns:
|
||||
dict: A dictionary containing the preference, scores, and/or other information.
|
||||
""" # noqa: E501
|
||||
self._check_evaluation_args(reference=reference, input=input)
|
||||
self._check_evaluation_args(reference=reference, input_=input)
|
||||
return self._evaluate_string_pairs(
|
||||
prediction=prediction,
|
||||
prediction_b=prediction_b,
|
||||
@ -338,7 +338,7 @@ class PairwiseStringEvaluator(_EvalArgsMixin, ABC):
|
||||
prediction: str,
|
||||
prediction_b: str,
|
||||
reference: Optional[str] = None,
|
||||
input: Optional[str] = None,
|
||||
input: Optional[str] = None, # noqa: A002
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
"""Asynchronously evaluate the output string pairs.
|
||||
@ -352,7 +352,7 @@ class PairwiseStringEvaluator(_EvalArgsMixin, ABC):
|
||||
Returns:
|
||||
dict: A dictionary containing the preference, scores, and/or other information.
|
||||
""" # noqa: E501
|
||||
self._check_evaluation_args(reference=reference, input=input)
|
||||
self._check_evaluation_args(reference=reference, input_=input)
|
||||
return await self._aevaluate_string_pairs(
|
||||
prediction=prediction,
|
||||
prediction_b=prediction_b,
|
||||
@ -376,7 +376,7 @@ class AgentTrajectoryEvaluator(_EvalArgsMixin, ABC):
|
||||
*,
|
||||
prediction: str,
|
||||
agent_trajectory: Sequence[tuple[AgentAction, str]],
|
||||
input: str,
|
||||
input: str, # noqa: A002
|
||||
reference: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
@ -398,7 +398,7 @@ class AgentTrajectoryEvaluator(_EvalArgsMixin, ABC):
|
||||
*,
|
||||
prediction: str,
|
||||
agent_trajectory: Sequence[tuple[AgentAction, str]],
|
||||
input: str,
|
||||
input: str, # noqa: A002
|
||||
reference: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
@ -429,7 +429,7 @@ class AgentTrajectoryEvaluator(_EvalArgsMixin, ABC):
|
||||
*,
|
||||
prediction: str,
|
||||
agent_trajectory: Sequence[tuple[AgentAction, str]],
|
||||
input: str,
|
||||
input: str, # noqa: A002
|
||||
reference: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
@ -445,7 +445,7 @@ class AgentTrajectoryEvaluator(_EvalArgsMixin, ABC):
|
||||
Returns:
|
||||
dict: The evaluation result.
|
||||
"""
|
||||
self._check_evaluation_args(reference=reference, input=input)
|
||||
self._check_evaluation_args(reference=reference, input_=input)
|
||||
return self._evaluate_agent_trajectory(
|
||||
prediction=prediction,
|
||||
input=input,
|
||||
@ -459,7 +459,7 @@ class AgentTrajectoryEvaluator(_EvalArgsMixin, ABC):
|
||||
*,
|
||||
prediction: str,
|
||||
agent_trajectory: Sequence[tuple[AgentAction, str]],
|
||||
input: str,
|
||||
input: str, # noqa: A002
|
||||
reference: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
@ -475,7 +475,7 @@ class AgentTrajectoryEvaluator(_EvalArgsMixin, ABC):
|
||||
Returns:
|
||||
dict: The evaluation result.
|
||||
"""
|
||||
self._check_evaluation_args(reference=reference, input=input)
|
||||
self._check_evaluation_args(reference=reference, input_=input)
|
||||
return await self._aevaluate_agent_trajectory(
|
||||
prediction=prediction,
|
||||
input=input,
|
||||
|
@ -11,6 +11,7 @@ from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.output_parsers import BaseOutputParser
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from pydantic import ConfigDict, Field
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple
|
||||
from langchain.chains.llm import LLMChain
|
||||
@ -292,7 +293,7 @@ Performance may be significantly worse with other models."
|
||||
def _prepare_input(
|
||||
self,
|
||||
prediction: str,
|
||||
input: Optional[str],
|
||||
input_: Optional[str],
|
||||
reference: Optional[str],
|
||||
) -> dict:
|
||||
"""Prepare the input for the chain.
|
||||
@ -300,20 +301,20 @@ Performance may be significantly worse with other models."
|
||||
Args:
|
||||
prediction (str): The output string from the first model.
|
||||
prediction_b (str): The output string from the second model.
|
||||
input (str, optional): The input or task string.
|
||||
input_ (str, optional): The input or task string.
|
||||
reference (str, optional): The reference string, if any.
|
||||
|
||||
Returns:
|
||||
dict: The prepared input for the chain.
|
||||
|
||||
"""
|
||||
input_ = {
|
||||
input_dict = {
|
||||
"prediction": prediction,
|
||||
"input": input,
|
||||
"input": input_,
|
||||
}
|
||||
if self.requires_reference:
|
||||
input_["reference"] = reference
|
||||
return input_
|
||||
input_dict["reference"] = reference
|
||||
return input_dict
|
||||
|
||||
def _prepare_output(self, result: dict) -> dict:
|
||||
"""Prepare the output."""
|
||||
@ -324,6 +325,7 @@ Performance may be significantly worse with other models."
|
||||
parsed["score"] = parsed["score"] / self.normalize_by
|
||||
return parsed
|
||||
|
||||
@override
|
||||
def _evaluate_strings(
|
||||
self,
|
||||
*,
|
||||
@ -361,7 +363,8 @@ Performance may be significantly worse with other models."
|
||||
)
|
||||
return self._prepare_output(result)
|
||||
|
||||
async def _aevaluate_string_pairs(
|
||||
@override
|
||||
async def _aevaluate_strings(
|
||||
self,
|
||||
*,
|
||||
prediction: str,
|
||||
|
@ -10,6 +10,7 @@ from langchain_core.callbacks.manager import (
|
||||
)
|
||||
from langchain_core.utils import pre_init
|
||||
from pydantic import Field
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.evaluation.schema import PairwiseStringEvaluator, StringEvaluator
|
||||
@ -260,6 +261,7 @@ class StringDistanceEvalChain(StringEvaluator, _RapidFuzzChainMixin):
|
||||
"""
|
||||
return {"score": self.compute_metric(inputs["reference"], inputs["prediction"])}
|
||||
|
||||
@override
|
||||
def _evaluate_strings(
|
||||
self,
|
||||
*,
|
||||
@ -295,6 +297,7 @@ class StringDistanceEvalChain(StringEvaluator, _RapidFuzzChainMixin):
|
||||
|
||||
return self._prepare_output(result)
|
||||
|
||||
@override
|
||||
async def _aevaluate_strings(
|
||||
self,
|
||||
*,
|
||||
|
@ -40,7 +40,7 @@ def _get_client(
|
||||
|
||||
def push(
|
||||
repo_full_name: str,
|
||||
object: Any,
|
||||
object: Any, # noqa: A002
|
||||
*,
|
||||
api_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
|
@ -167,7 +167,7 @@ class ConversationVectorStoreTokenBufferMemory(ConversationTokenBufferMemory):
|
||||
self._pop_and_store_interaction(buffer)
|
||||
|
||||
def _pop_and_store_interaction(self, buffer: list[BaseMessage]) -> None:
|
||||
input = buffer.pop(0)
|
||||
input_ = buffer.pop(0)
|
||||
output = buffer.pop(0)
|
||||
timestamp = self._timestamps.pop(0).strftime(TIMESTAMP_FORMAT)
|
||||
# Split AI output into smaller chunks to avoid creating documents
|
||||
@ -175,7 +175,7 @@ class ConversationVectorStoreTokenBufferMemory(ConversationTokenBufferMemory):
|
||||
ai_chunks = self._split_long_ai_text(str(output.content))
|
||||
for index, chunk in enumerate(ai_chunks):
|
||||
self.memory_retriever.save_context(
|
||||
{"Human": f"<{timestamp}/00> {str(input.content)}"},
|
||||
{"Human": f"<{timestamp}/00> {str(input_.content)}"},
|
||||
{"AI": f"<{timestamp}/{index:02}> {chunk}"},
|
||||
)
|
||||
|
||||
|
@ -28,6 +28,7 @@ from langchain_core.runnables.utils import (
|
||||
get_unique_config_specs,
|
||||
)
|
||||
from pydantic import model_validator
|
||||
from typing_extensions import override
|
||||
|
||||
T = TypeVar("T")
|
||||
H = TypeVar("H", bound=Hashable)
|
||||
@ -86,6 +87,7 @@ class EnsembleRetriever(BaseRetriever):
|
||||
values["weights"] = [1 / n_retrievers] * n_retrievers
|
||||
return values
|
||||
|
||||
@override
|
||||
def invoke(
|
||||
self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> list[Document]:
|
||||
@ -119,6 +121,7 @@ class EnsembleRetriever(BaseRetriever):
|
||||
)
|
||||
return result
|
||||
|
||||
@override
|
||||
async def ainvoke(
|
||||
self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> list[Document]:
|
||||
|
@ -88,23 +88,23 @@ class LocalFileStore(ByteStore):
|
||||
|
||||
return Path(full_path)
|
||||
|
||||
def _mkdir_for_store(self, dir: Path) -> None:
|
||||
def _mkdir_for_store(self, dir_path: Path) -> None:
|
||||
"""Makes a store directory path (including parents) with specified permissions
|
||||
|
||||
This is needed because `Path.mkdir()` is restricted by the current `umask`,
|
||||
whereas the explicit `os.chmod()` used here is not.
|
||||
|
||||
Args:
|
||||
dir: (Path) The store directory to make
|
||||
dir_path: (Path) The store directory to make
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if not dir.exists():
|
||||
self._mkdir_for_store(dir.parent)
|
||||
dir.mkdir(exist_ok=True)
|
||||
if not dir_path.exists():
|
||||
self._mkdir_for_store(dir_path.parent)
|
||||
dir_path.mkdir(exist_ok=True)
|
||||
if self.chmod_dir is not None:
|
||||
os.chmod(dir, self.chmod_dir)
|
||||
os.chmod(dir_path, self.chmod_dir)
|
||||
|
||||
def mget(self, keys: Sequence[str]) -> list[Optional[bytes]]:
|
||||
"""Get the values associated with the given keys.
|
||||
|
@ -143,7 +143,7 @@ ignore-regex = ".*(Stati Uniti|Tense=Pres).*"
|
||||
ignore-words-list = "momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogyny,unsecure,damon,crate,aadd,symbl,precesses,accademia,nin"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "I", "EM", "PGH003", "PIE", "T201", "D", "UP", "S", "W"]
|
||||
select = ["A", "E", "F", "I", "EM", "PGH003", "PIE", "T201", "D", "UP", "S", "W"]
|
||||
pydocstyle.convention = "google"
|
||||
pyupgrade.keep-runtime-typing = true
|
||||
|
||||
|
@ -33,10 +33,8 @@ print("Hello fifty shades of gray mans!"[::-1]) # noqa: T201
|
||||
)
|
||||
|
||||
|
||||
def _test_convo_output(
|
||||
input: str, expected_tool: str, expected_tool_input: str
|
||||
) -> None:
|
||||
result = ConvoOutputParser().parse(input.strip())
|
||||
def _test_convo_output(text: str, expected_tool: str, expected_tool_input: str) -> None:
|
||||
result = ConvoOutputParser().parse(text.strip())
|
||||
assert isinstance(result, AgentAction)
|
||||
assert result.tool == expected_tool
|
||||
assert result.tool_input == expected_tool_input
|
||||
|
@ -25,7 +25,7 @@ def test_parse_invalid_grammar(x: str) -> None:
|
||||
def test_parse_comparison() -> None:
|
||||
comp = 'gte("foo", 2)'
|
||||
expected = Comparison(comparator=Comparator.GTE, attribute="foo", value=2)
|
||||
for input in (
|
||||
for text in (
|
||||
comp,
|
||||
comp.replace('"', "'"),
|
||||
comp.replace(" ", ""),
|
||||
@ -34,7 +34,7 @@ def test_parse_comparison() -> None:
|
||||
comp.replace(",", ", "),
|
||||
comp.replace("2", "2.0"),
|
||||
):
|
||||
actual = DEFAULT_PARSER.parse(input)
|
||||
actual = DEFAULT_PARSER.parse(text)
|
||||
assert expected == actual
|
||||
|
||||
|
||||
@ -43,7 +43,7 @@ def test_parse_operation() -> None:
|
||||
eq = Comparison(comparator=Comparator.EQ, attribute="foo", value="bar")
|
||||
lt = Comparison(comparator=Comparator.LT, attribute="baz", value=1995.25)
|
||||
expected = Operation(operator=Operator.AND, arguments=[eq, lt])
|
||||
for input in (
|
||||
for text in (
|
||||
op,
|
||||
op.replace('"', "'"),
|
||||
op.replace(" ", ""),
|
||||
@ -52,7 +52,7 @@ def test_parse_operation() -> None:
|
||||
op.replace(",", ", "),
|
||||
op.replace("25", "250"),
|
||||
):
|
||||
actual = DEFAULT_PARSER.parse(input)
|
||||
actual = DEFAULT_PARSER.parse(text)
|
||||
assert expected == actual
|
||||
|
||||
|
||||
|
@ -152,7 +152,7 @@ def test_output_fixing_parser_output_type(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input,base_parser,retry_chain,expected",
|
||||
"completion,base_parser,retry_chain,expected",
|
||||
[
|
||||
(
|
||||
"2024/07/08",
|
||||
@ -171,7 +171,7 @@ def test_output_fixing_parser_output_type(
|
||||
],
|
||||
)
|
||||
def test_output_fixing_parser_parse_with_retry_chain(
|
||||
input: str,
|
||||
completion: str,
|
||||
base_parser: BaseOutputParser[T],
|
||||
retry_chain: Runnable[dict[str, Any], str],
|
||||
expected: T,
|
||||
@ -185,11 +185,11 @@ def test_output_fixing_parser_parse_with_retry_chain(
|
||||
retry_chain=retry_chain,
|
||||
legacy=False,
|
||||
)
|
||||
assert parser.parse(input) == expected
|
||||
assert parser.parse(completion) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input,base_parser,retry_chain,expected",
|
||||
"completion,base_parser,retry_chain,expected",
|
||||
[
|
||||
(
|
||||
"2024/07/08",
|
||||
@ -208,7 +208,7 @@ def test_output_fixing_parser_parse_with_retry_chain(
|
||||
],
|
||||
)
|
||||
async def test_output_fixing_parser_aparse_with_retry_chain(
|
||||
input: str,
|
||||
completion: str,
|
||||
base_parser: BaseOutputParser[T],
|
||||
retry_chain: Runnable[dict[str, Any], str],
|
||||
expected: T,
|
||||
@ -221,7 +221,7 @@ async def test_output_fixing_parser_aparse_with_retry_chain(
|
||||
retry_chain=retry_chain,
|
||||
legacy=False,
|
||||
)
|
||||
assert (await parser.aparse(input)) == expected
|
||||
assert (await parser.aparse(completion)) == expected
|
||||
|
||||
|
||||
def _extract_exception(
|
||||
|
@ -202,7 +202,7 @@ def test_retry_with_error_output_parser_parse_is_not_implemented() -> None:
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input,prompt,base_parser,retry_chain,expected",
|
||||
"completion,prompt,base_parser,retry_chain,expected",
|
||||
[
|
||||
(
|
||||
"2024/07/08",
|
||||
@ -215,7 +215,7 @@ def test_retry_with_error_output_parser_parse_is_not_implemented() -> None:
|
||||
],
|
||||
)
|
||||
def test_retry_output_parser_parse_with_prompt_with_retry_chain(
|
||||
input: str,
|
||||
completion: str,
|
||||
prompt: PromptValue,
|
||||
base_parser: BaseOutputParser[T],
|
||||
retry_chain: Runnable[dict[str, Any], str],
|
||||
@ -226,11 +226,11 @@ def test_retry_output_parser_parse_with_prompt_with_retry_chain(
|
||||
retry_chain=retry_chain,
|
||||
legacy=False,
|
||||
)
|
||||
assert parser.parse_with_prompt(input, prompt) == expected
|
||||
assert parser.parse_with_prompt(completion, prompt) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input,prompt,base_parser,retry_chain,expected",
|
||||
"completion,prompt,base_parser,retry_chain,expected",
|
||||
[
|
||||
(
|
||||
"2024/07/08",
|
||||
@ -243,7 +243,7 @@ def test_retry_output_parser_parse_with_prompt_with_retry_chain(
|
||||
],
|
||||
)
|
||||
async def test_retry_output_parser_aparse_with_prompt_with_retry_chain(
|
||||
input: str,
|
||||
completion: str,
|
||||
prompt: PromptValue,
|
||||
base_parser: BaseOutputParser[T],
|
||||
retry_chain: Runnable[dict[str, Any], str],
|
||||
@ -255,11 +255,11 @@ async def test_retry_output_parser_aparse_with_prompt_with_retry_chain(
|
||||
retry_chain=retry_chain,
|
||||
legacy=False,
|
||||
)
|
||||
assert (await parser.aparse_with_prompt(input, prompt)) == expected
|
||||
assert (await parser.aparse_with_prompt(completion, prompt)) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input,prompt,base_parser,retry_chain,expected",
|
||||
"completion,prompt,base_parser,retry_chain,expected",
|
||||
[
|
||||
(
|
||||
"2024/07/08",
|
||||
@ -272,7 +272,7 @@ async def test_retry_output_parser_aparse_with_prompt_with_retry_chain(
|
||||
],
|
||||
)
|
||||
def test_retry_with_error_output_parser_parse_with_prompt_with_retry_chain(
|
||||
input: str,
|
||||
completion: str,
|
||||
prompt: PromptValue,
|
||||
base_parser: BaseOutputParser[T],
|
||||
retry_chain: Runnable[dict[str, Any], str],
|
||||
@ -284,11 +284,11 @@ def test_retry_with_error_output_parser_parse_with_prompt_with_retry_chain(
|
||||
retry_chain=retry_chain,
|
||||
legacy=False,
|
||||
)
|
||||
assert parser.parse_with_prompt(input, prompt) == expected
|
||||
assert parser.parse_with_prompt(completion, prompt) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input,prompt,base_parser,retry_chain,expected",
|
||||
"completion,prompt,base_parser,retry_chain,expected",
|
||||
[
|
||||
(
|
||||
"2024/07/08",
|
||||
@ -301,7 +301,7 @@ def test_retry_with_error_output_parser_parse_with_prompt_with_retry_chain(
|
||||
],
|
||||
)
|
||||
async def test_retry_with_error_output_parser_aparse_with_prompt_with_retry_chain(
|
||||
input: str,
|
||||
completion: str,
|
||||
prompt: PromptValue,
|
||||
base_parser: BaseOutputParser[T],
|
||||
retry_chain: Runnable[dict[str, Any], str],
|
||||
@ -312,7 +312,7 @@ async def test_retry_with_error_output_parser_aparse_with_prompt_with_retry_chai
|
||||
retry_chain=retry_chain,
|
||||
legacy=False,
|
||||
)
|
||||
assert (await parser.aparse_with_prompt(input, prompt)) == expected
|
||||
assert (await parser.aparse_with_prompt(completion, prompt)) == expected
|
||||
|
||||
|
||||
def _extract_exception(
|
||||
|
@ -56,16 +56,16 @@ def test_debug_is_settable_directly() -> None:
|
||||
def test_debug_is_settable_via_setter() -> None:
|
||||
from langchain_core.callbacks.manager import _get_debug
|
||||
|
||||
from langchain import globals
|
||||
from langchain import globals as langchain_globals
|
||||
|
||||
previous_value = globals._debug
|
||||
previous_value = langchain_globals._debug
|
||||
previous_fn_reading = _get_debug()
|
||||
assert previous_value == previous_fn_reading
|
||||
|
||||
# Flip the value of the flag.
|
||||
set_debug(not previous_value)
|
||||
|
||||
new_value = globals._debug
|
||||
new_value = langchain_globals._debug
|
||||
new_fn_reading = _get_debug()
|
||||
|
||||
try:
|
||||
@ -115,17 +115,17 @@ def test_verbose_is_settable_directly() -> None:
|
||||
|
||||
|
||||
def test_verbose_is_settable_via_setter() -> None:
|
||||
from langchain import globals
|
||||
from langchain import globals as langchain_globals
|
||||
from langchain.chains.base import _get_verbosity
|
||||
|
||||
previous_value = globals._verbose
|
||||
previous_value = langchain_globals._verbose
|
||||
previous_fn_reading = _get_verbosity()
|
||||
assert previous_value == previous_fn_reading
|
||||
|
||||
# Flip the value of the flag.
|
||||
set_verbose(not previous_value)
|
||||
|
||||
new_value = globals._verbose
|
||||
new_value = langchain_globals._verbose
|
||||
new_fn_reading = _get_verbosity()
|
||||
|
||||
try:
|
||||
|
@ -26,9 +26,9 @@ def test_import_all() -> None:
|
||||
|
||||
mod = importlib.import_module(module_name)
|
||||
|
||||
all = getattr(mod, "__all__", [])
|
||||
all_attrs = getattr(mod, "__all__", [])
|
||||
|
||||
for name in all:
|
||||
for name in all_attrs:
|
||||
# Attempt to import the name from the module
|
||||
try:
|
||||
obj = getattr(mod, name)
|
||||
@ -65,9 +65,9 @@ def test_import_all_using_dir() -> None:
|
||||
except ModuleNotFoundError as e:
|
||||
msg = f"Could not import {module_name}"
|
||||
raise ModuleNotFoundError(msg) from e
|
||||
all = dir(mod)
|
||||
attributes = dir(mod)
|
||||
|
||||
for name in all:
|
||||
for name in attributes:
|
||||
if name.strip().startswith("_"):
|
||||
continue
|
||||
# Attempt to import the name from the module
|
||||
|
@ -2981,7 +2981,7 @@ requires-dist = [
|
||||
|
||||
[package.metadata.requires-dev]
|
||||
codespell = [{ name = "codespell", specifier = ">=2.2.0,<3.0.0" }]
|
||||
lint = [{ name = "ruff", specifier = ">=0.9.2,<1.0.0" }]
|
||||
lint = [{ name = "ruff", specifier = ">=0.12.2,<0.13" }]
|
||||
test = [{ name = "langchain-core", editable = "../core" }]
|
||||
test-integration = []
|
||||
typing = [
|
||||
@ -3007,7 +3007,7 @@ dev = [
|
||||
]
|
||||
lint = [
|
||||
{ name = "langchain-core", editable = "../core" },
|
||||
{ name = "ruff", specifier = ">=0.9.2,<1.0.0" },
|
||||
{ name = "ruff", specifier = ">=0.12.2,<0.13" },
|
||||
]
|
||||
test = [
|
||||
{ name = "freezegun", specifier = ">=1.2.2,<2.0.0" },
|
||||
|
Loading…
Reference in New Issue
Block a user