chore(langchain): add mypy pydantic plugin (#32610)

This commit is contained in:
Christophe Bornet 2025-08-19 22:59:59 +02:00 committed by GitHub
parent 73a7de63aa
commit f896bcdb1d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 101 additions and 100 deletions

View File

@ -277,7 +277,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
extra_prompt_messages=extra_prompt_messages, extra_prompt_messages=extra_prompt_messages,
system_message=system_message_, system_message=system_message_,
) )
return cls( # type: ignore[call-arg] return cls(
llm=llm, llm=llm,
prompt=prompt, prompt=prompt,
tools=tools, tools=tools,

View File

@ -328,7 +328,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
extra_prompt_messages=extra_prompt_messages, extra_prompt_messages=extra_prompt_messages,
system_message=system_message_, system_message=system_message_,
) )
return cls( # type: ignore[call-arg] return cls(
llm=llm, llm=llm,
prompt=prompt, prompt=prompt,
tools=tools, tools=tools,

View File

@ -56,7 +56,7 @@ def _load_question_to_checked_assertions_chain(
revised_answer_chain, revised_answer_chain,
] ]
return SequentialChain( return SequentialChain(
chains=chains, # type: ignore[arg-type] chains=chains,
input_variables=["question"], input_variables=["question"],
output_variables=["revised_statement"], output_variables=["revised_statement"],
verbose=True, verbose=True,

View File

@ -169,7 +169,7 @@ def _load_map_reduce_documents_chain(
return MapReduceDocumentsChain( return MapReduceDocumentsChain(
llm_chain=llm_chain, llm_chain=llm_chain,
reduce_documents_chain=reduce_documents_chain, # type: ignore[arg-type] reduce_documents_chain=reduce_documents_chain,
**config, **config,
) )
@ -293,10 +293,10 @@ def _load_llm_checker_chain(config: dict, **kwargs: Any) -> LLMCheckerChain:
revised_answer_prompt = load_prompt(config.pop("revised_answer_prompt_path")) revised_answer_prompt = load_prompt(config.pop("revised_answer_prompt_path"))
return LLMCheckerChain( return LLMCheckerChain(
llm=llm, llm=llm,
create_draft_answer_prompt=create_draft_answer_prompt, # type: ignore[arg-type] create_draft_answer_prompt=create_draft_answer_prompt,
list_assertions_prompt=list_assertions_prompt, # type: ignore[arg-type] list_assertions_prompt=list_assertions_prompt,
check_assertions_prompt=check_assertions_prompt, # type: ignore[arg-type] check_assertions_prompt=check_assertions_prompt,
revised_answer_prompt=revised_answer_prompt, # type: ignore[arg-type] revised_answer_prompt=revised_answer_prompt,
**config, **config,
) )
@ -325,7 +325,7 @@ def _load_llm_math_chain(config: dict, **kwargs: Any) -> LLMMathChain:
elif "prompt_path" in config: elif "prompt_path" in config:
prompt = load_prompt(config.pop("prompt_path")) prompt = load_prompt(config.pop("prompt_path"))
if llm_chain: if llm_chain:
return LLMMathChain(llm_chain=llm_chain, prompt=prompt, **config) # type: ignore[arg-type] return LLMMathChain(llm_chain=llm_chain, prompt=prompt, **config)
return LLMMathChain(llm=llm, prompt=prompt, **config) return LLMMathChain(llm=llm, prompt=prompt, **config)
@ -341,7 +341,7 @@ def _load_map_rerank_documents_chain(
else: else:
msg = "One of `llm_chain` or `llm_chain_path` must be present." msg = "One of `llm_chain` or `llm_chain_path` must be present."
raise ValueError(msg) raise ValueError(msg)
return MapRerankDocumentsChain(llm_chain=llm_chain, **config) # type: ignore[arg-type] return MapRerankDocumentsChain(llm_chain=llm_chain, **config)
def _load_pal_chain(config: dict, **kwargs: Any) -> Any: def _load_pal_chain(config: dict, **kwargs: Any) -> Any:
@ -377,8 +377,8 @@ def _load_refine_documents_chain(config: dict, **kwargs: Any) -> RefineDocuments
elif "document_prompt_path" in config: elif "document_prompt_path" in config:
document_prompt = load_prompt(config.pop("document_prompt_path")) document_prompt = load_prompt(config.pop("document_prompt_path"))
return RefineDocumentsChain( return RefineDocumentsChain(
initial_llm_chain=initial_llm_chain, # type: ignore[arg-type] initial_llm_chain=initial_llm_chain,
refine_llm_chain=refine_llm_chain, # type: ignore[arg-type] refine_llm_chain=refine_llm_chain,
document_prompt=document_prompt, document_prompt=document_prompt,
**config, **config,
) )
@ -402,7 +402,7 @@ def _load_qa_with_sources_chain(config: dict, **kwargs: Any) -> QAWithSourcesCha
"`combine_documents_chain_path` must be present." "`combine_documents_chain_path` must be present."
) )
raise ValueError(msg) raise ValueError(msg)
return QAWithSourcesChain(combine_documents_chain=combine_documents_chain, **config) # type: ignore[arg-type] return QAWithSourcesChain(combine_documents_chain=combine_documents_chain, **config)
def _load_sql_database_chain(config: dict, **kwargs: Any) -> Any: def _load_sql_database_chain(config: dict, **kwargs: Any) -> Any:
@ -445,7 +445,7 @@ def _load_vector_db_qa_with_sources_chain(
) )
raise ValueError(msg) raise ValueError(msg)
return VectorDBQAWithSourcesChain( return VectorDBQAWithSourcesChain(
combine_documents_chain=combine_documents_chain, # type: ignore[arg-type] combine_documents_chain=combine_documents_chain,
vectorstore=vectorstore, vectorstore=vectorstore,
**config, **config,
) )
@ -475,7 +475,7 @@ def _load_retrieval_qa(config: dict, **kwargs: Any) -> RetrievalQA:
) )
raise ValueError(msg) raise ValueError(msg)
return RetrievalQA( return RetrievalQA(
combine_documents_chain=combine_documents_chain, # type: ignore[arg-type] combine_documents_chain=combine_documents_chain,
retriever=retriever, retriever=retriever,
**config, **config,
) )
@ -508,7 +508,7 @@ def _load_retrieval_qa_with_sources_chain(
) )
raise ValueError(msg) raise ValueError(msg)
return RetrievalQAWithSourcesChain( return RetrievalQAWithSourcesChain(
combine_documents_chain=combine_documents_chain, # type: ignore[arg-type] combine_documents_chain=combine_documents_chain,
retriever=retriever, retriever=retriever,
**config, **config,
) )
@ -538,7 +538,7 @@ def _load_vector_db_qa(config: dict, **kwargs: Any) -> VectorDBQA:
) )
raise ValueError(msg) raise ValueError(msg)
return VectorDBQA( return VectorDBQA(
combine_documents_chain=combine_documents_chain, # type: ignore[arg-type] combine_documents_chain=combine_documents_chain,
vectorstore=vectorstore, vectorstore=vectorstore,
**config, **config,
) )
@ -606,8 +606,8 @@ def _load_api_chain(config: dict, **kwargs: Any) -> APIChain:
msg = "`requests_wrapper` must be present." msg = "`requests_wrapper` must be present."
raise ValueError(msg) raise ValueError(msg)
return APIChain( return APIChain(
api_request_chain=api_request_chain, # type: ignore[arg-type] api_request_chain=api_request_chain,
api_answer_chain=api_answer_chain, # type: ignore[arg-type] api_answer_chain=api_answer_chain,
requests_wrapper=requests_wrapper, requests_wrapper=requests_wrapper,
**config, **config,
) )

