core: Add ruff rules PERF (#29375)

See https://docs.astral.sh/ruff/rules/#perflint-perf
This commit is contained in:
Christophe Bornet 2025-04-01 19:34:56 +02:00 committed by GitHub
parent 8a33402016
commit 4f8ea13cea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 92 additions and 109 deletions

View File

@ -66,13 +66,7 @@ class InMemoryDocumentIndex(DocumentIndex):
def get(self, ids: Sequence[str], /, **kwargs: Any) -> list[Document]: def get(self, ids: Sequence[str], /, **kwargs: Any) -> list[Document]:
"""Get by ids.""" """Get by ids."""
found_documents = [] return [self.store[id_] for id_ in ids if id_ in self.store]
for id_ in ids:
if id_ in self.store:
found_documents.append(self.store[id_])
return found_documents
def _get_relevant_documents( def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun self, query: str, *, run_manager: CallbackManagerForRetrieverRun

View File

@ -217,31 +217,21 @@ class AIMessage(BaseMessage):
# Ensure "type" is properly set on all tool call-like dicts. # Ensure "type" is properly set on all tool call-like dicts.
if tool_calls := values.get("tool_calls"): if tool_calls := values.get("tool_calls"):
updated: list = [] values["tool_calls"] = [
for tc in tool_calls: create_tool_call(**{k: v for k, v in tc.items() if k != "type"})
updated.append( for tc in tool_calls
create_tool_call(**{k: v for k, v in tc.items() if k != "type"}) ]
)
values["tool_calls"] = updated
if invalid_tool_calls := values.get("invalid_tool_calls"): if invalid_tool_calls := values.get("invalid_tool_calls"):
updated = [] values["invalid_tool_calls"] = [
for tc in invalid_tool_calls: create_invalid_tool_call(**{k: v for k, v in tc.items() if k != "type"})
updated.append( for tc in invalid_tool_calls
create_invalid_tool_call( ]
**{k: v for k, v in tc.items() if k != "type"}
)
)
values["invalid_tool_calls"] = updated
if tool_call_chunks := values.get("tool_call_chunks"): if tool_call_chunks := values.get("tool_call_chunks"):
updated = [] values["tool_call_chunks"] = [
for tc in tool_call_chunks: create_tool_call_chunk(**{k: v for k, v in tc.items() if k != "type"})
updated.append( for tc in tool_call_chunks
create_tool_call_chunk( ]
**{k: v for k, v in tc.items() if k != "type"}
)
)
values["tool_call_chunks"] = updated
return values return values

View File

@ -557,11 +557,11 @@ class Runnable(Generic[Input, Output], ABC):
"""Return a list of prompts used by this Runnable.""" """Return a list of prompts used by this Runnable."""
from langchain_core.prompts.base import BasePromptTemplate from langchain_core.prompts.base import BasePromptTemplate
prompts = [] return [
for _, node in self.get_graph(config=config).nodes.items(): node.data
if isinstance(node.data, BasePromptTemplate): for node in self.get_graph(config=config).nodes.values()
prompts.append(node.data) if isinstance(node.data, BasePromptTemplate)
return prompts ]
def __or__( def __or__(
self, self,
@ -3183,9 +3183,13 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
**(kwargs if stepidx == 0 else {}), **(kwargs if stepidx == 0 else {}),
) )
# If an input failed, add it to the map # If an input failed, add it to the map
for i, inp in zip(remaining_idxs, inputs): failed_inputs_map.update(
if isinstance(inp, Exception): {
failed_inputs_map[i] = inp i: inp
for i, inp in zip(remaining_idxs, inputs)
if isinstance(inp, Exception)
}
)
inputs = [inp for inp in inputs if not isinstance(inp, Exception)] inputs = [inp for inp in inputs if not isinstance(inp, Exception)]
# If all inputs have failed, stop processing # If all inputs have failed, stop processing
if len(failed_inputs_map) == len(configs): if len(failed_inputs_map) == len(configs):
@ -3314,9 +3318,13 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
**(kwargs if stepidx == 0 else {}), **(kwargs if stepidx == 0 else {}),
) )
# If an input failed, add it to the map # If an input failed, add it to the map
for i, inp in zip(remaining_idxs, inputs): failed_inputs_map.update(
if isinstance(inp, Exception): {
failed_inputs_map[i] = inp i: inp
for i, inp in zip(remaining_idxs, inputs)
if isinstance(inp, Exception)
}
)
inputs = [inp for inp in inputs if not isinstance(inp, Exception)] inputs = [inp for inp in inputs if not isinstance(inp, Exception)]
# If all inputs have failed, stop processing # If all inputs have failed, stop processing
if len(failed_inputs_map) == len(configs): if len(failed_inputs_map) == len(configs):

