chore(langchain): add ruff rules TC (#31921)

See https://docs.astral.sh/ruff/rules/#flake8-type-checking-tc
This commit is contained in:
Christophe Bornet 2025-07-27 00:27:26 +02:00 committed by GitHub
parent 5ecbb5f277
commit a2ad5aca41
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 54 additions and 53 deletions

View File

@ -1185,7 +1185,7 @@ class AgentExecutor(Chain):
to reflect the changes made in the root_validator.
"""
if isinstance(self.agent, Runnable):
return cast(RunnableAgentType, self.agent)
return cast("RunnableAgentType", self.agent)
return self.agent
def save(self, file_path: Union[Path, str]) -> None:

View File

@ -73,7 +73,7 @@ class AsyncIteratorCallbackHandler(AsyncCallbackHandler):
other.pop().cancel()
# Extract the value of the first completed task
token_or_done = cast(Union[str, Literal[True]], done.pop().result())
token_or_done = cast("Union[str, Literal[True]]", done.pop().result())
# If the extracted value is the boolean True, the done event was set
if token_or_done is True:

View File

@ -411,7 +411,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
return self.invoke(
inputs,
cast(RunnableConfig, {k: v for k, v in config.items() if v is not None}),
cast("RunnableConfig", {k: v for k, v in config.items() if v is not None}),
return_only_outputs=return_only_outputs,
include_run_info=include_run_info,
)
@ -461,7 +461,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
}
return await self.ainvoke(
inputs,
cast(RunnableConfig, {k: v for k, v in config.items() if k is not None}),
cast("RunnableConfig", {k: v for k, v in config.items() if k is not None}),
return_only_outputs=return_only_outputs,
include_run_info=include_run_info,
)

View File

@ -229,7 +229,7 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
docs: list[Document],
results: Sequence[Union[str, list[str], dict[str, str]]],
) -> tuple[str, dict]:
typed_results = cast(list[dict], results)
typed_results = cast("list[dict]", results)
sorted_res = sorted(
zip(typed_results, docs),
key=lambda x: -int(x[0][self.rank_key]),

View File

@ -145,7 +145,7 @@ class LLMChain(Chain):
**self.llm_kwargs,
)
results = self.llm.bind(stop=stop, **self.llm_kwargs).batch(
cast(list, prompts),
cast("list", prompts),
{"callbacks": callbacks},
)
generations: list[list[Generation]] = []
@ -172,7 +172,7 @@ class LLMChain(Chain):
**self.llm_kwargs,
)
results = await self.llm.bind(stop=stop, **self.llm_kwargs).abatch(
cast(list, prompts),
cast("list", prompts),
{"callbacks": callbacks},
)
generations: list[list[Generation]] = []

View File

@ -76,11 +76,11 @@ def create_qa_with_structure_chain(
raise ValueError(msg)
if isinstance(schema, type) and is_basemodel_subclass(schema):
if hasattr(schema, "model_json_schema"):
schema_dict = cast(dict, schema.model_json_schema())
schema_dict = cast("dict", schema.model_json_schema())
else:
schema_dict = cast(dict, schema.schema())
schema_dict = cast("dict", schema.schema())
else:
schema_dict = cast(dict, schema)
schema_dict = cast("dict", schema)
function = {
"name": schema_dict["title"],
"description": schema_dict["description"],

View File

@ -91,7 +91,7 @@ class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
def ast_parse(raw_filter: str) -> Optional[FilterDirective]:
filter_directive = cast(
Optional[FilterDirective],
"Optional[FilterDirective]",
get_parser().parse(raw_filter),
)
return fix_filter_directive(
@ -144,7 +144,7 @@ def fix_filter_directive(
return None
args = [
cast(
FilterDirective,
"FilterDirective",
fix_filter_directive(
arg,
allowed_comparators=allowed_comparators,

View File

@ -137,7 +137,7 @@ class LLMRouterChain(RouterChain):
prediction = self.llm_chain.predict(callbacks=callbacks, **inputs)
return cast(
dict[str, Any],
"dict[str, Any]",
self.llm_chain.prompt.output_parser.parse(prediction),
)
@ -149,7 +149,7 @@ class LLMRouterChain(RouterChain):
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
return cast(
dict[str, Any],
"dict[str, Any]",
await self.llm_chain.apredict_and_parse(callbacks=callbacks, **inputs),
)

View File

@ -322,7 +322,7 @@ def init_chat_model(
if not configurable_fields:
return _init_chat_model_helper(
cast(str, model),
cast("str", model),
model_provider=model_provider,
**kwargs,
)
@ -632,7 +632,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
**kwargs: Any,
) -> _ConfigurableModel:
"""Bind config to a Runnable, returning a new Runnable."""
config = RunnableConfig(**(config or {}), **cast(RunnableConfig, kwargs))
config = RunnableConfig(**(config or {}), **cast("RunnableConfig", kwargs))
model_params = self._model_params(config)
remaining_config = {k: v for k, v in config.items() if k != "configurable"}
remaining_config["configurable"] = {
@ -781,7 +781,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
if config is None or isinstance(config, dict) or len(config) <= 1:
if isinstance(config, list):
config = config[0]
yield from self._model(cast(RunnableConfig, config)).batch_as_completed( # type: ignore[call-overload]
yield from self._model(cast("RunnableConfig", config)).batch_as_completed( # type: ignore[call-overload]
inputs,
config=config,
return_exceptions=return_exceptions,
@ -811,7 +811,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
if isinstance(config, list):
config = config[0]
async for x in self._model(
cast(RunnableConfig, config),
cast("RunnableConfig", config),
).abatch_as_completed( # type: ignore[call-overload]
inputs,
config=config,

View File

@ -1,12 +1,9 @@
from typing import TYPE_CHECKING, Any
from typing import Any
from langchain_core.document_loaders import Blob, BlobLoader
from langchain._api import create_importer
if TYPE_CHECKING:
pass
# Create a way to dynamically look up deprecated imports.
# Used to consolidate logic for raising deprecation warnings and
# handling optional imports.

View File

@ -80,7 +80,7 @@ def _value_serializer(value: Sequence[float]) -> bytes:
def _value_deserializer(serialized_value: bytes) -> list[float]:
"""Deserialize a value."""
return cast(list[float], json.loads(serialized_value.decode()))
return cast("list[float]", json.loads(serialized_value.decode()))
# The warning is global; track emission, so it appears only once.
@ -192,7 +192,7 @@ class CacheBackedEmbeddings(Embeddings):
vectors[index] = updated_vector
return cast(
list[list[float]],
"list[list[float]]",
vectors,
) # Nones should have been resolved by now
@ -230,7 +230,7 @@ class CacheBackedEmbeddings(Embeddings):
vectors[index] = updated_vector
return cast(
list[list[float]],
"list[list[float]]",
vectors,
) # Nones should have been resolved by now

View File

@ -301,7 +301,7 @@ The following is the expected answer. Use this to measure correctness:
chain_input,
callbacks=_run_manager.get_child(),
)
return cast(dict, self.output_parser.parse(raw_output))
return cast("dict", self.output_parser.parse(raw_output))
async def _acall(
self,
@ -326,7 +326,7 @@ The following is the expected answer. Use this to measure correctness:
chain_input,
callbacks=_run_manager.get_child(),
)
return cast(dict, self.output_parser.parse(raw_output))
return cast("dict", self.output_parser.parse(raw_output))
@override
def _evaluate_agent_trajectory(

View File

@ -166,7 +166,7 @@ class JsonEqualityEvaluator(StringEvaluator):
dict: A dictionary containing the evaluation score.
"""
parsed = self._parse_json(prediction)
label = self._parse_json(cast(str, reference))
label = self._parse_json(cast("str", reference))
if isinstance(label, list):
if not isinstance(parsed, list):
return {"score": 0}

View File

@ -82,7 +82,7 @@ class LLMChainExtractor(BaseDocumentCompressor):
if len(output) == 0:
continue
compressed_docs.append(
Document(page_content=cast(str, output), metadata=doc.metadata),
Document(page_content=cast("str", output), metadata=doc.metadata),
)
return compressed_docs

View File

@ -236,7 +236,7 @@ class EnsembleRetriever(BaseRetriever):
# Enforce that retrieved docs are Documents for each list in retriever_docs
for i in range(len(retriever_docs)):
retriever_docs[i] = [
Document(page_content=cast(str, doc)) if isinstance(doc, str) else doc
Document(page_content=cast("str", doc)) if isinstance(doc, str) else doc
for doc in retriever_docs[i]
]

View File

@ -214,24 +214,24 @@ def _wrap_in_chain_factory(
return lambda: lcf
if callable(llm_or_chain_factory):
if is_traceable_function(llm_or_chain_factory):
runnable_ = as_runnable(cast(Callable, llm_or_chain_factory))
runnable_ = as_runnable(cast("Callable", llm_or_chain_factory))
return lambda: runnable_
try:
_model = llm_or_chain_factory() # type: ignore[call-arg]
except TypeError:
# It's an arbitrary function, wrap it in a RunnableLambda
user_func = cast(Callable, llm_or_chain_factory)
user_func = cast("Callable", llm_or_chain_factory)
sig = inspect.signature(user_func)
logger.info("Wrapping function %s as RunnableLambda.", sig)
wrapped = RunnableLambda(user_func)
return lambda: wrapped
constructor = cast(Callable, llm_or_chain_factory)
constructor = cast("Callable", llm_or_chain_factory)
if isinstance(_model, BaseLanguageModel):
# It's not uncommon to do an LLM constructor instead of raw LLM,
# so we'll unpack it for the user.
return _model
if is_traceable_function(cast(Callable, _model)):
runnable_ = as_runnable(cast(Callable, _model))
if is_traceable_function(cast("Callable", _model)):
runnable_ = as_runnable(cast("Callable", _model))
return lambda: runnable_
if not isinstance(_model, Runnable):
# This is unlikely to happen - a constructor for a model function
@ -1104,7 +1104,7 @@ class _DatasetRunContainer:
) -> dict:
results: dict = {}
for example, output in zip(self.examples, batch_results):
row_result = cast(_RowResult, all_eval_results.get(str(example.id), {}))
row_result = cast("_RowResult", all_eval_results.get(str(example.id), {}))
results[str(example.id)] = {
"input": example.inputs,
"feedback": row_result.get("feedback", []),
@ -1131,7 +1131,7 @@ class _DatasetRunContainer:
result = evaluator(runs_list, self.examples)
if isinstance(result, EvaluationResult):
result = result.dict()
aggregate_feedback.append(cast(dict, result))
aggregate_feedback.append(cast("dict", result))
executor.submit(
self.client.create_feedback,
**result,
@ -1148,7 +1148,7 @@ class _DatasetRunContainer:
all_eval_results: dict = {}
all_runs: dict = {}
for c in self.configs:
for callback in cast(list, c["callbacks"]):
for callback in cast("list", c["callbacks"]):
if isinstance(callback, EvaluatorCallbackHandler):
eval_results = callback.logged_eval_results
for (_, example_id), v in eval_results.items():
@ -1171,7 +1171,7 @@ class _DatasetRunContainer:
},
)
all_runs[str(callback.example_id)] = run
return cast(dict[str, _RowResult], all_eval_results), all_runs
return cast("dict[str, _RowResult]", all_eval_results), all_runs
def _collect_test_results(
self,

View File

@ -145,8 +145,8 @@ ignore-words-list = "momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogy
[tool.ruff.lint]
select = [
"A", # flake8-builtins
"B", # flake8-bugbear
"ASYNC", # flake8-async
"B", # flake8-bugbear
"C4", # flake8-comprehensions
"COM", # flake8-commas
"D1", # pydocstyle: missing docstring
@ -174,11 +174,12 @@ select = [
"RSE", # flake8-rst-docstrings
"RUF", # ruff
"S", # flake8-bandit
"SLOT", # flake8-slots
"SIM", # flake8-simplify
"SLF", # flake8-self
"SLOT", # flake8-slots
"T10", # flake8-debugger
"T20", # flake8-print
"TC", # flake8-type-checking
"TID", # flake8-tidy-imports
"TRY", # tryceratops
"UP", # pyupgrade
@ -192,6 +193,9 @@ ignore = [
"PLR09", # Too many something (args, statements, etc)
"S112", # Rarely useful
"RUF012", # Doesn't play well with Pydantic
"TC001", # Doesn't play well with Pydantic
"TC002", # Doesn't play well with Pydantic
"TC003", # Doesn't play well with Pydantic
"UP007", # pyupgrade: non-pep604-annotation-union
# TODO rules

View File

@ -40,7 +40,7 @@ async def test_init_chat_model_chain() -> None:
class TestStandard(ChatModelIntegrationTests):
@property
def chat_model_class(self) -> type[BaseChatModel]:
return cast(type[BaseChatModel], init_chat_model)
return cast("type[BaseChatModel]", init_chat_model)
@property
def chat_model_params(self) -> dict:

View File

@ -571,7 +571,7 @@ async def test_runnable_agent_with_function_calls() -> None:
def fake_parse(inputs: dict) -> Union[AgentFinish, AgentAction]:
"""A parser."""
return cast(Union[AgentFinish, AgentAction], next(parser_responses))
return cast("Union[AgentFinish, AgentAction]", next(parser_responses))
@tool
def find_pet(pet: str) -> str:
@ -683,7 +683,7 @@ async def test_runnable_with_multi_action_per_step() -> None:
def fake_parse(inputs: dict) -> Union[AgentFinish, AgentAction]:
"""A parser."""
return cast(Union[AgentFinish, AgentAction], next(parser_responses))
return cast("Union[AgentFinish, AgentAction]", next(parser_responses))
@tool
def find_pet(pet: str) -> str:

View File

@ -82,7 +82,7 @@ def test_parse_disallowed_operator() -> None:
def _test_parse_value(x: Any) -> None:
parsed = cast(Comparison, (DEFAULT_PARSER.parse(f'eq("x", {x})')))
parsed = cast("Comparison", (DEFAULT_PARSER.parse(f'eq("x", {x})')))
actual = parsed.value
assert actual == x
@ -104,14 +104,14 @@ def test_parse_list_value(x: list) -> None:
@pytest.mark.parametrize("x", ['""', '" "', '"foo"', "'foo'"])
def test_parse_string_value(x: str) -> None:
parsed = cast(Comparison, DEFAULT_PARSER.parse(f'eq("x", {x})'))
parsed = cast("Comparison", DEFAULT_PARSER.parse(f'eq("x", {x})'))
actual = parsed.value
assert actual == x[1:-1]
@pytest.mark.parametrize("x", ["true", "True", "TRUE", "false", "False", "FALSE"])
def test_parse_bool_value(x: str) -> None:
parsed = cast(Comparison, DEFAULT_PARSER.parse(f'eq("x", {x})'))
parsed = cast("Comparison", DEFAULT_PARSER.parse(f'eq("x", {x})'))
actual = parsed.value
expected = x.lower() == "true"
assert actual == expected
@ -127,7 +127,7 @@ def test_parser_unpack_single_arg_operation(op: str, arg: str) -> None:
@pytest.mark.parametrize("x", ['"2022-10-20"', "'2022-10-20'", "2022-10-20"])
def test_parse_date_value(x: str) -> None:
parsed = cast(Comparison, DEFAULT_PARSER.parse(f'eq("x", {x})'))
parsed = cast("Comparison", DEFAULT_PARSER.parse(f'eq("x", {x})'))
actual = parsed.value["date"]
assert actual == x.strip("'\"")
@ -152,7 +152,7 @@ def test_parse_date_value(x: str) -> None:
def test_parse_datetime_value(x: str, expected: dict) -> None:
"""Test parsing of datetime values with ISO 8601 format."""
try:
parsed = cast(Comparison, DEFAULT_PARSER.parse(f'eq("publishedAt", {x})'))
parsed = cast("Comparison", DEFAULT_PARSER.parse(f'eq("publishedAt", {x})'))
actual = parsed.value
assert actual == expected, f"Expected {expected}, got {actual}"
except ValueError as e:

View File

@ -122,7 +122,7 @@ class GenericFakeChatModel(BaseChatModel):
# Use a regular expression to split on whitespace with a capture group
# so that we can preserve the whitespace in the output.
assert isinstance(content, str)
content_chunks = cast(list[str], re.split(r"(\s)", content))
content_chunks = cast("list[str]", re.split(r"(\s)", content))
for token in content_chunks:
chunk = ChatGenerationChunk(
@ -140,7 +140,7 @@ class GenericFakeChatModel(BaseChatModel):
for fkey, fvalue in value.items():
if isinstance(fvalue, str):
# Break function call by `,`
fvalue_chunks = cast(list[str], re.split(r"(,)", fvalue))
fvalue_chunks = cast("list[str]", re.split(r"(,)", fvalue))
for fvalue_chunk in fvalue_chunks:
chunk = ChatGenerationChunk(
message=AIMessageChunk(

View File

@ -53,7 +53,7 @@ class FakeLLM(LLM):
@property
def _get_next_response_in_sequence(self) -> str:
queries = cast(Mapping, self.queries)
queries = cast("Mapping", self.queries)
response = queries[list(queries.keys())[self.response_index]]
self.response_index = self.response_index + 1
return response

View File

@ -22,7 +22,7 @@ def test_create_lc_store(file_store: LocalFileStore) -> None:
"""Test that a docstore is created from a base store."""
docstore = create_lc_store(file_store)
docstore.mset([("key1", Document(page_content="hello", metadata={"key": "value"}))])
fetched_doc = cast(Document, docstore.mget(["key1"])[0])
fetched_doc = cast("Document", docstore.mget(["key1"])[0])
assert fetched_doc.page_content == "hello"
assert fetched_doc.metadata == {"key": "value"}