chore(langchain): add ruff rules TC (#31921)

See https://docs.astral.sh/ruff/rules/#flake8-type-checking-tc
This commit is contained in:
Christophe Bornet 2025-07-27 00:27:26 +02:00 committed by GitHub
parent 5ecbb5f277
commit a2ad5aca41
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 54 additions and 53 deletions

View File

@ -1185,7 +1185,7 @@ class AgentExecutor(Chain):
to reflect the changes made in the root_validator. to reflect the changes made in the root_validator.
""" """
if isinstance(self.agent, Runnable): if isinstance(self.agent, Runnable):
return cast(RunnableAgentType, self.agent) return cast("RunnableAgentType", self.agent)
return self.agent return self.agent
def save(self, file_path: Union[Path, str]) -> None: def save(self, file_path: Union[Path, str]) -> None:

View File

@ -73,7 +73,7 @@ class AsyncIteratorCallbackHandler(AsyncCallbackHandler):
other.pop().cancel() other.pop().cancel()
# Extract the value of the first completed task # Extract the value of the first completed task
token_or_done = cast(Union[str, Literal[True]], done.pop().result()) token_or_done = cast("Union[str, Literal[True]]", done.pop().result())
# If the extracted value is the boolean True, the done event was set # If the extracted value is the boolean True, the done event was set
if token_or_done is True: if token_or_done is True:

View File

@ -411,7 +411,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
return self.invoke( return self.invoke(
inputs, inputs,
cast(RunnableConfig, {k: v for k, v in config.items() if v is not None}), cast("RunnableConfig", {k: v for k, v in config.items() if v is not None}),
return_only_outputs=return_only_outputs, return_only_outputs=return_only_outputs,
include_run_info=include_run_info, include_run_info=include_run_info,
) )
@ -461,7 +461,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
} }
return await self.ainvoke( return await self.ainvoke(
inputs, inputs,
cast(RunnableConfig, {k: v for k, v in config.items() if k is not None}), cast("RunnableConfig", {k: v for k, v in config.items() if k is not None}),
return_only_outputs=return_only_outputs, return_only_outputs=return_only_outputs,
include_run_info=include_run_info, include_run_info=include_run_info,
) )

View File

@ -229,7 +229,7 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
docs: list[Document], docs: list[Document],
results: Sequence[Union[str, list[str], dict[str, str]]], results: Sequence[Union[str, list[str], dict[str, str]]],
) -> tuple[str, dict]: ) -> tuple[str, dict]:
typed_results = cast(list[dict], results) typed_results = cast("list[dict]", results)
sorted_res = sorted( sorted_res = sorted(
zip(typed_results, docs), zip(typed_results, docs),
key=lambda x: -int(x[0][self.rank_key]), key=lambda x: -int(x[0][self.rank_key]),

View File

@ -145,7 +145,7 @@ class LLMChain(Chain):
**self.llm_kwargs, **self.llm_kwargs,
) )
results = self.llm.bind(stop=stop, **self.llm_kwargs).batch( results = self.llm.bind(stop=stop, **self.llm_kwargs).batch(
cast(list, prompts), cast("list", prompts),
{"callbacks": callbacks}, {"callbacks": callbacks},
) )
generations: list[list[Generation]] = [] generations: list[list[Generation]] = []
@ -172,7 +172,7 @@ class LLMChain(Chain):
**self.llm_kwargs, **self.llm_kwargs,
) )
results = await self.llm.bind(stop=stop, **self.llm_kwargs).abatch( results = await self.llm.bind(stop=stop, **self.llm_kwargs).abatch(
cast(list, prompts), cast("list", prompts),
{"callbacks": callbacks}, {"callbacks": callbacks},
) )
generations: list[list[Generation]] = [] generations: list[list[Generation]] = []

View File

