mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-20 11:31:58 +00:00
core: Add ruff rules PERF (#29375)
See https://docs.astral.sh/ruff/rules/#perflint-perf
This commit is contained in:
parent
8a33402016
commit
4f8ea13cea
@ -66,13 +66,7 @@ class InMemoryDocumentIndex(DocumentIndex):
|
|||||||
|
|
||||||
def get(self, ids: Sequence[str], /, **kwargs: Any) -> list[Document]:
|
def get(self, ids: Sequence[str], /, **kwargs: Any) -> list[Document]:
|
||||||
"""Get by ids."""
|
"""Get by ids."""
|
||||||
found_documents = []
|
return [self.store[id_] for id_ in ids if id_ in self.store]
|
||||||
|
|
||||||
for id_ in ids:
|
|
||||||
if id_ in self.store:
|
|
||||||
found_documents.append(self.store[id_])
|
|
||||||
|
|
||||||
return found_documents
|
|
||||||
|
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||||
|
@ -217,31 +217,21 @@ class AIMessage(BaseMessage):
|
|||||||
|
|
||||||
# Ensure "type" is properly set on all tool call-like dicts.
|
# Ensure "type" is properly set on all tool call-like dicts.
|
||||||
if tool_calls := values.get("tool_calls"):
|
if tool_calls := values.get("tool_calls"):
|
||||||
updated: list = []
|
values["tool_calls"] = [
|
||||||
for tc in tool_calls:
|
create_tool_call(**{k: v for k, v in tc.items() if k != "type"})
|
||||||
updated.append(
|
for tc in tool_calls
|
||||||
create_tool_call(**{k: v for k, v in tc.items() if k != "type"})
|
]
|
||||||
)
|
|
||||||
values["tool_calls"] = updated
|
|
||||||
if invalid_tool_calls := values.get("invalid_tool_calls"):
|
if invalid_tool_calls := values.get("invalid_tool_calls"):
|
||||||
updated = []
|
values["invalid_tool_calls"] = [
|
||||||
for tc in invalid_tool_calls:
|
create_invalid_tool_call(**{k: v for k, v in tc.items() if k != "type"})
|
||||||
updated.append(
|
for tc in invalid_tool_calls
|
||||||
create_invalid_tool_call(
|
]
|
||||||
**{k: v for k, v in tc.items() if k != "type"}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
values["invalid_tool_calls"] = updated
|
|
||||||
|
|
||||||
if tool_call_chunks := values.get("tool_call_chunks"):
|
if tool_call_chunks := values.get("tool_call_chunks"):
|
||||||
updated = []
|
values["tool_call_chunks"] = [
|
||||||
for tc in tool_call_chunks:
|
create_tool_call_chunk(**{k: v for k, v in tc.items() if k != "type"})
|
||||||
updated.append(
|
for tc in tool_call_chunks
|
||||||
create_tool_call_chunk(
|
]
|
||||||
**{k: v for k, v in tc.items() if k != "type"}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
values["tool_call_chunks"] = updated
|
|
||||||
|
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
@ -557,11 +557,11 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
"""Return a list of prompts used by this Runnable."""
|
"""Return a list of prompts used by this Runnable."""
|
||||||
from langchain_core.prompts.base import BasePromptTemplate
|
from langchain_core.prompts.base import BasePromptTemplate
|
||||||
|
|
||||||
prompts = []
|
return [
|
||||||
for _, node in self.get_graph(config=config).nodes.items():
|
node.data
|
||||||
if isinstance(node.data, BasePromptTemplate):
|
for node in self.get_graph(config=config).nodes.values()
|
||||||
prompts.append(node.data)
|
if isinstance(node.data, BasePromptTemplate)
|
||||||
return prompts
|
]
|
||||||
|
|
||||||
def __or__(
|
def __or__(
|
||||||
self,
|
self,
|
||||||
@ -3183,9 +3183,13 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
**(kwargs if stepidx == 0 else {}),
|
**(kwargs if stepidx == 0 else {}),
|
||||||
)
|
)
|
||||||
# If an input failed, add it to the map
|
# If an input failed, add it to the map
|
||||||
for i, inp in zip(remaining_idxs, inputs):
|
failed_inputs_map.update(
|
||||||
if isinstance(inp, Exception):
|
{
|
||||||
failed_inputs_map[i] = inp
|
i: inp
|
||||||
|
for i, inp in zip(remaining_idxs, inputs)
|
||||||
|
if isinstance(inp, Exception)
|
||||||
|
}
|
||||||
|
)
|
||||||
inputs = [inp for inp in inputs if not isinstance(inp, Exception)]
|
inputs = [inp for inp in inputs if not isinstance(inp, Exception)]
|
||||||
# If all inputs have failed, stop processing
|
# If all inputs have failed, stop processing
|
||||||
if len(failed_inputs_map) == len(configs):
|
if len(failed_inputs_map) == len(configs):
|
||||||
@ -3314,9 +3318,13 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
**(kwargs if stepidx == 0 else {}),
|
**(kwargs if stepidx == 0 else {}),
|
||||||
)
|
)
|
||||||
# If an input failed, add it to the map
|
# If an input failed, add it to the map
|
||||||
for i, inp in zip(remaining_idxs, inputs):
|
failed_inputs_map.update(
|
||||||
if isinstance(inp, Exception):
|
{
|
||||||
failed_inputs_map[i] = inp
|
i: inp
|
||||||
|
for i, inp in zip(remaining_idxs, inputs)
|
||||||
|
if isinstance(inp, Exception)
|
||||||
|
}
|
||||||
|
)
|
||||||
inputs = [inp for inp in inputs if not isinstance(inp, Exception)]
|
inputs = [inp for inp in inputs if not isinstance(inp, Exception)]
|
||||||
# If all inputs have failed, stop processing
|
# If all inputs have failed, stop processing
|
||||||
if len(failed_inputs_map) == len(configs):
|
if len(failed_inputs_map) == len(configs):
|
||||||
|
@ -697,10 +697,11 @@ def _first_node(graph: Graph, exclude: Sequence[str] = ()) -> Optional[Node]:
|
|||||||
When drawing the graph, this node would be the origin.
|
When drawing the graph, this node would be the origin.
|
||||||
"""
|
"""
|
||||||
targets = {edge.target for edge in graph.edges if edge.source not in exclude}
|
targets = {edge.target for edge in graph.edges if edge.source not in exclude}
|
||||||
found: list[Node] = []
|
found: list[Node] = [
|
||||||
for node in graph.nodes.values():
|
node
|
||||||
if node.id not in exclude and node.id not in targets:
|
for node in graph.nodes.values()
|
||||||
found.append(node)
|
if node.id not in exclude and node.id not in targets
|
||||||
|
]
|
||||||
return found[0] if len(found) == 1 else None
|
return found[0] if len(found) == 1 else None
|
||||||
|
|
||||||
|
|
||||||
@ -712,8 +713,9 @@ def _last_node(graph: Graph, exclude: Sequence[str] = ()) -> Optional[Node]:
|
|||||||
When drawing the graph, this node would be the destination.
|
When drawing the graph, this node would be the destination.
|
||||||
"""
|
"""
|
||||||
sources = {edge.source for edge in graph.edges if edge.target not in exclude}
|
sources = {edge.source for edge in graph.edges if edge.target not in exclude}
|
||||||
found: list[Node] = []
|
found: list[Node] = [
|
||||||
for node in graph.nodes.values():
|
node
|
||||||
if node.id not in exclude and node.id not in sources:
|
for node in graph.nodes.values()
|
||||||
found.append(node)
|
if node.id not in exclude and node.id not in sources
|
||||||
|
]
|
||||||
return found[0] if len(found) == 1 else None
|
return found[0] if len(found) == 1 else None
|
||||||
|
@ -674,10 +674,12 @@ async def _astream_log_implementation(
|
|||||||
"value": copy.deepcopy(chunk),
|
"value": copy.deepcopy(chunk),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
for op in jsonpatch.JsonPatch.from_diff(
|
patches.extend(
|
||||||
prev_final_output, final_output, dumps=dumps
|
{**op, "path": f"/final_output{op['path']}"}
|
||||||
):
|
for op in jsonpatch.JsonPatch.from_diff(
|
||||||
patches.append({**op, "path": f"/final_output{op['path']}"})
|
prev_final_output, final_output, dumps=dumps
|
||||||
|
)
|
||||||
|
)
|
||||||
await stream.send_stream.send(RunLogPatch(*patches))
|
await stream.send_stream.send(RunLogPatch(*patches))
|
||||||
finally:
|
finally:
|
||||||
await stream.send_stream.aclose()
|
await stream.send_stream.aclose()
|
||||||
|
@ -672,21 +672,21 @@ def tool_example_to_messages(
|
|||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
messages: list[BaseMessage] = [HumanMessage(content=input)]
|
messages: list[BaseMessage] = [HumanMessage(content=input)]
|
||||||
openai_tool_calls = []
|
openai_tool_calls = [
|
||||||
for tool_call in tool_calls:
|
{
|
||||||
openai_tool_calls.append(
|
"id": str(uuid.uuid4()),
|
||||||
{
|
"type": "function",
|
||||||
"id": str(uuid.uuid4()),
|
"function": {
|
||||||
"type": "function",
|
# The name of the function right now corresponds to the name
|
||||||
"function": {
|
# of the pydantic model. This is implicit in the API right now,
|
||||||
# The name of the function right now corresponds to the name
|
# and will be improved over time.
|
||||||
# of the pydantic model. This is implicit in the API right now,
|
"name": tool_call.__class__.__name__,
|
||||||
# and will be improved over time.
|
"arguments": tool_call.model_dump_json(),
|
||||||
"name": tool_call.__class__.__name__,
|
},
|
||||||
"arguments": tool_call.model_dump_json(),
|
}
|
||||||
},
|
for tool_call in tool_calls
|
||||||
}
|
]
|
||||||
)
|
|
||||||
messages.append(
|
messages.append(
|
||||||
AIMessage(content="", additional_kwargs={"tool_calls": openai_tool_calls})
|
AIMessage(content="", additional_kwargs={"tool_calls": openai_tool_calls})
|
||||||
)
|
)
|
||||||
|
@ -83,6 +83,7 @@ ignore = [
|
|||||||
"COM812", # Messes with the formatter
|
"COM812", # Messes with the formatter
|
||||||
"FA100", # Can't activate since we exclude UP007 for now
|
"FA100", # Can't activate since we exclude UP007 for now
|
||||||
"ISC001", # Messes with the formatter
|
"ISC001", # Messes with the formatter
|
||||||
|
"PERF203", # Rarely useful
|
||||||
"PLR09", # Too many something (arg, statements, etc)
|
"PLR09", # Too many something (arg, statements, etc)
|
||||||
"TC001", # Doesn't play well with Pydantic
|
"TC001", # Doesn't play well with Pydantic
|
||||||
"TC002", # Doesn't play well with Pydantic
|
"TC002", # Doesn't play well with Pydantic
|
||||||
@ -98,7 +99,6 @@ ignore = [
|
|||||||
"DTZ",
|
"DTZ",
|
||||||
"FBT",
|
"FBT",
|
||||||
"FIX",
|
"FIX",
|
||||||
"PERF",
|
|
||||||
"PGH",
|
"PGH",
|
||||||
"PLC",
|
"PLC",
|
||||||
"PLE",
|
"PLE",
|
||||||
|
@ -1014,13 +1014,12 @@ def test_chat_prompt_template_variable_names() -> None:
|
|||||||
prompt.get_input_schema()
|
prompt.get_input_schema()
|
||||||
|
|
||||||
if record:
|
if record:
|
||||||
error_msg = []
|
error_msg = [
|
||||||
for warning in record:
|
f"Warning type: {warning.category.__name__}, "
|
||||||
error_msg.append(
|
f"Warning message: {warning.message}, "
|
||||||
f"Warning type: {warning.category.__name__}, "
|
f"Warning location: {warning.filename}:{warning.lineno}"
|
||||||
f"Warning message: {warning.message}, "
|
for warning in record
|
||||||
f"Warning location: {warning.filename}:{warning.lineno}"
|
]
|
||||||
)
|
|
||||||
msg = "\n".join(error_msg)
|
msg = "\n".join(error_msg)
|
||||||
else:
|
else:
|
||||||
msg = ""
|
msg = ""
|
||||||
|
@ -3476,9 +3476,7 @@ def test_deep_stream() -> None:
|
|||||||
|
|
||||||
stream = chain.stream({"question": "What up"})
|
stream = chain.stream({"question": "What up"})
|
||||||
|
|
||||||
chunks = []
|
chunks = list(stream)
|
||||||
for chunk in stream:
|
|
||||||
chunks.append(chunk)
|
|
||||||
|
|
||||||
assert len(chunks) == len("foo-lish")
|
assert len(chunks) == len("foo-lish")
|
||||||
assert "".join(chunks) == "foo-lish"
|
assert "".join(chunks) == "foo-lish"
|
||||||
@ -3502,9 +3500,7 @@ def test_deep_stream_assign() -> None:
|
|||||||
|
|
||||||
stream = chain.stream({"question": "What up"})
|
stream = chain.stream({"question": "What up"})
|
||||||
|
|
||||||
chunks = []
|
chunks = list(stream)
|
||||||
for chunk in stream:
|
|
||||||
chunks.append(chunk)
|
|
||||||
|
|
||||||
assert len(chunks) == len("foo-lish")
|
assert len(chunks) == len("foo-lish")
|
||||||
assert add(chunks) == {"str": "foo-lish"}
|
assert add(chunks) == {"str": "foo-lish"}
|
||||||
@ -3602,9 +3598,7 @@ async def test_deep_astream() -> None:
|
|||||||
|
|
||||||
stream = chain.astream({"question": "What up"})
|
stream = chain.astream({"question": "What up"})
|
||||||
|
|
||||||
chunks = []
|
chunks = [chunk async for chunk in stream]
|
||||||
async for chunk in stream:
|
|
||||||
chunks.append(chunk)
|
|
||||||
|
|
||||||
assert len(chunks) == len("foo-lish")
|
assert len(chunks) == len("foo-lish")
|
||||||
assert "".join(chunks) == "foo-lish"
|
assert "".join(chunks) == "foo-lish"
|
||||||
@ -3628,9 +3622,7 @@ async def test_deep_astream_assign() -> None:
|
|||||||
|
|
||||||
stream = chain.astream({"question": "What up"})
|
stream = chain.astream({"question": "What up"})
|
||||||
|
|
||||||
chunks = []
|
chunks = [chunk async for chunk in stream]
|
||||||
async for chunk in stream:
|
|
||||||
chunks.append(chunk)
|
|
||||||
|
|
||||||
assert len(chunks) == len("foo-lish")
|
assert len(chunks) == len("foo-lish")
|
||||||
assert add(chunks) == {"str": "foo-lish"}
|
assert add(chunks) == {"str": "foo-lish"}
|
||||||
@ -3726,9 +3718,7 @@ def test_runnable_sequence_transform() -> None:
|
|||||||
|
|
||||||
stream = chain.transform(llm.stream("Hi there!"))
|
stream = chain.transform(llm.stream("Hi there!"))
|
||||||
|
|
||||||
chunks = []
|
chunks = list(stream)
|
||||||
for chunk in stream:
|
|
||||||
chunks.append(chunk)
|
|
||||||
|
|
||||||
assert len(chunks) == len("foo-lish")
|
assert len(chunks) == len("foo-lish")
|
||||||
assert "".join(chunks) == "foo-lish"
|
assert "".join(chunks) == "foo-lish"
|
||||||
@ -3741,9 +3731,7 @@ async def test_runnable_sequence_atransform() -> None:
|
|||||||
|
|
||||||
stream = chain.atransform(llm.astream("Hi there!"))
|
stream = chain.atransform(llm.astream("Hi there!"))
|
||||||
|
|
||||||
chunks = []
|
chunks = [chunk async for chunk in stream]
|
||||||
async for chunk in stream:
|
|
||||||
chunks.append(chunk)
|
|
||||||
|
|
||||||
assert len(chunks) == len("foo-lish")
|
assert len(chunks) == len("foo-lish")
|
||||||
assert "".join(chunks) == "foo-lish"
|
assert "".join(chunks) == "foo-lish"
|
||||||
|
@ -1493,12 +1493,12 @@ async def test_chain_ordering() -> None:
|
|||||||
|
|
||||||
events = []
|
events = []
|
||||||
|
|
||||||
for _ in range(10):
|
try:
|
||||||
try:
|
for _ in range(10):
|
||||||
next_chunk = await iterable.__anext__()
|
next_chunk = await iterable.__anext__()
|
||||||
events.append(next_chunk)
|
events.append(next_chunk)
|
||||||
except Exception:
|
except Exception:
|
||||||
break
|
pass
|
||||||
|
|
||||||
events = _with_nulled_run_id(events)
|
events = _with_nulled_run_id(events)
|
||||||
for event in events:
|
for event in events:
|
||||||
@ -1610,12 +1610,12 @@ async def test_event_stream_with_retry() -> None:
|
|||||||
|
|
||||||
events = []
|
events = []
|
||||||
|
|
||||||
for _ in range(10):
|
try:
|
||||||
try:
|
for _ in range(10):
|
||||||
next_chunk = await iterable.__anext__()
|
next_chunk = await iterable.__anext__()
|
||||||
events.append(next_chunk)
|
events.append(next_chunk)
|
||||||
except Exception:
|
except Exception:
|
||||||
break
|
pass
|
||||||
|
|
||||||
events = _with_nulled_run_id(events)
|
events = _with_nulled_run_id(events)
|
||||||
for event in events:
|
for event in events:
|
||||||
|
@ -1453,12 +1453,12 @@ async def test_chain_ordering() -> None:
|
|||||||
|
|
||||||
events = []
|
events = []
|
||||||
|
|
||||||
for _ in range(10):
|
try:
|
||||||
try:
|
for _ in range(10):
|
||||||
next_chunk = await iterable.__anext__()
|
next_chunk = await iterable.__anext__()
|
||||||
events.append(next_chunk)
|
events.append(next_chunk)
|
||||||
except Exception:
|
except Exception:
|
||||||
break
|
pass
|
||||||
|
|
||||||
events = _with_nulled_run_id(events)
|
events = _with_nulled_run_id(events)
|
||||||
for event in events:
|
for event in events:
|
||||||
@ -1570,12 +1570,12 @@ async def test_event_stream_with_retry() -> None:
|
|||||||
|
|
||||||
events = []
|
events = []
|
||||||
|
|
||||||
for _ in range(10):
|
try:
|
||||||
try:
|
for _ in range(10):
|
||||||
next_chunk = await iterable.__anext__()
|
next_chunk = await iterable.__anext__()
|
||||||
events.append(next_chunk)
|
events.append(next_chunk)
|
||||||
except Exception:
|
except Exception:
|
||||||
break
|
pass
|
||||||
|
|
||||||
events = _with_nulled_run_id(events)
|
events = _with_nulled_run_id(events)
|
||||||
for event in events:
|
for event in events:
|
||||||
|
Loading…
Reference in New Issue
Block a user