core: Add ruff rules PERF (#29375)

See https://docs.astral.sh/ruff/rules/#perflint-perf
This commit is contained in:
Christophe Bornet 2025-04-01 19:34:56 +02:00 committed by GitHub
parent 8a33402016
commit 4f8ea13cea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 92 additions and 109 deletions

View File

@ -66,13 +66,7 @@ class InMemoryDocumentIndex(DocumentIndex):
def get(self, ids: Sequence[str], /, **kwargs: Any) -> list[Document]:
"""Get by ids."""
found_documents = []
for id_ in ids:
if id_ in self.store:
found_documents.append(self.store[id_])
return found_documents
return [self.store[id_] for id_ in ids if id_ in self.store]
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun

View File

@ -217,31 +217,21 @@ class AIMessage(BaseMessage):
# Ensure "type" is properly set on all tool call-like dicts.
if tool_calls := values.get("tool_calls"):
updated: list = []
for tc in tool_calls:
updated.append(
create_tool_call(**{k: v for k, v in tc.items() if k != "type"})
)
values["tool_calls"] = updated
values["tool_calls"] = [
create_tool_call(**{k: v for k, v in tc.items() if k != "type"})
for tc in tool_calls
]
if invalid_tool_calls := values.get("invalid_tool_calls"):
updated = []
for tc in invalid_tool_calls:
updated.append(
create_invalid_tool_call(
**{k: v for k, v in tc.items() if k != "type"}
)
)
values["invalid_tool_calls"] = updated
values["invalid_tool_calls"] = [
create_invalid_tool_call(**{k: v for k, v in tc.items() if k != "type"})
for tc in invalid_tool_calls
]
if tool_call_chunks := values.get("tool_call_chunks"):
updated = []
for tc in tool_call_chunks:
updated.append(
create_tool_call_chunk(
**{k: v for k, v in tc.items() if k != "type"}
)
)
values["tool_call_chunks"] = updated
values["tool_call_chunks"] = [
create_tool_call_chunk(**{k: v for k, v in tc.items() if k != "type"})
for tc in tool_call_chunks
]
return values

View File

@ -557,11 +557,11 @@ class Runnable(Generic[Input, Output], ABC):
"""Return a list of prompts used by this Runnable."""
from langchain_core.prompts.base import BasePromptTemplate
prompts = []
for _, node in self.get_graph(config=config).nodes.items():
if isinstance(node.data, BasePromptTemplate):
prompts.append(node.data)
return prompts
return [
node.data
for node in self.get_graph(config=config).nodes.values()
if isinstance(node.data, BasePromptTemplate)
]
def __or__(
self,
@ -3183,9 +3183,13 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
**(kwargs if stepidx == 0 else {}),
)
# If an input failed, add it to the map
for i, inp in zip(remaining_idxs, inputs):
if isinstance(inp, Exception):
failed_inputs_map[i] = inp
failed_inputs_map.update(
{
i: inp
for i, inp in zip(remaining_idxs, inputs)
if isinstance(inp, Exception)
}
)
inputs = [inp for inp in inputs if not isinstance(inp, Exception)]
# If all inputs have failed, stop processing
if len(failed_inputs_map) == len(configs):
@ -3314,9 +3318,13 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
**(kwargs if stepidx == 0 else {}),
)
# If an input failed, add it to the map
for i, inp in zip(remaining_idxs, inputs):
if isinstance(inp, Exception):
failed_inputs_map[i] = inp
failed_inputs_map.update(
{
i: inp
for i, inp in zip(remaining_idxs, inputs)
if isinstance(inp, Exception)
}
)
inputs = [inp for inp in inputs if not isinstance(inp, Exception)]
# If all inputs have failed, stop processing
if len(failed_inputs_map) == len(configs):

View File

@ -697,10 +697,11 @@ def _first_node(graph: Graph, exclude: Sequence[str] = ()) -> Optional[Node]:
When drawing the graph, this node would be the origin.
"""
targets = {edge.target for edge in graph.edges if edge.source not in exclude}
found: list[Node] = []
for node in graph.nodes.values():
if node.id not in exclude and node.id not in targets:
found.append(node)
found: list[Node] = [
node
for node in graph.nodes.values()
if node.id not in exclude and node.id not in targets
]
return found[0] if len(found) == 1 else None
@ -712,8 +713,9 @@ def _last_node(graph: Graph, exclude: Sequence[str] = ()) -> Optional[Node]:
When drawing the graph, this node would be the destination.
"""
sources = {edge.source for edge in graph.edges if edge.target not in exclude}
found: list[Node] = []
for node in graph.nodes.values():
if node.id not in exclude and node.id not in sources:
found.append(node)
found: list[Node] = [
node
for node in graph.nodes.values()
if node.id not in exclude and node.id not in sources
]
return found[0] if len(found) == 1 else None

View File

@ -674,10 +674,12 @@ async def _astream_log_implementation(
"value": copy.deepcopy(chunk),
}
)
for op in jsonpatch.JsonPatch.from_diff(
prev_final_output, final_output, dumps=dumps
):
patches.append({**op, "path": f"/final_output{op['path']}"})
patches.extend(
{**op, "path": f"/final_output{op['path']}"}
for op in jsonpatch.JsonPatch.from_diff(
prev_final_output, final_output, dumps=dumps
)
)
await stream.send_stream.send(RunLogPatch(*patches))
finally:
await stream.send_stream.aclose()

View File