@ -76,11 +76,11 @@ def create_qa_with_structure_chain(
raise ValueError(msg) raise ValueError(msg)
if isinstance(schema, type) and is_basemodel_subclass(schema): if isinstance(schema, type) and is_basemodel_subclass(schema):
if hasattr(schema, "model_json_schema"): if hasattr(schema, "model_json_schema"):
schema_dict = cast(dict, schema.model_json_schema()) schema_dict = cast("dict", schema.model_json_schema())
else: else:
schema_dict = cast(dict, schema.schema()) schema_dict = cast("dict", schema.schema())
else: else:
schema_dict = cast(dict, schema) schema_dict = cast("dict", schema)
function = { function = {
"name": schema_dict["title"], "name": schema_dict["title"],
"description": schema_dict["description"], "description": schema_dict["description"],

View File

@ -91,7 +91,7 @@ class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
def ast_parse(raw_filter: str) -> Optional[FilterDirective]: def ast_parse(raw_filter: str) -> Optional[FilterDirective]:
filter_directive = cast( filter_directive = cast(
Optional[FilterDirective], "Optional[FilterDirective]",
get_parser().parse(raw_filter), get_parser().parse(raw_filter),
) )
return fix_filter_directive( return fix_filter_directive(
@ -144,7 +144,7 @@ def fix_filter_directive(
return None return None
args = [ args = [
cast( cast(
FilterDirective, "FilterDirective",
fix_filter_directive( fix_filter_directive(
arg, arg,
allowed_comparators=allowed_comparators, allowed_comparators=allowed_comparators,

View File

@ -137,7 +137,7 @@ class LLMRouterChain(RouterChain):
prediction = self.llm_chain.predict(callbacks=callbacks, **inputs) prediction = self.llm_chain.predict(callbacks=callbacks, **inputs)
return cast( return cast(
dict[str, Any], "dict[str, Any]",
self.llm_chain.prompt.output_parser.parse(prediction), self.llm_chain.prompt.output_parser.parse(prediction),
) )
@ -149,7 +149,7 @@ class LLMRouterChain(RouterChain):
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child() callbacks = _run_manager.get_child()
return cast( return cast(
dict[str, Any], "dict[str, Any]",
await self.llm_chain.apredict_and_parse(callbacks=callbacks, **inputs), await self.llm_chain.apredict_and_parse(callbacks=callbacks, **inputs),
) )

View File

@ -322,7 +322,7 @@ def init_chat_model(
if not configurable_fields: if not configurable_fields:
return _init_chat_model_helper( return _init_chat_model_helper(
cast(str, model), cast("str", model),
model_provider=model_provider, model_provider=model_provider,
**kwargs, **kwargs,
) )
@ -632,7 +632,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
**kwargs: Any, **kwargs: Any,
) -> _ConfigurableModel: ) -> _ConfigurableModel:
"""Bind config to a Runnable, returning a new Runnable.""" """Bind config to a Runnable, returning a new Runnable."""
config = RunnableConfig(**(config or {}), **cast(RunnableConfig, kwargs)) config = RunnableConfig(**(config or {}), **cast("RunnableConfig", kwargs))
model_params = self._model_params(config) model_params = self._model_params(config)
remaining_config = {k: v for k, v in config.items() if k != "configurable"} remaining_config = {k: v for k, v in config.items() if k != "configurable"}
remaining_config["configurable"] = { remaining_config["configurable"] = {
@ -781,7 +781,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
if config is None or isinstance(config, dict) or len(config) <= 1: if config is None or isinstance(config, dict) or len(config) <= 1:
if isinstance(config, list): if isinstance(config, list):
config = config[0] config = config[0]
yield from self._model(cast(RunnableConfig, config)).batch_as_completed( # type: ignore[call-overload] yield from self._model(cast("RunnableConfig", config)).batch_as_completed( # type: ignore[call-overload]
inputs, inputs,
config=config, config=config,
return_exceptions=return_exceptions, return_exceptions=return_exceptions,
@ -811,7 +811,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
if isinstance(config, list): if isinstance(config, list):
config = config[0] config = config[0]
async for x in self._model( async for x in self._model(
cast(RunnableConfig, config), cast("RunnableConfig", config),
).abatch_as_completed( # type: ignore[call-overload] ).abatch_as_completed( # type: ignore[call-overload]
inputs, inputs,
config=config, config=config,

View File

@ -1,12 +1,9 @@
from typing import TYPE_CHECKING, Any from typing import Any
from langchain_core.document_loaders import Blob, BlobLoader from langchain_core.document_loaders import Blob, BlobLoader
from langchain._api import create_importer from langchain._api import create_importer
if TYPE_CHECKING:
pass
# Create a way to dynamically look up deprecated imports. # Create a way to dynamically look up deprecated imports.
# Used to consolidate logic for raising deprecation warnings and # Used to consolidate logic for raising deprecation warnings and
# handling optional imports. # handling optional imports.

View File

@ -80,7 +80,7 @@ def _value_serializer(value: Sequence[float]) -> bytes:
def _value_deserializer(serialized_value: bytes) -> list[float]: def _value_deserializer(serialized_value: bytes) -> list[float]:
"""Deserialize a value.""" """Deserialize a value."""
return cast(list[float], json.loads(serialized_value.decode())) return cast("list[float]", json.loads(serialized_value.decode()))
# The warning is global; track emission, so it appears only once. # The warning is global; track emission, so it appears only once.
@ -192,7 +192,7 @@ class CacheBackedEmbeddings(Embeddings):
vectors[index] = updated_vector vectors[index] = updated_vector
return cast( return cast(
list[list[float]], "list[list[float]]",
vectors, vectors,
) # Nones should have been resolved by now ) # Nones should have been resolved by now
@ -230,7 +230,7 @@ class CacheBackedEmbeddings(Embeddings):
vectors[index] = updated_vector vectors[index] = updated_vector
return cast( return cast(
list[list[float]], "list[list[float]]",
vectors, vectors,
) # Nones should have been resolved by now ) # Nones should have been resolved by now

View File

@ -301,7 +301,7 @@ The following is the expected answer. Use this to measure correctness:
chain_input, chain_input,
callbacks=_run_manager.get_child(), callbacks=_run_manager.get_child(),
) )
return cast(dict, self.output_parser.parse(raw_output)) return cast("dict", self.output_parser.parse(raw_output))
async def _acall( async def _acall(
self, self,
@ -326,7 +326,7 @@ The following is the expected answer. Use this to measure correctness:
chain_input, chain_input,
callbacks=_run_manager.get_child(), callbacks=_run_manager.get_child(),
) )
return cast(dict, self.output_parser.parse(raw_output)) return cast("dict", self.output_parser.parse(raw_output))
@override @override
def _evaluate_agent_trajectory( def _evaluate_agent_trajectory(

View File

@ -166,7 +166,7 @@ class JsonEqualityEvaluator(StringEvaluator):
dict: A dictionary containing the evaluation score. dict: A dictionary containing the evaluation score.
""" """
parsed = self._parse_json(prediction) parsed = self._parse_json(prediction)
label = self._parse_json(cast(str, reference)) label = self._parse_json(cast("str", reference))
if isinstance(label, list): if isinstance(label, list):
if not isinstance(parsed, list): if not isinstance(parsed, list):
return {"score": 0} return {"score": 0}

View File

@ -82,7 +82,7 @@ class LLMChainExtractor(BaseDocumentCompressor):
if len(output) == 0: if len(output) == 0:
continue continue
compressed_docs.append( compressed_docs.append(
Document(page_content=cast(str, output), metadata=doc.metadata), Document(page_content=cast("str", output), metadata=doc.metadata),
) )
return compressed_docs return compressed_docs

View File

@ -236,7 +236,7 @@ class EnsembleRetriever(BaseRetriever):
# Enforce that retrieved docs are Documents for each list in retriever_docs # Enforce that retrieved docs are Documents for each list in retriever_docs
for i in range(len(retriever_docs)): for i in range(len(retriever_docs)):
retriever_docs[i] = [ retriever_docs[i] = [
Document(page_content=cast(str, doc)) if isinstance(doc, str) else doc Document(page_content=cast("str", doc)) if isinstance(doc, str) else doc
for doc in retriever_docs[i] for doc in retriever_docs[i]
] ]

View File

@ -214,24 +214,24 @@ def _wrap_in_chain_factory(
return lambda: lcf return lambda: lcf
if callable(llm_or_chain_factory): if callable(llm_or_chain_factory):
if is_traceable_function(llm_or_chain_factory): if is_traceable_function(llm_or_chain_factory):
runnable_ = as_runnable(cast(Callable, llm_or_chain_factory)) runnable_ = as_runnable(cast("Callable", llm_or_chain_factory))
return lambda: runnable_ return lambda: runnable_
try: try:
_model = llm_or_chain_factory() # type: ignore[call-arg] _model = llm_or_chain_factory() # type: ignore[call-arg]
except TypeError: except TypeError:
# It's an arbitrary function, wrap it in a RunnableLambda # It's an arbitrary function, wrap it in a RunnableLambda
user_func = cast(Callable, llm_or_chain_factory) user_func = cast("Callable", llm_or_chain_factory)
sig = inspect.signature(user_func) sig = inspect.signature(user_func)
logger.info("Wrapping function %s as RunnableLambda.", sig) logger.info("Wrapping function %s as RunnableLambda.", sig)
wrapped = RunnableLambda(user_func) wrapped = RunnableLambda(user_func)
return lambda: wrapped return lambda: wrapped
constructor = cast(Callable, llm_or_chain_factory) constructor = cast("Callable", llm_or_chain_factory)
if isinstance(_model, BaseLanguageModel): if isinstance(_model, BaseLanguageModel):
# It's not uncommon to do an LLM constructor instead of raw LLM, # It's not uncommon to do an LLM constructor instead of raw LLM,
# so we'll unpack it for the user. # so we'll unpack it for the user.
return _model return _model
if is_traceable_function(cast(Callable, _model)): if is_traceable_function(cast("Callable", _model)):
runnable_ = as_runnable(cast(Callable, _model)) runnable_ = as_runnable(cast("Callable", _model))
return lambda: runnable_ return lambda: runnable_
if not isinstance(_model, Runnable): if not isinstance(_model, Runnable):
# This is unlikely to happen - a constructor for a model function # This is unlikely to happen - a constructor for a model function
@ -1104,7 +1104,7 @@ class _DatasetRunContainer:
) -> dict: ) -> dict:
results: dict = {} results: dict = {}
for example, output in zip(self.examples, batch_results): for example, output in zip(self.examples, batch_results):
row_result = cast(_RowResult, all_eval_results.get(str(example.id), {})) row_result = cast("_RowResult", all_eval_results.get(str(example.id), {}))
results[str(example.id)] = { results[str(example.id)] = {
"input": example.inputs, "input": example.inputs,
"feedback": row_result.get("feedback", []), "feedback": row_result.get("feedback", []),
@ -1131,7 +1131,7 @@ class _DatasetRunContainer:
result = evaluator(runs_list, self.examples) result = evaluator(runs_list, self.examples)
if isinstance(result, EvaluationResult): if isinstance(result, EvaluationResult):
result = result.dict() result = result.dict()
aggregate_feedback.append(cast(dict, result)) aggregate_feedback.append(cast("dict", result))
executor.submit( executor.submit(
self.client.create_feedback, self.client.create_feedback,
**result, **result,
@ -1148,7 +1148,7 @@ class _DatasetRunContainer:
all_eval_results: dict = {} all_eval_results: dict = {}
all_runs: dict = {} all_runs: dict = {}
for c in self.configs: for c in self.configs:
for callback in cast(list, c["callbacks"]): for callback in cast("list", c["callbacks"]):
if isinstance(callback, EvaluatorCallbackHandler): if isinstance(callback, EvaluatorCallbackHandler):
eval_results = callback.logged_eval_results eval_results = callback.logged_eval_results
for (_, example_id), v in eval_results.items(): for (_, example_id), v in eval_results.items():
@ -1171,7 +1171,7 @@ class _DatasetRunContainer:
}, },
) )
all_runs[str(callback.example_id)] = run all_runs[str(callback.example_id)] = run
return cast(dict[str, _RowResult], all_eval_results), all_runs return cast("dict[str, _RowResult]", all_eval_results), all_runs
def _collect_test_results( def _collect_test_results(
self, self,

View File

@ -145,8 +145,8 @@ ignore-words-list = "momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogy
[tool.ruff.lint] [tool.ruff.lint]
select = [ select = [
"A", # flake8-builtins "A", # flake8-builtins
"B", # flake8-bugbear
"ASYNC", # flake8-async "ASYNC", # flake8-async
"B", # flake8-bugbear
"C4", # flake8-comprehensions "C4", # flake8-comprehensions
"COM", # flake8-commas "COM", # flake8-commas
"D1", # pydocstyle: missing docstring "D1", # pydocstyle: missing docstring
@ -174,11 +174,12 @@ select = [
"RSE", # flake8-rst-docstrings "RSE", # flake8-rst-docstrings
"RUF", # ruff "RUF", # ruff
"S", # flake8-bandit "S", # flake8-bandit
"SLOT", # flake8-slots
"SIM", # flake8-simplify "SIM", # flake8-simplify
"SLF", # flake8-self "SLF", # flake8-self
"SLOT", # flake8-slots
"T10", # flake8-debugger "T10", # flake8-debugger
"T20", # flake8-print "T20", # flake8-print
"TC", # flake8-type-checking
"TID", # flake8-tidy-imports "TID", # flake8-tidy-imports
"TRY", # tryceratops "TRY", # tryceratops
"UP", # pyupgrade "UP", # pyupgrade
@ -192,6 +193,9 @@ ignore = [
"PLR09", # Too many something (args, statements, etc) "PLR09", # Too many something (args, statements, etc)
"S112", # Rarely useful "S112", # Rarely useful
"RUF012", # Doesn't play well with Pydantic "RUF012", # Doesn't play well with Pydantic
"TC001", # Doesn't play well with Pydantic
"TC002", # Doesn't play well with Pydantic
"TC003", # Doesn't play well with Pydantic
"UP007", # pyupgrade: non-pep604-annotation-union "UP007", # pyupgrade: non-pep604-annotation-union
# TODO rules # TODO rules

View File

@ -40,7 +40,7 @@ async def test_init_chat_model_chain() -> None:
class TestStandard(ChatModelIntegrationTests): class TestStandard(ChatModelIntegrationTests):
@property @property
def chat_model_class(self) -> type[BaseChatModel]: def chat_model_class(self) -> type[BaseChatModel]:
return cast(type[BaseChatModel], init_chat_model) return cast("type[BaseChatModel]", init_chat_model)
@property @property
def chat_model_params(self) -> dict: def chat_model_params(self) -> dict:

View File

@ -571,7 +571,7 @@ async def test_runnable_agent_with_function_calls() -> None:
def fake_parse(inputs: dict) -> Union[AgentFinish, AgentAction]: def fake_parse(inputs: dict) -> Union[AgentFinish, AgentAction]:
"""A parser.""" """A parser."""
return cast(Union[AgentFinish, AgentAction], next(parser_responses)) return cast("Union[AgentFinish, AgentAction]", next(parser_responses))
@tool @tool
def find_pet(pet: str) -> str: def find_pet(pet: str) -> str:
@ -683,7 +683,7 @@ async def test_runnable_with_multi_action_per_step() -> None:
def fake_parse(inputs: dict) -> Union[AgentFinish, AgentAction]: def fake_parse(inputs: dict) -> Union[AgentFinish, AgentAction]:
"""A parser.""" """A parser."""
return cast(Union[AgentFinish, AgentAction], next(parser_responses)) return cast("Union[AgentFinish, AgentAction]", next(parser_responses))
@tool @tool
def find_pet(pet: str) -> str: def find_pet(pet: str) -> str:

View File

@ -82,7 +82,7 @@ def test_parse_disallowed_operator() -> None:
def _test_parse_value(x: Any) -> None: def _test_parse_value(x: Any) -> None:
parsed = cast(Comparison, (DEFAULT_PARSER.parse(f'eq("x", {x})'))) parsed = cast("Comparison", (DEFAULT_PARSER.parse(f'eq("x", {x})')))
actual = parsed.value actual = parsed.value
assert actual == x assert actual == x
@ -104,14 +104,14 @@ def test_parse_list_value(x: list) -> None:
@pytest.mark.parametrize("x", ['""', '" "', '"foo"', "'foo'"]) @pytest.mark.parametrize("x", ['""', '" "', '"foo"', "'foo'"])
def test_parse_string_value(x: str) -> None: def test_parse_string_value(x: str) -> None:
parsed = cast(Comparison, DEFAULT_PARSER.parse(f'eq("x", {x})')) parsed = cast("Comparison", DEFAULT_PARSER.parse(f'eq("x", {x})'))
actual = parsed.value actual = parsed.value
assert actual == x[1:-1] assert actual == x[1:-1]
@pytest.mark.parametrize("x", ["true", "True", "TRUE", "false", "False", "FALSE"]) @pytest.mark.parametrize("x", ["true", "True", "TRUE", "false", "False", "FALSE"])
def test_parse_bool_value(x: str) -> None: def test_parse_bool_value(x: str) -> None:
parsed = cast(Comparison, DEFAULT_PARSER.parse(f'eq("x", {x})')) parsed = cast("Comparison", DEFAULT_PARSER.parse(f'eq("x", {x})'))
actual = parsed.value actual = parsed.value
expected = x.lower() == "true" expected = x.lower() == "true"
assert actual == expected assert actual == expected
@ -127,7 +127,7 @@ def test_parser_unpack_single_arg_operation(op: str, arg: str) -> None:
@pytest.mark.parametrize("x", ['"2022-10-20"', "'2022-10-20'", "2022-10-20"]) @pytest.mark.parametrize("x", ['"2022-10-20"', "'2022-10-20'", "2022-10-20"])
def test_parse_date_value(x: str) -> None: def test_parse_date_value(x: str) -> None:
parsed = cast(Comparison, DEFAULT_PARSER.parse(f'eq("x", {x})')) parsed = cast("Comparison", DEFAULT_PARSER.parse(f'eq("x", {x})'))
actual = parsed.value["date"] actual = parsed.value["date"]
assert actual == x.strip("'\"") assert actual == x.strip("'\"")
@ -152,7 +152,7 @@ def test_parse_date_value(x: str) -> None:
def test_parse_datetime_value(x: str, expected: dict) -> None: def test_parse_datetime_value(x: str, expected: dict) -> None:
"""Test parsing of datetime values with ISO 8601 format.""" """Test parsing of datetime values with ISO 8601 format."""
try: try:
parsed = cast(Comparison, DEFAULT_PARSER.parse(f'eq("publishedAt", {x})')) parsed = cast("Comparison", DEFAULT_PARSER.parse(f'eq("publishedAt", {x})'))
actual = parsed.value actual = parsed.value
assert actual == expected, f"Expected {expected}, got {actual}" assert actual == expected, f"Expected {expected}, got {actual}"
except ValueError as e: except ValueError as e:

View File

@ -122,7 +122,7 @@ class GenericFakeChatModel(BaseChatModel):
# Use a regular expression to split on whitespace with a capture group # Use a regular expression to split on whitespace with a capture group
# so that we can preserve the whitespace in the output. # so that we can preserve the whitespace in the output.
assert isinstance(content, str) assert isinstance(content, str)
content_chunks = cast(list[str], re.split(r"(\s)", content)) content_chunks = cast("list[str]", re.split(r"(\s)", content))
for token in content_chunks: for token in content_chunks:
chunk = ChatGenerationChunk( chunk = ChatGenerationChunk(
@ -140,7 +140,7 @@ class GenericFakeChatModel(BaseChatModel):
for fkey, fvalue in value.items(): for fkey, fvalue in value.items():
if isinstance(fvalue, str): if isinstance(fvalue, str):
# Break function call by `,` # Break function call by `,`
fvalue_chunks = cast(list[str], re.split(r"(,)", fvalue)) fvalue_chunks = cast("list[str]", re.split(r"(,)", fvalue))
for fvalue_chunk in fvalue_chunks: for fvalue_chunk in fvalue_chunks:
chunk = ChatGenerationChunk( chunk = ChatGenerationChunk(
message=AIMessageChunk( message=AIMessageChunk(

View File

@ -53,7 +53,7 @@ class FakeLLM(LLM):
@property @property
def _get_next_response_in_sequence(self) -> str: def _get_next_response_in_sequence(self) -> str:
queries = cast(Mapping, self.queries) queries = cast("Mapping", self.queries)
response = queries[list(queries.keys())[self.response_index]] response = queries[list(queries.keys())[self.response_index]]
self.response_index = self.response_index + 1 self.response_index = self.response_index + 1
return response return response

View File

@ -22,7 +22,7 @@ def test_create_lc_store(file_store: LocalFileStore) -> None:
"""Test that a docstore is created from a base store.""" """Test that a docstore is created from a base store."""
docstore = create_lc_store(file_store) docstore = create_lc_store(file_store)
docstore.mset([("key1", Document(page_content="hello", metadata={"key": "value"}))]) docstore.mset([("key1", Document(page_content="hello", metadata={"key": "value"}))])
fetched_doc = cast(Document, docstore.mget(["key1"])[0]) fetched_doc = cast("Document", docstore.mget(["key1"])[0])
assert fetched_doc.page_content == "hello" assert fetched_doc.page_content == "hello"
assert fetched_doc.metadata == {"key": "value"} assert fetched_doc.metadata == {"key": "value"}