From a2ad5aca411a0b1cb8e9fb8ba15013060b4edb38 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Sun, 27 Jul 2025 00:27:26 +0200 Subject: [PATCH] chore(langchain): add ruff rules TC (#31921) See https://docs.astral.sh/ruff/rules/#flake8-type-checking-tc --- libs/langchain/langchain/agents/agent.py | 2 +- .../langchain/callbacks/streaming_aiter.py | 2 +- libs/langchain/langchain/chains/base.py | 4 ++-- .../chains/combine_documents/map_rerank.py | 2 +- libs/langchain/langchain/chains/llm.py | 4 ++-- .../openai_functions/qa_with_structure.py | 6 +++--- .../langchain/chains/query_constructor/base.py | 4 ++-- .../langchain/chains/router/llm_router.py | 4 ++-- libs/langchain/langchain/chat_models/base.py | 8 ++++---- .../document_loaders/blob_loaders/schema.py | 5 +---- libs/langchain/langchain/embeddings/cache.py | 6 +++--- .../evaluation/agents/trajectory_eval_chain.py | 4 ++-- .../langchain/evaluation/parsing/base.py | 2 +- .../document_compressors/chain_extract.py | 2 +- .../langchain/langchain/retrievers/ensemble.py | 2 +- .../langchain/smith/evaluation/runner_utils.py | 18 +++++++++--------- libs/langchain/pyproject.toml | 8 ++++++-- .../integration_tests/chat_models/test_base.py | 2 +- .../tests/unit_tests/agents/test_agent.py | 4 ++-- .../chains/query_constructor/test_parser.py | 10 +++++----- .../tests/unit_tests/llms/fake_chat_model.py | 4 ++-- .../tests/unit_tests/llms/fake_llm.py | 2 +- .../tests/unit_tests/storage/test_lc_store.py | 2 +- 23 files changed, 54 insertions(+), 53 deletions(-) diff --git a/libs/langchain/langchain/agents/agent.py b/libs/langchain/langchain/agents/agent.py index 323afb33427..dab08813249 100644 --- a/libs/langchain/langchain/agents/agent.py +++ b/libs/langchain/langchain/agents/agent.py @@ -1185,7 +1185,7 @@ class AgentExecutor(Chain): to reflect the changes made in the root_validator. """ if isinstance(self.agent, Runnable): - return cast(RunnableAgentType, self.agent) + return cast("RunnableAgentType", self.agent) return self.agent def save(self, file_path: Union[Path, str]) -> None: diff --git a/libs/langchain/langchain/callbacks/streaming_aiter.py b/libs/langchain/langchain/callbacks/streaming_aiter.py index f00e4f3e094..96cf78fd83a 100644 --- a/libs/langchain/langchain/callbacks/streaming_aiter.py +++ b/libs/langchain/langchain/callbacks/streaming_aiter.py @@ -73,7 +73,7 @@ class AsyncIteratorCallbackHandler(AsyncCallbackHandler): other.pop().cancel() # Extract the value of the first completed task - token_or_done = cast(Union[str, Literal[True]], done.pop().result()) + token_or_done = cast("Union[str, Literal[True]]", done.pop().result()) # If the extracted value is the boolean True, the done event was set if token_or_done is True: diff --git a/libs/langchain/langchain/chains/base.py b/libs/langchain/langchain/chains/base.py index 11d7b7fb979..6d187bc43d9 100644 --- a/libs/langchain/langchain/chains/base.py +++ b/libs/langchain/langchain/chains/base.py @@ -411,7 +411,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC): return self.invoke( inputs, - cast(RunnableConfig, {k: v for k, v in config.items() if v is not None}), + cast("RunnableConfig", {k: v for k, v in config.items() if v is not None}), return_only_outputs=return_only_outputs, include_run_info=include_run_info, ) @@ -461,7 +461,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC): } return await self.ainvoke( inputs, - cast(RunnableConfig, {k: v for k, v in config.items() if k is not None}), + cast("RunnableConfig", {k: v for k, v in config.items() if k is not None}), return_only_outputs=return_only_outputs, include_run_info=include_run_info, ) diff --git a/libs/langchain/langchain/chains/combine_documents/map_rerank.py b/libs/langchain/langchain/chains/combine_documents/map_rerank.py index e2ed619a355..406bbebb3bb 100644 --- a/libs/langchain/langchain/chains/combine_documents/map_rerank.py +++ b/libs/langchain/langchain/chains/combine_documents/map_rerank.py @@ -229,7 +229,7 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain): docs: list[Document], results: Sequence[Union[str, list[str], dict[str, str]]], ) -> tuple[str, dict]: - typed_results = cast(list[dict], results) + typed_results = cast("list[dict]", results) sorted_res = sorted( zip(typed_results, docs), key=lambda x: -int(x[0][self.rank_key]), diff --git a/libs/langchain/langchain/chains/llm.py b/libs/langchain/langchain/chains/llm.py index ca43d7a9e84..de03197b91f 100644 --- a/libs/langchain/langchain/chains/llm.py +++ b/libs/langchain/langchain/chains/llm.py @@ -145,7 +145,7 @@ class LLMChain(Chain): **self.llm_kwargs, ) results = self.llm.bind(stop=stop, **self.llm_kwargs).batch( - cast(list, prompts), + cast("list", prompts), {"callbacks": callbacks}, ) generations: list[list[Generation]] = [] @@ -172,7 +172,7 @@ class LLMChain(Chain): **self.llm_kwargs, ) results = await self.llm.bind(stop=stop, **self.llm_kwargs).abatch( - cast(list, prompts), + cast("list", prompts), {"callbacks": callbacks}, ) generations: list[list[Generation]] = [] diff --git a/libs/langchain/langchain/chains/openai_functions/qa_with_structure.py b/libs/langchain/langchain/chains/openai_functions/qa_with_structure.py index 1c4f698f23a..b704eed49da 100644 --- a/libs/langchain/langchain/chains/openai_functions/qa_with_structure.py +++ b/libs/langchain/langchain/chains/openai_functions/qa_with_structure.py @@ -76,11 +76,11 @@ def create_qa_with_structure_chain( raise ValueError(msg) if isinstance(schema, type) and is_basemodel_subclass(schema): if hasattr(schema, "model_json_schema"): - schema_dict = cast(dict, schema.model_json_schema()) + schema_dict = cast("dict", schema.model_json_schema()) else: - schema_dict = cast(dict, schema.schema()) + schema_dict = cast("dict", schema.schema()) else: - schema_dict = cast(dict, schema) + schema_dict = cast("dict", schema) function = { "name": schema_dict["title"], "description": schema_dict["description"], diff --git a/libs/langchain/langchain/chains/query_constructor/base.py b/libs/langchain/langchain/chains/query_constructor/base.py index 7d27376fc13..b684953d83b 100644 --- a/libs/langchain/langchain/chains/query_constructor/base.py +++ b/libs/langchain/langchain/chains/query_constructor/base.py @@ -91,7 +91,7 @@ class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]): def ast_parse(raw_filter: str) -> Optional[FilterDirective]: filter_directive = cast( - Optional[FilterDirective], + "Optional[FilterDirective]", get_parser().parse(raw_filter), ) return fix_filter_directive( @@ -144,7 +144,7 @@ def fix_filter_directive( return None args = [ cast( - FilterDirective, + "FilterDirective", fix_filter_directive( arg, allowed_comparators=allowed_comparators, diff --git a/libs/langchain/langchain/chains/router/llm_router.py b/libs/langchain/langchain/chains/router/llm_router.py index fe9a2554c0a..1abbdbe79a3 100644 --- a/libs/langchain/langchain/chains/router/llm_router.py +++ b/libs/langchain/langchain/chains/router/llm_router.py @@ -137,7 +137,7 @@ class LLMRouterChain(RouterChain): prediction = self.llm_chain.predict(callbacks=callbacks, **inputs) return cast( - dict[str, Any], + "dict[str, Any]", self.llm_chain.prompt.output_parser.parse(prediction), ) @@ -149,7 +149,7 @@ class LLMRouterChain(RouterChain): _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() callbacks = _run_manager.get_child() return cast( - dict[str, Any], + "dict[str, Any]", await self.llm_chain.apredict_and_parse(callbacks=callbacks, **inputs), ) diff --git a/libs/langchain/langchain/chat_models/base.py b/libs/langchain/langchain/chat_models/base.py index 01e51c5b845..dda881bf1f4 100644 --- a/libs/langchain/langchain/chat_models/base.py +++ b/libs/langchain/langchain/chat_models/base.py @@ -322,7 +322,7 @@ def init_chat_model( if not configurable_fields: return _init_chat_model_helper( - cast(str, model), + cast("str", model), model_provider=model_provider, **kwargs, ) @@ -632,7 +632,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]): **kwargs: Any, ) -> _ConfigurableModel: """Bind config to a Runnable, returning a new Runnable.""" - config = RunnableConfig(**(config or {}), **cast(RunnableConfig, kwargs)) + config = RunnableConfig(**(config or {}), **cast("RunnableConfig", kwargs)) model_params = self._model_params(config) remaining_config = {k: v for k, v in config.items() if k != "configurable"} remaining_config["configurable"] = { @@ -781,7 +781,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]): if config is None or isinstance(config, dict) or len(config) <= 1: if isinstance(config, list): config = config[0] - yield from self._model(cast(RunnableConfig, config)).batch_as_completed( # type: ignore[call-overload] + yield from self._model(cast("RunnableConfig", config)).batch_as_completed( # type: ignore[call-overload] inputs, config=config, return_exceptions=return_exceptions, @@ -811,7 +811,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]): if isinstance(config, list): config = config[0] async for x in self._model( - cast(RunnableConfig, config), + cast("RunnableConfig", config), ).abatch_as_completed( # type: ignore[call-overload] inputs, config=config, diff --git a/libs/langchain/langchain/document_loaders/blob_loaders/schema.py b/libs/langchain/langchain/document_loaders/blob_loaders/schema.py index 677b9cfd98d..3d64013f72d 100644 --- a/libs/langchain/langchain/document_loaders/blob_loaders/schema.py +++ b/libs/langchain/langchain/document_loaders/blob_loaders/schema.py @@ -1,12 +1,9 @@ -from typing import TYPE_CHECKING, Any +from typing import Any from langchain_core.document_loaders import Blob, BlobLoader from langchain._api import create_importer -if TYPE_CHECKING: - pass - # Create a way to dynamically look up deprecated imports. # Used to consolidate logic for raising deprecation warnings and # handling optional imports. diff --git a/libs/langchain/langchain/embeddings/cache.py b/libs/langchain/langchain/embeddings/cache.py index fb200c76847..6be687f6196 100644 --- a/libs/langchain/langchain/embeddings/cache.py +++ b/libs/langchain/langchain/embeddings/cache.py @@ -80,7 +80,7 @@ def _value_serializer(value: Sequence[float]) -> bytes: def _value_deserializer(serialized_value: bytes) -> list[float]: """Deserialize a value.""" - return cast(list[float], json.loads(serialized_value.decode())) + return cast("list[float]", json.loads(serialized_value.decode())) # The warning is global; track emission, so it appears only once. @@ -192,7 +192,7 @@ class CacheBackedEmbeddings(Embeddings): vectors[index] = updated_vector return cast( - list[list[float]], + "list[list[float]]", vectors, ) # Nones should have been resolved by now @@ -230,7 +230,7 @@ class CacheBackedEmbeddings(Embeddings): vectors[index] = updated_vector return cast( - list[list[float]], + "list[list[float]]", vectors, ) # Nones should have been resolved by now diff --git a/libs/langchain/langchain/evaluation/agents/trajectory_eval_chain.py b/libs/langchain/langchain/evaluation/agents/trajectory_eval_chain.py index f80ab15b973..3b043f97f2c 100644 --- a/libs/langchain/langchain/evaluation/agents/trajectory_eval_chain.py +++ b/libs/langchain/langchain/evaluation/agents/trajectory_eval_chain.py @@ -301,7 +301,7 @@ The following is the expected answer. Use this to measure correctness: chain_input, callbacks=_run_manager.get_child(), ) - return cast(dict, self.output_parser.parse(raw_output)) + return cast("dict", self.output_parser.parse(raw_output)) async def _acall( self, @@ -326,7 +326,7 @@ The following is the expected answer. Use this to measure correctness: chain_input, callbacks=_run_manager.get_child(), ) - return cast(dict, self.output_parser.parse(raw_output)) + return cast("dict", self.output_parser.parse(raw_output)) @override def _evaluate_agent_trajectory( diff --git a/libs/langchain/langchain/evaluation/parsing/base.py b/libs/langchain/langchain/evaluation/parsing/base.py index aeeda41a053..bcd73477b4f 100644 --- a/libs/langchain/langchain/evaluation/parsing/base.py +++ b/libs/langchain/langchain/evaluation/parsing/base.py @@ -166,7 +166,7 @@ class JsonEqualityEvaluator(StringEvaluator): dict: A dictionary containing the evaluation score. """ parsed = self._parse_json(prediction) - label = self._parse_json(cast(str, reference)) + label = self._parse_json(cast("str", reference)) if isinstance(label, list): if not isinstance(parsed, list): return {"score": 0} diff --git a/libs/langchain/langchain/retrievers/document_compressors/chain_extract.py b/libs/langchain/langchain/retrievers/document_compressors/chain_extract.py index 14fe23cd49e..c769d88bbaf 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/chain_extract.py +++ b/libs/langchain/langchain/retrievers/document_compressors/chain_extract.py @@ -82,7 +82,7 @@ class LLMChainExtractor(BaseDocumentCompressor): if len(output) == 0: continue compressed_docs.append( - Document(page_content=cast(str, output), metadata=doc.metadata), + Document(page_content=cast("str", output), metadata=doc.metadata), ) return compressed_docs diff --git a/libs/langchain/langchain/retrievers/ensemble.py b/libs/langchain/langchain/retrievers/ensemble.py index 8392680284f..6679d1cd4cc 100644 --- a/libs/langchain/langchain/retrievers/ensemble.py +++ b/libs/langchain/langchain/retrievers/ensemble.py @@ -236,7 +236,7 @@ class EnsembleRetriever(BaseRetriever): # Enforce that retrieved docs are Documents for each list in retriever_docs for i in range(len(retriever_docs)): retriever_docs[i] = [ - Document(page_content=cast(str, doc)) if isinstance(doc, str) else doc + Document(page_content=cast("str", doc)) if isinstance(doc, str) else doc for doc in retriever_docs[i] ] diff --git a/libs/langchain/langchain/smith/evaluation/runner_utils.py b/libs/langchain/langchain/smith/evaluation/runner_utils.py index 89096c01794..ffbdcbf9111 100644 --- a/libs/langchain/langchain/smith/evaluation/runner_utils.py +++ b/libs/langchain/langchain/smith/evaluation/runner_utils.py @@ -214,24 +214,24 @@ def _wrap_in_chain_factory( return lambda: lcf if callable(llm_or_chain_factory): if is_traceable_function(llm_or_chain_factory): - runnable_ = as_runnable(cast(Callable, llm_or_chain_factory)) + runnable_ = as_runnable(cast("Callable", llm_or_chain_factory)) return lambda: runnable_ try: _model = llm_or_chain_factory() # type: ignore[call-arg] except TypeError: # It's an arbitrary function, wrap it in a RunnableLambda - user_func = cast(Callable, llm_or_chain_factory) + user_func = cast("Callable", llm_or_chain_factory) sig = inspect.signature(user_func) logger.info("Wrapping function %s as RunnableLambda.", sig) wrapped = RunnableLambda(user_func) return lambda: wrapped - constructor = cast(Callable, llm_or_chain_factory) + constructor = cast("Callable", llm_or_chain_factory) if isinstance(_model, BaseLanguageModel): # It's not uncommon to do an LLM constructor instead of raw LLM, # so we'll unpack it for the user. return _model - if is_traceable_function(cast(Callable, _model)): - runnable_ = as_runnable(cast(Callable, _model)) + if is_traceable_function(cast("Callable", _model)): + runnable_ = as_runnable(cast("Callable", _model)) return lambda: runnable_ if not isinstance(_model, Runnable): # This is unlikely to happen - a constructor for a model function @@ -1104,7 +1104,7 @@ class _DatasetRunContainer: ) -> dict: results: dict = {} for example, output in zip(self.examples, batch_results): - row_result = cast(_RowResult, all_eval_results.get(str(example.id), {})) + row_result = cast("_RowResult", all_eval_results.get(str(example.id), {})) results[str(example.id)] = { "input": example.inputs, "feedback": row_result.get("feedback", []), @@ -1131,7 +1131,7 @@ class _DatasetRunContainer: result = evaluator(runs_list, self.examples) if isinstance(result, EvaluationResult): result = result.dict() - aggregate_feedback.append(cast(dict, result)) + aggregate_feedback.append(cast("dict", result)) executor.submit( self.client.create_feedback, **result, @@ -1148,7 +1148,7 @@ class _DatasetRunContainer: all_eval_results: dict = {} all_runs: dict = {} for c in self.configs: - for callback in cast(list, c["callbacks"]): + for callback in cast("list", c["callbacks"]): if isinstance(callback, EvaluatorCallbackHandler): eval_results = callback.logged_eval_results for (_, example_id), v in eval_results.items(): @@ -1171,7 +1171,7 @@ class _DatasetRunContainer: }, ) all_runs[str(callback.example_id)] = run - return cast(dict[str, _RowResult], all_eval_results), all_runs + return cast("dict[str, _RowResult]", all_eval_results), all_runs def _collect_test_results( self, diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index f558488669d..c48fe751533 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -145,8 +145,8 @@ ignore-words-list = "momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogy [tool.ruff.lint] select = [ "A", # flake8-builtins - "B", # flake8-bugbear "ASYNC", # flake8-async + "B", # flake8-bugbear "C4", # flake8-comprehensions "COM", # flake8-commas "D1", # pydocstyle: missing docstring @@ -174,11 +174,12 @@ select = [ "RSE", # flake8-rst-docstrings "RUF", # ruff "S", # flake8-bandit - "SLOT", # flake8-slots "SIM", # flake8-simplify "SLF", # flake8-self + "SLOT", # flake8-slots "T10", # flake8-debugger "T20", # flake8-print + "TC", # flake8-type-checking "TID", # flake8-tidy-imports "TRY", # tryceratops "UP", # pyupgrade @@ -192,6 +193,9 @@ ignore = [ "PLR09", # Too many something (args, statements, etc) "S112", # Rarely useful "RUF012", # Doesn't play well with Pydantic + "TC001", # Doesn't play well with Pydantic + "TC002", # Doesn't play well with Pydantic + "TC003", # Doesn't play well with Pydantic "UP007", # pyupgrade: non-pep604-annotation-union # TODO rules diff --git a/libs/langchain/tests/integration_tests/chat_models/test_base.py b/libs/langchain/tests/integration_tests/chat_models/test_base.py index 1c915f7f588..4d87b98dfaf 100644 --- a/libs/langchain/tests/integration_tests/chat_models/test_base.py +++ b/libs/langchain/tests/integration_tests/chat_models/test_base.py @@ -40,7 +40,7 @@ async def test_init_chat_model_chain() -> None: class TestStandard(ChatModelIntegrationTests): @property def chat_model_class(self) -> type[BaseChatModel]: - return cast(type[BaseChatModel], init_chat_model) + return cast("type[BaseChatModel]", init_chat_model) @property def chat_model_params(self) -> dict: diff --git a/libs/langchain/tests/unit_tests/agents/test_agent.py b/libs/langchain/tests/unit_tests/agents/test_agent.py index f77125a793d..6442e8e3dd8 100644 --- a/libs/langchain/tests/unit_tests/agents/test_agent.py +++ b/libs/langchain/tests/unit_tests/agents/test_agent.py @@ -571,7 +571,7 @@ async def test_runnable_agent_with_function_calls() -> None: def fake_parse(inputs: dict) -> Union[AgentFinish, AgentAction]: """A parser.""" - return cast(Union[AgentFinish, AgentAction], next(parser_responses)) + return cast("Union[AgentFinish, AgentAction]", next(parser_responses)) @tool def find_pet(pet: str) -> str: @@ -683,7 +683,7 @@ async def test_runnable_with_multi_action_per_step() -> None: def fake_parse(inputs: dict) -> Union[AgentFinish, AgentAction]: """A parser.""" - return cast(Union[AgentFinish, AgentAction], next(parser_responses)) + return cast("Union[AgentFinish, AgentAction]", next(parser_responses)) @tool def find_pet(pet: str) -> str: diff --git a/libs/langchain/tests/unit_tests/chains/query_constructor/test_parser.py b/libs/langchain/tests/unit_tests/chains/query_constructor/test_parser.py index 8090c7922e0..836c759d413 100644 --- a/libs/langchain/tests/unit_tests/chains/query_constructor/test_parser.py +++ b/libs/langchain/tests/unit_tests/chains/query_constructor/test_parser.py @@ -82,7 +82,7 @@ def test_parse_disallowed_operator() -> None: def _test_parse_value(x: Any) -> None: - parsed = cast(Comparison, (DEFAULT_PARSER.parse(f'eq("x", {x})'))) + parsed = cast("Comparison", (DEFAULT_PARSER.parse(f'eq("x", {x})'))) actual = parsed.value assert actual == x @@ -104,14 +104,14 @@ def test_parse_list_value(x: list) -> None: @pytest.mark.parametrize("x", ['""', '" "', '"foo"', "'foo'"]) def test_parse_string_value(x: str) -> None: - parsed = cast(Comparison, DEFAULT_PARSER.parse(f'eq("x", {x})')) + parsed = cast("Comparison", DEFAULT_PARSER.parse(f'eq("x", {x})')) actual = parsed.value assert actual == x[1:-1] @pytest.mark.parametrize("x", ["true", "True", "TRUE", "false", "False", "FALSE"]) def test_parse_bool_value(x: str) -> None: - parsed = cast(Comparison, DEFAULT_PARSER.parse(f'eq("x", {x})')) + parsed = cast("Comparison", DEFAULT_PARSER.parse(f'eq("x", {x})')) actual = parsed.value expected = x.lower() == "true" assert actual == expected @@ -127,7 +127,7 @@ def test_parser_unpack_single_arg_operation(op: str, arg: str) -> None: @pytest.mark.parametrize("x", ['"2022-10-20"', "'2022-10-20'", "2022-10-20"]) def test_parse_date_value(x: str) -> None: - parsed = cast(Comparison, DEFAULT_PARSER.parse(f'eq("x", {x})')) + parsed = cast("Comparison", DEFAULT_PARSER.parse(f'eq("x", {x})')) actual = parsed.value["date"] assert actual == x.strip("'\"") @@ -152,7 +152,7 @@ def test_parse_date_value(x: str) -> None: def test_parse_datetime_value(x: str, expected: dict) -> None: """Test parsing of datetime values with ISO 8601 format.""" try: - parsed = cast(Comparison, DEFAULT_PARSER.parse(f'eq("publishedAt", {x})')) + parsed = cast("Comparison", DEFAULT_PARSER.parse(f'eq("publishedAt", {x})')) actual = parsed.value assert actual == expected, f"Expected {expected}, got {actual}" except ValueError as e: diff --git a/libs/langchain/tests/unit_tests/llms/fake_chat_model.py b/libs/langchain/tests/unit_tests/llms/fake_chat_model.py index 677fbc44684..aa59640829d 100644 --- a/libs/langchain/tests/unit_tests/llms/fake_chat_model.py +++ b/libs/langchain/tests/unit_tests/llms/fake_chat_model.py @@ -122,7 +122,7 @@ class GenericFakeChatModel(BaseChatModel): # Use a regular expression to split on whitespace with a capture group # so that we can preserve the whitespace in the output. assert isinstance(content, str) - content_chunks = cast(list[str], re.split(r"(\s)", content)) + content_chunks = cast("list[str]", re.split(r"(\s)", content)) for token in content_chunks: chunk = ChatGenerationChunk( @@ -140,7 +140,7 @@ class GenericFakeChatModel(BaseChatModel): for fkey, fvalue in value.items(): if isinstance(fvalue, str): # Break function call by `,` - fvalue_chunks = cast(list[str], re.split(r"(,)", fvalue)) + fvalue_chunks = cast("list[str]", re.split(r"(,)", fvalue)) for fvalue_chunk in fvalue_chunks: chunk = ChatGenerationChunk( message=AIMessageChunk( diff --git a/libs/langchain/tests/unit_tests/llms/fake_llm.py b/libs/langchain/tests/unit_tests/llms/fake_llm.py index a558d0faa69..61efe09cc2e 100644 --- a/libs/langchain/tests/unit_tests/llms/fake_llm.py +++ b/libs/langchain/tests/unit_tests/llms/fake_llm.py @@ -53,7 +53,7 @@ class FakeLLM(LLM): @property def _get_next_response_in_sequence(self) -> str: - queries = cast(Mapping, self.queries) + queries = cast("Mapping", self.queries) response = queries[list(queries.keys())[self.response_index]] self.response_index = self.response_index + 1 return response diff --git a/libs/langchain/tests/unit_tests/storage/test_lc_store.py b/libs/langchain/tests/unit_tests/storage/test_lc_store.py index b884f55748c..06a9f977f0a 100644 --- a/libs/langchain/tests/unit_tests/storage/test_lc_store.py +++ b/libs/langchain/tests/unit_tests/storage/test_lc_store.py @@ -22,7 +22,7 @@ def test_create_lc_store(file_store: LocalFileStore) -> None: """Test that a docstore is created from a base store.""" docstore = create_lc_store(file_store) docstore.mset([("key1", Document(page_content="hello", metadata={"key": "value"}))]) - fetched_doc = cast(Document, docstore.mget(["key1"])[0]) + fetched_doc = cast("Document", docstore.mget(["key1"])[0]) assert fetched_doc.page_content == "hello" assert fetched_doc.metadata == {"key": "value"}