mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-22 19:08:40 +00:00
chore(langchain): add mypy pydantic plugin (#32610)
This commit is contained in:
parent
73a7de63aa
commit
f896bcdb1d
@ -277,7 +277,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
|
||||
extra_prompt_messages=extra_prompt_messages,
|
||||
system_message=system_message_,
|
||||
)
|
||||
return cls( # type: ignore[call-arg]
|
||||
return cls(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
tools=tools,
|
||||
|
@ -328,7 +328,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
|
||||
extra_prompt_messages=extra_prompt_messages,
|
||||
system_message=system_message_,
|
||||
)
|
||||
return cls( # type: ignore[call-arg]
|
||||
return cls(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
tools=tools,
|
||||
|
@ -56,7 +56,7 @@ def _load_question_to_checked_assertions_chain(
|
||||
revised_answer_chain,
|
||||
]
|
||||
return SequentialChain(
|
||||
chains=chains, # type: ignore[arg-type]
|
||||
chains=chains,
|
||||
input_variables=["question"],
|
||||
output_variables=["revised_statement"],
|
||||
verbose=True,
|
||||
|
@ -169,7 +169,7 @@ def _load_map_reduce_documents_chain(
|
||||
|
||||
return MapReduceDocumentsChain(
|
||||
llm_chain=llm_chain,
|
||||
reduce_documents_chain=reduce_documents_chain, # type: ignore[arg-type]
|
||||
reduce_documents_chain=reduce_documents_chain,
|
||||
**config,
|
||||
)
|
||||
|
||||
@ -293,10 +293,10 @@ def _load_llm_checker_chain(config: dict, **kwargs: Any) -> LLMCheckerChain:
|
||||
revised_answer_prompt = load_prompt(config.pop("revised_answer_prompt_path"))
|
||||
return LLMCheckerChain(
|
||||
llm=llm,
|
||||
create_draft_answer_prompt=create_draft_answer_prompt, # type: ignore[arg-type]
|
||||
list_assertions_prompt=list_assertions_prompt, # type: ignore[arg-type]
|
||||
check_assertions_prompt=check_assertions_prompt, # type: ignore[arg-type]
|
||||
revised_answer_prompt=revised_answer_prompt, # type: ignore[arg-type]
|
||||
create_draft_answer_prompt=create_draft_answer_prompt,
|
||||
list_assertions_prompt=list_assertions_prompt,
|
||||
check_assertions_prompt=check_assertions_prompt,
|
||||
revised_answer_prompt=revised_answer_prompt,
|
||||
**config,
|
||||
)
|
||||
|
||||
@ -325,7 +325,7 @@ def _load_llm_math_chain(config: dict, **kwargs: Any) -> LLMMathChain:
|
||||
elif "prompt_path" in config:
|
||||
prompt = load_prompt(config.pop("prompt_path"))
|
||||
if llm_chain:
|
||||
return LLMMathChain(llm_chain=llm_chain, prompt=prompt, **config) # type: ignore[arg-type]
|
||||
return LLMMathChain(llm_chain=llm_chain, prompt=prompt, **config)
|
||||
return LLMMathChain(llm=llm, prompt=prompt, **config)
|
||||
|
||||
|
||||
@ -341,7 +341,7 @@ def _load_map_rerank_documents_chain(
|
||||
else:
|
||||
msg = "One of `llm_chain` or `llm_chain_path` must be present."
|
||||
raise ValueError(msg)
|
||||
return MapRerankDocumentsChain(llm_chain=llm_chain, **config) # type: ignore[arg-type]
|
||||
return MapRerankDocumentsChain(llm_chain=llm_chain, **config)
|
||||
|
||||
|
||||
def _load_pal_chain(config: dict, **kwargs: Any) -> Any:
|
||||
@ -377,8 +377,8 @@ def _load_refine_documents_chain(config: dict, **kwargs: Any) -> RefineDocuments
|
||||
elif "document_prompt_path" in config:
|
||||
document_prompt = load_prompt(config.pop("document_prompt_path"))
|
||||
return RefineDocumentsChain(
|
||||
initial_llm_chain=initial_llm_chain, # type: ignore[arg-type]
|
||||
refine_llm_chain=refine_llm_chain, # type: ignore[arg-type]
|
||||
initial_llm_chain=initial_llm_chain,
|
||||
refine_llm_chain=refine_llm_chain,
|
||||
document_prompt=document_prompt,
|
||||
**config,
|
||||
)
|
||||
@ -402,7 +402,7 @@ def _load_qa_with_sources_chain(config: dict, **kwargs: Any) -> QAWithSourcesCha
|
||||
"`combine_documents_chain_path` must be present."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
return QAWithSourcesChain(combine_documents_chain=combine_documents_chain, **config) # type: ignore[arg-type]
|
||||
return QAWithSourcesChain(combine_documents_chain=combine_documents_chain, **config)
|
||||
|
||||
|
||||
def _load_sql_database_chain(config: dict, **kwargs: Any) -> Any:
|
||||
@ -445,7 +445,7 @@ def _load_vector_db_qa_with_sources_chain(
|
||||
)
|
||||
raise ValueError(msg)
|
||||
return VectorDBQAWithSourcesChain(
|
||||
combine_documents_chain=combine_documents_chain, # type: ignore[arg-type]
|
||||
combine_documents_chain=combine_documents_chain,
|
||||
vectorstore=vectorstore,
|
||||
**config,
|
||||
)
|
||||
@ -475,7 +475,7 @@ def _load_retrieval_qa(config: dict, **kwargs: Any) -> RetrievalQA:
|
||||
)
|
||||
raise ValueError(msg)
|
||||
return RetrievalQA(
|
||||
combine_documents_chain=combine_documents_chain, # type: ignore[arg-type]
|
||||
combine_documents_chain=combine_documents_chain,
|
||||
retriever=retriever,
|
||||
**config,
|
||||
)
|
||||
@ -508,7 +508,7 @@ def _load_retrieval_qa_with_sources_chain(
|
||||
)
|
||||
raise ValueError(msg)
|
||||
return RetrievalQAWithSourcesChain(
|
||||
combine_documents_chain=combine_documents_chain, # type: ignore[arg-type]
|
||||
combine_documents_chain=combine_documents_chain,
|
||||
retriever=retriever,
|
||||
**config,
|
||||
)
|
||||
@ -538,7 +538,7 @@ def _load_vector_db_qa(config: dict, **kwargs: Any) -> VectorDBQA:
|
||||
)
|
||||
raise ValueError(msg)
|
||||
return VectorDBQA(
|
||||
combine_documents_chain=combine_documents_chain, # type: ignore[arg-type]
|
||||
combine_documents_chain=combine_documents_chain,
|
||||
vectorstore=vectorstore,
|
||||
**config,
|
||||
)
|
||||
@ -606,8 +606,8 @@ def _load_api_chain(config: dict, **kwargs: Any) -> APIChain:
|
||||
msg = "`requests_wrapper` must be present."
|
||||
raise ValueError(msg)
|
||||
return APIChain(
|
||||
api_request_chain=api_request_chain, # type: ignore[arg-type]
|
||||
api_answer_chain=api_answer_chain, # type: ignore[arg-type]
|
||||
api_request_chain=api_request_chain,
|
||||
api_answer_chain=api_answer_chain,
|
||||
requests_wrapper=requests_wrapper,
|
||||
**config,
|
||||
)
|
||||
|
@ -66,12 +66,12 @@ def _load_stuff_chain(
|
||||
verbose: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> StuffDocumentsChain:
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose) # type: ignore[arg-type]
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
|
||||
return StuffDocumentsChain(
|
||||
llm_chain=llm_chain,
|
||||
document_variable_name=document_variable_name,
|
||||
document_prompt=document_prompt,
|
||||
verbose=verbose, # type: ignore[arg-type]
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -91,14 +91,14 @@ def _load_map_reduce_chain(
|
||||
token_max: int = 3000,
|
||||
**kwargs: Any,
|
||||
) -> MapReduceDocumentsChain:
|
||||
map_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose) # type: ignore[arg-type]
|
||||
map_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
|
||||
_reduce_llm = reduce_llm or llm
|
||||
reduce_chain = LLMChain(llm=_reduce_llm, prompt=combine_prompt, verbose=verbose) # type: ignore[arg-type]
|
||||
reduce_chain = LLMChain(llm=_reduce_llm, prompt=combine_prompt, verbose=verbose)
|
||||
combine_documents_chain = StuffDocumentsChain(
|
||||
llm_chain=reduce_chain,
|
||||
document_variable_name=combine_document_variable_name,
|
||||
document_prompt=document_prompt,
|
||||
verbose=verbose, # type: ignore[arg-type]
|
||||
verbose=verbose,
|
||||
)
|
||||
if collapse_prompt is None:
|
||||
collapse_chain = None
|
||||
@ -114,7 +114,7 @@ def _load_map_reduce_chain(
|
||||
llm_chain=LLMChain(
|
||||
llm=_collapse_llm,
|
||||
prompt=collapse_prompt,
|
||||
verbose=verbose, # type: ignore[arg-type]
|
||||
verbose=verbose,
|
||||
),
|
||||
document_variable_name=combine_document_variable_name,
|
||||
document_prompt=document_prompt,
|
||||
@ -123,13 +123,13 @@ def _load_map_reduce_chain(
|
||||
combine_documents_chain=combine_documents_chain,
|
||||
collapse_documents_chain=collapse_chain,
|
||||
token_max=token_max,
|
||||
verbose=verbose, # type: ignore[arg-type]
|
||||
verbose=verbose,
|
||||
)
|
||||
return MapReduceDocumentsChain(
|
||||
llm_chain=map_chain,
|
||||
reduce_documents_chain=reduce_documents_chain,
|
||||
document_variable_name=map_reduce_document_variable_name,
|
||||
verbose=verbose, # type: ignore[arg-type]
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -146,16 +146,16 @@ def _load_refine_chain(
|
||||
verbose: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> RefineDocumentsChain:
|
||||
initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose) # type: ignore[arg-type]
|
||||
initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
|
||||
_refine_llm = refine_llm or llm
|
||||
refine_chain = LLMChain(llm=_refine_llm, prompt=refine_prompt, verbose=verbose) # type: ignore[arg-type]
|
||||
refine_chain = LLMChain(llm=_refine_llm, prompt=refine_prompt, verbose=verbose)
|
||||
return RefineDocumentsChain(
|
||||
initial_llm_chain=initial_chain,
|
||||
refine_llm_chain=refine_chain,
|
||||
document_variable_name=document_variable_name,
|
||||
initial_response_name=initial_response_name,
|
||||
document_prompt=document_prompt,
|
||||
verbose=verbose, # type: ignore[arg-type]
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
@ -80,7 +80,7 @@ def _load_stuff_chain(
|
||||
llm_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=_prompt,
|
||||
verbose=verbose, # type: ignore[arg-type]
|
||||
verbose=verbose,
|
||||
callback_manager=callback_manager,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
@ -88,7 +88,7 @@ def _load_stuff_chain(
|
||||
return StuffDocumentsChain(
|
||||
llm_chain=llm_chain,
|
||||
document_variable_name=document_variable_name,
|
||||
verbose=verbose, # type: ignore[arg-type]
|
||||
verbose=verbose,
|
||||
callback_manager=callback_manager,
|
||||
callbacks=callbacks,
|
||||
**kwargs,
|
||||
@ -120,7 +120,7 @@ def _load_map_reduce_chain(
|
||||
map_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=_question_prompt,
|
||||
verbose=verbose, # type: ignore[arg-type]
|
||||
verbose=verbose,
|
||||
callback_manager=callback_manager,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
@ -128,7 +128,7 @@ def _load_map_reduce_chain(
|
||||
reduce_chain = LLMChain(
|
||||
llm=_reduce_llm,
|
||||
prompt=_combine_prompt,
|
||||
verbose=verbose, # type: ignore[arg-type]
|
||||
verbose=verbose,
|
||||
callback_manager=callback_manager,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
@ -136,7 +136,7 @@ def _load_map_reduce_chain(
|
||||
combine_documents_chain = StuffDocumentsChain(
|
||||
llm_chain=reduce_chain,
|
||||
document_variable_name=combine_document_variable_name,
|
||||
verbose=verbose, # type: ignore[arg-type]
|
||||
verbose=verbose,
|
||||
callback_manager=callback_manager,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
@ -154,12 +154,12 @@ def _load_map_reduce_chain(
|
||||
llm_chain=LLMChain(
|
||||
llm=_collapse_llm,
|
||||
prompt=collapse_prompt,
|
||||
verbose=verbose, # type: ignore[arg-type]
|
||||
verbose=verbose,
|
||||
callback_manager=callback_manager,
|
||||
callbacks=callbacks,
|
||||
),
|
||||
document_variable_name=combine_document_variable_name,
|
||||
verbose=verbose, # type: ignore[arg-type]
|
||||
verbose=verbose,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
reduce_documents_chain = ReduceDocumentsChain(
|
||||
@ -172,7 +172,7 @@ def _load_map_reduce_chain(
|
||||
llm_chain=map_chain,
|
||||
document_variable_name=map_reduce_document_variable_name,
|
||||
reduce_documents_chain=reduce_documents_chain,
|
||||
verbose=verbose, # type: ignore[arg-type]
|
||||
verbose=verbose,
|
||||
callback_manager=callback_manager,
|
||||
callbacks=callbacks,
|
||||
**kwargs,
|
||||
@ -201,7 +201,7 @@ def _load_refine_chain(
|
||||
initial_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=_question_prompt,
|
||||
verbose=verbose, # type: ignore[arg-type]
|
||||
verbose=verbose,
|
||||
callback_manager=callback_manager,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
@ -209,7 +209,7 @@ def _load_refine_chain(
|
||||
refine_chain = LLMChain(
|
||||
llm=_refine_llm,
|
||||
prompt=_refine_prompt,
|
||||
verbose=verbose, # type: ignore[arg-type]
|
||||
verbose=verbose,
|
||||
callback_manager=callback_manager,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
@ -218,7 +218,7 @@ def _load_refine_chain(
|
||||
refine_llm_chain=refine_chain,
|
||||
document_variable_name=document_variable_name,
|
||||
initial_response_name=initial_response_name,
|
||||
verbose=verbose, # type: ignore[arg-type]
|
||||
verbose=verbose,
|
||||
callback_manager=callback_manager,
|
||||
callbacks=callbacks,
|
||||
**kwargs,
|
||||
|
@ -35,12 +35,12 @@ def _load_stuff_chain(
|
||||
verbose: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> StuffDocumentsChain:
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose) # type: ignore[arg-type]
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
|
||||
# TODO: document prompt
|
||||
return StuffDocumentsChain(
|
||||
llm_chain=llm_chain,
|
||||
document_variable_name=document_variable_name,
|
||||
verbose=verbose, # type: ignore[arg-type]
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -64,21 +64,21 @@ def _load_map_reduce_chain(
|
||||
map_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=map_prompt,
|
||||
verbose=verbose, # type: ignore[arg-type]
|
||||
verbose=verbose,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
_reduce_llm = reduce_llm or llm
|
||||
reduce_chain = LLMChain(
|
||||
llm=_reduce_llm,
|
||||
prompt=combine_prompt,
|
||||
verbose=verbose, # type: ignore[arg-type]
|
||||
verbose=verbose,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
# TODO: document prompt
|
||||
combine_documents_chain = StuffDocumentsChain(
|
||||
llm_chain=reduce_chain,
|
||||
document_variable_name=combine_document_variable_name,
|
||||
verbose=verbose, # type: ignore[arg-type]
|
||||
verbose=verbose,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
if collapse_prompt is None:
|
||||
@ -95,7 +95,7 @@ def _load_map_reduce_chain(
|
||||
llm_chain=LLMChain(
|
||||
llm=_collapse_llm,
|
||||
prompt=collapse_prompt,
|
||||
verbose=verbose, # type: ignore[arg-type]
|
||||
verbose=verbose,
|
||||
callbacks=callbacks,
|
||||
),
|
||||
document_variable_name=combine_document_variable_name,
|
||||
@ -104,7 +104,7 @@ def _load_map_reduce_chain(
|
||||
combine_documents_chain=combine_documents_chain,
|
||||
collapse_documents_chain=collapse_chain,
|
||||
token_max=token_max,
|
||||
verbose=verbose, # type: ignore[arg-type]
|
||||
verbose=verbose,
|
||||
callbacks=callbacks,
|
||||
collapse_max_retries=collapse_max_retries,
|
||||
)
|
||||
@ -112,7 +112,7 @@ def _load_map_reduce_chain(
|
||||
llm_chain=map_chain,
|
||||
reduce_documents_chain=reduce_documents_chain,
|
||||
document_variable_name=map_reduce_document_variable_name,
|
||||
verbose=verbose, # type: ignore[arg-type]
|
||||
verbose=verbose,
|
||||
callbacks=callbacks,
|
||||
**kwargs,
|
||||
)
|
||||
@ -129,15 +129,15 @@ def _load_refine_chain(
|
||||
verbose: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> RefineDocumentsChain:
|
||||
initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose) # type: ignore[arg-type]
|
||||
initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
|
||||
_refine_llm = refine_llm or llm
|
||||
refine_chain = LLMChain(llm=_refine_llm, prompt=refine_prompt, verbose=verbose) # type: ignore[arg-type]
|
||||
refine_chain = LLMChain(llm=_refine_llm, prompt=refine_prompt, verbose=verbose)
|
||||
return RefineDocumentsChain(
|
||||
initial_llm_chain=initial_chain,
|
||||
refine_llm_chain=refine_chain,
|
||||
document_variable_name=document_variable_name,
|
||||
initial_response_name=initial_response_name,
|
||||
verbose=verbose, # type: ignore[arg-type]
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
@ -250,7 +250,7 @@ The following is the expected answer. Use this to measure correctness:
|
||||
prompt = EVAL_CHAT_PROMPT if agent_tools else TOOL_FREE_EVAL_CHAT_PROMPT
|
||||
eval_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
return cls(
|
||||
agent_tools=agent_tools, # type: ignore[arg-type]
|
||||
agent_tools=agent_tools,
|
||||
eval_chain=eval_chain,
|
||||
output_parser=output_parser or TrajectoryOutputParser(),
|
||||
**kwargs,
|
||||
|
@ -120,4 +120,4 @@ class LLMChainExtractor(BaseDocumentCompressor):
|
||||
else:
|
||||
parser = StrOutputParser()
|
||||
llm_chain = _prompt | llm | parser
|
||||
return cls(llm_chain=llm_chain, get_input=_get_input) # type: ignore[arg-type]
|
||||
return cls(llm_chain=llm_chain, get_input=_get_input)
|
||||
|
@ -407,7 +407,7 @@ class SelfQueryRetriever(BaseRetriever):
|
||||
query_constructor = query_constructor.with_config(
|
||||
run_name=QUERY_CONSTRUCTOR_RUN_NAME,
|
||||
)
|
||||
return cls( # type: ignore[call-arg]
|
||||
return cls(
|
||||
query_constructor=query_constructor,
|
||||
vectorstore=vectorstore,
|
||||
use_original_query=use_original_query,
|
||||
|
@ -4,7 +4,7 @@ from langchain_core.runnables.base import RunnableBindingBase
|
||||
from langchain_core.runnables.utils import Input, Output
|
||||
|
||||
|
||||
class HubRunnable(RunnableBindingBase[Input, Output]):
|
||||
class HubRunnable(RunnableBindingBase[Input, Output]): # type: ignore[no-redef]
|
||||
"""
|
||||
An instance of a runnable stored in the LangChain Hub.
|
||||
"""
|
||||
|
@ -20,7 +20,7 @@ class OpenAIFunction(TypedDict):
|
||||
"""The parameters to the function."""
|
||||
|
||||
|
||||
class OpenAIFunctionsRouter(RunnableBindingBase[BaseMessage, Any]):
|
||||
class OpenAIFunctionsRouter(RunnableBindingBase[BaseMessage, Any]): # type: ignore[no-redef]
|
||||
"""A runnable that routes to the selected function."""
|
||||
|
||||
functions: Optional[list[OpenAIFunction]]
|
||||
|
@ -124,6 +124,7 @@ target-version = "py39"
|
||||
exclude = ["tests/integration_tests/examples/non-utf8-encoding.py"]
|
||||
|
||||
[tool.mypy]
|
||||
plugins = ["pydantic.mypy"]
|
||||
strict = "True"
|
||||
strict_bytes = "True"
|
||||
ignore_missing_imports = "True"
|
||||
|
@ -15,7 +15,7 @@ async def test_simplea() -> None:
|
||||
answer = "I know the answer!"
|
||||
llm = FakeListLLM(responses=[answer])
|
||||
retriever = SequentialRetriever(sequential_responses=[[]])
|
||||
memory = ConversationBufferMemory( # type: ignore[call-arg]
|
||||
memory = ConversationBufferMemory(
|
||||
k=1,
|
||||
output_key="answer",
|
||||
memory_key="chat_history",
|
||||
@ -42,7 +42,7 @@ async def test_fixed_message_response_when_docs_founda() -> None:
|
||||
retriever = SequentialRetriever(
|
||||
sequential_responses=[[Document(page_content=answer)]],
|
||||
)
|
||||
memory = ConversationBufferMemory( # type: ignore[call-arg]
|
||||
memory = ConversationBufferMemory(
|
||||
k=1,
|
||||
output_key="answer",
|
||||
memory_key="chat_history",
|
||||
@ -67,7 +67,7 @@ def test_fixed_message_response_when_no_docs_found() -> None:
|
||||
answer = "I know the answer!"
|
||||
llm = FakeListLLM(responses=[answer])
|
||||
retriever = SequentialRetriever(sequential_responses=[[]])
|
||||
memory = ConversationBufferMemory( # type: ignore[call-arg]
|
||||
memory = ConversationBufferMemory(
|
||||
k=1,
|
||||
output_key="answer",
|
||||
memory_key="chat_history",
|
||||
@ -94,7 +94,7 @@ def test_fixed_message_response_when_docs_found() -> None:
|
||||
retriever = SequentialRetriever(
|
||||
sequential_responses=[[Document(page_content=answer)]],
|
||||
)
|
||||
memory = ConversationBufferMemory( # type: ignore[call-arg]
|
||||
memory = ConversationBufferMemory(
|
||||
k=1,
|
||||
output_key="answer",
|
||||
memory_key="chat_history",
|
||||
|
@ -16,7 +16,7 @@ def dummy_transform(inputs: dict[str, str]) -> dict[str, str]:
|
||||
|
||||
def test_transform_chain() -> None:
|
||||
"""Test basic transform chain."""
|
||||
transform_chain = TransformChain( # type: ignore[call-arg]
|
||||
transform_chain = TransformChain(
|
||||
input_variables=["first_name", "last_name"],
|
||||
output_variables=["greeting"],
|
||||
transform=dummy_transform,
|
||||
@ -29,7 +29,7 @@ def test_transform_chain() -> None:
|
||||
|
||||
def test_transform_chain_bad_inputs() -> None:
|
||||
"""Test basic transform chain."""
|
||||
transform_chain = TransformChain( # type: ignore[call-arg]
|
||||
transform_chain = TransformChain(
|
||||
input_variables=["first_name", "last_name"],
|
||||
output_variables=["greeting"],
|
||||
transform=dummy_transform,
|
||||
|
@ -116,9 +116,9 @@ class TestClass(Serializable):
|
||||
|
||||
def test_aliases_hidden() -> None:
|
||||
test_class = TestClass(
|
||||
my_favorite_secret="hello", # noqa: S106 # type: ignore[call-arg]
|
||||
my_favorite_secret="hello", # noqa: S106
|
||||
my_other_secret="world", # noqa: S106
|
||||
) # type: ignore[call-arg]
|
||||
)
|
||||
dumped = json.loads(dumps(test_class, pretty=True))
|
||||
expected_dump = {
|
||||
"lc": 1,
|
||||
@ -143,7 +143,7 @@ def test_aliases_hidden() -> None:
|
||||
dumped = json.loads(dumps(test_class, pretty=True))
|
||||
|
||||
# Check by alias
|
||||
test_class = TestClass(
|
||||
test_class = TestClass( # type: ignore[call-arg]
|
||||
my_favorite_secret_alias="hello", # noqa: S106
|
||||
my_other_secret="parrot party", # noqa: S106
|
||||
)
|
||||
|
@ -162,7 +162,7 @@ def test_load_llmchain_with_non_serializable_arg() -> None:
|
||||
import httpx
|
||||
from langchain_openai import OpenAI
|
||||
|
||||
llm = OpenAI( # type: ignore[call-arg]
|
||||
llm = OpenAI(
|
||||
model="davinci",
|
||||
temperature=0.5,
|
||||
openai_api_key="hello",
|
||||
|
@ -37,7 +37,7 @@ class SuccessfulParseAfterRetries(BaseOutputParser[str]):
|
||||
def test_retry_output_parser_parse_with_prompt() -> None:
|
||||
n: int = 5 # Success on the (n+1)-th attempt
|
||||
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
|
||||
parser = RetryOutputParser(
|
||||
parser = RetryOutputParser[SuccessfulParseAfterRetries](
|
||||
parser=base_parser,
|
||||
retry_chain=RunnablePassthrough(),
|
||||
max_retries=n, # n times to retry, that is, (n+1) times call
|
||||
@ -51,7 +51,7 @@ def test_retry_output_parser_parse_with_prompt() -> None:
|
||||
def test_retry_output_parser_parse_with_prompt_fail() -> None:
|
||||
n: int = 5 # Success on the (n+1)-th attempt
|
||||
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
|
||||
parser = RetryOutputParser(
|
||||
parser = RetryOutputParser[SuccessfulParseAfterRetries](
|
||||
parser=base_parser,
|
||||
retry_chain=RunnablePassthrough(),
|
||||
max_retries=n - 1, # n-1 times to retry, that is, n times call
|
||||
@ -65,7 +65,7 @@ def test_retry_output_parser_parse_with_prompt_fail() -> None:
|
||||
async def test_retry_output_parser_aparse_with_prompt() -> None:
|
||||
n: int = 5 # Success on the (n+1)-th attempt
|
||||
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
|
||||
parser = RetryOutputParser(
|
||||
parser = RetryOutputParser[SuccessfulParseAfterRetries](
|
||||
parser=base_parser,
|
||||
retry_chain=RunnablePassthrough(),
|
||||
max_retries=n, # n times to retry, that is, (n+1) times call
|
||||
@ -82,7 +82,7 @@ async def test_retry_output_parser_aparse_with_prompt() -> None:
|
||||
async def test_retry_output_parser_aparse_with_prompt_fail() -> None:
|
||||
n: int = 5 # Success on the (n+1)-th attempt
|
||||
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
|
||||
parser = RetryOutputParser(
|
||||
parser = RetryOutputParser[SuccessfulParseAfterRetries](
|
||||
parser=base_parser,
|
||||
retry_chain=RunnablePassthrough(),
|
||||
max_retries=n - 1, # n-1 times to retry, that is, n times call
|
||||
@ -101,7 +101,7 @@ async def test_retry_output_parser_aparse_with_prompt_fail() -> None:
|
||||
],
|
||||
)
|
||||
def test_retry_output_parser_output_type(base_parser: BaseOutputParser) -> None:
|
||||
parser = RetryOutputParser(
|
||||
parser = RetryOutputParser[BaseOutputParser](
|
||||
parser=base_parser,
|
||||
retry_chain=RunnablePassthrough(),
|
||||
legacy=False,
|
||||
@ -110,7 +110,7 @@ def test_retry_output_parser_output_type(base_parser: BaseOutputParser) -> None:
|
||||
|
||||
|
||||
def test_retry_output_parser_parse_is_not_implemented() -> None:
|
||||
parser = RetryOutputParser(
|
||||
parser = RetryOutputParser[BooleanOutputParser](
|
||||
parser=BooleanOutputParser(),
|
||||
retry_chain=RunnablePassthrough(),
|
||||
legacy=False,
|
||||
@ -122,7 +122,7 @@ def test_retry_output_parser_parse_is_not_implemented() -> None:
|
||||
def test_retry_with_error_output_parser_parse_with_prompt() -> None:
|
||||
n: int = 5 # Success on the (n+1)-th attempt
|
||||
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
|
||||
parser = RetryWithErrorOutputParser(
|
||||
parser = RetryWithErrorOutputParser[SuccessfulParseAfterRetries](
|
||||
parser=base_parser,
|
||||
retry_chain=RunnablePassthrough(),
|
||||
max_retries=n, # n times to retry, that is, (n+1) times call
|
||||
@ -136,7 +136,7 @@ def test_retry_with_error_output_parser_parse_with_prompt() -> None:
|
||||
def test_retry_with_error_output_parser_parse_with_prompt_fail() -> None:
|
||||
n: int = 5 # Success on the (n+1)-th attempt
|
||||
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
|
||||
parser = RetryWithErrorOutputParser(
|
||||
parser = RetryWithErrorOutputParser[SuccessfulParseAfterRetries](
|
||||
parser=base_parser,
|
||||
retry_chain=RunnablePassthrough(),
|
||||
max_retries=n - 1, # n-1 times to retry, that is, n times call
|
||||
@ -150,7 +150,7 @@ def test_retry_with_error_output_parser_parse_with_prompt_fail() -> None:
|
||||
async def test_retry_with_error_output_parser_aparse_with_prompt() -> None:
|
||||
n: int = 5 # Success on the (n+1)-th attempt
|
||||
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
|
||||
parser = RetryWithErrorOutputParser(
|
||||
parser = RetryWithErrorOutputParser[SuccessfulParseAfterRetries](
|
||||
parser=base_parser,
|
||||
retry_chain=RunnablePassthrough(),
|
||||
max_retries=n, # n times to retry, that is, (n+1) times call
|
||||
@ -167,7 +167,7 @@ async def test_retry_with_error_output_parser_aparse_with_prompt() -> None:
|
||||
async def test_retry_with_error_output_parser_aparse_with_prompt_fail() -> None:
|
||||
n: int = 5 # Success on the (n+1)-th attempt
|
||||
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
|
||||
parser = RetryWithErrorOutputParser(
|
||||
parser = RetryWithErrorOutputParser[SuccessfulParseAfterRetries](
|
||||
parser=base_parser,
|
||||
retry_chain=RunnablePassthrough(),
|
||||
max_retries=n - 1, # n-1 times to retry, that is, n times call
|
||||
@ -188,7 +188,7 @@ async def test_retry_with_error_output_parser_aparse_with_prompt_fail() -> None:
|
||||
def test_retry_with_error_output_parser_output_type(
|
||||
base_parser: BaseOutputParser,
|
||||
) -> None:
|
||||
parser = RetryWithErrorOutputParser(
|
||||
parser = RetryWithErrorOutputParser[BaseOutputParser](
|
||||
parser=base_parser,
|
||||
retry_chain=RunnablePassthrough(),
|
||||
legacy=False,
|
||||
@ -197,7 +197,7 @@ def test_retry_with_error_output_parser_output_type(
|
||||
|
||||
|
||||
def test_retry_with_error_output_parser_parse_is_not_implemented() -> None:
|
||||
parser = RetryWithErrorOutputParser(
|
||||
parser = RetryWithErrorOutputParser[BooleanOutputParser](
|
||||
parser=BooleanOutputParser(),
|
||||
retry_chain=RunnablePassthrough(),
|
||||
legacy=False,
|
||||
@ -222,11 +222,11 @@ def test_retry_with_error_output_parser_parse_is_not_implemented() -> None:
|
||||
def test_retry_output_parser_parse_with_prompt_with_retry_chain(
|
||||
completion: str,
|
||||
prompt: PromptValue,
|
||||
base_parser: BaseOutputParser[T],
|
||||
base_parser: DatetimeOutputParser,
|
||||
retry_chain: Runnable[dict[str, Any], str],
|
||||
expected: T,
|
||||
expected: dt,
|
||||
) -> None:
|
||||
parser = RetryOutputParser(
|
||||
parser = RetryOutputParser[DatetimeOutputParser](
|
||||
parser=base_parser,
|
||||
retry_chain=retry_chain,
|
||||
legacy=False,
|
||||
@ -250,12 +250,12 @@ def test_retry_output_parser_parse_with_prompt_with_retry_chain(
|
||||
async def test_retry_output_parser_aparse_with_prompt_with_retry_chain(
|
||||
completion: str,
|
||||
prompt: PromptValue,
|
||||
base_parser: BaseOutputParser[T],
|
||||
base_parser: DatetimeOutputParser,
|
||||
retry_chain: Runnable[dict[str, Any], str],
|
||||
expected: T,
|
||||
expected: dt,
|
||||
) -> None:
|
||||
# test
|
||||
parser = RetryOutputParser(
|
||||
parser = RetryOutputParser[DatetimeOutputParser](
|
||||
parser=base_parser,
|
||||
retry_chain=retry_chain,
|
||||
legacy=False,
|
||||
@ -279,12 +279,12 @@ async def test_retry_output_parser_aparse_with_prompt_with_retry_chain(
|
||||
def test_retry_with_error_output_parser_parse_with_prompt_with_retry_chain(
|
||||
completion: str,
|
||||
prompt: PromptValue,
|
||||
base_parser: BaseOutputParser[T],
|
||||
base_parser: DatetimeOutputParser,
|
||||
retry_chain: Runnable[dict[str, Any], str],
|
||||
expected: T,
|
||||
expected: dt,
|
||||
) -> None:
|
||||
# test
|
||||
parser = RetryWithErrorOutputParser(
|
||||
parser = RetryWithErrorOutputParser[DatetimeOutputParser](
|
||||
parser=base_parser,
|
||||
retry_chain=retry_chain,
|
||||
legacy=False,
|
||||
@ -308,11 +308,11 @@ def test_retry_with_error_output_parser_parse_with_prompt_with_retry_chain(
|
||||
async def test_retry_with_error_output_parser_aparse_with_prompt_with_retry_chain(
|
||||
completion: str,
|
||||
prompt: PromptValue,
|
||||
base_parser: BaseOutputParser[T],
|
||||
base_parser: DatetimeOutputParser,
|
||||
retry_chain: Runnable[dict[str, Any], str],
|
||||
expected: T,
|
||||
expected: dt,
|
||||
) -> None:
|
||||
parser = RetryWithErrorOutputParser(
|
||||
parser = RetryWithErrorOutputParser[DatetimeOutputParser](
|
||||
parser=base_parser,
|
||||
retry_chain=retry_chain,
|
||||
legacy=False,
|
||||
|
@ -95,5 +95,5 @@ def test_yaml_output_parser_fail() -> None:
|
||||
|
||||
def test_yaml_output_parser_output_type() -> None:
|
||||
"""Test YamlOutputParser OutputType."""
|
||||
yaml_parser = YamlOutputParser(pydantic_object=TestModel)
|
||||
yaml_parser = YamlOutputParser[TestModel](pydantic_object=TestModel)
|
||||
assert yaml_parser.OutputType is TestModel
|
||||
|
@ -8,6 +8,6 @@ def test__list_rerank_init() -> None:
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
LLMListwiseRerank.from_llm(
|
||||
llm=ChatOpenAI(api_key="foo"), # type: ignore[arg-type]
|
||||
llm=ChatOpenAI(api_key="foo"),
|
||||
top_n=10,
|
||||
)
|
||||
|
@ -43,7 +43,7 @@ class InMemoryVectorstoreWithSearch(InMemoryVectorStore):
|
||||
|
||||
def test_multi_vector_retriever_initialization() -> None:
|
||||
vectorstore = InMemoryVectorstoreWithSearch()
|
||||
retriever = MultiVectorRetriever( # type: ignore[call-arg]
|
||||
retriever = MultiVectorRetriever(
|
||||
vectorstore=vectorstore,
|
||||
docstore=InMemoryStore(),
|
||||
doc_id="doc_id",
|
||||
@ -58,7 +58,7 @@ def test_multi_vector_retriever_initialization() -> None:
|
||||
|
||||
async def test_multi_vector_retriever_initialization_async() -> None:
|
||||
vectorstore = InMemoryVectorstoreWithSearch()
|
||||
retriever = MultiVectorRetriever( # type: ignore[call-arg]
|
||||
retriever = MultiVectorRetriever(
|
||||
vectorstore=vectorstore,
|
||||
docstore=InMemoryStore(),
|
||||
doc_id="doc_id",
|
||||
@ -77,7 +77,7 @@ def test_multi_vector_retriever_similarity_search_with_score() -> None:
|
||||
vectorstore.add_documents(documents, ids=["1"])
|
||||
|
||||
# score_threshold = 0.5
|
||||
retriever = MultiVectorRetriever( # type: ignore[call-arg]
|
||||
retriever = MultiVectorRetriever(
|
||||
vectorstore=vectorstore,
|
||||
docstore=InMemoryStore(),
|
||||
doc_id="doc_id",
|
||||
@ -90,7 +90,7 @@ def test_multi_vector_retriever_similarity_search_with_score() -> None:
|
||||
assert results[0].page_content == "test document"
|
||||
|
||||
# score_threshold = 0.9
|
||||
retriever = MultiVectorRetriever( # type: ignore[call-arg]
|
||||
retriever = MultiVectorRetriever(
|
||||
vectorstore=vectorstore,
|
||||
docstore=InMemoryStore(),
|
||||
doc_id="doc_id",
|
||||
@ -108,7 +108,7 @@ async def test_multi_vector_retriever_similarity_search_with_score_async() -> No
|
||||
await vectorstore.aadd_documents(documents, ids=["1"])
|
||||
|
||||
# score_threshold = 0.5
|
||||
retriever = MultiVectorRetriever( # type: ignore[call-arg]
|
||||
retriever = MultiVectorRetriever(
|
||||
vectorstore=vectorstore,
|
||||
docstore=InMemoryStore(),
|
||||
doc_id="doc_id",
|
||||
@ -121,7 +121,7 @@ async def test_multi_vector_retriever_similarity_search_with_score_async() -> No
|
||||
assert results[0].page_content == "test document"
|
||||
|
||||
# score_threshold = 0.9
|
||||
retriever = MultiVectorRetriever( # type: ignore[call-arg]
|
||||
retriever = MultiVectorRetriever(
|
||||
vectorstore=vectorstore,
|
||||
docstore=InMemoryStore(),
|
||||
doc_id="doc_id",
|
||||
|
@ -171,7 +171,7 @@ def test_run_llm_or_chain_with_input_mapper() -> None:
|
||||
assert "the right input" in inputs
|
||||
return {"output": "2"}
|
||||
|
||||
mock_chain = TransformChain( # type: ignore[call-arg]
|
||||
mock_chain = TransformChain(
|
||||
input_variables=["the right input"],
|
||||
output_variables=["output"],
|
||||
transform=run_val,
|
||||
|
Loading…
Reference in New Issue
Block a user