mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-23 11:32:10 +00:00
chore(langchain): add mypy pydantic plugin (#32610)
This commit is contained in:
parent
73a7de63aa
commit
f896bcdb1d
@ -277,7 +277,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
|
|||||||
extra_prompt_messages=extra_prompt_messages,
|
extra_prompt_messages=extra_prompt_messages,
|
||||||
system_message=system_message_,
|
system_message=system_message_,
|
||||||
)
|
)
|
||||||
return cls( # type: ignore[call-arg]
|
return cls(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
|
@ -328,7 +328,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
|
|||||||
extra_prompt_messages=extra_prompt_messages,
|
extra_prompt_messages=extra_prompt_messages,
|
||||||
system_message=system_message_,
|
system_message=system_message_,
|
||||||
)
|
)
|
||||||
return cls( # type: ignore[call-arg]
|
return cls(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
|
@ -56,7 +56,7 @@ def _load_question_to_checked_assertions_chain(
|
|||||||
revised_answer_chain,
|
revised_answer_chain,
|
||||||
]
|
]
|
||||||
return SequentialChain(
|
return SequentialChain(
|
||||||
chains=chains, # type: ignore[arg-type]
|
chains=chains,
|
||||||
input_variables=["question"],
|
input_variables=["question"],
|
||||||
output_variables=["revised_statement"],
|
output_variables=["revised_statement"],
|
||||||
verbose=True,
|
verbose=True,
|
||||||
|
@ -169,7 +169,7 @@ def _load_map_reduce_documents_chain(
|
|||||||
|
|
||||||
return MapReduceDocumentsChain(
|
return MapReduceDocumentsChain(
|
||||||
llm_chain=llm_chain,
|
llm_chain=llm_chain,
|
||||||
reduce_documents_chain=reduce_documents_chain, # type: ignore[arg-type]
|
reduce_documents_chain=reduce_documents_chain,
|
||||||
**config,
|
**config,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -293,10 +293,10 @@ def _load_llm_checker_chain(config: dict, **kwargs: Any) -> LLMCheckerChain:
|
|||||||
revised_answer_prompt = load_prompt(config.pop("revised_answer_prompt_path"))
|
revised_answer_prompt = load_prompt(config.pop("revised_answer_prompt_path"))
|
||||||
return LLMCheckerChain(
|
return LLMCheckerChain(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
create_draft_answer_prompt=create_draft_answer_prompt, # type: ignore[arg-type]
|
create_draft_answer_prompt=create_draft_answer_prompt,
|
||||||
list_assertions_prompt=list_assertions_prompt, # type: ignore[arg-type]
|
list_assertions_prompt=list_assertions_prompt,
|
||||||
check_assertions_prompt=check_assertions_prompt, # type: ignore[arg-type]
|
check_assertions_prompt=check_assertions_prompt,
|
||||||
revised_answer_prompt=revised_answer_prompt, # type: ignore[arg-type]
|
revised_answer_prompt=revised_answer_prompt,
|
||||||
**config,
|
**config,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -325,7 +325,7 @@ def _load_llm_math_chain(config: dict, **kwargs: Any) -> LLMMathChain:
|
|||||||
elif "prompt_path" in config:
|
elif "prompt_path" in config:
|
||||||
prompt = load_prompt(config.pop("prompt_path"))
|
prompt = load_prompt(config.pop("prompt_path"))
|
||||||
if llm_chain:
|
if llm_chain:
|
||||||
return LLMMathChain(llm_chain=llm_chain, prompt=prompt, **config) # type: ignore[arg-type]
|
return LLMMathChain(llm_chain=llm_chain, prompt=prompt, **config)
|
||||||
return LLMMathChain(llm=llm, prompt=prompt, **config)
|
return LLMMathChain(llm=llm, prompt=prompt, **config)
|
||||||
|
|
||||||
|
|
||||||
@ -341,7 +341,7 @@ def _load_map_rerank_documents_chain(
|
|||||||
else:
|
else:
|
||||||
msg = "One of `llm_chain` or `llm_chain_path` must be present."
|
msg = "One of `llm_chain` or `llm_chain_path` must be present."
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
return MapRerankDocumentsChain(llm_chain=llm_chain, **config) # type: ignore[arg-type]
|
return MapRerankDocumentsChain(llm_chain=llm_chain, **config)
|
||||||
|
|
||||||
|
|
||||||
def _load_pal_chain(config: dict, **kwargs: Any) -> Any:
|
def _load_pal_chain(config: dict, **kwargs: Any) -> Any:
|
||||||
@ -377,8 +377,8 @@ def _load_refine_documents_chain(config: dict, **kwargs: Any) -> RefineDocuments
|
|||||||
elif "document_prompt_path" in config:
|
elif "document_prompt_path" in config:
|
||||||
document_prompt = load_prompt(config.pop("document_prompt_path"))
|
document_prompt = load_prompt(config.pop("document_prompt_path"))
|
||||||
return RefineDocumentsChain(
|
return RefineDocumentsChain(
|
||||||
initial_llm_chain=initial_llm_chain, # type: ignore[arg-type]
|
initial_llm_chain=initial_llm_chain,
|
||||||
refine_llm_chain=refine_llm_chain, # type: ignore[arg-type]
|
refine_llm_chain=refine_llm_chain,
|
||||||
document_prompt=document_prompt,
|
document_prompt=document_prompt,
|
||||||
**config,
|
**config,
|
||||||
)
|
)
|
||||||
@ -402,7 +402,7 @@ def _load_qa_with_sources_chain(config: dict, **kwargs: Any) -> QAWithSourcesCha
|
|||||||
"`combine_documents_chain_path` must be present."
|
"`combine_documents_chain_path` must be present."
|
||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
return QAWithSourcesChain(combine_documents_chain=combine_documents_chain, **config) # type: ignore[arg-type]
|
return QAWithSourcesChain(combine_documents_chain=combine_documents_chain, **config)
|
||||||
|
|
||||||
|
|
||||||
def _load_sql_database_chain(config: dict, **kwargs: Any) -> Any:
|
def _load_sql_database_chain(config: dict, **kwargs: Any) -> Any:
|
||||||
@ -445,7 +445,7 @@ def _load_vector_db_qa_with_sources_chain(
|
|||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
return VectorDBQAWithSourcesChain(
|
return VectorDBQAWithSourcesChain(
|
||||||
combine_documents_chain=combine_documents_chain, # type: ignore[arg-type]
|
combine_documents_chain=combine_documents_chain,
|
||||||
vectorstore=vectorstore,
|
vectorstore=vectorstore,
|
||||||
**config,
|
**config,
|
||||||
)
|
)
|
||||||
@ -475,7 +475,7 @@ def _load_retrieval_qa(config: dict, **kwargs: Any) -> RetrievalQA:
|
|||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
return RetrievalQA(
|
return RetrievalQA(
|
||||||
combine_documents_chain=combine_documents_chain, # type: ignore[arg-type]
|
combine_documents_chain=combine_documents_chain,
|
||||||
retriever=retriever,
|
retriever=retriever,
|
||||||
**config,
|
**config,
|
||||||
)
|
)
|
||||||
@ -508,7 +508,7 @@ def _load_retrieval_qa_with_sources_chain(
|
|||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
return RetrievalQAWithSourcesChain(
|
return RetrievalQAWithSourcesChain(
|
||||||
combine_documents_chain=combine_documents_chain, # type: ignore[arg-type]
|
combine_documents_chain=combine_documents_chain,
|
||||||
retriever=retriever,
|
retriever=retriever,
|
||||||
**config,
|
**config,
|
||||||
)
|
)
|
||||||
@ -538,7 +538,7 @@ def _load_vector_db_qa(config: dict, **kwargs: Any) -> VectorDBQA:
|
|||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
return VectorDBQA(
|
return VectorDBQA(
|
||||||
combine_documents_chain=combine_documents_chain, # type: ignore[arg-type]
|
combine_documents_chain=combine_documents_chain,
|
||||||
vectorstore=vectorstore,
|
vectorstore=vectorstore,
|
||||||
**config,
|
**config,
|
||||||
)
|
)
|
||||||
@ -606,8 +606,8 @@ def _load_api_chain(config: dict, **kwargs: Any) -> APIChain:
|
|||||||
msg = "`requests_wrapper` must be present."
|
msg = "`requests_wrapper` must be present."
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
return APIChain(
|
return APIChain(
|
||||||
api_request_chain=api_request_chain, # type: ignore[arg-type]
|
api_request_chain=api_request_chain,
|
||||||
api_answer_chain=api_answer_chain, # type: ignore[arg-type]
|
api_answer_chain=api_answer_chain,
|
||||||
requests_wrapper=requests_wrapper,
|
requests_wrapper=requests_wrapper,
|
||||||
**config,
|
**config,
|
||||||
)
|
)
|
||||||
|
@ -66,12 +66,12 @@ def _load_stuff_chain(
|
|||||||
verbose: Optional[bool] = None,
|
verbose: Optional[bool] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> StuffDocumentsChain:
|
) -> StuffDocumentsChain:
|
||||||
llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose) # type: ignore[arg-type]
|
llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
|
||||||
return StuffDocumentsChain(
|
return StuffDocumentsChain(
|
||||||
llm_chain=llm_chain,
|
llm_chain=llm_chain,
|
||||||
document_variable_name=document_variable_name,
|
document_variable_name=document_variable_name,
|
||||||
document_prompt=document_prompt,
|
document_prompt=document_prompt,
|
||||||
verbose=verbose, # type: ignore[arg-type]
|
verbose=verbose,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -91,14 +91,14 @@ def _load_map_reduce_chain(
|
|||||||
token_max: int = 3000,
|
token_max: int = 3000,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> MapReduceDocumentsChain:
|
) -> MapReduceDocumentsChain:
|
||||||
map_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose) # type: ignore[arg-type]
|
map_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
|
||||||
_reduce_llm = reduce_llm or llm
|
_reduce_llm = reduce_llm or llm
|
||||||
reduce_chain = LLMChain(llm=_reduce_llm, prompt=combine_prompt, verbose=verbose) # type: ignore[arg-type]
|
reduce_chain = LLMChain(llm=_reduce_llm, prompt=combine_prompt, verbose=verbose)
|
||||||
combine_documents_chain = StuffDocumentsChain(
|
combine_documents_chain = StuffDocumentsChain(
|
||||||
llm_chain=reduce_chain,
|
llm_chain=reduce_chain,
|
||||||
document_variable_name=combine_document_variable_name,
|
document_variable_name=combine_document_variable_name,
|
||||||
document_prompt=document_prompt,
|
document_prompt=document_prompt,
|
||||||
verbose=verbose, # type: ignore[arg-type]
|
verbose=verbose,
|
||||||
)
|
)
|
||||||
if collapse_prompt is None:
|
if collapse_prompt is None:
|
||||||
collapse_chain = None
|
collapse_chain = None
|
||||||
@ -114,7 +114,7 @@ def _load_map_reduce_chain(
|
|||||||
llm_chain=LLMChain(
|
llm_chain=LLMChain(
|
||||||
llm=_collapse_llm,
|
llm=_collapse_llm,
|
||||||
prompt=collapse_prompt,
|
prompt=collapse_prompt,
|
||||||
verbose=verbose, # type: ignore[arg-type]
|
verbose=verbose,
|
||||||
),
|
),
|
||||||
document_variable_name=combine_document_variable_name,
|
document_variable_name=combine_document_variable_name,
|
||||||
document_prompt=document_prompt,
|
document_prompt=document_prompt,
|
||||||
@ -123,13 +123,13 @@ def _load_map_reduce_chain(
|
|||||||
combine_documents_chain=combine_documents_chain,
|
combine_documents_chain=combine_documents_chain,
|
||||||
collapse_documents_chain=collapse_chain,
|
collapse_documents_chain=collapse_chain,
|
||||||
token_max=token_max,
|
token_max=token_max,
|
||||||
verbose=verbose, # type: ignore[arg-type]
|
verbose=verbose,
|
||||||
)
|
)
|
||||||
return MapReduceDocumentsChain(
|
return MapReduceDocumentsChain(
|
||||||
llm_chain=map_chain,
|
llm_chain=map_chain,
|
||||||
reduce_documents_chain=reduce_documents_chain,
|
reduce_documents_chain=reduce_documents_chain,
|
||||||
document_variable_name=map_reduce_document_variable_name,
|
document_variable_name=map_reduce_document_variable_name,
|
||||||
verbose=verbose, # type: ignore[arg-type]
|
verbose=verbose,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -146,16 +146,16 @@ def _load_refine_chain(
|
|||||||
verbose: Optional[bool] = None,
|
verbose: Optional[bool] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> RefineDocumentsChain:
|
) -> RefineDocumentsChain:
|
||||||
initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose) # type: ignore[arg-type]
|
initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
|
||||||
_refine_llm = refine_llm or llm
|
_refine_llm = refine_llm or llm
|
||||||
refine_chain = LLMChain(llm=_refine_llm, prompt=refine_prompt, verbose=verbose) # type: ignore[arg-type]
|
refine_chain = LLMChain(llm=_refine_llm, prompt=refine_prompt, verbose=verbose)
|
||||||
return RefineDocumentsChain(
|
return RefineDocumentsChain(
|
||||||
initial_llm_chain=initial_chain,
|
initial_llm_chain=initial_chain,
|
||||||
refine_llm_chain=refine_chain,
|
refine_llm_chain=refine_chain,
|
||||||
document_variable_name=document_variable_name,
|
document_variable_name=document_variable_name,
|
||||||
initial_response_name=initial_response_name,
|
initial_response_name=initial_response_name,
|
||||||
document_prompt=document_prompt,
|
document_prompt=document_prompt,
|
||||||
verbose=verbose, # type: ignore[arg-type]
|
verbose=verbose,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -80,7 +80,7 @@ def _load_stuff_chain(
|
|||||||
llm_chain = LLMChain(
|
llm_chain = LLMChain(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
prompt=_prompt,
|
prompt=_prompt,
|
||||||
verbose=verbose, # type: ignore[arg-type]
|
verbose=verbose,
|
||||||
callback_manager=callback_manager,
|
callback_manager=callback_manager,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
)
|
)
|
||||||
@ -88,7 +88,7 @@ def _load_stuff_chain(
|
|||||||
return StuffDocumentsChain(
|
return StuffDocumentsChain(
|
||||||
llm_chain=llm_chain,
|
llm_chain=llm_chain,
|
||||||
document_variable_name=document_variable_name,
|
document_variable_name=document_variable_name,
|
||||||
verbose=verbose, # type: ignore[arg-type]
|
verbose=verbose,
|
||||||
callback_manager=callback_manager,
|
callback_manager=callback_manager,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -120,7 +120,7 @@ def _load_map_reduce_chain(
|
|||||||
map_chain = LLMChain(
|
map_chain = LLMChain(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
prompt=_question_prompt,
|
prompt=_question_prompt,
|
||||||
verbose=verbose, # type: ignore[arg-type]
|
verbose=verbose,
|
||||||
callback_manager=callback_manager,
|
callback_manager=callback_manager,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
)
|
)
|
||||||
@ -128,7 +128,7 @@ def _load_map_reduce_chain(
|
|||||||
reduce_chain = LLMChain(
|
reduce_chain = LLMChain(
|
||||||
llm=_reduce_llm,
|
llm=_reduce_llm,
|
||||||
prompt=_combine_prompt,
|
prompt=_combine_prompt,
|
||||||
verbose=verbose, # type: ignore[arg-type]
|
verbose=verbose,
|
||||||
callback_manager=callback_manager,
|
callback_manager=callback_manager,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
)
|
)
|
||||||
@ -136,7 +136,7 @@ def _load_map_reduce_chain(
|
|||||||
combine_documents_chain = StuffDocumentsChain(
|
combine_documents_chain = StuffDocumentsChain(
|
||||||
llm_chain=reduce_chain,
|
llm_chain=reduce_chain,
|
||||||
document_variable_name=combine_document_variable_name,
|
document_variable_name=combine_document_variable_name,
|
||||||
verbose=verbose, # type: ignore[arg-type]
|
verbose=verbose,
|
||||||
callback_manager=callback_manager,
|
callback_manager=callback_manager,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
)
|
)
|
||||||
@ -154,12 +154,12 @@ def _load_map_reduce_chain(
|
|||||||
llm_chain=LLMChain(
|
llm_chain=LLMChain(
|
||||||
llm=_collapse_llm,
|
llm=_collapse_llm,
|
||||||
prompt=collapse_prompt,
|
prompt=collapse_prompt,
|
||||||
verbose=verbose, # type: ignore[arg-type]
|
verbose=verbose,
|
||||||
callback_manager=callback_manager,
|
callback_manager=callback_manager,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
),
|
),
|
||||||
document_variable_name=combine_document_variable_name,
|
document_variable_name=combine_document_variable_name,
|
||||||
verbose=verbose, # type: ignore[arg-type]
|
verbose=verbose,
|
||||||
callback_manager=callback_manager,
|
callback_manager=callback_manager,
|
||||||
)
|
)
|
||||||
reduce_documents_chain = ReduceDocumentsChain(
|
reduce_documents_chain = ReduceDocumentsChain(
|
||||||
@ -172,7 +172,7 @@ def _load_map_reduce_chain(
|
|||||||
llm_chain=map_chain,
|
llm_chain=map_chain,
|
||||||
document_variable_name=map_reduce_document_variable_name,
|
document_variable_name=map_reduce_document_variable_name,
|
||||||
reduce_documents_chain=reduce_documents_chain,
|
reduce_documents_chain=reduce_documents_chain,
|
||||||
verbose=verbose, # type: ignore[arg-type]
|
verbose=verbose,
|
||||||
callback_manager=callback_manager,
|
callback_manager=callback_manager,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -201,7 +201,7 @@ def _load_refine_chain(
|
|||||||
initial_chain = LLMChain(
|
initial_chain = LLMChain(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
prompt=_question_prompt,
|
prompt=_question_prompt,
|
||||||
verbose=verbose, # type: ignore[arg-type]
|
verbose=verbose,
|
||||||
callback_manager=callback_manager,
|
callback_manager=callback_manager,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
)
|
)
|
||||||
@ -209,7 +209,7 @@ def _load_refine_chain(
|
|||||||
refine_chain = LLMChain(
|
refine_chain = LLMChain(
|
||||||
llm=_refine_llm,
|
llm=_refine_llm,
|
||||||
prompt=_refine_prompt,
|
prompt=_refine_prompt,
|
||||||
verbose=verbose, # type: ignore[arg-type]
|
verbose=verbose,
|
||||||
callback_manager=callback_manager,
|
callback_manager=callback_manager,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
)
|
)
|
||||||
@ -218,7 +218,7 @@ def _load_refine_chain(
|
|||||||
refine_llm_chain=refine_chain,
|
refine_llm_chain=refine_chain,
|
||||||
document_variable_name=document_variable_name,
|
document_variable_name=document_variable_name,
|
||||||
initial_response_name=initial_response_name,
|
initial_response_name=initial_response_name,
|
||||||
verbose=verbose, # type: ignore[arg-type]
|
verbose=verbose,
|
||||||
callback_manager=callback_manager,
|
callback_manager=callback_manager,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
@ -35,12 +35,12 @@ def _load_stuff_chain(
|
|||||||
verbose: Optional[bool] = None,
|
verbose: Optional[bool] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> StuffDocumentsChain:
|
) -> StuffDocumentsChain:
|
||||||
llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose) # type: ignore[arg-type]
|
llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
|
||||||
# TODO: document prompt
|
# TODO: document prompt
|
||||||
return StuffDocumentsChain(
|
return StuffDocumentsChain(
|
||||||
llm_chain=llm_chain,
|
llm_chain=llm_chain,
|
||||||
document_variable_name=document_variable_name,
|
document_variable_name=document_variable_name,
|
||||||
verbose=verbose, # type: ignore[arg-type]
|
verbose=verbose,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -64,21 +64,21 @@ def _load_map_reduce_chain(
|
|||||||
map_chain = LLMChain(
|
map_chain = LLMChain(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
prompt=map_prompt,
|
prompt=map_prompt,
|
||||||
verbose=verbose, # type: ignore[arg-type]
|
verbose=verbose,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
)
|
)
|
||||||
_reduce_llm = reduce_llm or llm
|
_reduce_llm = reduce_llm or llm
|
||||||
reduce_chain = LLMChain(
|
reduce_chain = LLMChain(
|
||||||
llm=_reduce_llm,
|
llm=_reduce_llm,
|
||||||
prompt=combine_prompt,
|
prompt=combine_prompt,
|
||||||
verbose=verbose, # type: ignore[arg-type]
|
verbose=verbose,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
)
|
)
|
||||||
# TODO: document prompt
|
# TODO: document prompt
|
||||||
combine_documents_chain = StuffDocumentsChain(
|
combine_documents_chain = StuffDocumentsChain(
|
||||||
llm_chain=reduce_chain,
|
llm_chain=reduce_chain,
|
||||||
document_variable_name=combine_document_variable_name,
|
document_variable_name=combine_document_variable_name,
|
||||||
verbose=verbose, # type: ignore[arg-type]
|
verbose=verbose,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
)
|
)
|
||||||
if collapse_prompt is None:
|
if collapse_prompt is None:
|
||||||
@ -95,7 +95,7 @@ def _load_map_reduce_chain(
|
|||||||
llm_chain=LLMChain(
|
llm_chain=LLMChain(
|
||||||
llm=_collapse_llm,
|
llm=_collapse_llm,
|
||||||
prompt=collapse_prompt,
|
prompt=collapse_prompt,
|
||||||
verbose=verbose, # type: ignore[arg-type]
|
verbose=verbose,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
),
|
),
|
||||||
document_variable_name=combine_document_variable_name,
|
document_variable_name=combine_document_variable_name,
|
||||||
@ -104,7 +104,7 @@ def _load_map_reduce_chain(
|
|||||||
combine_documents_chain=combine_documents_chain,
|
combine_documents_chain=combine_documents_chain,
|
||||||
collapse_documents_chain=collapse_chain,
|
collapse_documents_chain=collapse_chain,
|
||||||
token_max=token_max,
|
token_max=token_max,
|
||||||
verbose=verbose, # type: ignore[arg-type]
|
verbose=verbose,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
collapse_max_retries=collapse_max_retries,
|
collapse_max_retries=collapse_max_retries,
|
||||||
)
|
)
|
||||||
@ -112,7 +112,7 @@ def _load_map_reduce_chain(
|
|||||||
llm_chain=map_chain,
|
llm_chain=map_chain,
|
||||||
reduce_documents_chain=reduce_documents_chain,
|
reduce_documents_chain=reduce_documents_chain,
|
||||||
document_variable_name=map_reduce_document_variable_name,
|
document_variable_name=map_reduce_document_variable_name,
|
||||||
verbose=verbose, # type: ignore[arg-type]
|
verbose=verbose,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
@ -129,15 +129,15 @@ def _load_refine_chain(
|
|||||||
verbose: Optional[bool] = None,
|
verbose: Optional[bool] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> RefineDocumentsChain:
|
) -> RefineDocumentsChain:
|
||||||
initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose) # type: ignore[arg-type]
|
initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
|
||||||
_refine_llm = refine_llm or llm
|
_refine_llm = refine_llm or llm
|
||||||
refine_chain = LLMChain(llm=_refine_llm, prompt=refine_prompt, verbose=verbose) # type: ignore[arg-type]
|
refine_chain = LLMChain(llm=_refine_llm, prompt=refine_prompt, verbose=verbose)
|
||||||
return RefineDocumentsChain(
|
return RefineDocumentsChain(
|
||||||
initial_llm_chain=initial_chain,
|
initial_llm_chain=initial_chain,
|
||||||
refine_llm_chain=refine_chain,
|
refine_llm_chain=refine_chain,
|
||||||
document_variable_name=document_variable_name,
|
document_variable_name=document_variable_name,
|
||||||
initial_response_name=initial_response_name,
|
initial_response_name=initial_response_name,
|
||||||
verbose=verbose, # type: ignore[arg-type]
|
verbose=verbose,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -250,7 +250,7 @@ The following is the expected answer. Use this to measure correctness:
|
|||||||
prompt = EVAL_CHAT_PROMPT if agent_tools else TOOL_FREE_EVAL_CHAT_PROMPT
|
prompt = EVAL_CHAT_PROMPT if agent_tools else TOOL_FREE_EVAL_CHAT_PROMPT
|
||||||
eval_chain = LLMChain(llm=llm, prompt=prompt)
|
eval_chain = LLMChain(llm=llm, prompt=prompt)
|
||||||
return cls(
|
return cls(
|
||||||
agent_tools=agent_tools, # type: ignore[arg-type]
|
agent_tools=agent_tools,
|
||||||
eval_chain=eval_chain,
|
eval_chain=eval_chain,
|
||||||
output_parser=output_parser or TrajectoryOutputParser(),
|
output_parser=output_parser or TrajectoryOutputParser(),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
@ -120,4 +120,4 @@ class LLMChainExtractor(BaseDocumentCompressor):
|
|||||||
else:
|
else:
|
||||||
parser = StrOutputParser()
|
parser = StrOutputParser()
|
||||||
llm_chain = _prompt | llm | parser
|
llm_chain = _prompt | llm | parser
|
||||||
return cls(llm_chain=llm_chain, get_input=_get_input) # type: ignore[arg-type]
|
return cls(llm_chain=llm_chain, get_input=_get_input)
|
||||||
|
@ -407,7 +407,7 @@ class SelfQueryRetriever(BaseRetriever):
|
|||||||
query_constructor = query_constructor.with_config(
|
query_constructor = query_constructor.with_config(
|
||||||
run_name=QUERY_CONSTRUCTOR_RUN_NAME,
|
run_name=QUERY_CONSTRUCTOR_RUN_NAME,
|
||||||
)
|
)
|
||||||
return cls( # type: ignore[call-arg]
|
return cls(
|
||||||
query_constructor=query_constructor,
|
query_constructor=query_constructor,
|
||||||
vectorstore=vectorstore,
|
vectorstore=vectorstore,
|
||||||
use_original_query=use_original_query,
|
use_original_query=use_original_query,
|
||||||
|
@ -4,7 +4,7 @@ from langchain_core.runnables.base import RunnableBindingBase
|
|||||||
from langchain_core.runnables.utils import Input, Output
|
from langchain_core.runnables.utils import Input, Output
|
||||||
|
|
||||||
|
|
||||||
class HubRunnable(RunnableBindingBase[Input, Output]):
|
class HubRunnable(RunnableBindingBase[Input, Output]): # type: ignore[no-redef]
|
||||||
"""
|
"""
|
||||||
An instance of a runnable stored in the LangChain Hub.
|
An instance of a runnable stored in the LangChain Hub.
|
||||||
"""
|
"""
|
||||||
|
@ -20,7 +20,7 @@ class OpenAIFunction(TypedDict):
|
|||||||
"""The parameters to the function."""
|
"""The parameters to the function."""
|
||||||
|
|
||||||
|
|
||||||
class OpenAIFunctionsRouter(RunnableBindingBase[BaseMessage, Any]):
|
class OpenAIFunctionsRouter(RunnableBindingBase[BaseMessage, Any]): # type: ignore[no-redef]
|
||||||
"""A runnable that routes to the selected function."""
|
"""A runnable that routes to the selected function."""
|
||||||
|
|
||||||
functions: Optional[list[OpenAIFunction]]
|
functions: Optional[list[OpenAIFunction]]
|
||||||
|
@ -124,6 +124,7 @@ target-version = "py39"
|
|||||||
exclude = ["tests/integration_tests/examples/non-utf8-encoding.py"]
|
exclude = ["tests/integration_tests/examples/non-utf8-encoding.py"]
|
||||||
|
|
||||||
[tool.mypy]
|
[tool.mypy]
|
||||||
|
plugins = ["pydantic.mypy"]
|
||||||
strict = "True"
|
strict = "True"
|
||||||
strict_bytes = "True"
|
strict_bytes = "True"
|
||||||
ignore_missing_imports = "True"
|
ignore_missing_imports = "True"
|
||||||
|
@ -15,7 +15,7 @@ async def test_simplea() -> None:
|
|||||||
answer = "I know the answer!"
|
answer = "I know the answer!"
|
||||||
llm = FakeListLLM(responses=[answer])
|
llm = FakeListLLM(responses=[answer])
|
||||||
retriever = SequentialRetriever(sequential_responses=[[]])
|
retriever = SequentialRetriever(sequential_responses=[[]])
|
||||||
memory = ConversationBufferMemory( # type: ignore[call-arg]
|
memory = ConversationBufferMemory(
|
||||||
k=1,
|
k=1,
|
||||||
output_key="answer",
|
output_key="answer",
|
||||||
memory_key="chat_history",
|
memory_key="chat_history",
|
||||||
@ -42,7 +42,7 @@ async def test_fixed_message_response_when_docs_founda() -> None:
|
|||||||
retriever = SequentialRetriever(
|
retriever = SequentialRetriever(
|
||||||
sequential_responses=[[Document(page_content=answer)]],
|
sequential_responses=[[Document(page_content=answer)]],
|
||||||
)
|
)
|
||||||
memory = ConversationBufferMemory( # type: ignore[call-arg]
|
memory = ConversationBufferMemory(
|
||||||
k=1,
|
k=1,
|
||||||
output_key="answer",
|
output_key="answer",
|
||||||
memory_key="chat_history",
|
memory_key="chat_history",
|
||||||
@ -67,7 +67,7 @@ def test_fixed_message_response_when_no_docs_found() -> None:
|
|||||||
answer = "I know the answer!"
|
answer = "I know the answer!"
|
||||||
llm = FakeListLLM(responses=[answer])
|
llm = FakeListLLM(responses=[answer])
|
||||||
retriever = SequentialRetriever(sequential_responses=[[]])
|
retriever = SequentialRetriever(sequential_responses=[[]])
|
||||||
memory = ConversationBufferMemory( # type: ignore[call-arg]
|
memory = ConversationBufferMemory(
|
||||||
k=1,
|
k=1,
|
||||||
output_key="answer",
|
output_key="answer",
|
||||||
memory_key="chat_history",
|
memory_key="chat_history",
|
||||||
@ -94,7 +94,7 @@ def test_fixed_message_response_when_docs_found() -> None:
|
|||||||
retriever = SequentialRetriever(
|
retriever = SequentialRetriever(
|
||||||
sequential_responses=[[Document(page_content=answer)]],
|
sequential_responses=[[Document(page_content=answer)]],
|
||||||
)
|
)
|
||||||
memory = ConversationBufferMemory( # type: ignore[call-arg]
|
memory = ConversationBufferMemory(
|
||||||
k=1,
|
k=1,
|
||||||
output_key="answer",
|
output_key="answer",
|
||||||
memory_key="chat_history",
|
memory_key="chat_history",
|
||||||
|
@ -16,7 +16,7 @@ def dummy_transform(inputs: dict[str, str]) -> dict[str, str]:
|
|||||||
|
|
||||||
def test_transform_chain() -> None:
|
def test_transform_chain() -> None:
|
||||||
"""Test basic transform chain."""
|
"""Test basic transform chain."""
|
||||||
transform_chain = TransformChain( # type: ignore[call-arg]
|
transform_chain = TransformChain(
|
||||||
input_variables=["first_name", "last_name"],
|
input_variables=["first_name", "last_name"],
|
||||||
output_variables=["greeting"],
|
output_variables=["greeting"],
|
||||||
transform=dummy_transform,
|
transform=dummy_transform,
|
||||||
@ -29,7 +29,7 @@ def test_transform_chain() -> None:
|
|||||||
|
|
||||||
def test_transform_chain_bad_inputs() -> None:
|
def test_transform_chain_bad_inputs() -> None:
|
||||||
"""Test basic transform chain."""
|
"""Test basic transform chain."""
|
||||||
transform_chain = TransformChain( # type: ignore[call-arg]
|
transform_chain = TransformChain(
|
||||||
input_variables=["first_name", "last_name"],
|
input_variables=["first_name", "last_name"],
|
||||||
output_variables=["greeting"],
|
output_variables=["greeting"],
|
||||||
transform=dummy_transform,
|
transform=dummy_transform,
|
||||||
|
@ -116,9 +116,9 @@ class TestClass(Serializable):
|
|||||||
|
|
||||||
def test_aliases_hidden() -> None:
|
def test_aliases_hidden() -> None:
|
||||||
test_class = TestClass(
|
test_class = TestClass(
|
||||||
my_favorite_secret="hello", # noqa: S106 # type: ignore[call-arg]
|
my_favorite_secret="hello", # noqa: S106
|
||||||
my_other_secret="world", # noqa: S106
|
my_other_secret="world", # noqa: S106
|
||||||
) # type: ignore[call-arg]
|
)
|
||||||
dumped = json.loads(dumps(test_class, pretty=True))
|
dumped = json.loads(dumps(test_class, pretty=True))
|
||||||
expected_dump = {
|
expected_dump = {
|
||||||
"lc": 1,
|
"lc": 1,
|
||||||
@ -143,7 +143,7 @@ def test_aliases_hidden() -> None:
|
|||||||
dumped = json.loads(dumps(test_class, pretty=True))
|
dumped = json.loads(dumps(test_class, pretty=True))
|
||||||
|
|
||||||
# Check by alias
|
# Check by alias
|
||||||
test_class = TestClass(
|
test_class = TestClass( # type: ignore[call-arg]
|
||||||
my_favorite_secret_alias="hello", # noqa: S106
|
my_favorite_secret_alias="hello", # noqa: S106
|
||||||
my_other_secret="parrot party", # noqa: S106
|
my_other_secret="parrot party", # noqa: S106
|
||||||
)
|
)
|
||||||
|
@ -162,7 +162,7 @@ def test_load_llmchain_with_non_serializable_arg() -> None:
|
|||||||
import httpx
|
import httpx
|
||||||
from langchain_openai import OpenAI
|
from langchain_openai import OpenAI
|
||||||
|
|
||||||
llm = OpenAI( # type: ignore[call-arg]
|
llm = OpenAI(
|
||||||
model="davinci",
|
model="davinci",
|
||||||
temperature=0.5,
|
temperature=0.5,
|
||||||
openai_api_key="hello",
|
openai_api_key="hello",
|
||||||
|
@ -37,7 +37,7 @@ class SuccessfulParseAfterRetries(BaseOutputParser[str]):
|
|||||||
def test_retry_output_parser_parse_with_prompt() -> None:
|
def test_retry_output_parser_parse_with_prompt() -> None:
|
||||||
n: int = 5 # Success on the (n+1)-th attempt
|
n: int = 5 # Success on the (n+1)-th attempt
|
||||||
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
|
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
|
||||||
parser = RetryOutputParser(
|
parser = RetryOutputParser[SuccessfulParseAfterRetries](
|
||||||
parser=base_parser,
|
parser=base_parser,
|
||||||
retry_chain=RunnablePassthrough(),
|
retry_chain=RunnablePassthrough(),
|
||||||
max_retries=n, # n times to retry, that is, (n+1) times call
|
max_retries=n, # n times to retry, that is, (n+1) times call
|
||||||
@ -51,7 +51,7 @@ def test_retry_output_parser_parse_with_prompt() -> None:
|
|||||||
def test_retry_output_parser_parse_with_prompt_fail() -> None:
|
def test_retry_output_parser_parse_with_prompt_fail() -> None:
|
||||||
n: int = 5 # Success on the (n+1)-th attempt
|
n: int = 5 # Success on the (n+1)-th attempt
|
||||||
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
|
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
|
||||||
parser = RetryOutputParser(
|
parser = RetryOutputParser[SuccessfulParseAfterRetries](
|
||||||
parser=base_parser,
|
parser=base_parser,
|
||||||
retry_chain=RunnablePassthrough(),
|
retry_chain=RunnablePassthrough(),
|
||||||
max_retries=n - 1, # n-1 times to retry, that is, n times call
|
max_retries=n - 1, # n-1 times to retry, that is, n times call
|
||||||
@ -65,7 +65,7 @@ def test_retry_output_parser_parse_with_prompt_fail() -> None:
|
|||||||
async def test_retry_output_parser_aparse_with_prompt() -> None:
|
async def test_retry_output_parser_aparse_with_prompt() -> None:
|
||||||
n: int = 5 # Success on the (n+1)-th attempt
|
n: int = 5 # Success on the (n+1)-th attempt
|
||||||
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
|
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
|
||||||
parser = RetryOutputParser(
|
parser = RetryOutputParser[SuccessfulParseAfterRetries](
|
||||||
parser=base_parser,
|
parser=base_parser,
|
||||||
retry_chain=RunnablePassthrough(),
|
retry_chain=RunnablePassthrough(),
|
||||||
max_retries=n, # n times to retry, that is, (n+1) times call
|
max_retries=n, # n times to retry, that is, (n+1) times call
|
||||||
@ -82,7 +82,7 @@ async def test_retry_output_parser_aparse_with_prompt() -> None:
|
|||||||
async def test_retry_output_parser_aparse_with_prompt_fail() -> None:
|
async def test_retry_output_parser_aparse_with_prompt_fail() -> None:
|
||||||
n: int = 5 # Success on the (n+1)-th attempt
|
n: int = 5 # Success on the (n+1)-th attempt
|
||||||
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
|
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
|
||||||
parser = RetryOutputParser(
|
parser = RetryOutputParser[SuccessfulParseAfterRetries](
|
||||||
parser=base_parser,
|
parser=base_parser,
|
||||||
retry_chain=RunnablePassthrough(),
|
retry_chain=RunnablePassthrough(),
|
||||||
max_retries=n - 1, # n-1 times to retry, that is, n times call
|
max_retries=n - 1, # n-1 times to retry, that is, n times call
|
||||||
@ -101,7 +101,7 @@ async def test_retry_output_parser_aparse_with_prompt_fail() -> None:
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_retry_output_parser_output_type(base_parser: BaseOutputParser) -> None:
|
def test_retry_output_parser_output_type(base_parser: BaseOutputParser) -> None:
|
||||||
parser = RetryOutputParser(
|
parser = RetryOutputParser[BaseOutputParser](
|
||||||
parser=base_parser,
|
parser=base_parser,
|
||||||
retry_chain=RunnablePassthrough(),
|
retry_chain=RunnablePassthrough(),
|
||||||
legacy=False,
|
legacy=False,
|
||||||
@ -110,7 +110,7 @@ def test_retry_output_parser_output_type(base_parser: BaseOutputParser) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_retry_output_parser_parse_is_not_implemented() -> None:
|
def test_retry_output_parser_parse_is_not_implemented() -> None:
|
||||||
parser = RetryOutputParser(
|
parser = RetryOutputParser[BooleanOutputParser](
|
||||||
parser=BooleanOutputParser(),
|
parser=BooleanOutputParser(),
|
||||||
retry_chain=RunnablePassthrough(),
|
retry_chain=RunnablePassthrough(),
|
||||||
legacy=False,
|
legacy=False,
|
||||||
@ -122,7 +122,7 @@ def test_retry_output_parser_parse_is_not_implemented() -> None:
|
|||||||
def test_retry_with_error_output_parser_parse_with_prompt() -> None:
|
def test_retry_with_error_output_parser_parse_with_prompt() -> None:
|
||||||
n: int = 5 # Success on the (n+1)-th attempt
|
n: int = 5 # Success on the (n+1)-th attempt
|
||||||
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
|
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
|
||||||
parser = RetryWithErrorOutputParser(
|
parser = RetryWithErrorOutputParser[SuccessfulParseAfterRetries](
|
||||||
parser=base_parser,
|
parser=base_parser,
|
||||||
retry_chain=RunnablePassthrough(),
|
retry_chain=RunnablePassthrough(),
|
||||||
max_retries=n, # n times to retry, that is, (n+1) times call
|
max_retries=n, # n times to retry, that is, (n+1) times call
|
||||||
@ -136,7 +136,7 @@ def test_retry_with_error_output_parser_parse_with_prompt() -> None:
|
|||||||
def test_retry_with_error_output_parser_parse_with_prompt_fail() -> None:
|
def test_retry_with_error_output_parser_parse_with_prompt_fail() -> None:
|
||||||
n: int = 5 # Success on the (n+1)-th attempt
|
n: int = 5 # Success on the (n+1)-th attempt
|
||||||
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
|
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
|
||||||
parser = RetryWithErrorOutputParser(
|
parser = RetryWithErrorOutputParser[SuccessfulParseAfterRetries](
|
||||||
parser=base_parser,
|
parser=base_parser,
|
||||||
retry_chain=RunnablePassthrough(),
|
retry_chain=RunnablePassthrough(),
|
||||||
max_retries=n - 1, # n-1 times to retry, that is, n times call
|
max_retries=n - 1, # n-1 times to retry, that is, n times call
|
||||||
@ -150,7 +150,7 @@ def test_retry_with_error_output_parser_parse_with_prompt_fail() -> None:
|
|||||||
async def test_retry_with_error_output_parser_aparse_with_prompt() -> None:
|
async def test_retry_with_error_output_parser_aparse_with_prompt() -> None:
|
||||||
n: int = 5 # Success on the (n+1)-th attempt
|
n: int = 5 # Success on the (n+1)-th attempt
|
||||||
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
|
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
|
||||||
parser = RetryWithErrorOutputParser(
|
parser = RetryWithErrorOutputParser[SuccessfulParseAfterRetries](
|
||||||
parser=base_parser,
|
parser=base_parser,
|
||||||
retry_chain=RunnablePassthrough(),
|
retry_chain=RunnablePassthrough(),
|
||||||
max_retries=n, # n times to retry, that is, (n+1) times call
|
max_retries=n, # n times to retry, that is, (n+1) times call
|
||||||
@ -167,7 +167,7 @@ async def test_retry_with_error_output_parser_aparse_with_prompt() -> None:
|
|||||||
async def test_retry_with_error_output_parser_aparse_with_prompt_fail() -> None:
|
async def test_retry_with_error_output_parser_aparse_with_prompt_fail() -> None:
|
||||||
n: int = 5 # Success on the (n+1)-th attempt
|
n: int = 5 # Success on the (n+1)-th attempt
|
||||||
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
|
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
|
||||||
parser = RetryWithErrorOutputParser(
|
parser = RetryWithErrorOutputParser[SuccessfulParseAfterRetries](
|
||||||
parser=base_parser,
|
parser=base_parser,
|
||||||
retry_chain=RunnablePassthrough(),
|
retry_chain=RunnablePassthrough(),
|
||||||
max_retries=n - 1, # n-1 times to retry, that is, n times call
|
max_retries=n - 1, # n-1 times to retry, that is, n times call
|
||||||
@ -188,7 +188,7 @@ async def test_retry_with_error_output_parser_aparse_with_prompt_fail() -> None:
|
|||||||
def test_retry_with_error_output_parser_output_type(
|
def test_retry_with_error_output_parser_output_type(
|
||||||
base_parser: BaseOutputParser,
|
base_parser: BaseOutputParser,
|
||||||
) -> None:
|
) -> None:
|
||||||
parser = RetryWithErrorOutputParser(
|
parser = RetryWithErrorOutputParser[BaseOutputParser](
|
||||||
parser=base_parser,
|
parser=base_parser,
|
||||||
retry_chain=RunnablePassthrough(),
|
retry_chain=RunnablePassthrough(),
|
||||||
legacy=False,
|
legacy=False,
|
||||||
@ -197,7 +197,7 @@ def test_retry_with_error_output_parser_output_type(
|
|||||||
|
|
||||||
|
|
||||||
def test_retry_with_error_output_parser_parse_is_not_implemented() -> None:
|
def test_retry_with_error_output_parser_parse_is_not_implemented() -> None:
|
||||||
parser = RetryWithErrorOutputParser(
|
parser = RetryWithErrorOutputParser[BooleanOutputParser](
|
||||||
parser=BooleanOutputParser(),
|
parser=BooleanOutputParser(),
|
||||||
retry_chain=RunnablePassthrough(),
|
retry_chain=RunnablePassthrough(),
|
||||||
legacy=False,
|
legacy=False,
|
||||||
@ -222,11 +222,11 @@ def test_retry_with_error_output_parser_parse_is_not_implemented() -> None:
|
|||||||
def test_retry_output_parser_parse_with_prompt_with_retry_chain(
|
def test_retry_output_parser_parse_with_prompt_with_retry_chain(
|
||||||
completion: str,
|
completion: str,
|
||||||
prompt: PromptValue,
|
prompt: PromptValue,
|
||||||
base_parser: BaseOutputParser[T],
|
base_parser: DatetimeOutputParser,
|
||||||
retry_chain: Runnable[dict[str, Any], str],
|
retry_chain: Runnable[dict[str, Any], str],
|
||||||
expected: T,
|
expected: dt,
|
||||||
) -> None:
|
) -> None:
|
||||||
parser = RetryOutputParser(
|
parser = RetryOutputParser[DatetimeOutputParser](
|
||||||
parser=base_parser,
|
parser=base_parser,
|
||||||
retry_chain=retry_chain,
|
retry_chain=retry_chain,
|
||||||
legacy=False,
|
legacy=False,
|
||||||
@ -250,12 +250,12 @@ def test_retry_output_parser_parse_with_prompt_with_retry_chain(
|
|||||||
async def test_retry_output_parser_aparse_with_prompt_with_retry_chain(
|
async def test_retry_output_parser_aparse_with_prompt_with_retry_chain(
|
||||||
completion: str,
|
completion: str,
|
||||||
prompt: PromptValue,
|
prompt: PromptValue,
|
||||||
base_parser: BaseOutputParser[T],
|
base_parser: DatetimeOutputParser,
|
||||||
retry_chain: Runnable[dict[str, Any], str],
|
retry_chain: Runnable[dict[str, Any], str],
|
||||||
expected: T,
|
expected: dt,
|
||||||
) -> None:
|
) -> None:
|
||||||
# test
|
# test
|
||||||
parser = RetryOutputParser(
|
parser = RetryOutputParser[DatetimeOutputParser](
|
||||||
parser=base_parser,
|
parser=base_parser,
|
||||||
retry_chain=retry_chain,
|
retry_chain=retry_chain,
|
||||||
legacy=False,
|
legacy=False,
|
||||||
@ -279,12 +279,12 @@ async def test_retry_output_parser_aparse_with_prompt_with_retry_chain(
|
|||||||
def test_retry_with_error_output_parser_parse_with_prompt_with_retry_chain(
|
def test_retry_with_error_output_parser_parse_with_prompt_with_retry_chain(
|
||||||
completion: str,
|
completion: str,
|
||||||
prompt: PromptValue,
|
prompt: PromptValue,
|
||||||
base_parser: BaseOutputParser[T],
|
base_parser: DatetimeOutputParser,
|
||||||
retry_chain: Runnable[dict[str, Any], str],
|
retry_chain: Runnable[dict[str, Any], str],
|
||||||
expected: T,
|
expected: dt,
|
||||||
) -> None:
|
) -> None:
|
||||||
# test
|
# test
|
||||||
parser = RetryWithErrorOutputParser(
|
parser = RetryWithErrorOutputParser[DatetimeOutputParser](
|
||||||
parser=base_parser,
|
parser=base_parser,
|
||||||
retry_chain=retry_chain,
|
retry_chain=retry_chain,
|
||||||
legacy=False,
|
legacy=False,
|
||||||
@ -308,11 +308,11 @@ def test_retry_with_error_output_parser_parse_with_prompt_with_retry_chain(
|
|||||||
async def test_retry_with_error_output_parser_aparse_with_prompt_with_retry_chain(
|
async def test_retry_with_error_output_parser_aparse_with_prompt_with_retry_chain(
|
||||||
completion: str,
|
completion: str,
|
||||||
prompt: PromptValue,
|
prompt: PromptValue,
|
||||||
base_parser: BaseOutputParser[T],
|
base_parser: DatetimeOutputParser,
|
||||||
retry_chain: Runnable[dict[str, Any], str],
|
retry_chain: Runnable[dict[str, Any], str],
|
||||||
expected: T,
|
expected: dt,
|
||||||
) -> None:
|
) -> None:
|
||||||
parser = RetryWithErrorOutputParser(
|
parser = RetryWithErrorOutputParser[DatetimeOutputParser](
|
||||||
parser=base_parser,
|
parser=base_parser,
|
||||||
retry_chain=retry_chain,
|
retry_chain=retry_chain,
|
||||||
legacy=False,
|
legacy=False,
|
||||||
|
@ -95,5 +95,5 @@ def test_yaml_output_parser_fail() -> None:
|
|||||||
|
|
||||||
def test_yaml_output_parser_output_type() -> None:
|
def test_yaml_output_parser_output_type() -> None:
|
||||||
"""Test YamlOutputParser OutputType."""
|
"""Test YamlOutputParser OutputType."""
|
||||||
yaml_parser = YamlOutputParser(pydantic_object=TestModel)
|
yaml_parser = YamlOutputParser[TestModel](pydantic_object=TestModel)
|
||||||
assert yaml_parser.OutputType is TestModel
|
assert yaml_parser.OutputType is TestModel
|
||||||
|
@ -8,6 +8,6 @@ def test__list_rerank_init() -> None:
|
|||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
|
|
||||||
LLMListwiseRerank.from_llm(
|
LLMListwiseRerank.from_llm(
|
||||||
llm=ChatOpenAI(api_key="foo"), # type: ignore[arg-type]
|
llm=ChatOpenAI(api_key="foo"),
|
||||||
top_n=10,
|
top_n=10,
|
||||||
)
|
)
|
||||||
|
@ -43,7 +43,7 @@ class InMemoryVectorstoreWithSearch(InMemoryVectorStore):
|
|||||||
|
|
||||||
def test_multi_vector_retriever_initialization() -> None:
|
def test_multi_vector_retriever_initialization() -> None:
|
||||||
vectorstore = InMemoryVectorstoreWithSearch()
|
vectorstore = InMemoryVectorstoreWithSearch()
|
||||||
retriever = MultiVectorRetriever( # type: ignore[call-arg]
|
retriever = MultiVectorRetriever(
|
||||||
vectorstore=vectorstore,
|
vectorstore=vectorstore,
|
||||||
docstore=InMemoryStore(),
|
docstore=InMemoryStore(),
|
||||||
doc_id="doc_id",
|
doc_id="doc_id",
|
||||||
@ -58,7 +58,7 @@ def test_multi_vector_retriever_initialization() -> None:
|
|||||||
|
|
||||||
async def test_multi_vector_retriever_initialization_async() -> None:
|
async def test_multi_vector_retriever_initialization_async() -> None:
|
||||||
vectorstore = InMemoryVectorstoreWithSearch()
|
vectorstore = InMemoryVectorstoreWithSearch()
|
||||||
retriever = MultiVectorRetriever( # type: ignore[call-arg]
|
retriever = MultiVectorRetriever(
|
||||||
vectorstore=vectorstore,
|
vectorstore=vectorstore,
|
||||||
docstore=InMemoryStore(),
|
docstore=InMemoryStore(),
|
||||||
doc_id="doc_id",
|
doc_id="doc_id",
|
||||||
@ -77,7 +77,7 @@ def test_multi_vector_retriever_similarity_search_with_score() -> None:
|
|||||||
vectorstore.add_documents(documents, ids=["1"])
|
vectorstore.add_documents(documents, ids=["1"])
|
||||||
|
|
||||||
# score_threshold = 0.5
|
# score_threshold = 0.5
|
||||||
retriever = MultiVectorRetriever( # type: ignore[call-arg]
|
retriever = MultiVectorRetriever(
|
||||||
vectorstore=vectorstore,
|
vectorstore=vectorstore,
|
||||||
docstore=InMemoryStore(),
|
docstore=InMemoryStore(),
|
||||||
doc_id="doc_id",
|
doc_id="doc_id",
|
||||||
@ -90,7 +90,7 @@ def test_multi_vector_retriever_similarity_search_with_score() -> None:
|
|||||||
assert results[0].page_content == "test document"
|
assert results[0].page_content == "test document"
|
||||||
|
|
||||||
# score_threshold = 0.9
|
# score_threshold = 0.9
|
||||||
retriever = MultiVectorRetriever( # type: ignore[call-arg]
|
retriever = MultiVectorRetriever(
|
||||||
vectorstore=vectorstore,
|
vectorstore=vectorstore,
|
||||||
docstore=InMemoryStore(),
|
docstore=InMemoryStore(),
|
||||||
doc_id="doc_id",
|
doc_id="doc_id",
|
||||||
@ -108,7 +108,7 @@ async def test_multi_vector_retriever_similarity_search_with_score_async() -> No
|
|||||||
await vectorstore.aadd_documents(documents, ids=["1"])
|
await vectorstore.aadd_documents(documents, ids=["1"])
|
||||||
|
|
||||||
# score_threshold = 0.5
|
# score_threshold = 0.5
|
||||||
retriever = MultiVectorRetriever( # type: ignore[call-arg]
|
retriever = MultiVectorRetriever(
|
||||||
vectorstore=vectorstore,
|
vectorstore=vectorstore,
|
||||||
docstore=InMemoryStore(),
|
docstore=InMemoryStore(),
|
||||||
doc_id="doc_id",
|
doc_id="doc_id",
|
||||||
@ -121,7 +121,7 @@ async def test_multi_vector_retriever_similarity_search_with_score_async() -> No
|
|||||||
assert results[0].page_content == "test document"
|
assert results[0].page_content == "test document"
|
||||||
|
|
||||||
# score_threshold = 0.9
|
# score_threshold = 0.9
|
||||||
retriever = MultiVectorRetriever( # type: ignore[call-arg]
|
retriever = MultiVectorRetriever(
|
||||||
vectorstore=vectorstore,
|
vectorstore=vectorstore,
|
||||||
docstore=InMemoryStore(),
|
docstore=InMemoryStore(),
|
||||||
doc_id="doc_id",
|
doc_id="doc_id",
|
||||||
|
@ -171,7 +171,7 @@ def test_run_llm_or_chain_with_input_mapper() -> None:
|
|||||||
assert "the right input" in inputs
|
assert "the right input" in inputs
|
||||||
return {"output": "2"}
|
return {"output": "2"}
|
||||||
|
|
||||||
mock_chain = TransformChain( # type: ignore[call-arg]
|
mock_chain = TransformChain(
|
||||||
input_variables=["the right input"],
|
input_variables=["the right input"],
|
||||||
output_variables=["output"],
|
output_variables=["output"],
|
||||||
transform=run_val,
|
transform=run_val,
|
||||||
|
Loading…
Reference in New Issue
Block a user