mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-08 04:25:46 +00:00
chore(langchain): add ruff rules TC (#31921)
See https://docs.astral.sh/ruff/rules/#flake8-type-checking-tc
This commit is contained in:
parent
5ecbb5f277
commit
a2ad5aca41
@ -1185,7 +1185,7 @@ class AgentExecutor(Chain):
|
||||
to reflect the changes made in the root_validator.
|
||||
"""
|
||||
if isinstance(self.agent, Runnable):
|
||||
return cast(RunnableAgentType, self.agent)
|
||||
return cast("RunnableAgentType", self.agent)
|
||||
return self.agent
|
||||
|
||||
def save(self, file_path: Union[Path, str]) -> None:
|
||||
|
@ -73,7 +73,7 @@ class AsyncIteratorCallbackHandler(AsyncCallbackHandler):
|
||||
other.pop().cancel()
|
||||
|
||||
# Extract the value of the first completed task
|
||||
token_or_done = cast(Union[str, Literal[True]], done.pop().result())
|
||||
token_or_done = cast("Union[str, Literal[True]]", done.pop().result())
|
||||
|
||||
# If the extracted value is the boolean True, the done event was set
|
||||
if token_or_done is True:
|
||||
|
@ -411,7 +411,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
|
||||
|
||||
return self.invoke(
|
||||
inputs,
|
||||
cast(RunnableConfig, {k: v for k, v in config.items() if v is not None}),
|
||||
cast("RunnableConfig", {k: v for k, v in config.items() if v is not None}),
|
||||
return_only_outputs=return_only_outputs,
|
||||
include_run_info=include_run_info,
|
||||
)
|
||||
@ -461,7 +461,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
|
||||
}
|
||||
return await self.ainvoke(
|
||||
inputs,
|
||||
cast(RunnableConfig, {k: v for k, v in config.items() if k is not None}),
|
||||
cast("RunnableConfig", {k: v for k, v in config.items() if k is not None}),
|
||||
return_only_outputs=return_only_outputs,
|
||||
include_run_info=include_run_info,
|
||||
)
|
||||
|
@ -229,7 +229,7 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
|
||||
docs: list[Document],
|
||||
results: Sequence[Union[str, list[str], dict[str, str]]],
|
||||
) -> tuple[str, dict]:
|
||||
typed_results = cast(list[dict], results)
|
||||
typed_results = cast("list[dict]", results)
|
||||
sorted_res = sorted(
|
||||
zip(typed_results, docs),
|
||||
key=lambda x: -int(x[0][self.rank_key]),
|
||||
|
@ -145,7 +145,7 @@ class LLMChain(Chain):
|
||||
**self.llm_kwargs,
|
||||
)
|
||||
results = self.llm.bind(stop=stop, **self.llm_kwargs).batch(
|
||||
cast(list, prompts),
|
||||
cast("list", prompts),
|
||||
{"callbacks": callbacks},
|
||||
)
|
||||
generations: list[list[Generation]] = []
|
||||
@ -172,7 +172,7 @@ class LLMChain(Chain):
|
||||
**self.llm_kwargs,
|
||||
)
|
||||
results = await self.llm.bind(stop=stop, **self.llm_kwargs).abatch(
|
||||
cast(list, prompts),
|
||||
cast("list", prompts),
|
||||
{"callbacks": callbacks},
|
||||
)
|
||||
generations: list[list[Generation]] = []
|
||||
|
@ -76,11 +76,11 @@ def create_qa_with_structure_chain(
|
||||
raise ValueError(msg)
|
||||
if isinstance(schema, type) and is_basemodel_subclass(schema):
|
||||
if hasattr(schema, "model_json_schema"):
|
||||
schema_dict = cast(dict, schema.model_json_schema())
|
||||
schema_dict = cast("dict", schema.model_json_schema())
|
||||
else:
|
||||
schema_dict = cast(dict, schema.schema())
|
||||
schema_dict = cast("dict", schema.schema())
|
||||
else:
|
||||
schema_dict = cast(dict, schema)
|
||||
schema_dict = cast("dict", schema)
|
||||
function = {
|
||||
"name": schema_dict["title"],
|
||||
"description": schema_dict["description"],
|
||||
|
@ -91,7 +91,7 @@ class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
|
||||
|
||||
def ast_parse(raw_filter: str) -> Optional[FilterDirective]:
|
||||
filter_directive = cast(
|
||||
Optional[FilterDirective],
|
||||
"Optional[FilterDirective]",
|
||||
get_parser().parse(raw_filter),
|
||||
)
|
||||
return fix_filter_directive(
|
||||
@ -144,7 +144,7 @@ def fix_filter_directive(
|
||||
return None
|
||||
args = [
|
||||
cast(
|
||||
FilterDirective,
|
||||
"FilterDirective",
|
||||
fix_filter_directive(
|
||||
arg,
|
||||
allowed_comparators=allowed_comparators,
|
||||
|
@ -137,7 +137,7 @@ class LLMRouterChain(RouterChain):
|
||||
|
||||
prediction = self.llm_chain.predict(callbacks=callbacks, **inputs)
|
||||
return cast(
|
||||
dict[str, Any],
|
||||
"dict[str, Any]",
|
||||
self.llm_chain.prompt.output_parser.parse(prediction),
|
||||
)
|
||||
|
||||
@ -149,7 +149,7 @@ class LLMRouterChain(RouterChain):
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
callbacks = _run_manager.get_child()
|
||||
return cast(
|
||||
dict[str, Any],
|
||||
"dict[str, Any]",
|
||||
await self.llm_chain.apredict_and_parse(callbacks=callbacks, **inputs),
|
||||
)
|
||||
|
||||
|
@ -322,7 +322,7 @@ def init_chat_model(
|
||||
|
||||
if not configurable_fields:
|
||||
return _init_chat_model_helper(
|
||||
cast(str, model),
|
||||
cast("str", model),
|
||||
model_provider=model_provider,
|
||||
**kwargs,
|
||||
)
|
||||
@ -632,7 +632,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
**kwargs: Any,
|
||||
) -> _ConfigurableModel:
|
||||
"""Bind config to a Runnable, returning a new Runnable."""
|
||||
config = RunnableConfig(**(config or {}), **cast(RunnableConfig, kwargs))
|
||||
config = RunnableConfig(**(config or {}), **cast("RunnableConfig", kwargs))
|
||||
model_params = self._model_params(config)
|
||||
remaining_config = {k: v for k, v in config.items() if k != "configurable"}
|
||||
remaining_config["configurable"] = {
|
||||
@ -781,7 +781,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
if config is None or isinstance(config, dict) or len(config) <= 1:
|
||||
if isinstance(config, list):
|
||||
config = config[0]
|
||||
yield from self._model(cast(RunnableConfig, config)).batch_as_completed( # type: ignore[call-overload]
|
||||
yield from self._model(cast("RunnableConfig", config)).batch_as_completed( # type: ignore[call-overload]
|
||||
inputs,
|
||||
config=config,
|
||||
return_exceptions=return_exceptions,
|
||||
@ -811,7 +811,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
if isinstance(config, list):
|
||||
config = config[0]
|
||||
async for x in self._model(
|
||||
cast(RunnableConfig, config),
|
||||
cast("RunnableConfig", config),
|
||||
).abatch_as_completed( # type: ignore[call-overload]
|
||||
inputs,
|
||||
config=config,
|
||||
|
@ -1,12 +1,9 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.document_loaders import Blob, BlobLoader
|
||||
|
||||
from langchain._api import create_importer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
|
@ -80,7 +80,7 @@ def _value_serializer(value: Sequence[float]) -> bytes:
|
||||
|
||||
def _value_deserializer(serialized_value: bytes) -> list[float]:
|
||||
"""Deserialize a value."""
|
||||
return cast(list[float], json.loads(serialized_value.decode()))
|
||||
return cast("list[float]", json.loads(serialized_value.decode()))
|
||||
|
||||
|
||||
# The warning is global; track emission, so it appears only once.
|
||||
@ -192,7 +192,7 @@ class CacheBackedEmbeddings(Embeddings):
|
||||
vectors[index] = updated_vector
|
||||
|
||||
return cast(
|
||||
list[list[float]],
|
||||
"list[list[float]]",
|
||||
vectors,
|
||||
) # Nones should have been resolved by now
|
||||
|
||||
@ -230,7 +230,7 @@ class CacheBackedEmbeddings(Embeddings):
|
||||
vectors[index] = updated_vector
|
||||
|
||||
return cast(
|
||||
list[list[float]],
|
||||
"list[list[float]]",
|
||||
vectors,
|
||||
) # Nones should have been resolved by now
|
||||
|
||||
|
@ -301,7 +301,7 @@ The following is the expected answer. Use this to measure correctness:
|
||||
chain_input,
|
||||
callbacks=_run_manager.get_child(),
|
||||
)
|
||||
return cast(dict, self.output_parser.parse(raw_output))
|
||||
return cast("dict", self.output_parser.parse(raw_output))
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
@ -326,7 +326,7 @@ The following is the expected answer. Use this to measure correctness:
|
||||
chain_input,
|
||||
callbacks=_run_manager.get_child(),
|
||||
)
|
||||
return cast(dict, self.output_parser.parse(raw_output))
|
||||
return cast("dict", self.output_parser.parse(raw_output))
|
||||
|
||||
@override
|
||||
def _evaluate_agent_trajectory(
|
||||
|
@ -166,7 +166,7 @@ class JsonEqualityEvaluator(StringEvaluator):
|
||||
dict: A dictionary containing the evaluation score.
|
||||
"""
|
||||
parsed = self._parse_json(prediction)
|
||||
label = self._parse_json(cast(str, reference))
|
||||
label = self._parse_json(cast("str", reference))
|
||||
if isinstance(label, list):
|
||||
if not isinstance(parsed, list):
|
||||
return {"score": 0}
|
||||
|
@ -82,7 +82,7 @@ class LLMChainExtractor(BaseDocumentCompressor):
|
||||
if len(output) == 0:
|
||||
continue
|
||||
compressed_docs.append(
|
||||
Document(page_content=cast(str, output), metadata=doc.metadata),
|
||||
Document(page_content=cast("str", output), metadata=doc.metadata),
|
||||
)
|
||||
return compressed_docs
|
||||
|
||||
|
@ -236,7 +236,7 @@ class EnsembleRetriever(BaseRetriever):
|
||||
# Enforce that retrieved docs are Documents for each list in retriever_docs
|
||||
for i in range(len(retriever_docs)):
|
||||
retriever_docs[i] = [
|
||||
Document(page_content=cast(str, doc)) if isinstance(doc, str) else doc
|
||||
Document(page_content=cast("str", doc)) if isinstance(doc, str) else doc
|
||||
for doc in retriever_docs[i]
|
||||
]
|
||||
|
||||
|
@ -214,24 +214,24 @@ def _wrap_in_chain_factory(
|
||||
return lambda: lcf
|
||||
if callable(llm_or_chain_factory):
|
||||
if is_traceable_function(llm_or_chain_factory):
|
||||
runnable_ = as_runnable(cast(Callable, llm_or_chain_factory))
|
||||
runnable_ = as_runnable(cast("Callable", llm_or_chain_factory))
|
||||
return lambda: runnable_
|
||||
try:
|
||||
_model = llm_or_chain_factory() # type: ignore[call-arg]
|
||||
except TypeError:
|
||||
# It's an arbitrary function, wrap it in a RunnableLambda
|
||||
user_func = cast(Callable, llm_or_chain_factory)
|
||||
user_func = cast("Callable", llm_or_chain_factory)
|
||||
sig = inspect.signature(user_func)
|
||||
logger.info("Wrapping function %s as RunnableLambda.", sig)
|
||||
wrapped = RunnableLambda(user_func)
|
||||
return lambda: wrapped
|
||||
constructor = cast(Callable, llm_or_chain_factory)
|
||||
constructor = cast("Callable", llm_or_chain_factory)
|
||||
if isinstance(_model, BaseLanguageModel):
|
||||
# It's not uncommon to do an LLM constructor instead of raw LLM,
|
||||
# so we'll unpack it for the user.
|
||||
return _model
|
||||
if is_traceable_function(cast(Callable, _model)):
|
||||
runnable_ = as_runnable(cast(Callable, _model))
|
||||
if is_traceable_function(cast("Callable", _model)):
|
||||
runnable_ = as_runnable(cast("Callable", _model))
|
||||
return lambda: runnable_
|
||||
if not isinstance(_model, Runnable):
|
||||
# This is unlikely to happen - a constructor for a model function
|
||||
@ -1104,7 +1104,7 @@ class _DatasetRunContainer:
|
||||
) -> dict:
|
||||
results: dict = {}
|
||||
for example, output in zip(self.examples, batch_results):
|
||||
row_result = cast(_RowResult, all_eval_results.get(str(example.id), {}))
|
||||
row_result = cast("_RowResult", all_eval_results.get(str(example.id), {}))
|
||||
results[str(example.id)] = {
|
||||
"input": example.inputs,
|
||||
"feedback": row_result.get("feedback", []),
|
||||
@ -1131,7 +1131,7 @@ class _DatasetRunContainer:
|
||||
result = evaluator(runs_list, self.examples)
|
||||
if isinstance(result, EvaluationResult):
|
||||
result = result.dict()
|
||||
aggregate_feedback.append(cast(dict, result))
|
||||
aggregate_feedback.append(cast("dict", result))
|
||||
executor.submit(
|
||||
self.client.create_feedback,
|
||||
**result,
|
||||
@ -1148,7 +1148,7 @@ class _DatasetRunContainer:
|
||||
all_eval_results: dict = {}
|
||||
all_runs: dict = {}
|
||||
for c in self.configs:
|
||||
for callback in cast(list, c["callbacks"]):
|
||||
for callback in cast("list", c["callbacks"]):
|
||||
if isinstance(callback, EvaluatorCallbackHandler):
|
||||
eval_results = callback.logged_eval_results
|
||||
for (_, example_id), v in eval_results.items():
|
||||
@ -1171,7 +1171,7 @@ class _DatasetRunContainer:
|
||||
},
|
||||
)
|
||||
all_runs[str(callback.example_id)] = run
|
||||
return cast(dict[str, _RowResult], all_eval_results), all_runs
|
||||
return cast("dict[str, _RowResult]", all_eval_results), all_runs
|
||||
|
||||
def _collect_test_results(
|
||||
self,
|
||||
|
@ -145,8 +145,8 @@ ignore-words-list = "momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogy
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
"A", # flake8-builtins
|
||||
"B", # flake8-bugbear
|
||||
"ASYNC", # flake8-async
|
||||
"B", # flake8-bugbear
|
||||
"C4", # flake8-comprehensions
|
||||
"COM", # flake8-commas
|
||||
"D1", # pydocstyle: missing docstring
|
||||
@ -174,11 +174,12 @@ select = [
|
||||
"RSE", # flake8-rst-docstrings
|
||||
"RUF", # ruff
|
||||
"S", # flake8-bandit
|
||||
"SLOT", # flake8-slots
|
||||
"SIM", # flake8-simplify
|
||||
"SLF", # flake8-self
|
||||
"SLOT", # flake8-slots
|
||||
"T10", # flake8-debugger
|
||||
"T20", # flake8-print
|
||||
"TC", # flake8-type-checking
|
||||
"TID", # flake8-tidy-imports
|
||||
"TRY", # tryceratops
|
||||
"UP", # pyupgrade
|
||||
@ -192,6 +193,9 @@ ignore = [
|
||||
"PLR09", # Too many something (args, statements, etc)
|
||||
"S112", # Rarely useful
|
||||
"RUF012", # Doesn't play well with Pydantic
|
||||
"TC001", # Doesn't play well with Pydantic
|
||||
"TC002", # Doesn't play well with Pydantic
|
||||
"TC003", # Doesn't play well with Pydantic
|
||||
"UP007", # pyupgrade: non-pep604-annotation-union
|
||||
|
||||
# TODO rules
|
||||
|
@ -40,7 +40,7 @@ async def test_init_chat_model_chain() -> None:
|
||||
class TestStandard(ChatModelIntegrationTests):
|
||||
@property
|
||||
def chat_model_class(self) -> type[BaseChatModel]:
|
||||
return cast(type[BaseChatModel], init_chat_model)
|
||||
return cast("type[BaseChatModel]", init_chat_model)
|
||||
|
||||
@property
|
||||
def chat_model_params(self) -> dict:
|
||||
|
@ -571,7 +571,7 @@ async def test_runnable_agent_with_function_calls() -> None:
|
||||
|
||||
def fake_parse(inputs: dict) -> Union[AgentFinish, AgentAction]:
|
||||
"""A parser."""
|
||||
return cast(Union[AgentFinish, AgentAction], next(parser_responses))
|
||||
return cast("Union[AgentFinish, AgentAction]", next(parser_responses))
|
||||
|
||||
@tool
|
||||
def find_pet(pet: str) -> str:
|
||||
@ -683,7 +683,7 @@ async def test_runnable_with_multi_action_per_step() -> None:
|
||||
|
||||
def fake_parse(inputs: dict) -> Union[AgentFinish, AgentAction]:
|
||||
"""A parser."""
|
||||
return cast(Union[AgentFinish, AgentAction], next(parser_responses))
|
||||
return cast("Union[AgentFinish, AgentAction]", next(parser_responses))
|
||||
|
||||
@tool
|
||||
def find_pet(pet: str) -> str:
|
||||
|
@ -82,7 +82,7 @@ def test_parse_disallowed_operator() -> None:
|
||||
|
||||
|
||||
def _test_parse_value(x: Any) -> None:
|
||||
parsed = cast(Comparison, (DEFAULT_PARSER.parse(f'eq("x", {x})')))
|
||||
parsed = cast("Comparison", (DEFAULT_PARSER.parse(f'eq("x", {x})')))
|
||||
actual = parsed.value
|
||||
assert actual == x
|
||||
|
||||
@ -104,14 +104,14 @@ def test_parse_list_value(x: list) -> None:
|
||||
|
||||
@pytest.mark.parametrize("x", ['""', '" "', '"foo"', "'foo'"])
|
||||
def test_parse_string_value(x: str) -> None:
|
||||
parsed = cast(Comparison, DEFAULT_PARSER.parse(f'eq("x", {x})'))
|
||||
parsed = cast("Comparison", DEFAULT_PARSER.parse(f'eq("x", {x})'))
|
||||
actual = parsed.value
|
||||
assert actual == x[1:-1]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("x", ["true", "True", "TRUE", "false", "False", "FALSE"])
|
||||
def test_parse_bool_value(x: str) -> None:
|
||||
parsed = cast(Comparison, DEFAULT_PARSER.parse(f'eq("x", {x})'))
|
||||
parsed = cast("Comparison", DEFAULT_PARSER.parse(f'eq("x", {x})'))
|
||||
actual = parsed.value
|
||||
expected = x.lower() == "true"
|
||||
assert actual == expected
|
||||
@ -127,7 +127,7 @@ def test_parser_unpack_single_arg_operation(op: str, arg: str) -> None:
|
||||
|
||||
@pytest.mark.parametrize("x", ['"2022-10-20"', "'2022-10-20'", "2022-10-20"])
|
||||
def test_parse_date_value(x: str) -> None:
|
||||
parsed = cast(Comparison, DEFAULT_PARSER.parse(f'eq("x", {x})'))
|
||||
parsed = cast("Comparison", DEFAULT_PARSER.parse(f'eq("x", {x})'))
|
||||
actual = parsed.value["date"]
|
||||
assert actual == x.strip("'\"")
|
||||
|
||||
@ -152,7 +152,7 @@ def test_parse_date_value(x: str) -> None:
|
||||
def test_parse_datetime_value(x: str, expected: dict) -> None:
|
||||
"""Test parsing of datetime values with ISO 8601 format."""
|
||||
try:
|
||||
parsed = cast(Comparison, DEFAULT_PARSER.parse(f'eq("publishedAt", {x})'))
|
||||
parsed = cast("Comparison", DEFAULT_PARSER.parse(f'eq("publishedAt", {x})'))
|
||||
actual = parsed.value
|
||||
assert actual == expected, f"Expected {expected}, got {actual}"
|
||||
except ValueError as e:
|
||||
|
@ -122,7 +122,7 @@ class GenericFakeChatModel(BaseChatModel):
|
||||
# Use a regular expression to split on whitespace with a capture group
|
||||
# so that we can preserve the whitespace in the output.
|
||||
assert isinstance(content, str)
|
||||
content_chunks = cast(list[str], re.split(r"(\s)", content))
|
||||
content_chunks = cast("list[str]", re.split(r"(\s)", content))
|
||||
|
||||
for token in content_chunks:
|
||||
chunk = ChatGenerationChunk(
|
||||
@ -140,7 +140,7 @@ class GenericFakeChatModel(BaseChatModel):
|
||||
for fkey, fvalue in value.items():
|
||||
if isinstance(fvalue, str):
|
||||
# Break function call by `,`
|
||||
fvalue_chunks = cast(list[str], re.split(r"(,)", fvalue))
|
||||
fvalue_chunks = cast("list[str]", re.split(r"(,)", fvalue))
|
||||
for fvalue_chunk in fvalue_chunks:
|
||||
chunk = ChatGenerationChunk(
|
||||
message=AIMessageChunk(
|
||||
|
@ -53,7 +53,7 @@ class FakeLLM(LLM):
|
||||
|
||||
@property
|
||||
def _get_next_response_in_sequence(self) -> str:
|
||||
queries = cast(Mapping, self.queries)
|
||||
queries = cast("Mapping", self.queries)
|
||||
response = queries[list(queries.keys())[self.response_index]]
|
||||
self.response_index = self.response_index + 1
|
||||
return response
|
||||
|
@ -22,7 +22,7 @@ def test_create_lc_store(file_store: LocalFileStore) -> None:
|
||||
"""Test that a docstore is created from a base store."""
|
||||
docstore = create_lc_store(file_store)
|
||||
docstore.mset([("key1", Document(page_content="hello", metadata={"key": "value"}))])
|
||||
fetched_doc = cast(Document, docstore.mget(["key1"])[0])
|
||||
fetched_doc = cast("Document", docstore.mget(["key1"])[0])
|
||||
assert fetched_doc.page_content == "hello"
|
||||
assert fetched_doc.metadata == {"key": "value"}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user