@ -672,21 +672,21 @@ def tool_example_to_messages(
)
"""
messages: list[BaseMessage] = [HumanMessage(content=input)]
openai_tool_calls = []
for tool_call in tool_calls:
openai_tool_calls.append(
{
"id": str(uuid.uuid4()),
"type": "function",
"function": {
# The name of the function right now corresponds to the name
# of the pydantic model. This is implicit in the API right now,
# and will be improved over time.
"name": tool_call.__class__.__name__,
"arguments": tool_call.model_dump_json(),
},
}
)
openai_tool_calls = [
{
"id": str(uuid.uuid4()),
"type": "function",
"function": {
# The name of the function right now corresponds to the name
# of the pydantic model. This is implicit in the API right now,
# and will be improved over time.
"name": tool_call.__class__.__name__,
"arguments": tool_call.model_dump_json(),
},
}
for tool_call in tool_calls
]
messages.append(
AIMessage(content="", additional_kwargs={"tool_calls": openai_tool_calls})
)

View File

@ -83,6 +83,7 @@ ignore = [
"COM812", # Messes with the formatter
"FA100", # Can't activate since we exclude UP007 for now
"ISC001", # Messes with the formatter
"PERF203", # Rarely useful
"PLR09", # Too many something (arg, statements, etc)
"TC001", # Doesn't play well with Pydantic
"TC002", # Doesn't play well with Pydantic
@ -98,7 +99,6 @@ ignore = [
"DTZ",
"FBT",
"FIX",
"PERF",
"PGH",
"PLC",
"PLE",

View File

@ -1014,13 +1014,12 @@ def test_chat_prompt_template_variable_names() -> None:
prompt.get_input_schema()
if record:
error_msg = []
for warning in record:
error_msg.append(
f"Warning type: {warning.category.__name__}, "
f"Warning message: {warning.message}, "
f"Warning location: {warning.filename}:{warning.lineno}"
)
error_msg = [
f"Warning type: {warning.category.__name__}, "
f"Warning message: {warning.message}, "
f"Warning location: {warning.filename}:{warning.lineno}"
for warning in record
]
msg = "\n".join(error_msg)
else:
msg = ""

View File

@ -3476,9 +3476,7 @@ def test_deep_stream() -> None:
stream = chain.stream({"question": "What up"})
chunks = []
for chunk in stream:
chunks.append(chunk)
chunks = list(stream)
assert len(chunks) == len("foo-lish")
assert "".join(chunks) == "foo-lish"
@ -3502,9 +3500,7 @@ def test_deep_stream_assign() -> None:
stream = chain.stream({"question": "What up"})
chunks = []
for chunk in stream:
chunks.append(chunk)
chunks = list(stream)
assert len(chunks) == len("foo-lish")
assert add(chunks) == {"str": "foo-lish"}
@ -3602,9 +3598,7 @@ async def test_deep_astream() -> None:
stream = chain.astream({"question": "What up"})
chunks = []
async for chunk in stream:
chunks.append(chunk)
chunks = [chunk async for chunk in stream]
assert len(chunks) == len("foo-lish")
assert "".join(chunks) == "foo-lish"
@ -3628,9 +3622,7 @@ async def test_deep_astream_assign() -> None:
stream = chain.astream({"question": "What up"})
chunks = []
async for chunk in stream:
chunks.append(chunk)
chunks = [chunk async for chunk in stream]
assert len(chunks) == len("foo-lish")
assert add(chunks) == {"str": "foo-lish"}
@ -3726,9 +3718,7 @@ def test_runnable_sequence_transform() -> None:
stream = chain.transform(llm.stream("Hi there!"))
chunks = []
for chunk in stream:
chunks.append(chunk)
chunks = list(stream)
assert len(chunks) == len("foo-lish")
assert "".join(chunks) == "foo-lish"
@ -3741,9 +3731,7 @@ async def test_runnable_sequence_atransform() -> None:
stream = chain.atransform(llm.astream("Hi there!"))
chunks = []
async for chunk in stream:
chunks.append(chunk)
chunks = [chunk async for chunk in stream]
assert len(chunks) == len("foo-lish")
assert "".join(chunks) == "foo-lish"

View File

@ -1493,12 +1493,12 @@ async def test_chain_ordering() -> None:
events = []
for _ in range(10):
try:
try:
for _ in range(10):
next_chunk = await iterable.__anext__()
events.append(next_chunk)
except Exception:
break
except Exception:
pass
events = _with_nulled_run_id(events)
for event in events:
@ -1610,12 +1610,12 @@ async def test_event_stream_with_retry() -> None:
events = []
for _ in range(10):
try:
try:
for _ in range(10):
next_chunk = await iterable.__anext__()
events.append(next_chunk)
except Exception:
break
except Exception:
pass
events = _with_nulled_run_id(events)
for event in events:

View File

@ -1453,12 +1453,12 @@ async def test_chain_ordering() -> None:
events = []
for _ in range(10):
try:
try:
for _ in range(10):
next_chunk = await iterable.__anext__()
events.append(next_chunk)
except Exception:
break
except Exception:
pass
events = _with_nulled_run_id(events)
for event in events:
@ -1570,12 +1570,12 @@ async def test_event_stream_with_retry() -> None:
events = []
for _ in range(10):
try:
try:
for _ in range(10):
next_chunk = await iterable.__anext__()
events.append(next_chunk)
except Exception:
break
except Exception:
pass
events = _with_nulled_run_id(events)
for event in events: