mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-25 13:07:58 +00:00
core: Add ruff rules PLR (#30696)
Add ruff rules [PLR](https://docs.astral.sh/ruff/rules/#refactor-plr) Except PLR09xxx and PLR2004. Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
committed by
GitHub
parent
68361f9c2d
commit
4cc7bc6c93
@@ -53,8 +53,7 @@ class FileCallbackHandler(BaseCallbackHandler):
|
||||
"""
|
||||
if "name" in kwargs:
|
||||
name = kwargs["name"]
|
||||
else:
|
||||
if serialized:
|
||||
elif serialized:
|
||||
name = serialized.get("name", serialized.get("id", ["<unknown>"])[-1])
|
||||
else:
|
||||
name = "<unknown>"
|
||||
|
@@ -364,17 +364,14 @@ async def _ahandle_event_for_handler(
|
||||
event = getattr(handler, event_name)
|
||||
if asyncio.iscoroutinefunction(event):
|
||||
await event(*args, **kwargs)
|
||||
else:
|
||||
if handler.run_inline:
|
||||
elif handler.run_inline:
|
||||
event(*args, **kwargs)
|
||||
else:
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
cast(
|
||||
"Callable",
|
||||
functools.partial(
|
||||
copy_context().run, event, *args, **kwargs
|
||||
),
|
||||
functools.partial(copy_context().run, event, *args, **kwargs),
|
||||
),
|
||||
)
|
||||
except NotImplementedError as e:
|
||||
@@ -2426,8 +2423,7 @@ def _configure(
|
||||
for handler in callback_manager.handlers
|
||||
):
|
||||
callback_manager.add_handler(var_handler, inheritable)
|
||||
else:
|
||||
if not any(
|
||||
elif not any(
|
||||
isinstance(handler, handler_class)
|
||||
for handler in callback_manager.handlers
|
||||
):
|
||||
|
@@ -37,8 +37,7 @@ class StdOutCallbackHandler(BaseCallbackHandler):
|
||||
"""
|
||||
if "name" in kwargs:
|
||||
name = kwargs["name"]
|
||||
else:
|
||||
if serialized:
|
||||
elif serialized:
|
||||
name = serialized.get("name", serialized.get("id", ["<unknown>"])[-1])
|
||||
else:
|
||||
name = "<unknown>"
|
||||
|
@@ -316,7 +316,7 @@ def index(
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
if (cleanup == "incremental" or cleanup == "scoped_full") and source_id_key is None:
|
||||
if (cleanup in {"incremental", "scoped_full"}) and source_id_key is None:
|
||||
msg = (
|
||||
"Source id key is required when cleanup mode is incremental or scoped_full."
|
||||
)
|
||||
@@ -379,7 +379,7 @@ def index(
|
||||
source_id_assigner(doc) for doc in hashed_docs
|
||||
]
|
||||
|
||||
if cleanup == "incremental" or cleanup == "scoped_full":
|
||||
if cleanup in {"incremental", "scoped_full"}:
|
||||
# source ids are required.
|
||||
for source_id, hashed_doc in zip(source_ids, hashed_docs):
|
||||
if source_id is None:
|
||||
@@ -622,7 +622,7 @@ async def aindex(
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
if (cleanup == "incremental" or cleanup == "scoped_full") and source_id_key is None:
|
||||
if (cleanup in {"incremental", "scoped_full"}) and source_id_key is None:
|
||||
msg = (
|
||||
"Source id key is required when cleanup mode is incremental or scoped_full."
|
||||
)
|
||||
@@ -667,8 +667,7 @@ async def aindex(
|
||||
# In such a case, we use the load method and convert it to an async
|
||||
# iterator.
|
||||
async_doc_iterator = _to_async_iterator(docs_source.load())
|
||||
else:
|
||||
if hasattr(docs_source, "__aiter__"):
|
||||
elif hasattr(docs_source, "__aiter__"):
|
||||
async_doc_iterator = docs_source # type: ignore[assignment]
|
||||
else:
|
||||
async_doc_iterator = _to_async_iterator(docs_source)
|
||||
@@ -694,7 +693,7 @@ async def aindex(
|
||||
source_id_assigner(doc) for doc in hashed_docs
|
||||
]
|
||||
|
||||
if cleanup == "incremental" or cleanup == "scoped_full":
|
||||
if cleanup in {"incremental", "scoped_full"}:
|
||||
# If the cleanup mode is incremental, source ids are required.
|
||||
for source_id, hashed_doc in zip(source_ids, hashed_docs):
|
||||
if source_id is None:
|
||||
|
@@ -955,8 +955,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
)
|
||||
chunks.append(chunk)
|
||||
result = generate_from_stream(iter(chunks))
|
||||
else:
|
||||
if inspect.signature(self._generate).parameters.get("run_manager"):
|
||||
elif inspect.signature(self._generate).parameters.get("run_manager"):
|
||||
result = self._generate(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
@@ -1028,8 +1027,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
)
|
||||
chunks.append(chunk)
|
||||
result = generate_from_stream(iter(chunks))
|
||||
else:
|
||||
if inspect.signature(self._agenerate).parameters.get("run_manager"):
|
||||
elif inspect.signature(self._agenerate).parameters.get("run_manager"):
|
||||
result = await self._agenerate(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
|
@@ -170,10 +170,9 @@ def merge_content(
|
||||
# If both are lists
|
||||
merged = merge_lists(cast("list", merged), content) # type: ignore
|
||||
# If the first content is a list, and the second content is a string
|
||||
else:
|
||||
# If the last element of the first content is a string
|
||||
# Add the second content to the last element
|
||||
if merged and isinstance(merged[-1], str):
|
||||
elif merged and isinstance(merged[-1], str):
|
||||
merged[-1] += content
|
||||
# If second content is an empty string, treat as a no-op
|
||||
elif content == "":
|
||||
|
@@ -1030,8 +1030,7 @@ def convert_to_openai_messages(
|
||||
content = message.content
|
||||
else:
|
||||
content = [{"type": "text", "text": message.content}]
|
||||
else:
|
||||
if text_format == "string" and all(
|
||||
elif text_format == "string" and all(
|
||||
isinstance(block, str) or block.get("type") == "text"
|
||||
for block in message.content
|
||||
):
|
||||
@@ -1075,9 +1074,7 @@ def convert_to_openai_messages(
|
||||
# Anthropic
|
||||
if source := block.get("source"):
|
||||
if missing := [
|
||||
k
|
||||
for k in ("media_type", "type", "data")
|
||||
if k not in source
|
||||
k for k in ("media_type", "type", "data") if k not in source
|
||||
]:
|
||||
err = (
|
||||
f"Unrecognized content block at "
|
||||
@@ -1192,9 +1189,7 @@ def convert_to_openai_messages(
|
||||
"text": json.dumps(block["json"]),
|
||||
}
|
||||
)
|
||||
elif (
|
||||
block.get("type") == "guard_content"
|
||||
) or "guard_content" in block:
|
||||
elif (block.get("type") == "guard_content") or "guard_content" in block:
|
||||
if (
|
||||
"guard_content" not in block
|
||||
or "text" not in block["guard_content"]
|
||||
@@ -1213,9 +1208,7 @@ def convert_to_openai_messages(
|
||||
content.append({"type": "text", "text": text})
|
||||
# VertexAI format
|
||||
elif block.get("type") == "media":
|
||||
if missing := [
|
||||
k for k in ("mime_type", "data") if k not in block
|
||||
]:
|
||||
if missing := [k for k in ("mime_type", "data") if k not in block]:
|
||||
err = (
|
||||
f"Unrecognized content block at "
|
||||
f"messages[{i}].content[{j}] has 'type': "
|
||||
@@ -1235,9 +1228,7 @@ def convert_to_openai_messages(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": (
|
||||
f"data:{block['mime_type']};base64,{b64_image}"
|
||||
)
|
||||
"url": (f"data:{block['mime_type']};base64,{b64_image}")
|
||||
},
|
||||
}
|
||||
)
|
||||
|
@@ -118,12 +118,9 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
|
||||
}
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
else:
|
||||
if self.args_only:
|
||||
elif self.args_only:
|
||||
try:
|
||||
return json.loads(
|
||||
function_call["arguments"], strict=self.strict
|
||||
)
|
||||
return json.loads(function_call["arguments"], strict=self.strict)
|
||||
except (json.JSONDecodeError, TypeError) as exc:
|
||||
msg = f"Could not parse function call data: {exc}"
|
||||
raise OutputParserException(msg) from exc
|
||||
|
@@ -9,10 +9,9 @@ from typing import Any, Callable, Literal
|
||||
|
||||
from pydantic import BaseModel, create_model
|
||||
|
||||
import langchain_core.utils.mustache as mustache
|
||||
from langchain_core.prompt_values import PromptValue, StringPromptValue
|
||||
from langchain_core.prompts.base import BasePromptTemplate
|
||||
from langchain_core.utils import get_colored_text
|
||||
from langchain_core.utils import get_colored_text, mustache
|
||||
from langchain_core.utils.formatting import formatter
|
||||
from langchain_core.utils.interactive_env import is_interactive_env
|
||||
|
||||
|
@@ -351,9 +351,7 @@ class Graph:
|
||||
"""
|
||||
self.nodes.pop(node.id)
|
||||
self.edges = [
|
||||
edge
|
||||
for edge in self.edges
|
||||
if edge.source != node.id and edge.target != node.id
|
||||
edge for edge in self.edges if node.id not in (edge.source, edge.target)
|
||||
]
|
||||
|
||||
def add_edge(
|
||||
|
@@ -401,7 +401,7 @@ def _render_mermaid_using_api(
|
||||
f"?type={file_type}&bgColor={background_color}"
|
||||
)
|
||||
response = requests.get(image_url, timeout=10)
|
||||
if response.status_code == 200:
|
||||
if response.status_code == requests.codes.ok:
|
||||
img_bytes = response.content
|
||||
if output_file_path is not None:
|
||||
Path(output_file_path).write_bytes(response.content)
|
||||
|
@@ -79,8 +79,7 @@ class RootListenersTracer(BaseTracer):
|
||||
if run.error is None:
|
||||
if self._arg_on_end is not None:
|
||||
call_func_with_variable_args(self._arg_on_end, run, self.config)
|
||||
else:
|
||||
if self._arg_on_error is not None:
|
||||
elif self._arg_on_error is not None:
|
||||
call_func_with_variable_args(self._arg_on_error, run, self.config)
|
||||
|
||||
|
||||
@@ -143,8 +142,5 @@ class AsyncRootListenersTracer(AsyncBaseTracer):
|
||||
if run.error is None:
|
||||
if self._arg_on_end is not None:
|
||||
await acall_func_with_variable_args(self._arg_on_end, run, self.config)
|
||||
else:
|
||||
if self._arg_on_error is not None:
|
||||
await acall_func_with_variable_args(
|
||||
self._arg_on_error, run, self.config
|
||||
)
|
||||
elif self._arg_on_error is not None:
|
||||
await acall_func_with_variable_args(self._arg_on_error, run, self.config)
|
||||
|
@@ -7,6 +7,8 @@ from langchain_core.tracers.base import BaseTracer
|
||||
from langchain_core.tracers.schemas import Run
|
||||
from langchain_core.utils.input import get_bolded_text, get_colored_text
|
||||
|
||||
MILLISECONDS_IN_SECOND = 1000
|
||||
|
||||
|
||||
def try_json_stringify(obj: Any, fallback: str) -> str:
|
||||
"""Try to stringify an object to JSON.
|
||||
@@ -36,10 +38,10 @@ def elapsed(run: Any) -> str:
|
||||
|
||||
"""
|
||||
elapsed_time = run.end_time - run.start_time
|
||||
milliseconds = elapsed_time.total_seconds() * 1000
|
||||
if milliseconds < 1000:
|
||||
return f"{milliseconds:.0f}ms"
|
||||
return f"{(milliseconds / 1000):.2f}s"
|
||||
seconds = elapsed_time.total_seconds()
|
||||
if seconds < 1:
|
||||
return f"{seconds * MILLISECONDS_IN_SECOND:.0f}ms"
|
||||
return f"{seconds:.2f}s"
|
||||
|
||||
|
||||
class FunctionCallbackHandler(BaseTracer):
|
||||
|
@@ -85,7 +85,7 @@ def extract_sub_links(
|
||||
try:
|
||||
parsed_link = urlparse(link)
|
||||
# Some may be absolute links like https://to/path
|
||||
if parsed_link.scheme == "http" or parsed_link.scheme == "https":
|
||||
if parsed_link.scheme in {"http", "https"}:
|
||||
absolute_path = link
|
||||
# Some may have omitted the protocol like //to/path
|
||||
elif link.startswith("//"):
|
||||
|
@@ -78,15 +78,14 @@ def parse_partial_json(s: str, *, strict: bool = False) -> Any:
|
||||
escaped = not escaped
|
||||
else:
|
||||
escaped = False
|
||||
else:
|
||||
if char == '"':
|
||||
elif char == '"':
|
||||
is_inside_string = True
|
||||
escaped = False
|
||||
elif char == "{":
|
||||
stack.append("}")
|
||||
elif char == "[":
|
||||
stack.append("]")
|
||||
elif char == "}" or char == "]":
|
||||
elif char in {"}", "]"}:
|
||||
if stack and stack[-1] == char:
|
||||
stack.pop()
|
||||
else:
|
||||
|
@@ -459,8 +459,7 @@ def render(
|
||||
# Then we don't need to tokenize it
|
||||
# But it does need to be a generator
|
||||
tokens: Iterator[tuple[str, str]] = (token for token in template)
|
||||
else:
|
||||
if template in g_token_cache:
|
||||
elif template in g_token_cache:
|
||||
tokens = (token for token in g_token_cache[template])
|
||||
else:
|
||||
# Otherwise make a generator
|
||||
|
@@ -103,7 +103,7 @@ ignore = [
|
||||
"FBT001",
|
||||
"FBT002",
|
||||
"PGH003",
|
||||
"PLR",
|
||||
"PLR2004",
|
||||
"RUF",
|
||||
"SLF",
|
||||
]
|
||||
|
@@ -112,11 +112,8 @@ def pytest_collection_modifyitems(
|
||||
pytest.mark.skip(reason=f"Requires pkg: `{pkg}`")
|
||||
)
|
||||
break
|
||||
else:
|
||||
if only_extended:
|
||||
item.add_marker(
|
||||
pytest.mark.skip(reason="Skipping not an extended test.")
|
||||
)
|
||||
elif only_extended:
|
||||
item.add_marker(pytest.mark.skip(reason="Skipping not an extended test."))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@@ -1,4 +1,4 @@
|
||||
import langchain_core.tracers.schemas as schemas
|
||||
from langchain_core.tracers import schemas
|
||||
from langchain_core.tracers.schemas import __all__ as schemas_all
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user