mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-04 02:33:05 +00:00
feat(langchain): add ruff rules PL (#32079)
See https://docs.astral.sh/ruff/rules/#pylint-pl
This commit is contained in:
parent
0f39155f62
commit
3496e1739e
@ -357,7 +357,7 @@ def __getattr__(name: str) -> Any:
|
|||||||
|
|
||||||
return ElasticVectorSearch
|
return ElasticVectorSearch
|
||||||
# For backwards compatibility
|
# For backwards compatibility
|
||||||
if name == "SerpAPIChain" or name == "SerpAPIWrapper":
|
if name in {"SerpAPIChain", "SerpAPIWrapper"}:
|
||||||
from langchain_community.utilities import SerpAPIWrapper
|
from langchain_community.utilities import SerpAPIWrapper
|
||||||
|
|
||||||
_warn_on_import(
|
_warn_on_import(
|
||||||
|
@ -40,13 +40,13 @@ def format_xml(
|
|||||||
# Escape XML tags in tool names and inputs using custom delimiters
|
# Escape XML tags in tool names and inputs using custom delimiters
|
||||||
tool = _escape(action.tool)
|
tool = _escape(action.tool)
|
||||||
tool_input = _escape(str(action.tool_input))
|
tool_input = _escape(str(action.tool_input))
|
||||||
observation = _escape(str(observation))
|
observation_ = _escape(str(observation))
|
||||||
else:
|
else:
|
||||||
tool = action.tool
|
tool = action.tool
|
||||||
tool_input = str(action.tool_input)
|
tool_input = str(action.tool_input)
|
||||||
observation = str(observation)
|
observation_ = str(observation)
|
||||||
log += (
|
log += (
|
||||||
f"<tool>{tool}</tool><tool_input>{tool_input}"
|
f"<tool>{tool}</tool><tool_input>{tool_input}"
|
||||||
f"</tool_input><observation>{observation}</observation>"
|
f"</tool_input><observation>{observation_}</observation>"
|
||||||
)
|
)
|
||||||
return log
|
return log
|
||||||
|
@ -582,7 +582,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
|
|||||||
major_version = int(openai.version.VERSION.split(".")[0])
|
major_version = int(openai.version.VERSION.split(".")[0])
|
||||||
minor_version = int(openai.version.VERSION.split(".")[1])
|
minor_version = int(openai.version.VERSION.split(".")[1])
|
||||||
version_gte_1_14 = (major_version > 1) or (
|
version_gte_1_14 = (major_version > 1) or (
|
||||||
major_version == 1 and minor_version >= 14
|
major_version == 1 and minor_version >= 14 # noqa: PLR2004
|
||||||
)
|
)
|
||||||
|
|
||||||
messages = self.client.beta.threads.messages.list(
|
messages = self.client.beta.threads.messages.list(
|
||||||
@ -739,7 +739,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
|
|||||||
major_version = int(openai.version.VERSION.split(".")[0])
|
major_version = int(openai.version.VERSION.split(".")[0])
|
||||||
minor_version = int(openai.version.VERSION.split(".")[1])
|
minor_version = int(openai.version.VERSION.split(".")[1])
|
||||||
version_gte_1_14 = (major_version > 1) or (
|
version_gte_1_14 = (major_version > 1) or (
|
||||||
major_version == 1 and minor_version >= 14
|
major_version == 1 and minor_version >= 14 # noqa: PLR2004
|
||||||
)
|
)
|
||||||
|
|
||||||
messages = await self.async_client.beta.threads.messages.list(
|
messages = await self.async_client.beta.threads.messages.list(
|
||||||
|
@ -24,6 +24,9 @@ if TYPE_CHECKING:
|
|||||||
from langchain_community.docstore.base import Docstore
|
from langchain_community.docstore.base import Docstore
|
||||||
|
|
||||||
|
|
||||||
|
_LOOKUP_AND_SEARCH_TOOLS = {"Lookup", "Search"}
|
||||||
|
|
||||||
|
|
||||||
@deprecated(
|
@deprecated(
|
||||||
"0.1.0",
|
"0.1.0",
|
||||||
message=AGENT_DEPRECATION_WARNING,
|
message=AGENT_DEPRECATION_WARNING,
|
||||||
@ -52,11 +55,11 @@ class ReActDocstoreAgent(Agent):
|
|||||||
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
||||||
validate_tools_single_input(cls.__name__, tools)
|
validate_tools_single_input(cls.__name__, tools)
|
||||||
super()._validate_tools(tools)
|
super()._validate_tools(tools)
|
||||||
if len(tools) != 2:
|
if len(tools) != len(_LOOKUP_AND_SEARCH_TOOLS):
|
||||||
msg = f"Exactly two tools must be specified, but got {tools}"
|
msg = f"Exactly two tools must be specified, but got {tools}"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
tool_names = {tool.name for tool in tools}
|
tool_names = {tool.name for tool in tools}
|
||||||
if tool_names != {"Lookup", "Search"}:
|
if tool_names != _LOOKUP_AND_SEARCH_TOOLS:
|
||||||
msg = f"Tool names should be Lookup and Search, got {tool_names}"
|
msg = f"Tool names should be Lookup and Search, got {tool_names}"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
@ -47,8 +47,8 @@ def _check_in_allowed_domain(url: str, limit_to_domains: Sequence[str]) -> bool:
|
|||||||
scheme, domain = _extract_scheme_and_domain(url)
|
scheme, domain = _extract_scheme_and_domain(url)
|
||||||
|
|
||||||
for allowed_domain in limit_to_domains:
|
for allowed_domain in limit_to_domains:
|
||||||
allowed_scheme, allowed_domain = _extract_scheme_and_domain(allowed_domain)
|
allowed_scheme, allowed_domain_ = _extract_scheme_and_domain(allowed_domain)
|
||||||
if scheme == allowed_scheme and domain == allowed_domain:
|
if scheme == allowed_scheme and domain == allowed_domain_:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -193,8 +193,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
"multiple llm_chain input_variables"
|
"multiple llm_chain input_variables"
|
||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
else:
|
elif values["document_variable_name"] not in llm_chain_variables:
|
||||||
if values["document_variable_name"] not in llm_chain_variables:
|
|
||||||
msg = (
|
msg = (
|
||||||
f"document_variable_name {values['document_variable_name']} was "
|
f"document_variable_name {values['document_variable_name']} was "
|
||||||
f"not found in llm_chain input_variables: {llm_chain_variables}"
|
f"not found in llm_chain input_variables: {llm_chain_variables}"
|
||||||
|
@ -161,8 +161,7 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
"multiple llm_chain input_variables"
|
"multiple llm_chain input_variables"
|
||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
else:
|
elif values["document_variable_name"] not in llm_chain_variables:
|
||||||
if values["document_variable_name"] not in llm_chain_variables:
|
|
||||||
msg = (
|
msg = (
|
||||||
f"document_variable_name {values['document_variable_name']} was "
|
f"document_variable_name {values['document_variable_name']} was "
|
||||||
f"not found in llm_chain input_variables: {llm_chain_variables}"
|
f"not found in llm_chain input_variables: {llm_chain_variables}"
|
||||||
|
@ -325,10 +325,10 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
_token_max,
|
_token_max,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
result_docs = []
|
result_docs = [
|
||||||
for docs in new_result_doc_list:
|
collapse_docs(docs_, _collapse_docs_func, **kwargs)
|
||||||
new_doc = collapse_docs(docs, _collapse_docs_func, **kwargs)
|
for docs_ in new_result_doc_list
|
||||||
result_docs.append(new_doc)
|
]
|
||||||
num_tokens = length_func(result_docs, **kwargs)
|
num_tokens = length_func(result_docs, **kwargs)
|
||||||
retries += 1
|
retries += 1
|
||||||
if self.collapse_max_retries and retries == self.collapse_max_retries:
|
if self.collapse_max_retries and retries == self.collapse_max_retries:
|
||||||
@ -364,10 +364,10 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
_token_max,
|
_token_max,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
result_docs = []
|
result_docs = [
|
||||||
for docs in new_result_doc_list:
|
await acollapse_docs(docs_, _collapse_docs_func, **kwargs)
|
||||||
new_doc = await acollapse_docs(docs, _collapse_docs_func, **kwargs)
|
for docs_ in new_result_doc_list
|
||||||
result_docs.append(new_doc)
|
]
|
||||||
num_tokens = length_func(result_docs, **kwargs)
|
num_tokens = length_func(result_docs, **kwargs)
|
||||||
retries += 1
|
retries += 1
|
||||||
if self.collapse_max_retries and retries == self.collapse_max_retries:
|
if self.collapse_max_retries and retries == self.collapse_max_retries:
|
||||||
|
@ -140,8 +140,7 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
"multiple llm_chain input_variables"
|
"multiple llm_chain input_variables"
|
||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
else:
|
elif values["document_variable_name"] not in llm_chain_variables:
|
||||||
if values["document_variable_name"] not in llm_chain_variables:
|
|
||||||
msg = (
|
msg = (
|
||||||
f"document_variable_name {values['document_variable_name']} was "
|
f"document_variable_name {values['document_variable_name']} was "
|
||||||
f"not found in llm_chain input_variables: {llm_chain_variables}"
|
f"not found in llm_chain input_variables: {llm_chain_variables}"
|
||||||
|
@ -180,8 +180,7 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
"multiple llm_chain_variables"
|
"multiple llm_chain_variables"
|
||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
else:
|
elif values["document_variable_name"] not in llm_chain_variables:
|
||||||
if values["document_variable_name"] not in llm_chain_variables:
|
|
||||||
msg = (
|
msg = (
|
||||||
f"document_variable_name {values['document_variable_name']} was "
|
f"document_variable_name {values['document_variable_name']} was "
|
||||||
f"not found in llm_chain input_variables: {llm_chain_variables}"
|
f"not found in llm_chain input_variables: {llm_chain_variables}"
|
||||||
|
@ -322,7 +322,7 @@ class Crawler:
|
|||||||
|
|
||||||
if node_name == "#text" and ancestor_exception and ancestor_node:
|
if node_name == "#text" and ancestor_exception and ancestor_node:
|
||||||
text = strings[node_value[index]]
|
text = strings[node_value[index]]
|
||||||
if text == "|" or text == "•":
|
if text in {"|", "•"}:
|
||||||
continue
|
continue
|
||||||
ancestor_node.append({"type": "type", "value": text})
|
ancestor_node.append({"type": "type", "value": text})
|
||||||
else:
|
else:
|
||||||
@ -367,7 +367,7 @@ class Crawler:
|
|||||||
element_node_value = strings[text_index]
|
element_node_value = strings[text_index]
|
||||||
|
|
||||||
# remove redundant elements
|
# remove redundant elements
|
||||||
if ancestor_exception and (node_name != "a" and node_name != "button"):
|
if ancestor_exception and (node_name not in {"a", "button"}):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
elements_in_view_port.append(
|
elements_in_view_port.append(
|
||||||
@ -423,10 +423,7 @@ class Crawler:
|
|||||||
# not very elegant, more like a placeholder
|
# not very elegant, more like a placeholder
|
||||||
if (
|
if (
|
||||||
(converted_node_name != "button" or meta == "")
|
(converted_node_name != "button" or meta == "")
|
||||||
and converted_node_name != "link"
|
and converted_node_name not in {"link", "input", "img", "textarea"}
|
||||||
and converted_node_name != "input"
|
|
||||||
and converted_node_name != "img"
|
|
||||||
and converted_node_name != "textarea"
|
|
||||||
) and inner_text.strip() == "":
|
) and inner_text.strip() == "":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -51,8 +51,7 @@ def _format_url(url: str, path_params: dict) -> str:
|
|||||||
sep = ","
|
sep = ","
|
||||||
new_val = ""
|
new_val = ""
|
||||||
new_val += sep.join(kv_strs)
|
new_val += sep.join(kv_strs)
|
||||||
else:
|
elif param[0] == ".":
|
||||||
if param[0] == ".":
|
|
||||||
new_val = f".{val}"
|
new_val = f".{val}"
|
||||||
elif param[0] == ";":
|
elif param[0] == ";":
|
||||||
new_val = f";{clean_param}={val}"
|
new_val = f";{clean_param}={val}"
|
||||||
@ -224,7 +223,7 @@ class SimpleRequestChain(Chain):
|
|||||||
_text = f"Calling endpoint {_pretty_name} with arguments:\n" + _pretty_args
|
_text = f"Calling endpoint {_pretty_name} with arguments:\n" + _pretty_args
|
||||||
_run_manager.on_text(_text)
|
_run_manager.on_text(_text)
|
||||||
api_response: Response = self.request_method(name, args)
|
api_response: Response = self.request_method(name, args)
|
||||||
if api_response.status_code != 200:
|
if api_response.status_code != requests.codes.ok:
|
||||||
response = (
|
response = (
|
||||||
f"{api_response.status_code}: {api_response.reason}"
|
f"{api_response.status_code}: {api_response.reason}"
|
||||||
f"\nFor {name} "
|
f"\nFor {name} "
|
||||||
|
@ -87,7 +87,7 @@ class HypotheticalDocumentEmbedder:
|
|||||||
)
|
)
|
||||||
from langchain.chains.hyde.base import HypotheticalDocumentEmbedder as H
|
from langchain.chains.hyde.base import HypotheticalDocumentEmbedder as H
|
||||||
|
|
||||||
return H(*args, **kwargs) # type: ignore[return-value]
|
return H(*args, **kwargs) # type: ignore[return-value] # noqa: PLE0101
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_llm(cls, *args: Any, **kwargs: Any) -> Any:
|
def from_llm(cls, *args: Any, **kwargs: Any) -> Any:
|
||||||
|
@ -89,7 +89,6 @@ def _infer_model_and_provider(
|
|||||||
if provider is None and ":" in model:
|
if provider is None and ":" in model:
|
||||||
provider, model_name = _parse_model_string(model)
|
provider, model_name = _parse_model_string(model)
|
||||||
else:
|
else:
|
||||||
provider = provider
|
|
||||||
model_name = model
|
model_name = model
|
||||||
|
|
||||||
if not provider:
|
if not provider:
|
||||||
|
@ -89,7 +89,7 @@ _warned_about_sha1: bool = False
|
|||||||
|
|
||||||
def _warn_about_sha1_encoder() -> None:
|
def _warn_about_sha1_encoder() -> None:
|
||||||
"""Emit a one-time warning about SHA-1 collision weaknesses."""
|
"""Emit a one-time warning about SHA-1 collision weaknesses."""
|
||||||
global _warned_about_sha1
|
global _warned_about_sha1 # noqa: PLW0603
|
||||||
if not _warned_about_sha1:
|
if not _warned_about_sha1:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Using default key encoder: SHA-1 is *not* collision-resistant. "
|
"Using default key encoder: SHA-1 is *not* collision-resistant. "
|
||||||
|
@ -36,6 +36,8 @@ from langchain.evaluation.agents.trajectory_eval_prompt import (
|
|||||||
)
|
)
|
||||||
from langchain.evaluation.schema import AgentTrajectoryEvaluator, LLMEvalChain
|
from langchain.evaluation.schema import AgentTrajectoryEvaluator, LLMEvalChain
|
||||||
|
|
||||||
|
_MAX_SCORE = 5
|
||||||
|
|
||||||
|
|
||||||
class TrajectoryEval(TypedDict):
|
class TrajectoryEval(TypedDict):
|
||||||
"""A named tuple containing the score and reasoning for a trajectory."""
|
"""A named tuple containing the score and reasoning for a trajectory."""
|
||||||
@ -86,10 +88,10 @@ class TrajectoryOutputParser(BaseOutputParser):
|
|||||||
raise OutputParserException(msg)
|
raise OutputParserException(msg)
|
||||||
score = int(_score.group(1))
|
score = int(_score.group(1))
|
||||||
# If the score is not in the range 1-5, raise an exception.
|
# If the score is not in the range 1-5, raise an exception.
|
||||||
if not 1 <= score <= 5:
|
if not 1 <= score <= _MAX_SCORE:
|
||||||
msg = f"Score is not a digit in the range 1-5: {text}"
|
msg = f"Score is not a digit in the range 1-5: {text}"
|
||||||
raise OutputParserException(msg)
|
raise OutputParserException(msg)
|
||||||
normalized_score = (score - 1) / 4
|
normalized_score = (score - 1) / (_MAX_SCORE - 1)
|
||||||
return TrajectoryEval(score=normalized_score, reasoning=reasoning)
|
return TrajectoryEval(score=normalized_score, reasoning=reasoning)
|
||||||
|
|
||||||
|
|
||||||
|
@ -39,7 +39,7 @@ def set_verbose(
|
|||||||
# have migrated to using `set_verbose()` here.
|
# have migrated to using `set_verbose()` here.
|
||||||
langchain.verbose = value
|
langchain.verbose = value
|
||||||
|
|
||||||
global _verbose
|
global _verbose # noqa: PLW0603
|
||||||
_verbose = value
|
_verbose = value
|
||||||
|
|
||||||
|
|
||||||
@ -69,7 +69,6 @@ def get_verbose() -> bool:
|
|||||||
# directing them to use `set_verbose()` when they import `langchain.verbose`.
|
# directing them to use `set_verbose()` when they import `langchain.verbose`.
|
||||||
old_verbose = langchain.verbose
|
old_verbose = langchain.verbose
|
||||||
|
|
||||||
global _verbose
|
|
||||||
return _verbose or old_verbose
|
return _verbose or old_verbose
|
||||||
|
|
||||||
|
|
||||||
@ -94,7 +93,7 @@ def set_debug(
|
|||||||
# have migrated to using `set_debug()` here.
|
# have migrated to using `set_debug()` here.
|
||||||
langchain.debug = value
|
langchain.debug = value
|
||||||
|
|
||||||
global _debug
|
global _debug # noqa: PLW0603
|
||||||
_debug = value
|
_debug = value
|
||||||
|
|
||||||
|
|
||||||
@ -122,7 +121,6 @@ def get_debug() -> bool:
|
|||||||
# directing them to use `set_debug()` when they import `langchain.debug`.
|
# directing them to use `set_debug()` when they import `langchain.debug`.
|
||||||
old_debug = langchain.debug
|
old_debug = langchain.debug
|
||||||
|
|
||||||
global _debug
|
|
||||||
return _debug or old_debug
|
return _debug or old_debug
|
||||||
|
|
||||||
|
|
||||||
@ -147,7 +145,7 @@ def set_llm_cache(value: Optional["BaseCache"]) -> None:
|
|||||||
# once all users have migrated to using `set_llm_cache()` here.
|
# once all users have migrated to using `set_llm_cache()` here.
|
||||||
langchain.llm_cache = value
|
langchain.llm_cache = value
|
||||||
|
|
||||||
global _llm_cache
|
global _llm_cache # noqa: PLW0603
|
||||||
_llm_cache = value
|
_llm_cache = value
|
||||||
|
|
||||||
|
|
||||||
@ -179,5 +177,4 @@ def get_llm_cache() -> "BaseCache":
|
|||||||
# to use `set_llm_cache()` when they import `langchain.llm_cache`.
|
# to use `set_llm_cache()` when they import `langchain.llm_cache`.
|
||||||
old_llm_cache = langchain.llm_cache
|
old_llm_cache = langchain.llm_cache
|
||||||
|
|
||||||
global _llm_cache
|
|
||||||
return _llm_cache or old_llm_cache
|
return _llm_cache or old_llm_cache
|
||||||
|
@ -5,6 +5,8 @@ from typing import Any
|
|||||||
from langchain_core.output_parsers import BaseOutputParser
|
from langchain_core.output_parsers import BaseOutputParser
|
||||||
from langchain_core.utils import pre_init
|
from langchain_core.utils import pre_init
|
||||||
|
|
||||||
|
_MIN_PARSERS = 2
|
||||||
|
|
||||||
|
|
||||||
class CombiningOutputParser(BaseOutputParser[dict[str, Any]]):
|
class CombiningOutputParser(BaseOutputParser[dict[str, Any]]):
|
||||||
"""Combine multiple output parsers into one."""
|
"""Combine multiple output parsers into one."""
|
||||||
@ -19,7 +21,7 @@ class CombiningOutputParser(BaseOutputParser[dict[str, Any]]):
|
|||||||
def validate_parsers(cls, values: dict[str, Any]) -> dict[str, Any]:
|
def validate_parsers(cls, values: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""Validate the parsers."""
|
"""Validate the parsers."""
|
||||||
parsers = values["parsers"]
|
parsers = values["parsers"]
|
||||||
if len(parsers) < 2:
|
if len(parsers) < _MIN_PARSERS:
|
||||||
msg = "Must have at least two parsers"
|
msg = "Must have at least two parsers"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
for parser in parsers:
|
for parser in parsers:
|
||||||
|
@ -79,7 +79,7 @@ class PandasDataFrameOutputParser(BaseOutputParser[dict[str, Any]]):
|
|||||||
def parse(self, request: str) -> dict[str, Any]:
|
def parse(self, request: str) -> dict[str, Any]:
|
||||||
stripped_request_params = None
|
stripped_request_params = None
|
||||||
splitted_request = request.strip().split(":")
|
splitted_request = request.strip().split(":")
|
||||||
if len(splitted_request) != 2:
|
if len(splitted_request) != 2: # noqa: PLR2004
|
||||||
msg = f"Request '{request}' is not correctly formatted. \
|
msg = f"Request '{request}' is not correctly formatted. \
|
||||||
Please refer to the format instructions."
|
Please refer to the format instructions."
|
||||||
raise OutputParserException(msg)
|
raise OutputParserException(msg)
|
||||||
@ -127,8 +127,7 @@ class PandasDataFrameOutputParser(BaseOutputParser[dict[str, Any]]):
|
|||||||
filtered_df[stripped_request_params],
|
filtered_df[stripped_request_params],
|
||||||
request_type,
|
request_type,
|
||||||
)()
|
)()
|
||||||
else:
|
elif request_type == "column":
|
||||||
if request_type == "column":
|
|
||||||
result[request_params] = self.dataframe[request_params]
|
result[request_params] = self.dataframe[request_params]
|
||||||
elif request_type == "row":
|
elif request_type == "row":
|
||||||
result[request_params] = self.dataframe.iloc[int(request_params)]
|
result[request_params] = self.dataframe.iloc[int(request_params)]
|
||||||
|
@ -70,8 +70,7 @@ class LLMChainFilter(BaseDocumentCompressor):
|
|||||||
output = output_[self.llm_chain.output_key]
|
output = output_[self.llm_chain.output_key]
|
||||||
if self.llm_chain.prompt.output_parser is not None:
|
if self.llm_chain.prompt.output_parser is not None:
|
||||||
include_doc = self.llm_chain.prompt.output_parser.parse(output)
|
include_doc = self.llm_chain.prompt.output_parser.parse(output)
|
||||||
else:
|
elif isinstance(output_, bool):
|
||||||
if isinstance(output_, bool):
|
|
||||||
include_doc = output_
|
include_doc = output_
|
||||||
if include_doc:
|
if include_doc:
|
||||||
filtered_docs.append(doc)
|
filtered_docs.append(doc)
|
||||||
@ -101,8 +100,7 @@ class LLMChainFilter(BaseDocumentCompressor):
|
|||||||
output = output_[self.llm_chain.output_key]
|
output = output_[self.llm_chain.output_key]
|
||||||
if self.llm_chain.prompt.output_parser is not None:
|
if self.llm_chain.prompt.output_parser is not None:
|
||||||
include_doc = self.llm_chain.prompt.output_parser.parse(output)
|
include_doc = self.llm_chain.prompt.output_parser.parse(output)
|
||||||
else:
|
elif isinstance(output_, bool):
|
||||||
if isinstance(output_, bool):
|
|
||||||
include_doc = output_
|
include_doc = output_
|
||||||
if include_doc:
|
if include_doc:
|
||||||
filtered_docs.append(doc)
|
filtered_docs.append(doc)
|
||||||
|
@ -147,11 +147,9 @@ select = [
|
|||||||
"A", # flake8-builtins
|
"A", # flake8-builtins
|
||||||
"B", # flake8-bugbear
|
"B", # flake8-bugbear
|
||||||
"ASYNC", # flake8-async
|
"ASYNC", # flake8-async
|
||||||
"B", # flake8-bugbear
|
|
||||||
"C4", # flake8-comprehensions
|
"C4", # flake8-comprehensions
|
||||||
"COM", # flake8-commas
|
"COM", # flake8-commas
|
||||||
"D", # pydocstyle
|
"D", # pydocstyle
|
||||||
"DOC", # pydoclint
|
|
||||||
"DTZ", # flake8-datetimez
|
"DTZ", # flake8-datetimez
|
||||||
"E", # pycodestyle error
|
"E", # pycodestyle error
|
||||||
"EM", # flake8-errmsg
|
"EM", # flake8-errmsg
|
||||||
@ -164,9 +162,10 @@ select = [
|
|||||||
"ICN", # flake8-import-conventions
|
"ICN", # flake8-import-conventions
|
||||||
"INT", # flake8-gettext
|
"INT", # flake8-gettext
|
||||||
"ISC", # isort-comprehensions
|
"ISC", # isort-comprehensions
|
||||||
|
"PERF", # flake8-perf
|
||||||
"PGH", # pygrep-hooks
|
"PGH", # pygrep-hooks
|
||||||
"PIE", # flake8-pie
|
"PIE", # flake8-pie
|
||||||
"PERF", # flake8-perf
|
"PL", # pylint
|
||||||
"PT", # flake8-pytest-style
|
"PT", # flake8-pytest-style
|
||||||
"PTH", # flake8-use-pathlib
|
"PTH", # flake8-use-pathlib
|
||||||
"PYI", # flake8-pyi
|
"PYI", # flake8-pyi
|
||||||
@ -175,9 +174,9 @@ select = [
|
|||||||
"RSE", # flake8-rst-docstrings
|
"RSE", # flake8-rst-docstrings
|
||||||
"RUF", # ruff
|
"RUF", # ruff
|
||||||
"S", # flake8-bandit
|
"S", # flake8-bandit
|
||||||
"SLF", # flake8-self
|
|
||||||
"SLOT", # flake8-slots
|
"SLOT", # flake8-slots
|
||||||
"SIM", # flake8-simplify
|
"SIM", # flake8-simplify
|
||||||
|
"SLF", # flake8-self
|
||||||
"T10", # flake8-debugger
|
"T10", # flake8-debugger
|
||||||
"T20", # flake8-print
|
"T20", # flake8-print
|
||||||
"TID", # flake8-tidy-imports
|
"TID", # flake8-tidy-imports
|
||||||
@ -198,12 +197,14 @@ ignore = [
|
|||||||
"COM812", # Messes with the formatter
|
"COM812", # Messes with the formatter
|
||||||
"ISC001", # Messes with the formatter
|
"ISC001", # Messes with the formatter
|
||||||
"PERF203", # Rarely useful
|
"PERF203", # Rarely useful
|
||||||
|
"PLR09", # Too many something (args, statements, etc)
|
||||||
"S112", # Rarely useful
|
"S112", # Rarely useful
|
||||||
"RUF012", # Doesn't play well with Pydantic
|
"RUF012", # Doesn't play well with Pydantic
|
||||||
"SLF001", # Private member access
|
"SLF001", # Private member access
|
||||||
"UP007", # pyupgrade: non-pep604-annotation-union
|
"UP007", # pyupgrade: non-pep604-annotation-union
|
||||||
|
|
||||||
# TODO
|
# TODO rules
|
||||||
|
"PLC0415", # pylint: import-outside-top-level
|
||||||
"TRY301", # tryceratops: raise-within-try
|
"TRY301", # tryceratops: raise-within-try
|
||||||
]
|
]
|
||||||
unfixable = ["B028"] # People should intentionally tune the stacklevel
|
unfixable = ["B028"] # People should intentionally tune the stacklevel
|
||||||
@ -220,6 +221,7 @@ pyupgrade.keep-runtime-typing = true
|
|||||||
"S101", # Tests need assertions
|
"S101", # Tests need assertions
|
||||||
"S311", # Standard pseudo-random generators are not suitable for cryptographic purposes
|
"S311", # Standard pseudo-random generators are not suitable for cryptographic purposes
|
||||||
"SLF001", # Private member access in tests
|
"SLF001", # Private member access in tests
|
||||||
|
"PLR2004", # Magic value comparisons
|
||||||
]
|
]
|
||||||
"langchain/chains/constitutional_ai/principles.py" = [
|
"langchain/chains/constitutional_ai/principles.py" = [
|
||||||
"E501", # Line too long
|
"E501", # Line too long
|
||||||
|
@ -130,8 +130,7 @@ def pytest_collection_modifyitems(
|
|||||||
pytest.mark.skip(reason=f"Requires pkg: `{pkg}`"),
|
pytest.mark.skip(reason=f"Requires pkg: `{pkg}`"),
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
else:
|
elif only_extended:
|
||||||
if only_extended:
|
|
||||||
item.add_marker(
|
item.add_marker(
|
||||||
pytest.mark.skip(reason="Skipping not an extended test."),
|
pytest.mark.skip(reason="Skipping not an extended test."),
|
||||||
)
|
)
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import pytest as pytest
|
import pytest
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
|
|
||||||
from langchain.retrievers.multi_query import LineListOutputParser, _unique_documents
|
from langchain.retrievers.multi_query import LineListOutputParser, _unique_documents
|
||||||
|
@ -10,6 +10,8 @@ class AnyStr(str):
|
|||||||
def __eq__(self, other: object) -> bool:
|
def __eq__(self, other: object) -> bool:
|
||||||
return isinstance(other, str)
|
return isinstance(other, str)
|
||||||
|
|
||||||
|
__hash__ = str.__hash__
|
||||||
|
|
||||||
|
|
||||||
# The code below creates version of pydantic models
|
# The code below creates version of pydantic models
|
||||||
# that will work in unit tests with AnyStr as id field
|
# that will work in unit tests with AnyStr as id field
|
||||||
|
Loading…
Reference in New Issue
Block a user