View File

@ -697,10 +697,11 @@ def _first_node(graph: Graph, exclude: Sequence[str] = ()) -> Optional[Node]:
When drawing the graph, this node would be the origin. When drawing the graph, this node would be the origin.
""" """
targets = {edge.target for edge in graph.edges if edge.source not in exclude} targets = {edge.target for edge in graph.edges if edge.source not in exclude}
found: list[Node] = [] found: list[Node] = [
for node in graph.nodes.values(): node
if node.id not in exclude and node.id not in targets: for node in graph.nodes.values()
found.append(node) if node.id not in exclude and node.id not in targets
]
return found[0] if len(found) == 1 else None return found[0] if len(found) == 1 else None
@ -712,8 +713,9 @@ def _last_node(graph: Graph, exclude: Sequence[str] = ()) -> Optional[Node]:
When drawing the graph, this node would be the destination. When drawing the graph, this node would be the destination.
""" """
sources = {edge.source for edge in graph.edges if edge.target not in exclude} sources = {edge.source for edge in graph.edges if edge.target not in exclude}
found: list[Node] = [] found: list[Node] = [
for node in graph.nodes.values(): node
if node.id not in exclude and node.id not in sources: for node in graph.nodes.values()
found.append(node) if node.id not in exclude and node.id not in sources
]
return found[0] if len(found) == 1 else None return found[0] if len(found) == 1 else None

View File

@ -674,10 +674,12 @@ async def _astream_log_implementation(
"value": copy.deepcopy(chunk), "value": copy.deepcopy(chunk),
} }
) )
for op in jsonpatch.JsonPatch.from_diff( patches.extend(
prev_final_output, final_output, dumps=dumps {**op, "path": f"/final_output{op['path']}"}
): for op in jsonpatch.JsonPatch.from_diff(
patches.append({**op, "path": f"/final_output{op['path']}"}) prev_final_output, final_output, dumps=dumps
)
)
await stream.send_stream.send(RunLogPatch(*patches)) await stream.send_stream.send(RunLogPatch(*patches))
finally: finally:
await stream.send_stream.aclose() await stream.send_stream.aclose()

View File

@ -672,21 +672,21 @@ def tool_example_to_messages(
) )
""" """
messages: list[BaseMessage] = [HumanMessage(content=input)] messages: list[BaseMessage] = [HumanMessage(content=input)]
openai_tool_calls = [] openai_tool_calls = [
for tool_call in tool_calls: {
openai_tool_calls.append( "id": str(uuid.uuid4()),
{ "type": "function",
"id": str(uuid.uuid4()), "function": {
"type": "function", # The name of the function right now corresponds to the name
"function": { # of the pydantic model. This is implicit in the API right now,
# The name of the function right now corresponds to the name # and will be improved over time.
# of the pydantic model. This is implicit in the API right now, "name": tool_call.__class__.__name__,
# and will be improved over time. "arguments": tool_call.model_dump_json(),
"name": tool_call.__class__.__name__, },
"arguments": tool_call.model_dump_json(), }
}, for tool_call in tool_calls
} ]
)
messages.append( messages.append(
AIMessage(content="", additional_kwargs={"tool_calls": openai_tool_calls}) AIMessage(content="", additional_kwargs={"tool_calls": openai_tool_calls})
) )

View File

