chore(langchain): add mypy pydantic plugin (#32610)

This commit is contained in:
Christophe Bornet 2025-08-19 22:59:59 +02:00 committed by GitHub
parent 73a7de63aa
commit f896bcdb1d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 101 additions and 100 deletions

View File

@ -277,7 +277,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
extra_prompt_messages=extra_prompt_messages,
system_message=system_message_,
)
return cls( # type: ignore[call-arg]
return cls(
llm=llm,
prompt=prompt,
tools=tools,

View File

@ -328,7 +328,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
extra_prompt_messages=extra_prompt_messages,
system_message=system_message_,
)
return cls( # type: ignore[call-arg]
return cls(
llm=llm,
prompt=prompt,
tools=tools,

View File

@ -56,7 +56,7 @@ def _load_question_to_checked_assertions_chain(
revised_answer_chain,
]
return SequentialChain(
chains=chains, # type: ignore[arg-type]
chains=chains,
input_variables=["question"],
output_variables=["revised_statement"],
verbose=True,

View File

@ -169,7 +169,7 @@ def _load_map_reduce_documents_chain(
return MapReduceDocumentsChain(
llm_chain=llm_chain,
reduce_documents_chain=reduce_documents_chain, # type: ignore[arg-type]
reduce_documents_chain=reduce_documents_chain,
**config,
)
@ -293,10 +293,10 @@ def _load_llm_checker_chain(config: dict, **kwargs: Any) -> LLMCheckerChain:
revised_answer_prompt = load_prompt(config.pop("revised_answer_prompt_path"))
return LLMCheckerChain(
llm=llm,
create_draft_answer_prompt=create_draft_answer_prompt, # type: ignore[arg-type]
list_assertions_prompt=list_assertions_prompt, # type: ignore[arg-type]
check_assertions_prompt=check_assertions_prompt, # type: ignore[arg-type]
revised_answer_prompt=revised_answer_prompt, # type: ignore[arg-type]
create_draft_answer_prompt=create_draft_answer_prompt,
list_assertions_prompt=list_assertions_prompt,
check_assertions_prompt=check_assertions_prompt,
revised_answer_prompt=revised_answer_prompt,
**config,
)
@ -325,7 +325,7 @@ def _load_llm_math_chain(config: dict, **kwargs: Any) -> LLMMathChain:
elif "prompt_path" in config:
prompt = load_prompt(config.pop("prompt_path"))
if llm_chain:
return LLMMathChain(llm_chain=llm_chain, prompt=prompt, **config) # type: ignore[arg-type]
return LLMMathChain(llm_chain=llm_chain, prompt=prompt, **config)
return LLMMathChain(llm=llm, prompt=prompt, **config)
@ -341,7 +341,7 @@ def _load_map_rerank_documents_chain(
else:
msg = "One of `llm_chain` or `llm_chain_path` must be present."
raise ValueError(msg)
return MapRerankDocumentsChain(llm_chain=llm_chain, **config) # type: ignore[arg-type]
return MapRerankDocumentsChain(llm_chain=llm_chain, **config)
def _load_pal_chain(config: dict, **kwargs: Any) -> Any:
@ -377,8 +377,8 @@ def _load_refine_documents_chain(config: dict, **kwargs: Any) -> RefineDocuments
elif "document_prompt_path" in config:
document_prompt = load_prompt(config.pop("document_prompt_path"))
return RefineDocumentsChain(
initial_llm_chain=initial_llm_chain, # type: ignore[arg-type]
refine_llm_chain=refine_llm_chain, # type: ignore[arg-type]
initial_llm_chain=initial_llm_chain,
refine_llm_chain=refine_llm_chain,
document_prompt=document_prompt,
**config,
)
@ -402,7 +402,7 @@ def _load_qa_with_sources_chain(config: dict, **kwargs: Any) -> QAWithSourcesCha
"`combine_documents_chain_path` must be present."
)
raise ValueError(msg)
return QAWithSourcesChain(combine_documents_chain=combine_documents_chain, **config) # type: ignore[arg-type]
return QAWithSourcesChain(combine_documents_chain=combine_documents_chain, **config)
def _load_sql_database_chain(config: dict, **kwargs: Any) -> Any:
@ -445,7 +445,7 @@ def _load_vector_db_qa_with_sources_chain(
)
raise ValueError(msg)
return VectorDBQAWithSourcesChain(
combine_documents_chain=combine_documents_chain, # type: ignore[arg-type]
combine_documents_chain=combine_documents_chain,
vectorstore=vectorstore,
**config,
)
@ -475,7 +475,7 @@ def _load_retrieval_qa(config: dict, **kwargs: Any) -> RetrievalQA:
)
raise ValueError(msg)
return RetrievalQA(
combine_documents_chain=combine_documents_chain, # type: ignore[arg-type]
combine_documents_chain=combine_documents_chain,
retriever=retriever,
**config,
)
@ -508,7 +508,7 @@ def _load_retrieval_qa_with_sources_chain(
)
raise ValueError(msg)
return RetrievalQAWithSourcesChain(
combine_documents_chain=combine_documents_chain, # type: ignore[arg-type]
combine_documents_chain=combine_documents_chain,
retriever=retriever,
**config,
)
@ -538,7 +538,7 @@ def _load_vector_db_qa(config: dict, **kwargs: Any) -> VectorDBQA:
)
raise ValueError(msg)
return VectorDBQA(
combine_documents_chain=combine_documents_chain, # type: ignore[arg-type]
combine_documents_chain=combine_documents_chain,
vectorstore=vectorstore,
**config,
)
@ -606,8 +606,8 @@ def _load_api_chain(config: dict, **kwargs: Any) -> APIChain:
msg = "`requests_wrapper` must be present."
raise ValueError(msg)
return APIChain(
api_request_chain=api_request_chain, # type: ignore[arg-type]
api_answer_chain=api_answer_chain, # type: ignore[arg-type]
api_request_chain=api_request_chain,
api_answer_chain=api_answer_chain,
requests_wrapper=requests_wrapper,
**config,
)

View File

@ -66,12 +66,12 @@ def _load_stuff_chain(
verbose: Optional[bool] = None,
**kwargs: Any,
) -> StuffDocumentsChain:
llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose) # type: ignore[arg-type]
llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
return StuffDocumentsChain(
llm_chain=llm_chain,
document_variable_name=document_variable_name,
document_prompt=document_prompt,
verbose=verbose, # type: ignore[arg-type]
verbose=verbose,
**kwargs,
)
@ -91,14 +91,14 @@ def _load_map_reduce_chain(
token_max: int = 3000,
**kwargs: Any,
) -> MapReduceDocumentsChain:
map_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose) # type: ignore[arg-type]
map_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
_reduce_llm = reduce_llm or llm
reduce_chain = LLMChain(llm=_reduce_llm, prompt=combine_prompt, verbose=verbose) # type: ignore[arg-type]
reduce_chain = LLMChain(llm=_reduce_llm, prompt=combine_prompt, verbose=verbose)
combine_documents_chain = StuffDocumentsChain(
llm_chain=reduce_chain,
document_variable_name=combine_document_variable_name,
document_prompt=document_prompt,
verbose=verbose, # type: ignore[arg-type]
verbose=verbose,
)
if collapse_prompt is None:
collapse_chain = None
@ -114,7 +114,7 @@ def _load_map_reduce_chain(
llm_chain=LLMChain(
llm=_collapse_llm,
prompt=collapse_prompt,
verbose=verbose, # type: ignore[arg-type]
verbose=verbose,
),
document_variable_name=combine_document_variable_name,
document_prompt=document_prompt,
@ -123,13 +123,13 @@ def _load_map_reduce_chain(
combine_documents_chain=combine_documents_chain,
collapse_documents_chain=collapse_chain,
token_max=token_max,
verbose=verbose, # type: ignore[arg-type]
verbose=verbose,
)
return MapReduceDocumentsChain(
llm_chain=map_chain,
reduce_documents_chain=reduce_documents_chain,
document_variable_name=map_reduce_document_variable_name,
verbose=verbose, # type: ignore[arg-type]
verbose=verbose,
**kwargs,
)
@ -146,16 +146,16 @@ def _load_refine_chain(
verbose: Optional[bool] = None,
**kwargs: Any,
) -> RefineDocumentsChain:
initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose) # type: ignore[arg-type]
initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
_refine_llm = refine_llm or llm
refine_chain = LLMChain(llm=_refine_llm, prompt=refine_prompt, verbose=verbose) # type: ignore[arg-type]
refine_chain = LLMChain(llm=_refine_llm, prompt=refine_prompt, verbose=verbose)
return RefineDocumentsChain(
initial_llm_chain=initial_chain,
refine_llm_chain=refine_chain,
document_variable_name=document_variable_name,
initial_response_name=initial_response_name,
document_prompt=document_prompt,
verbose=verbose, # type: ignore[arg-type]
verbose=verbose,
**kwargs,
)

View File

@ -80,7 +80,7 @@ def _load_stuff_chain(
llm_chain = LLMChain(
llm=llm,
prompt=_prompt,
verbose=verbose, # type: ignore[arg-type]
verbose=verbose,
callback_manager=callback_manager,
callbacks=callbacks,
)
@ -88,7 +88,7 @@ def _load_stuff_chain(
return StuffDocumentsChain(
llm_chain=llm_chain,
document_variable_name=document_variable_name,
verbose=verbose, # type: ignore[arg-type]
verbose=verbose,
callback_manager=callback_manager,
callbacks=callbacks,
**kwargs,
@ -120,7 +120,7 @@ def _load_map_reduce_chain(
map_chain = LLMChain(
llm=llm,
prompt=_question_prompt,
verbose=verbose, # type: ignore[arg-type]
verbose=verbose,
callback_manager=callback_manager,
callbacks=callbacks,
)
@ -128,7 +128,7 @@ def _load_map_reduce_chain(
reduce_chain = LLMChain(
llm=_reduce_llm,
prompt=_combine_prompt,
verbose=verbose, # type: ignore[arg-type]
verbose=verbose,
callback_manager=callback_manager,
callbacks=callbacks,
)
@ -136,7 +136,7 @@ def _load_map_reduce_chain(
combine_documents_chain = StuffDocumentsChain(
llm_chain=reduce_chain,
document_variable_name=combine_document_variable_name,
verbose=verbose, # type: ignore[arg-type]
verbose=verbose,
callback_manager=callback_manager,
callbacks=callbacks,
)
@ -154,12 +154,12 @@ def _load_map_reduce_chain(
llm_chain=LLMChain(
llm=_collapse_llm,
prompt=collapse_prompt,
verbose=verbose, # type: ignore[arg-type]
verbose=verbose,
callback_manager=callback_manager,
callbacks=callbacks,
),
document_variable_name=combine_document_variable_name,
verbose=verbose, # type: ignore[arg-type]
verbose=verbose,
callback_manager=callback_manager,
)
reduce_documents_chain = ReduceDocumentsChain(
@ -172,7 +172,7 @@ def _load_map_reduce_chain(
llm_chain=map_chain,
document_variable_name=map_reduce_document_variable_name,
reduce_documents_chain=reduce_documents_chain,
verbose=verbose, # type: ignore[arg-type]
verbose=verbose,
callback_manager=callback_manager,
callbacks=callbacks,
**kwargs,
@ -201,7 +201,7 @@ def _load_refine_chain(
initial_chain = LLMChain(
llm=llm,
prompt=_question_prompt,
verbose=verbose, # type: ignore[arg-type]
verbose=verbose,
callback_manager=callback_manager,
callbacks=callbacks,
)
@ -209,7 +209,7 @@ def _load_refine_chain(
refine_chain = LLMChain(
llm=_refine_llm,
prompt=_refine_prompt,
verbose=verbose, # type: ignore[arg-type]
verbose=verbose,
callback_manager=callback_manager,
callbacks=callbacks,
)
@ -218,7 +218,7 @@ def _load_refine_chain(
refine_llm_chain=refine_chain,
document_variable_name=document_variable_name,
initial_response_name=initial_response_name,
verbose=verbose, # type: ignore[arg-type]
verbose=verbose,
callback_manager=callback_manager,
callbacks=callbacks,
**kwargs,

View File

@ -35,12 +35,12 @@ def _load_stuff_chain(
verbose: Optional[bool] = None,
**kwargs: Any,
) -> StuffDocumentsChain:
llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose) # type: ignore[arg-type]
llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
# TODO: document prompt
return StuffDocumentsChain(
llm_chain=llm_chain,
document_variable_name=document_variable_name,
verbose=verbose, # type: ignore[arg-type]
verbose=verbose,
**kwargs,
)
@ -64,21 +64,21 @@ def _load_map_reduce_chain(
map_chain = LLMChain(
llm=llm,
prompt=map_prompt,
verbose=verbose, # type: ignore[arg-type]
verbose=verbose,
callbacks=callbacks,
)
_reduce_llm = reduce_llm or llm
reduce_chain = LLMChain(
llm=_reduce_llm,
prompt=combine_prompt,
verbose=verbose, # type: ignore[arg-type]
verbose=verbose,
callbacks=callbacks,
)
# TODO: document prompt
combine_documents_chain = StuffDocumentsChain(
llm_chain=reduce_chain,
document_variable_name=combine_document_variable_name,
verbose=verbose, # type: ignore[arg-type]
verbose=verbose,
callbacks=callbacks,
)
if collapse_prompt is None:
@ -95,7 +95,7 @@ def _load_map_reduce_chain(
llm_chain=LLMChain(
llm=_collapse_llm,
prompt=collapse_prompt,
verbose=verbose, # type: ignore[arg-type]
verbose=verbose,
callbacks=callbacks,
),
document_variable_name=combine_document_variable_name,
@ -104,7 +104,7 @@ def _load_map_reduce_chain(
combine_documents_chain=combine_documents_chain,
collapse_documents_chain=collapse_chain,
token_max=token_max,
verbose=verbose, # type: ignore[arg-type]
verbose=verbose,
callbacks=callbacks,
collapse_max_retries=collapse_max_retries,
)
@ -112,7 +112,7 @@ def _load_map_reduce_chain(
llm_chain=map_chain,
reduce_documents_chain=reduce_documents_chain,
document_variable_name=map_reduce_document_variable_name,
verbose=verbose, # type: ignore[arg-type]
verbose=verbose,
callbacks=callbacks,
**kwargs,
)
@ -129,15 +129,15 @@ def _load_refine_chain(
verbose: Optional[bool] = None,
**kwargs: Any,
) -> RefineDocumentsChain:
initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose) # type: ignore[arg-type]
initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
_refine_llm = refine_llm or llm
refine_chain = LLMChain(llm=_refine_llm, prompt=refine_prompt, verbose=verbose) # type: ignore[arg-type]
refine_chain = LLMChain(llm=_refine_llm, prompt=refine_prompt, verbose=verbose)
return RefineDocumentsChain(
initial_llm_chain=initial_chain,
refine_llm_chain=refine_chain,
document_variable_name=document_variable_name,
initial_response_name=initial_response_name,
verbose=verbose, # type: ignore[arg-type]
verbose=verbose,
**kwargs,
)

View File

@ -250,7 +250,7 @@ The following is the expected answer. Use this to measure correctness:
prompt = EVAL_CHAT_PROMPT if agent_tools else TOOL_FREE_EVAL_CHAT_PROMPT
eval_chain = LLMChain(llm=llm, prompt=prompt)
return cls(
agent_tools=agent_tools, # type: ignore[arg-type]
agent_tools=agent_tools,
eval_chain=eval_chain,
output_parser=output_parser or TrajectoryOutputParser(),
**kwargs,

View File

@ -120,4 +120,4 @@ class LLMChainExtractor(BaseDocumentCompressor):
else:
parser = StrOutputParser()
llm_chain = _prompt | llm | parser
return cls(llm_chain=llm_chain, get_input=_get_input) # type: ignore[arg-type]
return cls(llm_chain=llm_chain, get_input=_get_input)

View File

@ -407,7 +407,7 @@ class SelfQueryRetriever(BaseRetriever):
query_constructor = query_constructor.with_config(
run_name=QUERY_CONSTRUCTOR_RUN_NAME,
)
return cls( # type: ignore[call-arg]
return cls(
query_constructor=query_constructor,
vectorstore=vectorstore,
use_original_query=use_original_query,

View File

@ -4,7 +4,7 @@ from langchain_core.runnables.base import RunnableBindingBase
from langchain_core.runnables.utils import Input, Output
class HubRunnable(RunnableBindingBase[Input, Output]):
class HubRunnable(RunnableBindingBase[Input, Output]): # type: ignore[no-redef]
"""
An instance of a runnable stored in the LangChain Hub.
"""

View File

@ -20,7 +20,7 @@ class OpenAIFunction(TypedDict):
"""The parameters to the function."""
class OpenAIFunctionsRouter(RunnableBindingBase[BaseMessage, Any]):
class OpenAIFunctionsRouter(RunnableBindingBase[BaseMessage, Any]): # type: ignore[no-redef]
"""A runnable that routes to the selected function."""
functions: Optional[list[OpenAIFunction]]

View File

@ -124,6 +124,7 @@ target-version = "py39"
exclude = ["tests/integration_tests/examples/non-utf8-encoding.py"]
[tool.mypy]
plugins = ["pydantic.mypy"]
strict = "True"
strict_bytes = "True"
ignore_missing_imports = "True"

View File

@ -15,7 +15,7 @@ async def test_simplea() -> None:
answer = "I know the answer!"
llm = FakeListLLM(responses=[answer])
retriever = SequentialRetriever(sequential_responses=[[]])
memory = ConversationBufferMemory( # type: ignore[call-arg]
memory = ConversationBufferMemory(
k=1,
output_key="answer",
memory_key="chat_history",
@ -42,7 +42,7 @@ async def test_fixed_message_response_when_docs_founda() -> None:
retriever = SequentialRetriever(
sequential_responses=[[Document(page_content=answer)]],
)
memory = ConversationBufferMemory( # type: ignore[call-arg]
memory = ConversationBufferMemory(
k=1,
output_key="answer",
memory_key="chat_history",
@ -67,7 +67,7 @@ def test_fixed_message_response_when_no_docs_found() -> None:
answer = "I know the answer!"
llm = FakeListLLM(responses=[answer])
retriever = SequentialRetriever(sequential_responses=[[]])
memory = ConversationBufferMemory( # type: ignore[call-arg]
memory = ConversationBufferMemory(
k=1,
output_key="answer",
memory_key="chat_history",
@ -94,7 +94,7 @@ def test_fixed_message_response_when_docs_found() -> None:
retriever = SequentialRetriever(
sequential_responses=[[Document(page_content=answer)]],
)
memory = ConversationBufferMemory( # type: ignore[call-arg]
memory = ConversationBufferMemory(
k=1,
output_key="answer",
memory_key="chat_history",

View File

@ -16,7 +16,7 @@ def dummy_transform(inputs: dict[str, str]) -> dict[str, str]:
def test_transform_chain() -> None:
"""Test basic transform chain."""
transform_chain = TransformChain( # type: ignore[call-arg]
transform_chain = TransformChain(
input_variables=["first_name", "last_name"],
output_variables=["greeting"],
transform=dummy_transform,
@ -29,7 +29,7 @@ def test_transform_chain() -> None:
def test_transform_chain_bad_inputs() -> None:
"""Test basic transform chain."""
transform_chain = TransformChain( # type: ignore[call-arg]
transform_chain = TransformChain(
input_variables=["first_name", "last_name"],
output_variables=["greeting"],
transform=dummy_transform,

View File

@ -116,9 +116,9 @@ class TestClass(Serializable):
def test_aliases_hidden() -> None:
test_class = TestClass(
my_favorite_secret="hello", # noqa: S106 # type: ignore[call-arg]
my_favorite_secret="hello", # noqa: S106
my_other_secret="world", # noqa: S106
) # type: ignore[call-arg]
)
dumped = json.loads(dumps(test_class, pretty=True))
expected_dump = {
"lc": 1,
@ -143,7 +143,7 @@ def test_aliases_hidden() -> None:
dumped = json.loads(dumps(test_class, pretty=True))
# Check by alias
test_class = TestClass(
test_class = TestClass( # type: ignore[call-arg]
my_favorite_secret_alias="hello", # noqa: S106
my_other_secret="parrot party", # noqa: S106
)

View File

@ -162,7 +162,7 @@ def test_load_llmchain_with_non_serializable_arg() -> None:
import httpx
from langchain_openai import OpenAI
llm = OpenAI( # type: ignore[call-arg]
llm = OpenAI(
model="davinci",
temperature=0.5,
openai_api_key="hello",

View File

@ -37,7 +37,7 @@ class SuccessfulParseAfterRetries(BaseOutputParser[str]):
def test_retry_output_parser_parse_with_prompt() -> None:
n: int = 5 # Success on the (n+1)-th attempt
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
parser = RetryOutputParser(
parser = RetryOutputParser[SuccessfulParseAfterRetries](
parser=base_parser,
retry_chain=RunnablePassthrough(),
max_retries=n, # n times to retry, that is, (n+1) times call
@ -51,7 +51,7 @@ def test_retry_output_parser_parse_with_prompt() -> None:
def test_retry_output_parser_parse_with_prompt_fail() -> None:
n: int = 5 # Success on the (n+1)-th attempt
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
parser = RetryOutputParser(
parser = RetryOutputParser[SuccessfulParseAfterRetries](
parser=base_parser,
retry_chain=RunnablePassthrough(),
max_retries=n - 1, # n-1 times to retry, that is, n times call
@ -65,7 +65,7 @@ def test_retry_output_parser_parse_with_prompt_fail() -> None:
async def test_retry_output_parser_aparse_with_prompt() -> None:
n: int = 5 # Success on the (n+1)-th attempt
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
parser = RetryOutputParser(
parser = RetryOutputParser[SuccessfulParseAfterRetries](
parser=base_parser,
retry_chain=RunnablePassthrough(),
max_retries=n, # n times to retry, that is, (n+1) times call
@ -82,7 +82,7 @@ async def test_retry_output_parser_aparse_with_prompt() -> None:
async def test_retry_output_parser_aparse_with_prompt_fail() -> None:
n: int = 5 # Success on the (n+1)-th attempt
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
parser = RetryOutputParser(
parser = RetryOutputParser[SuccessfulParseAfterRetries](
parser=base_parser,
retry_chain=RunnablePassthrough(),
max_retries=n - 1, # n-1 times to retry, that is, n times call
@ -101,7 +101,7 @@ async def test_retry_output_parser_aparse_with_prompt_fail() -> None:
],
)
def test_retry_output_parser_output_type(base_parser: BaseOutputParser) -> None:
parser = RetryOutputParser(
parser = RetryOutputParser[BaseOutputParser](
parser=base_parser,
retry_chain=RunnablePassthrough(),
legacy=False,
@ -110,7 +110,7 @@ def test_retry_output_parser_output_type(base_parser: BaseOutputParser) -> None:
def test_retry_output_parser_parse_is_not_implemented() -> None:
parser = RetryOutputParser(
parser = RetryOutputParser[BooleanOutputParser](
parser=BooleanOutputParser(),
retry_chain=RunnablePassthrough(),
legacy=False,
@ -122,7 +122,7 @@ def test_retry_output_parser_parse_is_not_implemented() -> None:
def test_retry_with_error_output_parser_parse_with_prompt() -> None:
n: int = 5 # Success on the (n+1)-th attempt
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
parser = RetryWithErrorOutputParser(
parser = RetryWithErrorOutputParser[SuccessfulParseAfterRetries](
parser=base_parser,
retry_chain=RunnablePassthrough(),
max_retries=n, # n times to retry, that is, (n+1) times call
@ -136,7 +136,7 @@ def test_retry_with_error_output_parser_parse_with_prompt() -> None:
def test_retry_with_error_output_parser_parse_with_prompt_fail() -> None:
n: int = 5 # Success on the (n+1)-th attempt
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
parser = RetryWithErrorOutputParser(
parser = RetryWithErrorOutputParser[SuccessfulParseAfterRetries](
parser=base_parser,
retry_chain=RunnablePassthrough(),
max_retries=n - 1, # n-1 times to retry, that is, n times call
@ -150,7 +150,7 @@ def test_retry_with_error_output_parser_parse_with_prompt_fail() -> None:
async def test_retry_with_error_output_parser_aparse_with_prompt() -> None:
n: int = 5 # Success on the (n+1)-th attempt
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
parser = RetryWithErrorOutputParser(
parser = RetryWithErrorOutputParser[SuccessfulParseAfterRetries](
parser=base_parser,
retry_chain=RunnablePassthrough(),
max_retries=n, # n times to retry, that is, (n+1) times call
@ -167,7 +167,7 @@ async def test_retry_with_error_output_parser_aparse_with_prompt() -> None:
async def test_retry_with_error_output_parser_aparse_with_prompt_fail() -> None:
n: int = 5 # Success on the (n+1)-th attempt
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
parser = RetryWithErrorOutputParser(
parser = RetryWithErrorOutputParser[SuccessfulParseAfterRetries](
parser=base_parser,
retry_chain=RunnablePassthrough(),
max_retries=n - 1, # n-1 times to retry, that is, n times call
@ -188,7 +188,7 @@ async def test_retry_with_error_output_parser_aparse_with_prompt_fail() -> None:
def test_retry_with_error_output_parser_output_type(
base_parser: BaseOutputParser,
) -> None:
parser = RetryWithErrorOutputParser(
parser = RetryWithErrorOutputParser[BaseOutputParser](
parser=base_parser,
retry_chain=RunnablePassthrough(),
legacy=False,
@ -197,7 +197,7 @@ def test_retry_with_error_output_parser_output_type(
def test_retry_with_error_output_parser_parse_is_not_implemented() -> None:
parser = RetryWithErrorOutputParser(
parser = RetryWithErrorOutputParser[BooleanOutputParser](
parser=BooleanOutputParser(),
retry_chain=RunnablePassthrough(),
legacy=False,
@ -222,11 +222,11 @@ def test_retry_with_error_output_parser_parse_is_not_implemented() -> None:
def test_retry_output_parser_parse_with_prompt_with_retry_chain(
completion: str,
prompt: PromptValue,
base_parser: BaseOutputParser[T],
base_parser: DatetimeOutputParser,
retry_chain: Runnable[dict[str, Any], str],
expected: T,
expected: dt,
) -> None:
parser = RetryOutputParser(
parser = RetryOutputParser[DatetimeOutputParser](
parser=base_parser,
retry_chain=retry_chain,
legacy=False,
@ -250,12 +250,12 @@ def test_retry_output_parser_parse_with_prompt_with_retry_chain(
async def test_retry_output_parser_aparse_with_prompt_with_retry_chain(
completion: str,
prompt: PromptValue,
base_parser: BaseOutputParser[T],
base_parser: DatetimeOutputParser,
retry_chain: Runnable[dict[str, Any], str],
expected: T,
expected: dt,
) -> None:
# test
parser = RetryOutputParser(
parser = RetryOutputParser[DatetimeOutputParser](
parser=base_parser,
retry_chain=retry_chain,
legacy=False,
@ -279,12 +279,12 @@ async def test_retry_output_parser_aparse_with_prompt_with_retry_chain(
def test_retry_with_error_output_parser_parse_with_prompt_with_retry_chain(
completion: str,
prompt: PromptValue,
base_parser: BaseOutputParser[T],
base_parser: DatetimeOutputParser,
retry_chain: Runnable[dict[str, Any], str],
expected: T,
expected: dt,
) -> None:
# test
parser = RetryWithErrorOutputParser(
parser = RetryWithErrorOutputParser[DatetimeOutputParser](
parser=base_parser,
retry_chain=retry_chain,
legacy=False,
@ -308,11 +308,11 @@ def test_retry_with_error_output_parser_parse_with_prompt_with_retry_chain(
async def test_retry_with_error_output_parser_aparse_with_prompt_with_retry_chain(
completion: str,
prompt: PromptValue,
base_parser: BaseOutputParser[T],
base_parser: DatetimeOutputParser,
retry_chain: Runnable[dict[str, Any], str],
expected: T,
expected: dt,
) -> None:
parser = RetryWithErrorOutputParser(
parser = RetryWithErrorOutputParser[DatetimeOutputParser](
parser=base_parser,
retry_chain=retry_chain,
legacy=False,

View File

@ -95,5 +95,5 @@ def test_yaml_output_parser_fail() -> None:
def test_yaml_output_parser_output_type() -> None:
"""Test YamlOutputParser OutputType."""
yaml_parser = YamlOutputParser(pydantic_object=TestModel)
yaml_parser = YamlOutputParser[TestModel](pydantic_object=TestModel)
assert yaml_parser.OutputType is TestModel

View File

@ -8,6 +8,6 @@ def test__list_rerank_init() -> None:
from langchain_openai import ChatOpenAI
LLMListwiseRerank.from_llm(
llm=ChatOpenAI(api_key="foo"), # type: ignore[arg-type]
llm=ChatOpenAI(api_key="foo"),
top_n=10,
)

View File

@ -43,7 +43,7 @@ class InMemoryVectorstoreWithSearch(InMemoryVectorStore):
def test_multi_vector_retriever_initialization() -> None:
vectorstore = InMemoryVectorstoreWithSearch()
retriever = MultiVectorRetriever( # type: ignore[call-arg]
retriever = MultiVectorRetriever(
vectorstore=vectorstore,
docstore=InMemoryStore(),
doc_id="doc_id",
@ -58,7 +58,7 @@ def test_multi_vector_retriever_initialization() -> None:
async def test_multi_vector_retriever_initialization_async() -> None:
vectorstore = InMemoryVectorstoreWithSearch()
retriever = MultiVectorRetriever( # type: ignore[call-arg]
retriever = MultiVectorRetriever(
vectorstore=vectorstore,
docstore=InMemoryStore(),
doc_id="doc_id",
@ -77,7 +77,7 @@ def test_multi_vector_retriever_similarity_search_with_score() -> None:
vectorstore.add_documents(documents, ids=["1"])
# score_threshold = 0.5
retriever = MultiVectorRetriever( # type: ignore[call-arg]
retriever = MultiVectorRetriever(
vectorstore=vectorstore,
docstore=InMemoryStore(),
doc_id="doc_id",
@ -90,7 +90,7 @@ def test_multi_vector_retriever_similarity_search_with_score() -> None:
assert results[0].page_content == "test document"
# score_threshold = 0.9
retriever = MultiVectorRetriever( # type: ignore[call-arg]
retriever = MultiVectorRetriever(
vectorstore=vectorstore,
docstore=InMemoryStore(),
doc_id="doc_id",
@ -108,7 +108,7 @@ async def test_multi_vector_retriever_similarity_search_with_score_async() -> No
await vectorstore.aadd_documents(documents, ids=["1"])
# score_threshold = 0.5
retriever = MultiVectorRetriever( # type: ignore[call-arg]
retriever = MultiVectorRetriever(
vectorstore=vectorstore,
docstore=InMemoryStore(),
doc_id="doc_id",
@ -121,7 +121,7 @@ async def test_multi_vector_retriever_similarity_search_with_score_async() -> No
assert results[0].page_content == "test document"
# score_threshold = 0.9
retriever = MultiVectorRetriever( # type: ignore[call-arg]
retriever = MultiVectorRetriever(
vectorstore=vectorstore,
docstore=InMemoryStore(),
doc_id="doc_id",

View File

@ -171,7 +171,7 @@ def test_run_llm_or_chain_with_input_mapper() -> None:
assert "the right input" in inputs
return {"output": "2"}
mock_chain = TransformChain( # type: ignore[call-arg]
mock_chain = TransformChain(
input_variables=["the right input"],
output_variables=["output"],
transform=run_val,