View File

@ -66,12 +66,12 @@ def _load_stuff_chain(
verbose: Optional[bool] = None, verbose: Optional[bool] = None,
**kwargs: Any, **kwargs: Any,
) -> StuffDocumentsChain: ) -> StuffDocumentsChain:
llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose) # type: ignore[arg-type] llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
return StuffDocumentsChain( return StuffDocumentsChain(
llm_chain=llm_chain, llm_chain=llm_chain,
document_variable_name=document_variable_name, document_variable_name=document_variable_name,
document_prompt=document_prompt, document_prompt=document_prompt,
verbose=verbose, # type: ignore[arg-type] verbose=verbose,
**kwargs, **kwargs,
) )
@ -91,14 +91,14 @@ def _load_map_reduce_chain(
token_max: int = 3000, token_max: int = 3000,
**kwargs: Any, **kwargs: Any,
) -> MapReduceDocumentsChain: ) -> MapReduceDocumentsChain:
map_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose) # type: ignore[arg-type] map_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
_reduce_llm = reduce_llm or llm _reduce_llm = reduce_llm or llm
reduce_chain = LLMChain(llm=_reduce_llm, prompt=combine_prompt, verbose=verbose) # type: ignore[arg-type] reduce_chain = LLMChain(llm=_reduce_llm, prompt=combine_prompt, verbose=verbose)
combine_documents_chain = StuffDocumentsChain( combine_documents_chain = StuffDocumentsChain(
llm_chain=reduce_chain, llm_chain=reduce_chain,
document_variable_name=combine_document_variable_name, document_variable_name=combine_document_variable_name,
document_prompt=document_prompt, document_prompt=document_prompt,
verbose=verbose, # type: ignore[arg-type] verbose=verbose,
) )
if collapse_prompt is None: if collapse_prompt is None:
collapse_chain = None collapse_chain = None
@ -114,7 +114,7 @@ def _load_map_reduce_chain(
llm_chain=LLMChain( llm_chain=LLMChain(
llm=_collapse_llm, llm=_collapse_llm,
prompt=collapse_prompt, prompt=collapse_prompt,
verbose=verbose, # type: ignore[arg-type] verbose=verbose,
), ),
document_variable_name=combine_document_variable_name, document_variable_name=combine_document_variable_name,
document_prompt=document_prompt, document_prompt=document_prompt,
@ -123,13 +123,13 @@ def _load_map_reduce_chain(
combine_documents_chain=combine_documents_chain, combine_documents_chain=combine_documents_chain,
collapse_documents_chain=collapse_chain, collapse_documents_chain=collapse_chain,
token_max=token_max, token_max=token_max,
verbose=verbose, # type: ignore[arg-type] verbose=verbose,
) )
return MapReduceDocumentsChain( return MapReduceDocumentsChain(
llm_chain=map_chain, llm_chain=map_chain,
reduce_documents_chain=reduce_documents_chain, reduce_documents_chain=reduce_documents_chain,
document_variable_name=map_reduce_document_variable_name, document_variable_name=map_reduce_document_variable_name,
verbose=verbose, # type: ignore[arg-type] verbose=verbose,
**kwargs, **kwargs,
) )
@ -146,16 +146,16 @@ def _load_refine_chain(
verbose: Optional[bool] = None, verbose: Optional[bool] = None,
**kwargs: Any, **kwargs: Any,
) -> RefineDocumentsChain: ) -> RefineDocumentsChain:
initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose) # type: ignore[arg-type] initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
_refine_llm = refine_llm or llm _refine_llm = refine_llm or llm
refine_chain = LLMChain(llm=_refine_llm, prompt=refine_prompt, verbose=verbose) # type: ignore[arg-type] refine_chain = LLMChain(llm=_refine_llm, prompt=refine_prompt, verbose=verbose)
return RefineDocumentsChain( return RefineDocumentsChain(
initial_llm_chain=initial_chain, initial_llm_chain=initial_chain,
refine_llm_chain=refine_chain, refine_llm_chain=refine_chain,
document_variable_name=document_variable_name, document_variable_name=document_variable_name,
initial_response_name=initial_response_name, initial_response_name=initial_response_name,
document_prompt=document_prompt, document_prompt=document_prompt,
verbose=verbose, # type: ignore[arg-type] verbose=verbose,
**kwargs, **kwargs,
) )

View File

@ -80,7 +80,7 @@ def _load_stuff_chain(
llm_chain = LLMChain( llm_chain = LLMChain(
llm=llm, llm=llm,
prompt=_prompt, prompt=_prompt,
verbose=verbose, # type: ignore[arg-type] verbose=verbose,
callback_manager=callback_manager, callback_manager=callback_manager,
callbacks=callbacks, callbacks=callbacks,
) )
@ -88,7 +88,7 @@ def _load_stuff_chain(
return StuffDocumentsChain( return StuffDocumentsChain(
llm_chain=llm_chain, llm_chain=llm_chain,
document_variable_name=document_variable_name, document_variable_name=document_variable_name,
verbose=verbose, # type: ignore[arg-type] verbose=verbose,
callback_manager=callback_manager, callback_manager=callback_manager,
callbacks=callbacks, callbacks=callbacks,
**kwargs, **kwargs,
@ -120,7 +120,7 @@ def _load_map_reduce_chain(
map_chain = LLMChain( map_chain = LLMChain(
llm=llm, llm=llm,
prompt=_question_prompt, prompt=_question_prompt,
verbose=verbose, # type: ignore[arg-type] verbose=verbose,
callback_manager=callback_manager, callback_manager=callback_manager,
callbacks=callbacks, callbacks=callbacks,
) )
@ -128,7 +128,7 @@ def _load_map_reduce_chain(
reduce_chain = LLMChain( reduce_chain = LLMChain(
llm=_reduce_llm, llm=_reduce_llm,
prompt=_combine_prompt, prompt=_combine_prompt,
verbose=verbose, # type: ignore[arg-type] verbose=verbose,
callback_manager=callback_manager, callback_manager=callback_manager,
callbacks=callbacks, callbacks=callbacks,
) )
@ -136,7 +136,7 @@ def _load_map_reduce_chain(
combine_documents_chain = StuffDocumentsChain( combine_documents_chain = StuffDocumentsChain(
llm_chain=reduce_chain, llm_chain=reduce_chain,
document_variable_name=combine_document_variable_name, document_variable_name=combine_document_variable_name,
verbose=verbose, # type: ignore[arg-type] verbose=verbose,
callback_manager=callback_manager, callback_manager=callback_manager,
callbacks=callbacks, callbacks=callbacks,
) )
@ -154,12 +154,12 @@ def _load_map_reduce_chain(
llm_chain=LLMChain( llm_chain=LLMChain(
llm=_collapse_llm, llm=_collapse_llm,
prompt=collapse_prompt, prompt=collapse_prompt,
verbose=verbose, # type: ignore[arg-type] verbose=verbose,
callback_manager=callback_manager, callback_manager=callback_manager,
callbacks=callbacks, callbacks=callbacks,
), ),
document_variable_name=combine_document_variable_name, document_variable_name=combine_document_variable_name,
verbose=verbose, # type: ignore[arg-type] verbose=verbose,
callback_manager=callback_manager, callback_manager=callback_manager,
) )
reduce_documents_chain = ReduceDocumentsChain( reduce_documents_chain = ReduceDocumentsChain(
@ -172,7 +172,7 @@ def _load_map_reduce_chain(
llm_chain=map_chain, llm_chain=map_chain,
document_variable_name=map_reduce_document_variable_name, document_variable_name=map_reduce_document_variable_name,
reduce_documents_chain=reduce_documents_chain, reduce_documents_chain=reduce_documents_chain,
verbose=verbose, # type: ignore[arg-type] verbose=verbose,
callback_manager=callback_manager, callback_manager=callback_manager,
callbacks=callbacks, callbacks=callbacks,
**kwargs, **kwargs,
@ -201,7 +201,7 @@ def _load_refine_chain(
initial_chain = LLMChain( initial_chain = LLMChain(
llm=llm, llm=llm,
prompt=_question_prompt, prompt=_question_prompt,
verbose=verbose, # type: ignore[arg-type] verbose=verbose,
callback_manager=callback_manager, callback_manager=callback_manager,
callbacks=callbacks, callbacks=callbacks,
) )
@ -209,7 +209,7 @@ def _load_refine_chain(
refine_chain = LLMChain( refine_chain = LLMChain(
llm=_refine_llm, llm=_refine_llm,
prompt=_refine_prompt, prompt=_refine_prompt,
verbose=verbose, # type: ignore[arg-type] verbose=verbose,
callback_manager=callback_manager, callback_manager=callback_manager,
callbacks=callbacks, callbacks=callbacks,
) )
@ -218,7 +218,7 @@ def _load_refine_chain(
refine_llm_chain=refine_chain, refine_llm_chain=refine_chain,
document_variable_name=document_variable_name, document_variable_name=document_variable_name,
initial_response_name=initial_response_name, initial_response_name=initial_response_name,
verbose=verbose, # type: ignore[arg-type] verbose=verbose,
callback_manager=callback_manager, callback_manager=callback_manager,
callbacks=callbacks, callbacks=callbacks,
**kwargs, **kwargs,

View File

@ -35,12 +35,12 @@ def _load_stuff_chain(
verbose: Optional[bool] = None, verbose: Optional[bool] = None,
**kwargs: Any, **kwargs: Any,
) -> StuffDocumentsChain: ) -> StuffDocumentsChain:
llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose) # type: ignore[arg-type] llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
# TODO: document prompt # TODO: document prompt
return StuffDocumentsChain( return StuffDocumentsChain(
llm_chain=llm_chain, llm_chain=llm_chain,
document_variable_name=document_variable_name, document_variable_name=document_variable_name,
verbose=verbose, # type: ignore[arg-type] verbose=verbose,
**kwargs, **kwargs,
) )
@ -64,21 +64,21 @@ def _load_map_reduce_chain(
map_chain = LLMChain( map_chain = LLMChain(
llm=llm, llm=llm,
prompt=map_prompt, prompt=map_prompt,
verbose=verbose, # type: ignore[arg-type] verbose=verbose,
callbacks=callbacks, callbacks=callbacks,
) )
_reduce_llm = reduce_llm or llm _reduce_llm = reduce_llm or llm
reduce_chain = LLMChain( reduce_chain = LLMChain(
llm=_reduce_llm, llm=_reduce_llm,
prompt=combine_prompt, prompt=combine_prompt,
verbose=verbose, # type: ignore[arg-type] verbose=verbose,
callbacks=callbacks, callbacks=callbacks,
) )
# TODO: document prompt # TODO: document prompt
combine_documents_chain = StuffDocumentsChain( combine_documents_chain = StuffDocumentsChain(
llm_chain=reduce_chain, llm_chain=reduce_chain,
document_variable_name=combine_document_variable_name, document_variable_name=combine_document_variable_name,
verbose=verbose, # type: ignore[arg-type] verbose=verbose,
callbacks=callbacks, callbacks=callbacks,
) )
if collapse_prompt is None: if collapse_prompt is None:
@ -95,7 +95,7 @@ def _load_map_reduce_chain(
llm_chain=LLMChain( llm_chain=LLMChain(
llm=_collapse_llm, llm=_collapse_llm,
prompt=collapse_prompt, prompt=collapse_prompt,
verbose=verbose, # type: ignore[arg-type] verbose=verbose,
callbacks=callbacks, callbacks=callbacks,
), ),
document_variable_name=combine_document_variable_name, document_variable_name=combine_document_variable_name,
@ -104,7 +104,7 @@ def _load_map_reduce_chain(
combine_documents_chain=combine_documents_chain, combine_documents_chain=combine_documents_chain,
collapse_documents_chain=collapse_chain, collapse_documents_chain=collapse_chain,
token_max=token_max, token_max=token_max,
verbose=verbose, # type: ignore[arg-type] verbose=verbose,
callbacks=callbacks, callbacks=callbacks,
collapse_max_retries=collapse_max_retries, collapse_max_retries=collapse_max_retries,
) )
@ -112,7 +112,7 @@ def _load_map_reduce_chain(
llm_chain=map_chain, llm_chain=map_chain,
reduce_documents_chain=reduce_documents_chain, reduce_documents_chain=reduce_documents_chain,
document_variable_name=map_reduce_document_variable_name, document_variable_name=map_reduce_document_variable_name,
verbose=verbose, # type: ignore[arg-type] verbose=verbose,
callbacks=callbacks, callbacks=callbacks,
**kwargs, **kwargs,
) )
@ -129,15 +129,15 @@ def _load_refine_chain(
verbose: Optional[bool] = None, verbose: Optional[bool] = None,
**kwargs: Any, **kwargs: Any,
) -> RefineDocumentsChain: ) -> RefineDocumentsChain:
initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose) # type: ignore[arg-type] initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
_refine_llm = refine_llm or llm _refine_llm = refine_llm or llm
refine_chain = LLMChain(llm=_refine_llm, prompt=refine_prompt, verbose=verbose) # type: ignore[arg-type] refine_chain = LLMChain(llm=_refine_llm, prompt=refine_prompt, verbose=verbose)
return RefineDocumentsChain( return RefineDocumentsChain(
initial_llm_chain=initial_chain, initial_llm_chain=initial_chain,
refine_llm_chain=refine_chain, refine_llm_chain=refine_chain,
document_variable_name=document_variable_name, document_variable_name=document_variable_name,
initial_response_name=initial_response_name, initial_response_name=initial_response_name,
verbose=verbose, # type: ignore[arg-type] verbose=verbose,
**kwargs, **kwargs,
) )

View File

@ -250,7 +250,7 @@ The following is the expected answer. Use this to measure correctness:
prompt = EVAL_CHAT_PROMPT if agent_tools else TOOL_FREE_EVAL_CHAT_PROMPT prompt = EVAL_CHAT_PROMPT if agent_tools else TOOL_FREE_EVAL_CHAT_PROMPT
eval_chain = LLMChain(llm=llm, prompt=prompt) eval_chain = LLMChain(llm=llm, prompt=prompt)
return cls( return cls(
agent_tools=agent_tools, # type: ignore[arg-type] agent_tools=agent_tools,
eval_chain=eval_chain, eval_chain=eval_chain,
output_parser=output_parser or TrajectoryOutputParser(), output_parser=output_parser or TrajectoryOutputParser(),
**kwargs, **kwargs,

View File

@ -120,4 +120,4 @@ class LLMChainExtractor(BaseDocumentCompressor):
else: else:
parser = StrOutputParser() parser = StrOutputParser()
llm_chain = _prompt | llm | parser llm_chain = _prompt | llm | parser
return cls(llm_chain=llm_chain, get_input=_get_input) # type: ignore[arg-type] return cls(llm_chain=llm_chain, get_input=_get_input)

View File

@ -407,7 +407,7 @@ class SelfQueryRetriever(BaseRetriever):
query_constructor = query_constructor.with_config( query_constructor = query_constructor.with_config(
run_name=QUERY_CONSTRUCTOR_RUN_NAME, run_name=QUERY_CONSTRUCTOR_RUN_NAME,
) )
return cls( # type: ignore[call-arg] return cls(
query_constructor=query_constructor, query_constructor=query_constructor,
vectorstore=vectorstore, vectorstore=vectorstore,
use_original_query=use_original_query, use_original_query=use_original_query,

View File

@ -4,7 +4,7 @@ from langchain_core.runnables.base import RunnableBindingBase
from langchain_core.runnables.utils import Input, Output from langchain_core.runnables.utils import Input, Output
class HubRunnable(RunnableBindingBase[Input, Output]): class HubRunnable(RunnableBindingBase[Input, Output]): # type: ignore[no-redef]
""" """
An instance of a runnable stored in the LangChain Hub. An instance of a runnable stored in the LangChain Hub.
""" """

View File

@ -20,7 +20,7 @@ class OpenAIFunction(TypedDict):
"""The parameters to the function.""" """The parameters to the function."""
class OpenAIFunctionsRouter(RunnableBindingBase[BaseMessage, Any]): class OpenAIFunctionsRouter(RunnableBindingBase[BaseMessage, Any]): # type: ignore[no-redef]
"""A runnable that routes to the selected function.""" """A runnable that routes to the selected function."""
functions: Optional[list[OpenAIFunction]] functions: Optional[list[OpenAIFunction]]

View File

@ -124,6 +124,7 @@ target-version = "py39"
exclude = ["tests/integration_tests/examples/non-utf8-encoding.py"] exclude = ["tests/integration_tests/examples/non-utf8-encoding.py"]
[tool.mypy] [tool.mypy]
plugins = ["pydantic.mypy"]
strict = "True" strict = "True"
strict_bytes = "True" strict_bytes = "True"
ignore_missing_imports = "True" ignore_missing_imports = "True"

View File

@ -15,7 +15,7 @@ async def test_simplea() -> None:
answer = "I know the answer!" answer = "I know the answer!"
llm = FakeListLLM(responses=[answer]) llm = FakeListLLM(responses=[answer])
retriever = SequentialRetriever(sequential_responses=[[]]) retriever = SequentialRetriever(sequential_responses=[[]])
memory = ConversationBufferMemory( # type: ignore[call-arg] memory = ConversationBufferMemory(
k=1, k=1,
output_key="answer", output_key="answer",
memory_key="chat_history", memory_key="chat_history",
@ -42,7 +42,7 @@ async def test_fixed_message_response_when_docs_founda() -> None:
retriever = SequentialRetriever( retriever = SequentialRetriever(
sequential_responses=[[Document(page_content=answer)]], sequential_responses=[[Document(page_content=answer)]],
) )
memory = ConversationBufferMemory( # type: ignore[call-arg] memory = ConversationBufferMemory(
k=1, k=1,
output_key="answer", output_key="answer",
memory_key="chat_history", memory_key="chat_history",
@ -67,7 +67,7 @@ def test_fixed_message_response_when_no_docs_found() -> None:
answer = "I know the answer!" answer = "I know the answer!"
llm = FakeListLLM(responses=[answer]) llm = FakeListLLM(responses=[answer])
retriever = SequentialRetriever(sequential_responses=[[]]) retriever = SequentialRetriever(sequential_responses=[[]])
memory = ConversationBufferMemory( # type: ignore[call-arg] memory = ConversationBufferMemory(
k=1, k=1,
output_key="answer", output_key="answer",
memory_key="chat_history", memory_key="chat_history",
@ -94,7 +94,7 @@ def test_fixed_message_response_when_docs_found() -> None:
retriever = SequentialRetriever( retriever = SequentialRetriever(
sequential_responses=[[Document(page_content=answer)]], sequential_responses=[[Document(page_content=answer)]],
) )
memory = ConversationBufferMemory( # type: ignore[call-arg] memory = ConversationBufferMemory(
k=1, k=1,
output_key="answer", output_key="answer",
memory_key="chat_history", memory_key="chat_history",

View File

@ -16,7 +16,7 @@ def dummy_transform(inputs: dict[str, str]) -> dict[str, str]:
def test_transform_chain() -> None: def test_transform_chain() -> None:
"""Test basic transform chain.""" """Test basic transform chain."""
transform_chain = TransformChain( # type: ignore[call-arg] transform_chain = TransformChain(
input_variables=["first_name", "last_name"], input_variables=["first_name", "last_name"],
output_variables=["greeting"], output_variables=["greeting"],
transform=dummy_transform, transform=dummy_transform,
@ -29,7 +29,7 @@ def test_transform_chain() -> None:
def test_transform_chain_bad_inputs() -> None: def test_transform_chain_bad_inputs() -> None:
"""Test basic transform chain.""" """Test basic transform chain."""
transform_chain = TransformChain( # type: ignore[call-arg] transform_chain = TransformChain(
input_variables=["first_name", "last_name"], input_variables=["first_name", "last_name"],
output_variables=["greeting"], output_variables=["greeting"],
transform=dummy_transform, transform=dummy_transform,

View File

@ -116,9 +116,9 @@ class TestClass(Serializable):
def test_aliases_hidden() -> None: def test_aliases_hidden() -> None:
test_class = TestClass( test_class = TestClass(
my_favorite_secret="hello", # noqa: S106 # type: ignore[call-arg] my_favorite_secret="hello", # noqa: S106
my_other_secret="world", # noqa: S106 my_other_secret="world", # noqa: S106
) # type: ignore[call-arg] )
dumped = json.loads(dumps(test_class, pretty=True)) dumped = json.loads(dumps(test_class, pretty=True))
expected_dump = { expected_dump = {
"lc": 1, "lc": 1,
@ -143,7 +143,7 @@ def test_aliases_hidden() -> None:
dumped = json.loads(dumps(test_class, pretty=True)) dumped = json.loads(dumps(test_class, pretty=True))
# Check by alias # Check by alias
test_class = TestClass( test_class = TestClass( # type: ignore[call-arg]
my_favorite_secret_alias="hello", # noqa: S106 my_favorite_secret_alias="hello", # noqa: S106
my_other_secret="parrot party", # noqa: S106 my_other_secret="parrot party", # noqa: S106
) )

View File

@ -162,7 +162,7 @@ def test_load_llmchain_with_non_serializable_arg() -> None:
import httpx import httpx
from langchain_openai import OpenAI from langchain_openai import OpenAI
llm = OpenAI( # type: ignore[call-arg] llm = OpenAI(
model="davinci", model="davinci",
temperature=0.5, temperature=0.5,
openai_api_key="hello", openai_api_key="hello",

View File

@ -37,7 +37,7 @@ class SuccessfulParseAfterRetries(BaseOutputParser[str]):
def test_retry_output_parser_parse_with_prompt() -> None: def test_retry_output_parser_parse_with_prompt() -> None:
n: int = 5 # Success on the (n+1)-th attempt n: int = 5 # Success on the (n+1)-th attempt
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n) base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
parser = RetryOutputParser( parser = RetryOutputParser[SuccessfulParseAfterRetries](
parser=base_parser, parser=base_parser,
retry_chain=RunnablePassthrough(), retry_chain=RunnablePassthrough(),
max_retries=n, # n times to retry, that is, (n+1) times call max_retries=n, # n times to retry, that is, (n+1) times call
@ -51,7 +51,7 @@ def test_retry_output_parser_parse_with_prompt() -> None:
def test_retry_output_parser_parse_with_prompt_fail() -> None: def test_retry_output_parser_parse_with_prompt_fail() -> None:
n: int = 5 # Success on the (n+1)-th attempt n: int = 5 # Success on the (n+1)-th attempt
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n) base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
parser = RetryOutputParser( parser = RetryOutputParser[SuccessfulParseAfterRetries](
parser=base_parser, parser=base_parser,
retry_chain=RunnablePassthrough(), retry_chain=RunnablePassthrough(),
max_retries=n - 1, # n-1 times to retry, that is, n times call max_retries=n - 1, # n-1 times to retry, that is, n times call
@ -65,7 +65,7 @@ def test_retry_output_parser_parse_with_prompt_fail() -> None:
async def test_retry_output_parser_aparse_with_prompt() -> None: async def test_retry_output_parser_aparse_with_prompt() -> None:
n: int = 5 # Success on the (n+1)-th attempt n: int = 5 # Success on the (n+1)-th attempt
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n) base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
parser = RetryOutputParser( parser = RetryOutputParser[SuccessfulParseAfterRetries](
parser=base_parser, parser=base_parser,
retry_chain=RunnablePassthrough(), retry_chain=RunnablePassthrough(),
max_retries=n, # n times to retry, that is, (n+1) times call max_retries=n, # n times to retry, that is, (n+1) times call
@ -82,7 +82,7 @@ async def test_retry_output_parser_aparse_with_prompt() -> None:
async def test_retry_output_parser_aparse_with_prompt_fail() -> None: async def test_retry_output_parser_aparse_with_prompt_fail() -> None:
n: int = 5 # Success on the (n+1)-th attempt n: int = 5 # Success on the (n+1)-th attempt
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n) base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
parser = RetryOutputParser( parser = RetryOutputParser[SuccessfulParseAfterRetries](
parser=base_parser, parser=base_parser,
retry_chain=RunnablePassthrough(), retry_chain=RunnablePassthrough(),
max_retries=n - 1, # n-1 times to retry, that is, n times call max_retries=n - 1, # n-1 times to retry, that is, n times call
@ -101,7 +101,7 @@ async def test_retry_output_parser_aparse_with_prompt_fail() -> None:
], ],
) )
def test_retry_output_parser_output_type(base_parser: BaseOutputParser) -> None: def test_retry_output_parser_output_type(base_parser: BaseOutputParser) -> None:
parser = RetryOutputParser( parser = RetryOutputParser[BaseOutputParser](
parser=base_parser, parser=base_parser,
retry_chain=RunnablePassthrough(), retry_chain=RunnablePassthrough(),
legacy=False, legacy=False,
@ -110,7 +110,7 @@ def test_retry_output_parser_output_type(base_parser: BaseOutputParser) -> None:
def test_retry_output_parser_parse_is_not_implemented() -> None: def test_retry_output_parser_parse_is_not_implemented() -> None:
parser = RetryOutputParser( parser = RetryOutputParser[BooleanOutputParser](
parser=BooleanOutputParser(), parser=BooleanOutputParser(),
retry_chain=RunnablePassthrough(), retry_chain=RunnablePassthrough(),
legacy=False, legacy=False,
@ -122,7 +122,7 @@ def test_retry_output_parser_parse_is_not_implemented() -> None:
def test_retry_with_error_output_parser_parse_with_prompt() -> None: def test_retry_with_error_output_parser_parse_with_prompt() -> None:
n: int = 5 # Success on the (n+1)-th attempt n: int = 5 # Success on the (n+1)-th attempt
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n) base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
parser = RetryWithErrorOutputParser( parser = RetryWithErrorOutputParser[SuccessfulParseAfterRetries](
parser=base_parser, parser=base_parser,
retry_chain=RunnablePassthrough(), retry_chain=RunnablePassthrough(),
max_retries=n, # n times to retry, that is, (n+1) times call max_retries=n, # n times to retry, that is, (n+1) times call
@ -136,7 +136,7 @@ def test_retry_with_error_output_parser_parse_with_prompt() -> None:
def test_retry_with_error_output_parser_parse_with_prompt_fail() -> None: def test_retry_with_error_output_parser_parse_with_prompt_fail() -> None:
n: int = 5 # Success on the (n+1)-th attempt n: int = 5 # Success on the (n+1)-th attempt
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n) base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
parser = RetryWithErrorOutputParser( parser = RetryWithErrorOutputParser[SuccessfulParseAfterRetries](
parser=base_parser, parser=base_parser,
retry_chain=RunnablePassthrough(), retry_chain=RunnablePassthrough(),
max_retries=n - 1, # n-1 times to retry, that is, n times call max_retries=n - 1, # n-1 times to retry, that is, n times call
@ -150,7 +150,7 @@ def test_retry_with_error_output_parser_parse_with_prompt_fail() -> None:
async def test_retry_with_error_output_parser_aparse_with_prompt() -> None: async def test_retry_with_error_output_parser_aparse_with_prompt() -> None:
n: int = 5 # Success on the (n+1)-th attempt n: int = 5 # Success on the (n+1)-th attempt
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n) base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
parser = RetryWithErrorOutputParser( parser = RetryWithErrorOutputParser[SuccessfulParseAfterRetries](
parser=base_parser, parser=base_parser,
retry_chain=RunnablePassthrough(), retry_chain=RunnablePassthrough(),
max_retries=n, # n times to retry, that is, (n+1) times call max_retries=n, # n times to retry, that is, (n+1) times call
@ -167,7 +167,7 @@ async def test_retry_with_error_output_parser_aparse_with_prompt() -> None:
async def test_retry_with_error_output_parser_aparse_with_prompt_fail() -> None: async def test_retry_with_error_output_parser_aparse_with_prompt_fail() -> None:
n: int = 5 # Success on the (n+1)-th attempt n: int = 5 # Success on the (n+1)-th attempt
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n) base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
parser = RetryWithErrorOutputParser( parser = RetryWithErrorOutputParser[SuccessfulParseAfterRetries](
parser=base_parser, parser=base_parser,
retry_chain=RunnablePassthrough(), retry_chain=RunnablePassthrough(),
max_retries=n - 1, # n-1 times to retry, that is, n times call max_retries=n - 1, # n-1 times to retry, that is, n times call
@ -188,7 +188,7 @@ async def test_retry_with_error_output_parser_aparse_with_prompt_fail() -> None:
def test_retry_with_error_output_parser_output_type( def test_retry_with_error_output_parser_output_type(
base_parser: BaseOutputParser, base_parser: BaseOutputParser,
) -> None: ) -> None:
parser = RetryWithErrorOutputParser( parser = RetryWithErrorOutputParser[BaseOutputParser](
parser=base_parser, parser=base_parser,
retry_chain=RunnablePassthrough(), retry_chain=RunnablePassthrough(),
legacy=False, legacy=False,
@ -197,7 +197,7 @@ def test_retry_with_error_output_parser_output_type(
def test_retry_with_error_output_parser_parse_is_not_implemented() -> None: def test_retry_with_error_output_parser_parse_is_not_implemented() -> None:
parser = RetryWithErrorOutputParser( parser = RetryWithErrorOutputParser[BooleanOutputParser](
parser=BooleanOutputParser(), parser=BooleanOutputParser(),
retry_chain=RunnablePassthrough(), retry_chain=RunnablePassthrough(),
legacy=False, legacy=False,
@ -222,11 +222,11 @@ def test_retry_with_error_output_parser_parse_is_not_implemented() -> None:
def test_retry_output_parser_parse_with_prompt_with_retry_chain( def test_retry_output_parser_parse_with_prompt_with_retry_chain(
completion: str, completion: str,
prompt: PromptValue, prompt: PromptValue,
base_parser: BaseOutputParser[T], base_parser: DatetimeOutputParser,
retry_chain: Runnable[dict[str, Any], str], retry_chain: Runnable[dict[str, Any], str],
expected: T, expected: dt,
) -> None: ) -> None:
parser = RetryOutputParser( parser = RetryOutputParser[DatetimeOutputParser](
parser=base_parser, parser=base_parser,
retry_chain=retry_chain, retry_chain=retry_chain,
legacy=False, legacy=False,
@ -250,12 +250,12 @@ def test_retry_output_parser_parse_with_prompt_with_retry_chain(
async def test_retry_output_parser_aparse_with_prompt_with_retry_chain( async def test_retry_output_parser_aparse_with_prompt_with_retry_chain(
completion: str, completion: str,
prompt: PromptValue, prompt: PromptValue,
base_parser: BaseOutputParser[T], base_parser: DatetimeOutputParser,
retry_chain: Runnable[dict[str, Any], str], retry_chain: Runnable[dict[str, Any], str],
expected: T, expected: dt,
) -> None: ) -> None:
# test # test
parser = RetryOutputParser( parser = RetryOutputParser[DatetimeOutputParser](
parser=base_parser, parser=base_parser,
retry_chain=retry_chain, retry_chain=retry_chain,
legacy=False, legacy=False,
@ -279,12 +279,12 @@ async def test_retry_output_parser_aparse_with_prompt_with_retry_chain(
def test_retry_with_error_output_parser_parse_with_prompt_with_retry_chain( def test_retry_with_error_output_parser_parse_with_prompt_with_retry_chain(
completion: str, completion: str,
prompt: PromptValue, prompt: PromptValue,
base_parser: BaseOutputParser[T], base_parser: DatetimeOutputParser,
retry_chain: Runnable[dict[str, Any], str], retry_chain: Runnable[dict[str, Any], str],
expected: T, expected: dt,
) -> None: ) -> None:
# test # test
parser = RetryWithErrorOutputParser( parser = RetryWithErrorOutputParser[DatetimeOutputParser](
parser=base_parser, parser=base_parser,
retry_chain=retry_chain, retry_chain=retry_chain,
legacy=False, legacy=False,
@ -308,11 +308,11 @@ def test_retry_with_error_output_parser_parse_with_prompt_with_retry_chain(
async def test_retry_with_error_output_parser_aparse_with_prompt_with_retry_chain( async def test_retry_with_error_output_parser_aparse_with_prompt_with_retry_chain(
completion: str, completion: str,
prompt: PromptValue, prompt: PromptValue,
base_parser: BaseOutputParser[T], base_parser: DatetimeOutputParser,
retry_chain: Runnable[dict[str, Any], str], retry_chain: Runnable[dict[str, Any], str],
expected: T, expected: dt,
) -> None: ) -> None:
parser = RetryWithErrorOutputParser( parser = RetryWithErrorOutputParser[DatetimeOutputParser](
parser=base_parser, parser=base_parser,
retry_chain=retry_chain, retry_chain=retry_chain,
legacy=False, legacy=False,

View File

@ -95,5 +95,5 @@ def test_yaml_output_parser_fail() -> None:
def test_yaml_output_parser_output_type() -> None: def test_yaml_output_parser_output_type() -> None:
"""Test YamlOutputParser OutputType.""" """Test YamlOutputParser OutputType."""
yaml_parser = YamlOutputParser(pydantic_object=TestModel) yaml_parser = YamlOutputParser[TestModel](pydantic_object=TestModel)
assert yaml_parser.OutputType is TestModel assert yaml_parser.OutputType is TestModel

View File

@ -8,6 +8,6 @@ def test__list_rerank_init() -> None:
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
LLMListwiseRerank.from_llm( LLMListwiseRerank.from_llm(
llm=ChatOpenAI(api_key="foo"), # type: ignore[arg-type] llm=ChatOpenAI(api_key="foo"),
top_n=10, top_n=10,
) )

View File

@ -43,7 +43,7 @@ class InMemoryVectorstoreWithSearch(InMemoryVectorStore):
def test_multi_vector_retriever_initialization() -> None: def test_multi_vector_retriever_initialization() -> None:
vectorstore = InMemoryVectorstoreWithSearch() vectorstore = InMemoryVectorstoreWithSearch()
retriever = MultiVectorRetriever( # type: ignore[call-arg] retriever = MultiVectorRetriever(
vectorstore=vectorstore, vectorstore=vectorstore,
docstore=InMemoryStore(), docstore=InMemoryStore(),
doc_id="doc_id", doc_id="doc_id",
@ -58,7 +58,7 @@ def test_multi_vector_retriever_initialization() -> None:
async def test_multi_vector_retriever_initialization_async() -> None: async def test_multi_vector_retriever_initialization_async() -> None:
vectorstore = InMemoryVectorstoreWithSearch() vectorstore = InMemoryVectorstoreWithSearch()
retriever = MultiVectorRetriever( # type: ignore[call-arg] retriever = MultiVectorRetriever(
vectorstore=vectorstore, vectorstore=vectorstore,
docstore=InMemoryStore(), docstore=InMemoryStore(),
doc_id="doc_id", doc_id="doc_id",
@ -77,7 +77,7 @@ def test_multi_vector_retriever_similarity_search_with_score() -> None:
vectorstore.add_documents(documents, ids=["1"]) vectorstore.add_documents(documents, ids=["1"])
# score_threshold = 0.5 # score_threshold = 0.5
retriever = MultiVectorRetriever( # type: ignore[call-arg] retriever = MultiVectorRetriever(
vectorstore=vectorstore, vectorstore=vectorstore,
docstore=InMemoryStore(), docstore=InMemoryStore(),
doc_id="doc_id", doc_id="doc_id",
@ -90,7 +90,7 @@ def test_multi_vector_retriever_similarity_search_with_score() -> None:
assert results[0].page_content == "test document" assert results[0].page_content == "test document"
# score_threshold = 0.9 # score_threshold = 0.9
retriever = MultiVectorRetriever( # type: ignore[call-arg] retriever = MultiVectorRetriever(
vectorstore=vectorstore, vectorstore=vectorstore,
docstore=InMemoryStore(), docstore=InMemoryStore(),
doc_id="doc_id", doc_id="doc_id",
@ -108,7 +108,7 @@ async def test_multi_vector_retriever_similarity_search_with_score_async() -> No
await vectorstore.aadd_documents(documents, ids=["1"]) await vectorstore.aadd_documents(documents, ids=["1"])
# score_threshold = 0.5 # score_threshold = 0.5
retriever = MultiVectorRetriever( # type: ignore[call-arg] retriever = MultiVectorRetriever(
vectorstore=vectorstore, vectorstore=vectorstore,
docstore=InMemoryStore(), docstore=InMemoryStore(),
doc_id="doc_id", doc_id="doc_id",
@ -121,7 +121,7 @@ async def test_multi_vector_retriever_similarity_search_with_score_async() -> No
assert results[0].page_content == "test document" assert results[0].page_content == "test document"
# score_threshold = 0.9 # score_threshold = 0.9
retriever = MultiVectorRetriever( # type: ignore[call-arg] retriever = MultiVectorRetriever(
vectorstore=vectorstore, vectorstore=vectorstore,
docstore=InMemoryStore(), docstore=InMemoryStore(),
doc_id="doc_id", doc_id="doc_id",

View File

@ -171,7 +171,7 @@ def test_run_llm_or_chain_with_input_mapper() -> None:
assert "the right input" in inputs assert "the right input" in inputs
return {"output": "2"} return {"output": "2"}
mock_chain = TransformChain( # type: ignore[call-arg] mock_chain = TransformChain(
input_variables=["the right input"], input_variables=["the right input"],
output_variables=["output"], output_variables=["output"],
transform=run_val, transform=run_val,