mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-12 15:59:56 +00:00
langchain: Add ruff rules C4 (#31879)
All auto-fixes See https://docs.astral.sh/ruff/rules/#flake8-comprehensions-c4 --------- Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
parent
4134b36db8
commit
fceebbb387
@ -1126,9 +1126,9 @@ class AgentExecutor(Chain):
|
||||
agent = self.agent
|
||||
tools = self.tools
|
||||
allowed_tools = agent.get_allowed_tools() # type: ignore[union-attr]
|
||||
if allowed_tools is not None and set(allowed_tools) != set(
|
||||
[tool.name for tool in tools]
|
||||
):
|
||||
if allowed_tools is not None and set(allowed_tools) != {
|
||||
tool.name for tool in tools
|
||||
}:
|
||||
msg = (
|
||||
f"Allowed tools ({allowed_tools}) different than "
|
||||
f"provided tools ({[tool.name for tool in tools]})"
|
||||
@ -1318,16 +1318,15 @@ class AgentExecutor(Chain):
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Union[AgentFinish, list[tuple[AgentAction, str]]]:
|
||||
return self._consume_next_step(
|
||||
[
|
||||
a
|
||||
for a in self._iter_next_step(
|
||||
list(
|
||||
self._iter_next_step(
|
||||
name_to_tool_map,
|
||||
color_mapping,
|
||||
inputs,
|
||||
intermediate_steps,
|
||||
run_manager,
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
def _iter_next_step(
|
||||
|
@ -37,7 +37,7 @@ class SelfAskOutputParser(AgentOutputParser):
|
||||
|
||||
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
|
||||
last_line = text.split("\n")[-1]
|
||||
if not any([follow in last_line for follow in self.followups]):
|
||||
if not any(follow in last_line for follow in self.followups):
|
||||
if self.finish_string not in last_line:
|
||||
msg = f"Could not parse output: {text}"
|
||||
raise OutputParserException(msg)
|
||||
|
@ -24,7 +24,7 @@ class InvalidTool(BaseTool):
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Use the tool."""
|
||||
available_tool_names_str = ", ".join([tool for tool in available_tool_names])
|
||||
available_tool_names_str = ", ".join(list(available_tool_names))
|
||||
return (
|
||||
f"{requested_tool_name} is not a valid tool, "
|
||||
f"try one of [{available_tool_names_str}]."
|
||||
@ -37,7 +37,7 @@ class InvalidTool(BaseTool):
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Use the tool asynchronously."""
|
||||
available_tool_names_str = ", ".join([tool for tool in available_tool_names])
|
||||
available_tool_names_str = ", ".join(list(available_tool_names))
|
||||
return (
|
||||
f"{requested_tool_name} is not a valid tool, "
|
||||
f"try one of [{available_tool_names_str}]."
|
||||
|
@ -352,7 +352,7 @@ try:
|
||||
headers: Optional[dict] = None,
|
||||
api_url_prompt: BasePromptTemplate = API_URL_PROMPT,
|
||||
api_response_prompt: BasePromptTemplate = API_RESPONSE_PROMPT,
|
||||
limit_to_domains: Optional[Sequence[str]] = tuple(),
|
||||
limit_to_domains: Optional[Sequence[str]] = (),
|
||||
**kwargs: Any,
|
||||
) -> APIChain:
|
||||
"""Load chain from just an LLM and the api docs."""
|
||||
|
@ -112,13 +112,15 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> type[BaseModel]:
|
||||
# This is correct, but pydantic typings/mypy don't think so.
|
||||
return create_model("ChainInput", **{k: (Any, None) for k in self.input_keys})
|
||||
return create_model("ChainInput", **dict.fromkeys(self.input_keys, (Any, None)))
|
||||
|
||||
def get_output_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> type[BaseModel]:
|
||||
# This is correct, but pydantic typings/mypy don't think so.
|
||||
return create_model("ChainOutput", **{k: (Any, None) for k in self.output_keys})
|
||||
return create_model(
|
||||
"ChainOutput", **dict.fromkeys(self.output_keys, (Any, None))
|
||||
)
|
||||
|
||||
@override
|
||||
def invoke(
|
||||
|
@ -100,7 +100,7 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
|
||||
if self.return_intermediate_steps:
|
||||
schema["intermediate_steps"] = (list[str], None)
|
||||
if self.metadata_keys:
|
||||
schema.update({key: (Any, None) for key in self.metadata_keys})
|
||||
schema.update(dict.fromkeys(self.metadata_keys, (Any, None)))
|
||||
|
||||
return create_model("MapRerankOutput", **schema)
|
||||
|
||||
|
@ -48,7 +48,7 @@ class SequentialChain(Chain):
|
||||
"""Validate that the correct inputs exist for all chains."""
|
||||
chains = values["chains"]
|
||||
input_variables = values["input_variables"]
|
||||
memory_keys = list()
|
||||
memory_keys = []
|
||||
if "memory" in values and values["memory"] is not None:
|
||||
"""Validate that prompt input variables are consistent."""
|
||||
memory_keys = values["memory"].memory_variables
|
||||
|
@ -69,7 +69,7 @@ def load_dataset(uri: str) -> list[dict]:
|
||||
raise ImportError(msg)
|
||||
|
||||
dataset = load_dataset(f"LangChainDatasets/{uri}")
|
||||
return [d for d in dataset["train"]]
|
||||
return list(dataset["train"])
|
||||
|
||||
|
||||
_EVALUATOR_MAP: dict[
|
||||
|
@ -311,10 +311,10 @@ class SQLRecordManager(RecordManager):
|
||||
).values(records_to_upsert)
|
||||
stmt = sqlite_insert_stmt.on_conflict_do_update(
|
||||
[UpsertionRecord.key, UpsertionRecord.namespace],
|
||||
set_=dict(
|
||||
updated_at=sqlite_insert_stmt.excluded.updated_at,
|
||||
group_id=sqlite_insert_stmt.excluded.group_id,
|
||||
),
|
||||
set_={
|
||||
"updated_at": sqlite_insert_stmt.excluded.updated_at,
|
||||
"group_id": sqlite_insert_stmt.excluded.group_id,
|
||||
},
|
||||
)
|
||||
elif self.dialect == "postgresql":
|
||||
from sqlalchemy.dialects.postgresql import Insert as PgInsertType
|
||||
@ -327,10 +327,10 @@ class SQLRecordManager(RecordManager):
|
||||
)
|
||||
stmt = pg_insert_stmt.on_conflict_do_update( # type: ignore[assignment]
|
||||
"uix_key_namespace", # Name of constraint
|
||||
set_=dict(
|
||||
updated_at=pg_insert_stmt.excluded.updated_at,
|
||||
group_id=pg_insert_stmt.excluded.group_id,
|
||||
),
|
||||
set_={
|
||||
"updated_at": pg_insert_stmt.excluded.updated_at,
|
||||
"group_id": pg_insert_stmt.excluded.group_id,
|
||||
},
|
||||
)
|
||||
else:
|
||||
msg = f"Unsupported dialect {self.dialect}"
|
||||
@ -393,10 +393,10 @@ class SQLRecordManager(RecordManager):
|
||||
).values(records_to_upsert)
|
||||
stmt = sqlite_insert_stmt.on_conflict_do_update(
|
||||
[UpsertionRecord.key, UpsertionRecord.namespace],
|
||||
set_=dict(
|
||||
updated_at=sqlite_insert_stmt.excluded.updated_at,
|
||||
group_id=sqlite_insert_stmt.excluded.group_id,
|
||||
),
|
||||
set_={
|
||||
"updated_at": sqlite_insert_stmt.excluded.updated_at,
|
||||
"group_id": sqlite_insert_stmt.excluded.group_id,
|
||||
},
|
||||
)
|
||||
elif self.dialect == "postgresql":
|
||||
from sqlalchemy.dialects.postgresql import Insert as PgInsertType
|
||||
@ -409,10 +409,10 @@ class SQLRecordManager(RecordManager):
|
||||
)
|
||||
stmt = pg_insert_stmt.on_conflict_do_update( # type: ignore[assignment]
|
||||
"uix_key_namespace", # Name of constraint
|
||||
set_=dict(
|
||||
updated_at=pg_insert_stmt.excluded.updated_at,
|
||||
group_id=pg_insert_stmt.excluded.group_id,
|
||||
),
|
||||
set_={
|
||||
"updated_at": pg_insert_stmt.excluded.updated_at,
|
||||
"group_id": pg_insert_stmt.excluded.group_id,
|
||||
},
|
||||
)
|
||||
else:
|
||||
msg = f"Unsupported dialect {self.dialect}"
|
||||
@ -432,7 +432,7 @@ class SQLRecordManager(RecordManager):
|
||||
)
|
||||
)
|
||||
records = filtered_query.all()
|
||||
found_keys = set(r.key for r in records)
|
||||
found_keys = {r.key for r in records}
|
||||
return [k in found_keys for k in keys]
|
||||
|
||||
async def aexists(self, keys: Sequence[str]) -> list[bool]:
|
||||
|
@ -8,7 +8,7 @@ class SimpleMemory(BaseMemory):
|
||||
ever change between prompts.
|
||||
"""
|
||||
|
||||
memories: dict[str, Any] = dict()
|
||||
memories: dict[str, Any] = {}
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> list[str]:
|
||||
|
@ -49,7 +49,7 @@ class CombiningOutputParser(BaseOutputParser[dict[str, Any]]):
|
||||
def parse(self, text: str) -> dict[str, Any]:
|
||||
"""Parse the output of an LLM call."""
|
||||
texts = text.split("\n\n")
|
||||
output = dict()
|
||||
output = {}
|
||||
for txt, parser in zip(texts, self.parsers):
|
||||
output.update(parser.parse(txt.strip()))
|
||||
return output
|
||||
|
@ -82,19 +82,19 @@ class OutputFixingParser(BaseOutputParser[T]):
|
||||
else:
|
||||
try:
|
||||
completion = self.retry_chain.invoke(
|
||||
dict(
|
||||
instructions=self.parser.get_format_instructions(),
|
||||
completion=completion,
|
||||
error=repr(e),
|
||||
)
|
||||
{
|
||||
"instructions": self.parser.get_format_instructions(), # noqa: E501
|
||||
"completion": completion,
|
||||
"error": repr(e),
|
||||
}
|
||||
)
|
||||
except (NotImplementedError, AttributeError):
|
||||
# Case: self.parser does not have get_format_instructions
|
||||
completion = self.retry_chain.invoke(
|
||||
dict(
|
||||
completion=completion,
|
||||
error=repr(e),
|
||||
)
|
||||
{
|
||||
"completion": completion,
|
||||
"error": repr(e),
|
||||
}
|
||||
)
|
||||
|
||||
msg = "Failed to parse"
|
||||
@ -120,19 +120,19 @@ class OutputFixingParser(BaseOutputParser[T]):
|
||||
else:
|
||||
try:
|
||||
completion = await self.retry_chain.ainvoke(
|
||||
dict(
|
||||
instructions=self.parser.get_format_instructions(),
|
||||
completion=completion,
|
||||
error=repr(e),
|
||||
)
|
||||
{
|
||||
"instructions": self.parser.get_format_instructions(), # noqa: E501
|
||||
"completion": completion,
|
||||
"error": repr(e),
|
||||
}
|
||||
)
|
||||
except (NotImplementedError, AttributeError):
|
||||
# Case: self.parser does not have get_format_instructions
|
||||
completion = await self.retry_chain.ainvoke(
|
||||
dict(
|
||||
completion=completion,
|
||||
error=repr(e),
|
||||
)
|
||||
{
|
||||
"completion": completion,
|
||||
"error": repr(e),
|
||||
}
|
||||
)
|
||||
|
||||
msg = "Failed to parse"
|
||||
|
@ -116,10 +116,10 @@ class RetryOutputParser(BaseOutputParser[T]):
|
||||
)
|
||||
else:
|
||||
completion = self.retry_chain.invoke(
|
||||
dict(
|
||||
prompt=prompt_value.to_string(),
|
||||
completion=completion,
|
||||
)
|
||||
{
|
||||
"prompt": prompt_value.to_string(),
|
||||
"completion": completion,
|
||||
}
|
||||
)
|
||||
|
||||
msg = "Failed to parse"
|
||||
@ -153,10 +153,10 @@ class RetryOutputParser(BaseOutputParser[T]):
|
||||
)
|
||||
else:
|
||||
completion = await self.retry_chain.ainvoke(
|
||||
dict(
|
||||
prompt=prompt_value.to_string(),
|
||||
completion=completion,
|
||||
)
|
||||
{
|
||||
"prompt": prompt_value.to_string(),
|
||||
"completion": completion,
|
||||
}
|
||||
)
|
||||
|
||||
msg = "Failed to parse"
|
||||
@ -244,11 +244,11 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
|
||||
)
|
||||
else:
|
||||
completion = self.retry_chain.invoke(
|
||||
dict(
|
||||
completion=completion,
|
||||
prompt=prompt_value.to_string(),
|
||||
error=repr(e),
|
||||
)
|
||||
{
|
||||
"completion": completion,
|
||||
"prompt": prompt_value.to_string(),
|
||||
"error": repr(e),
|
||||
}
|
||||
)
|
||||
|
||||
msg = "Failed to parse"
|
||||
@ -273,11 +273,11 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
|
||||
)
|
||||
else:
|
||||
completion = await self.retry_chain.ainvoke(
|
||||
dict(
|
||||
prompt=prompt_value.to_string(),
|
||||
completion=completion,
|
||||
error=repr(e),
|
||||
)
|
||||
{
|
||||
"prompt": prompt_value.to_string(),
|
||||
"completion": completion,
|
||||
"error": repr(e),
|
||||
}
|
||||
)
|
||||
|
||||
msg = "Failed to parse"
|
||||
|
@ -43,7 +43,7 @@ class YamlOutputParser(BaseOutputParser[T]):
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
# Copy schema to avoid altering original Pydantic schema.
|
||||
schema = {k: v for k, v in self.pydantic_object.schema().items()}
|
||||
schema = dict(self.pydantic_object.schema().items())
|
||||
|
||||
# Remove extraneous fields.
|
||||
reduced_schema = schema
|
||||
|
@ -24,7 +24,7 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever):
|
||||
vectorstore: VectorStore
|
||||
"""The vectorstore to store documents and determine salience."""
|
||||
|
||||
search_kwargs: dict = Field(default_factory=lambda: dict(k=100))
|
||||
search_kwargs: dict = Field(default_factory=lambda: {"k": 100})
|
||||
"""Keyword arguments to pass to the vectorstore similarity search."""
|
||||
|
||||
# TODO: abstract as a queue
|
||||
|
@ -623,7 +623,7 @@ def _load_run_evaluators(
|
||||
input_key, prediction_key, reference_key = None, None, None
|
||||
if config.evaluators or (
|
||||
config.custom_evaluators
|
||||
and any([isinstance(e, StringEvaluator) for e in config.custom_evaluators])
|
||||
and any(isinstance(e, StringEvaluator) for e in config.custom_evaluators)
|
||||
):
|
||||
input_key, prediction_key, reference_key = _get_keys(
|
||||
config, run_inputs, run_outputs, example_outputs
|
||||
|
@ -143,7 +143,7 @@ ignore-regex = ".*(Stati Uniti|Tense=Pres).*"
|
||||
ignore-words-list = "momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogyny,unsecure,damon,crate,aadd,symbl,precesses,accademia,nin"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["A", "E", "F", "I", "EM", "PGH003", "PIE", "SIM", "T201", "D", "UP", "S", "W"]
|
||||
select = ["A", "C4", "E", "F", "I", "EM", "PGH003", "PIE", "SIM", "T201", "D", "UP", "S", "W"]
|
||||
pydocstyle.convention = "google"
|
||||
pyupgrade.keep-runtime-typing = true
|
||||
|
||||
|
@ -205,7 +205,7 @@ def test_agent_stream() -> None:
|
||||
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
||||
)
|
||||
|
||||
output = [a for a in agent.stream("when was langchain made")]
|
||||
output = list(agent.stream("when was langchain made"))
|
||||
assert output == [
|
||||
{
|
||||
"actions": [
|
||||
|
@ -24,7 +24,7 @@ def test_resolve_criteria_enum(criterion: Criteria) -> None:
|
||||
def test_resolve_criteria_list_enum() -> None:
|
||||
val = resolve_pairwise_criteria(list(Criteria))
|
||||
assert isinstance(val, dict)
|
||||
assert set(val.keys()) == set(c.value for c in list(Criteria))
|
||||
assert set(val.keys()) == {c.value for c in list(Criteria)}
|
||||
|
||||
|
||||
def test_PairwiseStringResultOutputParser_parse() -> None:
|
||||
|
@ -288,11 +288,11 @@ def test_index_simple_delete_full(
|
||||
"num_updated": 0,
|
||||
}
|
||||
|
||||
doc_texts = set(
|
||||
doc_texts = {
|
||||
# Ignoring type since doc should be in the store and not a None
|
||||
vector_store.store.get(uid).page_content # type: ignore[union-attr]
|
||||
for uid in vector_store.store
|
||||
)
|
||||
}
|
||||
assert doc_texts == {"mutated document 1", "This is another document."}
|
||||
|
||||
# Attempt to index again verify that nothing changes
|
||||
@ -364,11 +364,11 @@ async def test_aindex_simple_delete_full(
|
||||
"num_updated": 0,
|
||||
}
|
||||
|
||||
doc_texts = set(
|
||||
doc_texts = {
|
||||
# Ignoring type since doc should be in the store and not a None
|
||||
vector_store.store.get(uid).page_content # type: ignore[union-attr]
|
||||
for uid in vector_store.store
|
||||
)
|
||||
}
|
||||
assert doc_texts == {"mutated document 1", "This is another document."}
|
||||
|
||||
# Attempt to index again verify that nothing changes
|
||||
@ -657,11 +657,11 @@ def test_incremental_delete(
|
||||
"num_updated": 0,
|
||||
}
|
||||
|
||||
doc_texts = set(
|
||||
doc_texts = {
|
||||
# Ignoring type since doc should be in the store and not a None
|
||||
vector_store.store.get(uid).page_content # type: ignore[union-attr]
|
||||
for uid in vector_store.store
|
||||
)
|
||||
}
|
||||
assert doc_texts == {"This is another document.", "This is a test document."}
|
||||
|
||||
# Attempt to index again verify that nothing changes
|
||||
@ -716,11 +716,11 @@ def test_incremental_delete(
|
||||
"num_updated": 0,
|
||||
}
|
||||
|
||||
doc_texts = set(
|
||||
doc_texts = {
|
||||
# Ignoring type since doc should be in the store and not a None
|
||||
vector_store.store.get(uid).page_content # type: ignore[union-attr]
|
||||
for uid in vector_store.store
|
||||
)
|
||||
}
|
||||
assert doc_texts == {
|
||||
"mutated document 1",
|
||||
"mutated document 2",
|
||||
@ -784,11 +784,11 @@ def test_incremental_indexing_with_batch_size(
|
||||
"num_updated": 0,
|
||||
}
|
||||
|
||||
doc_texts = set(
|
||||
doc_texts = {
|
||||
# Ignoring type since doc should be in the store and not a None
|
||||
vector_store.store.get(uid).page_content # type: ignore[union-attr]
|
||||
for uid in vector_store.store
|
||||
)
|
||||
}
|
||||
assert doc_texts == {"1", "2", "3", "4"}
|
||||
|
||||
|
||||
@ -834,11 +834,11 @@ def test_incremental_delete_with_batch_size(
|
||||
"num_updated": 0,
|
||||
}
|
||||
|
||||
doc_texts = set(
|
||||
doc_texts = {
|
||||
# Ignoring type since doc should be in the store and not a None
|
||||
vector_store.store.get(uid).page_content # type: ignore[union-attr]
|
||||
for uid in vector_store.store
|
||||
)
|
||||
}
|
||||
assert doc_texts == {"1", "2", "3", "4"}
|
||||
|
||||
# Attempt to index again verify that nothing changes
|
||||
@ -980,11 +980,11 @@ async def test_aincremental_delete(
|
||||
"num_updated": 0,
|
||||
}
|
||||
|
||||
doc_texts = set(
|
||||
doc_texts = {
|
||||
# Ignoring type since doc should be in the store and not a None
|
||||
vector_store.store.get(uid).page_content # type: ignore[union-attr]
|
||||
for uid in vector_store.store
|
||||
)
|
||||
}
|
||||
assert doc_texts == {"This is another document.", "This is a test document."}
|
||||
|
||||
# Attempt to index again verify that nothing changes
|
||||
@ -1039,11 +1039,11 @@ async def test_aincremental_delete(
|
||||
"num_updated": 0,
|
||||
}
|
||||
|
||||
doc_texts = set(
|
||||
doc_texts = {
|
||||
# Ignoring type since doc should be in the store and not a None
|
||||
vector_store.store.get(uid).page_content # type: ignore[union-attr]
|
||||
for uid in vector_store.store
|
||||
)
|
||||
}
|
||||
assert doc_texts == {
|
||||
"mutated document 1",
|
||||
"mutated document 2",
|
||||
|
@ -51,7 +51,7 @@ async def test_generic_fake_chat_model_stream() -> None:
|
||||
_AnyIdAIMessageChunk(content="goodbye"),
|
||||
]
|
||||
|
||||
chunks = [chunk for chunk in model.stream("meow")]
|
||||
chunks = list(model.stream("meow"))
|
||||
assert chunks == [
|
||||
_AnyIdAIMessageChunk(content="hello"),
|
||||
_AnyIdAIMessageChunk(content=" "),
|
||||
|
@ -28,7 +28,7 @@ def test_required_dependencies(uv_conf: Mapping[str, Any]) -> None:
|
||||
"""
|
||||
# Get the dependencies from the [tool.poetry.dependencies] section
|
||||
dependencies = uv_conf["project"]["dependencies"]
|
||||
required_dependencies = set(Requirement(dep).name for dep in dependencies)
|
||||
required_dependencies = {Requirement(dep).name for dep in dependencies}
|
||||
|
||||
assert sorted(required_dependencies) == sorted(
|
||||
[
|
||||
@ -54,7 +54,7 @@ def test_test_group_dependencies(uv_conf: Mapping[str, Any]) -> None:
|
||||
"""
|
||||
|
||||
dependencies = uv_conf["dependency-groups"]["test"]
|
||||
test_group_deps = set(Requirement(dep).name for dep in dependencies)
|
||||
test_group_deps = {Requirement(dep).name for dep in dependencies}
|
||||
|
||||
assert sorted(test_group_deps) == sorted(
|
||||
[
|
||||
|
Loading…
Reference in New Issue
Block a user