@ -83,6 +83,7 @@ ignore = [
"COM812", # Messes with the formatter "COM812", # Messes with the formatter
"FA100", # Can't activate since we exclude UP007 for now "FA100", # Can't activate since we exclude UP007 for now
"ISC001", # Messes with the formatter "ISC001", # Messes with the formatter
"PERF203", # Rarely useful
"PLR09", # Too many something (arg, statements, etc) "PLR09", # Too many something (arg, statements, etc)
"TC001", # Doesn't play well with Pydantic "TC001", # Doesn't play well with Pydantic
"TC002", # Doesn't play well with Pydantic "TC002", # Doesn't play well with Pydantic
@ -98,7 +99,6 @@ ignore = [
"DTZ", "DTZ",
"FBT", "FBT",
"FIX", "FIX",
"PERF",
"PGH", "PGH",
"PLC", "PLC",
"PLE", "PLE",

View File

@ -1014,13 +1014,12 @@ def test_chat_prompt_template_variable_names() -> None:
prompt.get_input_schema() prompt.get_input_schema()
if record: if record:
error_msg = [] error_msg = [
for warning in record: f"Warning type: {warning.category.__name__}, "
error_msg.append( f"Warning message: {warning.message}, "
f"Warning type: {warning.category.__name__}, " f"Warning location: {warning.filename}:{warning.lineno}"
f"Warning message: {warning.message}, " for warning in record
f"Warning location: {warning.filename}:{warning.lineno}" ]
)
msg = "\n".join(error_msg) msg = "\n".join(error_msg)
else: else:
msg = "" msg = ""

View File

@ -3476,9 +3476,7 @@ def test_deep_stream() -> None:
stream = chain.stream({"question": "What up"}) stream = chain.stream({"question": "What up"})
chunks = [] chunks = list(stream)
for chunk in stream:
chunks.append(chunk)
assert len(chunks) == len("foo-lish") assert len(chunks) == len("foo-lish")
assert "".join(chunks) == "foo-lish" assert "".join(chunks) == "foo-lish"
@ -3502,9 +3500,7 @@ def test_deep_stream_assign() -> None:
stream = chain.stream({"question": "What up"}) stream = chain.stream({"question": "What up"})
chunks = [] chunks = list(stream)
for chunk in stream:
chunks.append(chunk)
assert len(chunks) == len("foo-lish") assert len(chunks) == len("foo-lish")
assert add(chunks) == {"str": "foo-lish"} assert add(chunks) == {"str": "foo-lish"}
@ -3602,9 +3598,7 @@ async def test_deep_astream() -> None:
stream = chain.astream({"question": "What up"}) stream = chain.astream({"question": "What up"})
chunks = [] chunks = [chunk async for chunk in stream]
async for chunk in stream:
chunks.append(chunk)
assert len(chunks) == len("foo-lish") assert len(chunks) == len("foo-lish")
assert "".join(chunks) == "foo-lish" assert "".join(chunks) == "foo-lish"
@ -3628,9 +3622,7 @@ async def test_deep_astream_assign() -> None:
stream = chain.astream({"question": "What up"}) stream = chain.astream({"question": "What up"})
chunks = [] chunks = [chunk async for chunk in stream]
async for chunk in stream:
chunks.append(chunk)
assert len(chunks) == len("foo-lish") assert len(chunks) == len("foo-lish")
assert add(chunks) == {"str": "foo-lish"} assert add(chunks) == {"str": "foo-lish"}
@ -3726,9 +3718,7 @@ def test_runnable_sequence_transform() -> None:
stream = chain.transform(llm.stream("Hi there!")) stream = chain.transform(llm.stream("Hi there!"))
chunks = [] chunks = list(stream)
for chunk in stream:
chunks.append(chunk)
assert len(chunks) == len("foo-lish") assert len(chunks) == len("foo-lish")
assert "".join(chunks) == "foo-lish" assert "".join(chunks) == "foo-lish"
@ -3741,9 +3731,7 @@ async def test_runnable_sequence_atransform() -> None:
stream = chain.atransform(llm.astream("Hi there!")) stream = chain.atransform(llm.astream("Hi there!"))
chunks = [] chunks = [chunk async for chunk in stream]
async for chunk in stream:
chunks.append(chunk)
assert len(chunks) == len("foo-lish") assert len(chunks) == len("foo-lish")
assert "".join(chunks) == "foo-lish" assert "".join(chunks) == "foo-lish"

View File

@ -1493,12 +1493,12 @@ async def test_chain_ordering() -> None:
events = [] events = []
for _ in range(10): try:
try: for _ in range(10):
next_chunk = await iterable.__anext__() next_chunk = await iterable.__anext__()
events.append(next_chunk) events.append(next_chunk)
except Exception: except Exception:
break pass
events = _with_nulled_run_id(events) events = _with_nulled_run_id(events)
for event in events: for event in events:
@ -1610,12 +1610,12 @@ async def test_event_stream_with_retry() -> None:
events = [] events = []
for _ in range(10): try:
try: for _ in range(10):
next_chunk = await iterable.__anext__() next_chunk = await iterable.__anext__()
events.append(next_chunk) events.append(next_chunk)
except Exception: except Exception:
break pass
events = _with_nulled_run_id(events) events = _with_nulled_run_id(events)
for event in events: for event in events:

View File

@ -1453,12 +1453,12 @@ async def test_chain_ordering() -> None:
events = [] events = []
for _ in range(10): try:
try: for _ in range(10):
next_chunk = await iterable.__anext__() next_chunk = await iterable.__anext__()
events.append(next_chunk) events.append(next_chunk)
except Exception: except Exception:
break pass
events = _with_nulled_run_id(events) events = _with_nulled_run_id(events)
for event in events: for event in events:
@ -1570,12 +1570,12 @@ async def test_event_stream_with_retry() -> None:
events = [] events = []
for _ in range(10): try:
try: for _ in range(10):
next_chunk = await iterable.__anext__() next_chunk = await iterable.__anext__()
events.append(next_chunk) events.append(next_chunk)
except Exception: except Exception:
break pass
events = _with_nulled_run_id(events) events = _with_nulled_run_id(events)
for event in events: for event in events: