From 56bbfd97234ee4beddba196fe419adaee70f6097 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Mon, 7 Jul 2025 17:33:18 +0200 Subject: [PATCH] langchain: Add ruff rule RET (#31875) All auto-fixes See https://docs.astral.sh/ruff/rules/#flake8-return-ret --------- Co-authored-by: Mason Daugherty --- libs/langchain/langchain/__init__.py | 101 +++++---- libs/langchain/langchain/agents/agent.py | 52 ++--- libs/langchain/langchain/agents/chat/base.py | 3 +- .../conversational_chat/output_parser.py | 16 +- .../format_scratchpad/openai_functions.py | 3 +- .../langchain/agents/json_chat/base.py | 3 +- .../langchain/agents/mrkl/output_parser.py | 14 +- .../langchain/agents/openai_assistant/base.py | 33 ++- .../agents/openai_functions_agent/base.py | 25 +-- .../openai_functions_multi_agent/base.py | 6 +- .../langchain/agents/openai_tools/base.py | 3 +- .../langchain/agents/output_parsers/json.py | 9 +- .../output_parsers/react_single_input.py | 9 +- .../langchain/agents/output_parsers/xml.py | 5 +- .../langchain/langchain/agents/react/agent.py | 3 +- libs/langchain/langchain/agents/react/base.py | 12 +- .../langchain/agents/react/output_parser.py | 3 +- .../agents/self_ask_with_search/base.py | 3 +- .../langchain/agents/structured_chat/base.py | 6 +- .../agents/structured_chat/output_parser.py | 15 +- .../agents/tool_calling_agent/base.py | 3 +- libs/langchain/langchain/agents/xml/base.py | 3 +- .../callbacks/streaming_aiter_final_only.py | 3 +- .../callbacks/streaming_stdout_final_only.py | 3 +- libs/langchain/langchain/chains/base.py | 22 +- .../chains/combine_documents/map_reduce.py | 29 ++- .../chains/combine_documents/reduce.py | 3 +- .../chains/combine_documents/refine.py | 3 +- .../chains/constitutional_ai/base.py | 3 +- .../chains/elasticsearch_database/base.py | 3 +- libs/langchain/langchain/chains/hyde/base.py | 3 +- libs/langchain/langchain/chains/llm.py | 71 +++---- .../langchain/chains/llm_checker/base.py | 3 +- .../chains/llm_summarization_checker/base.py | 3 +- libs/langchain/langchain/chains/loading.py | 6 +- libs/langchain/langchain/chains/moderation.py | 3 +- .../langchain/chains/openai_functions/base.py | 3 +- .../openai_functions/citation_fuzzy_match.py | 3 +- .../chains/openai_functions/extraction.py | 6 +- .../openai_functions/qa_with_structure.py | 3 +- .../chains/openai_functions/tagging.py | 6 +- .../chains/openai_tools/extraction.py | 3 +- .../chains/query_constructor/base.py | 21 +- .../chains/query_constructor/parser.py | 18 +- libs/langchain/langchain/chains/retrieval.py | 4 +- .../langchain/chains/retrieval_qa/base.py | 6 +- .../langchain/langchain/chains/router/base.py | 18 +- .../langchain/chains/router/llm_router.py | 6 +- .../langchain/chains/sql_database/query.py | 3 +- .../chains/structured_output/base.py | 24 +-- libs/langchain/langchain/chains/transform.py | 11 +- libs/langchain/langchain/chat_models/base.py | 119 +++++------ libs/langchain/langchain/embeddings/base.py | 27 ++- .../evaluation/comparison/eval_chain.py | 2 +- .../evaluation/embedding_distance/base.py | 5 +- .../langchain/langchain/evaluation/loading.py | 3 +- .../evaluation/parsing/json_schema.py | 2 +- .../langchain/evaluation/qa/eval_chain.py | 6 +- libs/langchain/langchain/evaluation/schema.py | 4 +- .../evaluation/scoring/eval_chain.py | 2 +- .../evaluation/string_distance/base.py | 3 +- libs/langchain/langchain/hub.py | 11 +- libs/langchain/langchain/llms/__init__.py | 3 +- libs/langchain/langchain/memory/entity.py | 3 + .../langchain/output_parsers/boolean.py | 2 +- .../langchain/langchain/output_parsers/fix.py | 90 ++++---- .../output_parsers/pandas_dataframe.py | 2 +- .../langchain/output_parsers/regex.py | 16 +- .../langchain/output_parsers/regex_dict.py | 9 +- .../langchain/output_parsers/retry.py | 110 +++++----- .../langchain/output_parsers/structured.py | 3 +- .../langchain/output_parsers/yaml.py | 3 +- .../retrievers/contextual_compression.py | 6 +- .../langchain/retrievers/ensemble.py | 19 +- .../langchain/retrievers/merger_retriever.py | 8 +- .../langchain/retrievers/re_phraser.py | 3 +- .../langchain/retrievers/self_query/base.py | 193 +++++++++--------- .../smith/evaluation/runner_utils.py | 77 ++++--- .../smith/evaluation/string_run_evaluator.py | 64 +++--- libs/langchain/langchain/tools/__init__.py | 29 ++- libs/langchain/pyproject.toml | 2 +- .../tests/mock_servers/robot/server.py | 14 +- .../tests/unit_tests/agents/test_agent.py | 3 +- .../unit_tests/agents/test_agent_async.py | 3 +- .../tests/unit_tests/agents/test_chat.py | 3 +- .../tests/unit_tests/agents/test_mrkl.py | 3 +- .../unit_tests/agents/test_structured_chat.py | 7 +- .../tests/unit_tests/chains/test_base.py | 3 +- .../evaluation/agents/test_eval_chain.py | 5 +- .../tests/unit_tests/llms/fake_llm.py | 3 +- .../retrievers/sequential_retriever.py | 5 +- 91 files changed, 663 insertions(+), 835 deletions(-) diff --git a/libs/langchain/langchain/__init__.py b/libs/langchain/langchain/__init__.py index f4e9e521d3e..0628b251a68 100644 --- a/libs/langchain/langchain/__init__.py +++ b/libs/langchain/langchain/__init__.py @@ -48,25 +48,25 @@ def __getattr__(name: str) -> Any: _warn_on_import(name, replacement="langchain.agents.MRKLChain") return MRKLChain - elif name == "ReActChain": + if name == "ReActChain": from langchain.agents import ReActChain _warn_on_import(name, replacement="langchain.agents.ReActChain") return ReActChain - elif name == "SelfAskWithSearchChain": + if name == "SelfAskWithSearchChain": from langchain.agents import SelfAskWithSearchChain _warn_on_import(name, replacement="langchain.agents.SelfAskWithSearchChain") return SelfAskWithSearchChain - elif name == "ConversationChain": + if name == "ConversationChain": from langchain.chains import ConversationChain _warn_on_import(name, replacement="langchain.chains.ConversationChain") return ConversationChain - elif name == "LLMBashChain": + if name == "LLMBashChain": msg = ( "This module has been moved to langchain-experimental. " "For more details: " @@ -77,97 +77,97 @@ def __getattr__(name: str) -> Any: ) raise ImportError(msg) - elif name == "LLMChain": + if name == "LLMChain": from langchain.chains import LLMChain _warn_on_import(name, replacement="langchain.chains.LLMChain") return LLMChain - elif name == "LLMCheckerChain": + if name == "LLMCheckerChain": from langchain.chains import LLMCheckerChain _warn_on_import(name, replacement="langchain.chains.LLMCheckerChain") return LLMCheckerChain - elif name == "LLMMathChain": + if name == "LLMMathChain": from langchain.chains import LLMMathChain _warn_on_import(name, replacement="langchain.chains.LLMMathChain") return LLMMathChain - elif name == "QAWithSourcesChain": + if name == "QAWithSourcesChain": from langchain.chains import QAWithSourcesChain _warn_on_import(name, replacement="langchain.chains.QAWithSourcesChain") return QAWithSourcesChain - elif name == "VectorDBQA": + if name == "VectorDBQA": from langchain.chains import VectorDBQA _warn_on_import(name, replacement="langchain.chains.VectorDBQA") return VectorDBQA - elif name == "VectorDBQAWithSourcesChain": + if name == "VectorDBQAWithSourcesChain": from langchain.chains import VectorDBQAWithSourcesChain _warn_on_import(name, replacement="langchain.chains.VectorDBQAWithSourcesChain") return VectorDBQAWithSourcesChain - elif name == "InMemoryDocstore": + if name == "InMemoryDocstore": from langchain_community.docstore import InMemoryDocstore _warn_on_import(name, replacement="langchain.docstore.InMemoryDocstore") return InMemoryDocstore - elif name == "Wikipedia": + if name == "Wikipedia": from langchain_community.docstore import Wikipedia _warn_on_import(name, replacement="langchain.docstore.Wikipedia") return Wikipedia - elif name == "Anthropic": + if name == "Anthropic": from langchain_community.llms import Anthropic _warn_on_import(name, replacement="langchain_community.llms.Anthropic") return Anthropic - elif name == "Banana": + if name == "Banana": from langchain_community.llms import Banana _warn_on_import(name, replacement="langchain_community.llms.Banana") return Banana - elif name == "CerebriumAI": + if name == "CerebriumAI": from langchain_community.llms import CerebriumAI _warn_on_import(name, replacement="langchain_community.llms.CerebriumAI") return CerebriumAI - elif name == "Cohere": + if name == "Cohere": from langchain_community.llms import Cohere _warn_on_import(name, replacement="langchain_community.llms.Cohere") return Cohere - elif name == "ForefrontAI": + if name == "ForefrontAI": from langchain_community.llms import ForefrontAI _warn_on_import(name, replacement="langchain_community.llms.ForefrontAI") return ForefrontAI - elif name == "GooseAI": + if name == "GooseAI": from langchain_community.llms import GooseAI _warn_on_import(name, replacement="langchain_community.llms.GooseAI") return GooseAI - elif name == "HuggingFaceHub": + if name == "HuggingFaceHub": from langchain_community.llms import HuggingFaceHub _warn_on_import(name, replacement="langchain_community.llms.HuggingFaceHub") return HuggingFaceHub - elif name == "HuggingFaceTextGenInference": + if name == "HuggingFaceTextGenInference": from langchain_community.llms import HuggingFaceTextGenInference _warn_on_import( @@ -175,55 +175,55 @@ def __getattr__(name: str) -> Any: ) return HuggingFaceTextGenInference - elif name == "LlamaCpp": + if name == "LlamaCpp": from langchain_community.llms import LlamaCpp _warn_on_import(name, replacement="langchain_community.llms.LlamaCpp") return LlamaCpp - elif name == "Modal": + if name == "Modal": from langchain_community.llms import Modal _warn_on_import(name, replacement="langchain_community.llms.Modal") return Modal - elif name == "OpenAI": + if name == "OpenAI": from langchain_community.llms import OpenAI _warn_on_import(name, replacement="langchain_community.llms.OpenAI") return OpenAI - elif name == "Petals": + if name == "Petals": from langchain_community.llms import Petals _warn_on_import(name, replacement="langchain_community.llms.Petals") return Petals - elif name == "PipelineAI": + if name == "PipelineAI": from langchain_community.llms import PipelineAI _warn_on_import(name, replacement="langchain_community.llms.PipelineAI") return PipelineAI - elif name == "SagemakerEndpoint": + if name == "SagemakerEndpoint": from langchain_community.llms import SagemakerEndpoint _warn_on_import(name, replacement="langchain_community.llms.SagemakerEndpoint") return SagemakerEndpoint - elif name == "StochasticAI": + if name == "StochasticAI": from langchain_community.llms import StochasticAI _warn_on_import(name, replacement="langchain_community.llms.StochasticAI") return StochasticAI - elif name == "Writer": + if name == "Writer": from langchain_community.llms import Writer _warn_on_import(name, replacement="langchain_community.llms.Writer") return Writer - elif name == "HuggingFacePipeline": + if name == "HuggingFacePipeline": from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline _warn_on_import( @@ -232,7 +232,7 @@ def __getattr__(name: str) -> Any: ) return HuggingFacePipeline - elif name == "FewShotPromptTemplate": + if name == "FewShotPromptTemplate": from langchain_core.prompts import FewShotPromptTemplate _warn_on_import( @@ -240,7 +240,7 @@ def __getattr__(name: str) -> Any: ) return FewShotPromptTemplate - elif name == "Prompt": + if name == "Prompt": from langchain_core.prompts import PromptTemplate _warn_on_import(name, replacement="langchain_core.prompts.PromptTemplate") @@ -248,19 +248,19 @@ def __getattr__(name: str) -> Any: # it's renamed as prompt template anyways # this is just for backwards compat return PromptTemplate - elif name == "PromptTemplate": + if name == "PromptTemplate": from langchain_core.prompts import PromptTemplate _warn_on_import(name, replacement="langchain_core.prompts.PromptTemplate") return PromptTemplate - elif name == "BasePromptTemplate": + if name == "BasePromptTemplate": from langchain_core.prompts import BasePromptTemplate _warn_on_import(name, replacement="langchain_core.prompts.BasePromptTemplate") return BasePromptTemplate - elif name == "ArxivAPIWrapper": + if name == "ArxivAPIWrapper": from langchain_community.utilities import ArxivAPIWrapper _warn_on_import( @@ -268,7 +268,7 @@ def __getattr__(name: str) -> Any: ) return ArxivAPIWrapper - elif name == "GoldenQueryAPIWrapper": + if name == "GoldenQueryAPIWrapper": from langchain_community.utilities import GoldenQueryAPIWrapper _warn_on_import( @@ -276,7 +276,7 @@ def __getattr__(name: str) -> Any: ) return GoldenQueryAPIWrapper - elif name == "GoogleSearchAPIWrapper": + if name == "GoogleSearchAPIWrapper": from langchain_community.utilities import GoogleSearchAPIWrapper _warn_on_import( @@ -284,7 +284,7 @@ def __getattr__(name: str) -> Any: ) return GoogleSearchAPIWrapper - elif name == "GoogleSerperAPIWrapper": + if name == "GoogleSerperAPIWrapper": from langchain_community.utilities import GoogleSerperAPIWrapper _warn_on_import( @@ -292,7 +292,7 @@ def __getattr__(name: str) -> Any: ) return GoogleSerperAPIWrapper - elif name == "PowerBIDataset": + if name == "PowerBIDataset": from langchain_community.utilities import PowerBIDataset _warn_on_import( @@ -300,7 +300,7 @@ def __getattr__(name: str) -> Any: ) return PowerBIDataset - elif name == "SearxSearchWrapper": + if name == "SearxSearchWrapper": from langchain_community.utilities import SearxSearchWrapper _warn_on_import( @@ -308,7 +308,7 @@ def __getattr__(name: str) -> Any: ) return SearxSearchWrapper - elif name == "WikipediaAPIWrapper": + if name == "WikipediaAPIWrapper": from langchain_community.utilities import WikipediaAPIWrapper _warn_on_import( @@ -316,7 +316,7 @@ def __getattr__(name: str) -> Any: ) return WikipediaAPIWrapper - elif name == "WolframAlphaAPIWrapper": + if name == "WolframAlphaAPIWrapper": from langchain_community.utilities import WolframAlphaAPIWrapper _warn_on_import( @@ -324,19 +324,19 @@ def __getattr__(name: str) -> Any: ) return WolframAlphaAPIWrapper - elif name == "SQLDatabase": + if name == "SQLDatabase": from langchain_community.utilities import SQLDatabase _warn_on_import(name, replacement="langchain_community.utilities.SQLDatabase") return SQLDatabase - elif name == "FAISS": + if name == "FAISS": from langchain_community.vectorstores import FAISS _warn_on_import(name, replacement="langchain_community.vectorstores.FAISS") return FAISS - elif name == "ElasticVectorSearch": + if name == "ElasticVectorSearch": from langchain_community.vectorstores import ElasticVectorSearch _warn_on_import( @@ -345,7 +345,7 @@ def __getattr__(name: str) -> Any: return ElasticVectorSearch # For backwards compatibility - elif name == "SerpAPIChain" or name == "SerpAPIWrapper": + if name == "SerpAPIChain" or name == "SerpAPIWrapper": from langchain_community.utilities import SerpAPIWrapper _warn_on_import( @@ -353,7 +353,7 @@ def __getattr__(name: str) -> Any: ) return SerpAPIWrapper - elif name == "verbose": + if name == "verbose": from langchain.globals import _verbose _warn_on_import( @@ -364,7 +364,7 @@ def __getattr__(name: str) -> Any: ) return _verbose - elif name == "debug": + if name == "debug": from langchain.globals import _debug _warn_on_import( @@ -375,7 +375,7 @@ def __getattr__(name: str) -> Any: ) return _debug - elif name == "llm_cache": + if name == "llm_cache": from langchain.globals import _llm_cache _warn_on_import( @@ -386,9 +386,8 @@ def __getattr__(name: str) -> Any: ) return _llm_cache - else: - msg = f"Could not find: {name}" - raise AttributeError(msg) + msg = f"Could not find: {name}" + raise AttributeError(msg) __all__ = [ diff --git a/libs/langchain/langchain/agents/agent.py b/libs/langchain/langchain/agents/agent.py index 160fec14832..d0fa9d376a1 100644 --- a/libs/langchain/langchain/agents/agent.py +++ b/libs/langchain/langchain/agents/agent.py @@ -137,9 +137,8 @@ class BaseSingleActionAgent(BaseModel): return AgentFinish( {"output": "Agent stopped due to iteration limit or time limit."}, "" ) - else: - msg = f"Got unsupported early_stopping_method `{early_stopping_method}`" - raise ValueError(msg) + msg = f"Got unsupported early_stopping_method `{early_stopping_method}`" + raise ValueError(msg) @classmethod def from_llm_and_tools( @@ -308,9 +307,8 @@ class BaseMultiActionAgent(BaseModel): if early_stopping_method == "force": # `force` just returns a constant string return AgentFinish({"output": "Agent stopped due to max iterations."}, "") - else: - msg = f"Got unsupported early_stopping_method `{early_stopping_method}`" - raise ValueError(msg) + msg = f"Got unsupported early_stopping_method `{early_stopping_method}`" + raise ValueError(msg) @property def _agent_type(self) -> str: @@ -815,8 +813,7 @@ class Agent(BaseSingleActionAgent): """ full_inputs = self.get_full_inputs(intermediate_steps, **kwargs) full_output = await self.llm_chain.apredict(callbacks=callbacks, **full_inputs) - agent_output = await self.output_parser.aparse(full_output) - return agent_output + return await self.output_parser.aparse(full_output) def get_full_inputs( self, intermediate_steps: list[tuple[AgentAction, str]], **kwargs: Any @@ -833,8 +830,7 @@ class Agent(BaseSingleActionAgent): """ thoughts = self._construct_scratchpad(intermediate_steps) new_inputs = {"agent_scratchpad": thoughts, "stop": self._stop} - full_inputs = {**kwargs, **new_inputs} - return full_inputs + return {**kwargs, **new_inputs} @property def input_keys(self) -> list[str]: @@ -970,7 +966,7 @@ class Agent(BaseSingleActionAgent): return AgentFinish( {"output": "Agent stopped due to iteration limit or time limit."}, "" ) - elif early_stopping_method == "generate": + if early_stopping_method == "generate": # Generate does one final forward pass thoughts = "" for action, observation in intermediate_steps: @@ -990,16 +986,14 @@ class Agent(BaseSingleActionAgent): if isinstance(parsed_output, AgentFinish): # If we can extract, we send the correct stuff return parsed_output - else: - # If we can extract, but the tool is not the final tool, - # we just return the full output - return AgentFinish({"output": full_output}, full_output) - else: - msg = ( - "early_stopping_method should be one of `force` or `generate`, " - f"got {early_stopping_method}" - ) - raise ValueError(msg) + # If we can extract, but the tool is not the final tool, + # we just return the full output + return AgentFinish({"output": full_output}, full_output) + msg = ( + "early_stopping_method should be one of `force` or `generate`, " + f"got {early_stopping_method}" + ) + raise ValueError(msg) def tool_run_logging_kwargs(self) -> builtins.dict: """Return logging kwargs for tool run.""" @@ -1179,8 +1173,7 @@ class AgentExecutor(Chain): """ if isinstance(self.agent, Runnable): return cast(RunnableAgentType, self.agent) - else: - return self.agent + return self.agent def save(self, file_path: Union[Path, str]) -> None: """Raise error - saving not supported for Agent Executors. @@ -1249,8 +1242,7 @@ class AgentExecutor(Chain): """ if self.return_intermediate_steps: return self._action_agent.return_values + ["intermediate_steps"] - else: - return self._action_agent.return_values + return self._action_agent.return_values def lookup_tool(self, name: str) -> BaseTool: """Lookup tool by name. @@ -1304,10 +1296,7 @@ class AgentExecutor(Chain): msg = "Expected a single AgentFinish output, but got multiple values." raise ValueError(msg) return values[-1] - else: - return [ - (a.action, a.observation) for a in values if isinstance(a, AgentStep) - ] + return [(a.action, a.observation) for a in values if isinstance(a, AgentStep)] def _take_next_step( self, @@ -1727,10 +1716,9 @@ class AgentExecutor(Chain): and self.trim_intermediate_steps > 0 ): return intermediate_steps[-self.trim_intermediate_steps :] - elif callable(self.trim_intermediate_steps): + if callable(self.trim_intermediate_steps): return self.trim_intermediate_steps(intermediate_steps) - else: - return intermediate_steps + return intermediate_steps @override def stream( diff --git a/libs/langchain/langchain/agents/chat/base.py b/libs/langchain/langchain/agents/chat/base.py index 718b97603b8..3409e1b607d 100644 --- a/libs/langchain/langchain/agents/chat/base.py +++ b/libs/langchain/langchain/agents/chat/base.py @@ -61,8 +61,7 @@ class ChatAgent(Agent): f"(but I haven't seen any of it! I only see what " f"you return as final answer):\n{agent_scratchpad}" ) - else: - return agent_scratchpad + return agent_scratchpad @classmethod def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser: diff --git a/libs/langchain/langchain/agents/conversational_chat/output_parser.py b/libs/langchain/langchain/agents/conversational_chat/output_parser.py index d8192f7bedc..b6867a20415 100644 --- a/libs/langchain/langchain/agents/conversational_chat/output_parser.py +++ b/libs/langchain/langchain/agents/conversational_chat/output_parser.py @@ -39,15 +39,13 @@ class ConvoOutputParser(AgentOutputParser): # If the action indicates a final answer, return an AgentFinish if action == "Final Answer": return AgentFinish({"output": action_input}, text) - else: - # Otherwise, return an AgentAction with the specified action and - # input - return AgentAction(action, action_input, text) - else: - # If the necessary keys aren't present in the response, raise an - # exception - msg = f"Missing 'action' or 'action_input' in LLM output: {text}" - raise OutputParserException(msg) + # Otherwise, return an AgentAction with the specified action and + # input + return AgentAction(action, action_input, text) + # If the necessary keys aren't present in the response, raise an + # exception + msg = f"Missing 'action' or 'action_input' in LLM output: {text}" + raise OutputParserException(msg) except Exception as e: # If any other exception is raised during parsing, also raise an # OutputParserException diff --git a/libs/langchain/langchain/agents/format_scratchpad/openai_functions.py b/libs/langchain/langchain/agents/format_scratchpad/openai_functions.py index 172c4a677ea..ee08a8f8a27 100644 --- a/libs/langchain/langchain/agents/format_scratchpad/openai_functions.py +++ b/libs/langchain/langchain/agents/format_scratchpad/openai_functions.py @@ -23,8 +23,7 @@ def _convert_agent_action_to_messages( return list(agent_action.message_log) + [ _create_function_message(agent_action, observation) ] - else: - return [AIMessage(content=agent_action.log)] + return [AIMessage(content=agent_action.log)] def _create_function_message( diff --git a/libs/langchain/langchain/agents/json_chat/base.py b/libs/langchain/langchain/agents/json_chat/base.py index 6c61f133579..fc58c4b7a34 100644 --- a/libs/langchain/langchain/agents/json_chat/base.py +++ b/libs/langchain/langchain/agents/json_chat/base.py @@ -182,7 +182,7 @@ def create_json_chat_agent( else: llm_to_use = llm - agent = ( + return ( RunnablePassthrough.assign( agent_scratchpad=lambda x: format_log_to_messages( x["intermediate_steps"], template_tool_response=template_tool_response @@ -192,4 +192,3 @@ def create_json_chat_agent( | llm_to_use | JSONAgentOutputParser() ) - return agent diff --git a/libs/langchain/langchain/agents/mrkl/output_parser.py b/libs/langchain/langchain/agents/mrkl/output_parser.py index d9390c698d7..9be800f297b 100644 --- a/libs/langchain/langchain/agents/mrkl/output_parser.py +++ b/libs/langchain/langchain/agents/mrkl/output_parser.py @@ -55,9 +55,8 @@ class MRKLOutputParser(AgentOutputParser): return AgentFinish( {"output": text[start_index:end_index].strip()}, text[:end_index] ) - else: - msg = f"{FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE}: {text}" - raise OutputParserException(msg) + msg = f"{FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE}: {text}" + raise OutputParserException(msg) if action_match: action = action_match.group(1).strip() @@ -69,7 +68,7 @@ class MRKLOutputParser(AgentOutputParser): return AgentAction(action, tool_input, text) - elif includes_answer: + if includes_answer: return AgentFinish( {"output": text.split(FINAL_ANSWER_ACTION)[-1].strip()}, text ) @@ -82,7 +81,7 @@ class MRKLOutputParser(AgentOutputParser): llm_output=text, send_to_llm=True, ) - elif not re.search( + if not re.search( r"[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)", text, re.DOTALL ): msg = f"Could not parse LLM output: `{text}`" @@ -92,9 +91,8 @@ class MRKLOutputParser(AgentOutputParser): llm_output=text, send_to_llm=True, ) - else: - msg = f"Could not parse LLM output: `{text}`" - raise OutputParserException(msg) + msg = f"Could not parse LLM output: `{text}`" + raise OutputParserException(msg) @property def _type(self) -> str: diff --git a/libs/langchain/langchain/agents/openai_assistant/base.py b/libs/langchain/langchain/agents/openai_assistant/base.py index 5afbf4a23e8..915a4a0b270 100644 --- a/libs/langchain/langchain/agents/openai_assistant/base.py +++ b/libs/langchain/langchain/agents/openai_assistant/base.py @@ -128,8 +128,7 @@ def _get_assistants_tool( """ if _is_assistants_builtin_tool(tool): return tool # type: ignore[return-value] - else: - return convert_to_openai_tool(tool) + return convert_to_openai_tool(tool) OutputType = Union[ @@ -510,12 +509,11 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]): for action, output in intermediate_steps if action.tool_call_id in required_tool_call_ids ] - submit_tool_outputs = { + return { "tool_outputs": tool_outputs, "run_id": last_action.run_id, "thread_id": last_action.thread_id, } - return submit_tool_outputs def _create_run(self, input_dict: dict) -> Any: params = { @@ -558,12 +556,11 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]): "run_metadata", ) } - run = self.client.beta.threads.create_and_run( + return self.client.beta.threads.create_and_run( assistant_id=self.assistant_id, thread=thread, **params, ) - return run def _get_response(self, run: Any) -> Any: # TODO: Pagination @@ -612,7 +609,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]): run_id=run.id, thread_id=run.thread_id, ) - elif run.status == "requires_action": + if run.status == "requires_action": if not self.as_agent: return run.required_action.submit_tool_outputs.tool_calls actions = [] @@ -639,10 +636,9 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]): ) ) return actions - else: - run_info = json.dumps(run.dict(), indent=2) - msg = f"Unexpected run status: {run.status}. Full run info:\n\n{run_info})" - raise ValueError(msg) + run_info = json.dumps(run.dict(), indent=2) + msg = f"Unexpected run status: {run.status}. Full run info:\n\n{run_info})" + raise ValueError(msg) def _wait_for_run(self, run_id: str, thread_id: str) -> Any: in_progress = True @@ -668,12 +664,11 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]): for action, output in intermediate_steps if action.tool_call_id in required_tool_call_ids ] - submit_tool_outputs = { + return { "tool_outputs": tool_outputs, "run_id": last_action.run_id, "thread_id": last_action.thread_id, } - return submit_tool_outputs async def _acreate_run(self, input_dict: dict) -> Any: params = { @@ -716,12 +711,11 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]): "run_metadata", ) } - run = await self.async_client.beta.threads.create_and_run( + return await self.async_client.beta.threads.create_and_run( assistant_id=self.assistant_id, thread=thread, **params, ) - return run async def _aget_response(self, run: Any) -> Any: # TODO: Pagination @@ -766,7 +760,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]): run_id=run.id, thread_id=run.thread_id, ) - elif run.status == "requires_action": + if run.status == "requires_action": if not self.as_agent: return run.required_action.submit_tool_outputs.tool_calls actions = [] @@ -793,10 +787,9 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]): ) ) return actions - else: - run_info = json.dumps(run.dict(), indent=2) - msg = f"Unexpected run status: {run.status}. Full run info:\n\n{run_info})" - raise ValueError(msg) + run_info = json.dumps(run.dict(), indent=2) + msg = f"Unexpected run status: {run.status}. Full run info:\n\n{run_info})" + raise ValueError(msg) async def _await_for_run(self, run_id: str, thread_id: str) -> Any: in_progress = True diff --git a/libs/langchain/langchain/agents/openai_functions_agent/base.py b/libs/langchain/langchain/agents/openai_functions_agent/base.py index fa182072478..7c0baea759e 100644 --- a/libs/langchain/langchain/agents/openai_functions_agent/base.py +++ b/libs/langchain/langchain/agents/openai_functions_agent/base.py @@ -132,8 +132,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent): messages, callbacks=callbacks, ) - agent_decision = self.output_parser._parse_ai_message(predicted_message) - return agent_decision + return self.output_parser._parse_ai_message(predicted_message) async def aplan( self, @@ -164,8 +163,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent): predicted_message = await self.llm.apredict_messages( messages, functions=self.functions, callbacks=callbacks ) - agent_decision = self.output_parser._parse_ai_message(predicted_message) - return agent_decision + return self.output_parser._parse_ai_message(predicted_message) def return_stopped_response( self, @@ -192,22 +190,20 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent): return AgentFinish( {"output": "Agent stopped due to iteration limit or time limit."}, "" ) - elif early_stopping_method == "generate": + if early_stopping_method == "generate": # Generate does one final forward pass agent_decision = self.plan( intermediate_steps, with_functions=False, **kwargs ) if isinstance(agent_decision, AgentFinish): return agent_decision - else: - msg = f"got AgentAction with no functions provided: {agent_decision}" - raise ValueError(msg) - else: - msg = ( - "early_stopping_method should be one of `force` or `generate`, " - f"got {early_stopping_method}" - ) + msg = f"got AgentAction with no functions provided: {agent_decision}" raise ValueError(msg) + msg = ( + "early_stopping_method should be one of `force` or `generate`, " + f"got {early_stopping_method}" + ) + raise ValueError(msg) @classmethod def create_prompt( @@ -358,7 +354,7 @@ def create_openai_functions_agent( ) raise ValueError(msg) llm_with_tools = llm.bind(functions=[convert_to_openai_function(t) for t in tools]) - agent = ( + return ( RunnablePassthrough.assign( agent_scratchpad=lambda x: format_to_openai_function_messages( x["intermediate_steps"] @@ -368,4 +364,3 @@ def create_openai_functions_agent( | llm_with_tools | OpenAIFunctionsAgentOutputParser() ) - return agent diff --git a/libs/langchain/langchain/agents/openai_functions_multi_agent/base.py b/libs/langchain/langchain/agents/openai_functions_multi_agent/base.py index 9a949a1ff3d..7824dfffce7 100644 --- a/libs/langchain/langchain/agents/openai_functions_multi_agent/base.py +++ b/libs/langchain/langchain/agents/openai_functions_multi_agent/base.py @@ -224,8 +224,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent): predicted_message = self.llm.predict_messages( messages, functions=self.functions, callbacks=callbacks ) - agent_decision = _parse_ai_message(predicted_message) - return agent_decision + return _parse_ai_message(predicted_message) async def aplan( self, @@ -254,8 +253,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent): predicted_message = await self.llm.apredict_messages( messages, functions=self.functions, callbacks=callbacks ) - agent_decision = _parse_ai_message(predicted_message) - return agent_decision + return _parse_ai_message(predicted_message) @classmethod def create_prompt( diff --git a/libs/langchain/langchain/agents/openai_tools/base.py b/libs/langchain/langchain/agents/openai_tools/base.py index 6e53662aecd..4dad4505738 100644 --- a/libs/langchain/langchain/agents/openai_tools/base.py +++ b/libs/langchain/langchain/agents/openai_tools/base.py @@ -96,7 +96,7 @@ def create_openai_tools_agent( tools=[convert_to_openai_tool(tool, strict=strict) for tool in tools] ) - agent = ( + return ( RunnablePassthrough.assign( agent_scratchpad=lambda x: format_to_openai_tool_messages( x["intermediate_steps"] @@ -106,4 +106,3 @@ def create_openai_tools_agent( | llm_with_tools | OpenAIToolsAgentOutputParser() ) - return agent diff --git a/libs/langchain/langchain/agents/output_parsers/json.py b/libs/langchain/langchain/agents/output_parsers/json.py index cca8f1c5054..c78c3562d7b 100644 --- a/libs/langchain/langchain/agents/output_parsers/json.py +++ b/libs/langchain/langchain/agents/output_parsers/json.py @@ -49,11 +49,10 @@ class JSONAgentOutputParser(AgentOutputParser): response = response[0] if response["action"] == "Final Answer": return AgentFinish({"output": response["action_input"]}, text) - else: - action_input = response.get("action_input", {}) - if action_input is None: - action_input = {} - return AgentAction(response["action"], action_input, text) + action_input = response.get("action_input", {}) + if action_input is None: + action_input = {} + return AgentAction(response["action"], action_input, text) except Exception as e: msg = f"Could not parse LLM output: {text}" raise OutputParserException(msg) from e diff --git a/libs/langchain/langchain/agents/output_parsers/react_single_input.py b/libs/langchain/langchain/agents/output_parsers/react_single_input.py index d87aedfa11c..90f058be4ab 100644 --- a/libs/langchain/langchain/agents/output_parsers/react_single_input.py +++ b/libs/langchain/langchain/agents/output_parsers/react_single_input.py @@ -65,7 +65,7 @@ class ReActSingleInputOutputParser(AgentOutputParser): return AgentAction(action, tool_input, text) - elif includes_answer: + if includes_answer: return AgentFinish( {"output": text.split(FINAL_ANSWER_ACTION)[-1].strip()}, text ) @@ -78,7 +78,7 @@ class ReActSingleInputOutputParser(AgentOutputParser): llm_output=text, send_to_llm=True, ) - elif not re.search( + if not re.search( r"[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)", text, re.DOTALL ): msg = f"Could not parse LLM output: `{text}`" @@ -88,9 +88,8 @@ class ReActSingleInputOutputParser(AgentOutputParser): llm_output=text, send_to_llm=True, ) - else: - msg = f"Could not parse LLM output: `{text}`" - raise OutputParserException(msg) + msg = f"Could not parse LLM output: `{text}`" + raise OutputParserException(msg) @property def _type(self) -> str: diff --git a/libs/langchain/langchain/agents/output_parsers/xml.py b/libs/langchain/langchain/agents/output_parsers/xml.py index 730d069ae6a..e8a7587b1cd 100644 --- a/libs/langchain/langchain/agents/output_parsers/xml.py +++ b/libs/langchain/langchain/agents/output_parsers/xml.py @@ -36,13 +36,12 @@ class XMLAgentOutputParser(AgentOutputParser): if "" in _tool_input: _tool_input = _tool_input.split("")[0] return AgentAction(tool=_tool, tool_input=_tool_input, log=text) - elif "" in text: + if "" in text: _, answer = text.split("") if "" in answer: answer = answer.split("")[0] return AgentFinish(return_values={"output": answer}, log=text) - else: - raise ValueError + raise ValueError def get_format_instructions(self) -> str: raise NotImplementedError diff --git a/libs/langchain/langchain/agents/react/agent.py b/libs/langchain/langchain/agents/react/agent.py index a19b821fda3..8fcbb46724a 100644 --- a/libs/langchain/langchain/agents/react/agent.py +++ b/libs/langchain/langchain/agents/react/agent.py @@ -134,7 +134,7 @@ def create_react_agent( else: llm_with_stop = llm output_parser = output_parser or ReActSingleInputOutputParser() - agent = ( + return ( RunnablePassthrough.assign( agent_scratchpad=lambda x: format_log_to_str(x["intermediate_steps"]), ) @@ -142,4 +142,3 @@ def create_react_agent( | llm_with_stop | output_parser ) - return agent diff --git a/libs/langchain/langchain/agents/react/base.py b/libs/langchain/langchain/agents/react/base.py index c3c55206f6a..7430e639650 100644 --- a/libs/langchain/langchain/agents/react/base.py +++ b/libs/langchain/langchain/agents/react/base.py @@ -96,9 +96,8 @@ class DocstoreExplorer: if isinstance(result, Document): self.document = result return self._summary - else: - self.document = None - return result + self.document = None + return result def lookup(self, term: str) -> str: """Lookup a term in document (if saved).""" @@ -113,11 +112,10 @@ class DocstoreExplorer: lookups = [p for p in self._paragraphs if self.lookup_str in p.lower()] if len(lookups) == 0: return "No Results" - elif self.lookup_index >= len(lookups): + if self.lookup_index >= len(lookups): return "No More Results" - else: - result_prefix = f"(Result {self.lookup_index + 1}/{len(lookups)})" - return f"{result_prefix} {lookups[self.lookup_index]}" + result_prefix = f"(Result {self.lookup_index + 1}/{len(lookups)})" + return f"{result_prefix} {lookups[self.lookup_index]}" @property def _summary(self) -> str: diff --git a/libs/langchain/langchain/agents/react/output_parser.py b/libs/langchain/langchain/agents/react/output_parser.py index a632f22527d..eceaa27edac 100644 --- a/libs/langchain/langchain/agents/react/output_parser.py +++ b/libs/langchain/langchain/agents/react/output_parser.py @@ -26,8 +26,7 @@ class ReActOutputParser(AgentOutputParser): action, action_input = re_matches.group(1), re_matches.group(2) if action == "Finish": return AgentFinish({"output": action_input}, text) - else: - return AgentAction(action, action_input, text) + return AgentAction(action, action_input, text) @property def _type(self) -> str: diff --git a/libs/langchain/langchain/agents/self_ask_with_search/base.py b/libs/langchain/langchain/agents/self_ask_with_search/base.py index a04414fb0e2..63302caa751 100644 --- a/libs/langchain/langchain/agents/self_ask_with_search/base.py +++ b/libs/langchain/langchain/agents/self_ask_with_search/base.py @@ -195,7 +195,7 @@ def create_self_ask_with_search_agent( raise ValueError(msg) llm_with_stop = llm.bind(stop=["\nIntermediate answer:"]) - agent = ( + return ( RunnablePassthrough.assign( agent_scratchpad=lambda x: format_log_to_str( x["intermediate_steps"], @@ -209,4 +209,3 @@ def create_self_ask_with_search_agent( | llm_with_stop | SelfAskOutputParser() ) - return agent diff --git a/libs/langchain/langchain/agents/structured_chat/base.py b/libs/langchain/langchain/agents/structured_chat/base.py index f333955338f..bc1464dbc5e 100644 --- a/libs/langchain/langchain/agents/structured_chat/base.py +++ b/libs/langchain/langchain/agents/structured_chat/base.py @@ -62,8 +62,7 @@ class StructuredChatAgent(Agent): f"(but I haven't seen any of it! I only see what " f"you return as final answer):\n{agent_scratchpad}" ) - else: - return agent_scratchpad + return agent_scratchpad @classmethod def _validate_tools(cls, tools: Sequence[BaseTool]) -> None: @@ -292,7 +291,7 @@ def create_structured_chat_agent( else: llm_with_stop = llm - agent = ( + return ( RunnablePassthrough.assign( agent_scratchpad=lambda x: format_log_to_str(x["intermediate_steps"]), ) @@ -300,4 +299,3 @@ def create_structured_chat_agent( | llm_with_stop | JSONAgentOutputParser() ) - return agent diff --git a/libs/langchain/langchain/agents/structured_chat/output_parser.py b/libs/langchain/langchain/agents/structured_chat/output_parser.py index 985ca6156ea..0e6e072c970 100644 --- a/libs/langchain/langchain/agents/structured_chat/output_parser.py +++ b/libs/langchain/langchain/agents/structured_chat/output_parser.py @@ -42,12 +42,10 @@ class StructuredChatOutputParser(AgentOutputParser): response = response[0] if response["action"] == "Final Answer": return AgentFinish({"output": response["action_input"]}, text) - else: - return AgentAction( - response["action"], response.get("action_input", {}), text - ) - else: - return AgentFinish({"output": text}, text) + return AgentAction( + response["action"], response.get("action_input", {}), text + ) + return AgentFinish({"output": text}, text) except Exception as e: msg = f"Could not parse LLM output: {text}" raise OutputParserException(msg) from e @@ -93,10 +91,9 @@ class StructuredChatOutputParserWithRetries(AgentOutputParser): llm=llm, parser=base_parser ) return cls(output_fixing_parser=output_fixing_parser) - elif base_parser is not None: + if base_parser is not None: return cls(base_parser=base_parser) - else: - return cls() + return cls() @property def _type(self) -> str: diff --git a/libs/langchain/langchain/agents/tool_calling_agent/base.py b/libs/langchain/langchain/agents/tool_calling_agent/base.py index a7624f582a9..fd412fea703 100644 --- a/libs/langchain/langchain/agents/tool_calling_agent/base.py +++ b/libs/langchain/langchain/agents/tool_calling_agent/base.py @@ -100,7 +100,7 @@ def create_tool_calling_agent( ) llm_with_tools = llm.bind_tools(tools) - agent = ( + return ( RunnablePassthrough.assign( agent_scratchpad=lambda x: message_formatter(x["intermediate_steps"]) ) @@ -108,4 +108,3 @@ def create_tool_calling_agent( | llm_with_tools | ToolsAgentOutputParser() ) - return agent diff --git a/libs/langchain/langchain/agents/xml/base.py b/libs/langchain/langchain/agents/xml/base.py index e1658ba680b..347d474dfc2 100644 --- a/libs/langchain/langchain/agents/xml/base.py +++ b/libs/langchain/langchain/agents/xml/base.py @@ -221,7 +221,7 @@ def create_xml_agent( else: llm_with_stop = llm - agent = ( + return ( RunnablePassthrough.assign( agent_scratchpad=lambda x: format_xml(x["intermediate_steps"]), ) @@ -229,4 +229,3 @@ def create_xml_agent( | llm_with_stop | XMLAgentOutputParser() ) - return agent diff --git a/libs/langchain/langchain/callbacks/streaming_aiter_final_only.py b/libs/langchain/langchain/callbacks/streaming_aiter_final_only.py index fd1be579811..db46a3e146b 100644 --- a/libs/langchain/langchain/callbacks/streaming_aiter_final_only.py +++ b/libs/langchain/langchain/callbacks/streaming_aiter_final_only.py @@ -24,8 +24,7 @@ class AsyncFinalIteratorCallbackHandler(AsyncIteratorCallbackHandler): def check_if_answer_reached(self) -> bool: if self.strip_tokens: return self.last_tokens_stripped == self.answer_prefix_tokens_stripped - else: - return self.last_tokens == self.answer_prefix_tokens + return self.last_tokens == self.answer_prefix_tokens def __init__( self, diff --git a/libs/langchain/langchain/callbacks/streaming_stdout_final_only.py b/libs/langchain/langchain/callbacks/streaming_stdout_final_only.py index 5a963abf74c..674e5579175 100644 --- a/libs/langchain/langchain/callbacks/streaming_stdout_final_only.py +++ b/libs/langchain/langchain/callbacks/streaming_stdout_final_only.py @@ -25,8 +25,7 @@ class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler): def check_if_answer_reached(self) -> bool: if self.strip_tokens: return self.last_tokens_stripped == self.answer_prefix_tokens_stripped - else: - return self.last_tokens == self.answer_prefix_tokens + return self.last_tokens == self.answer_prefix_tokens def __init__( self, diff --git a/libs/langchain/langchain/chains/base.py b/libs/langchain/langchain/chains/base.py index dc0d131a630..f624c0ae1ac 100644 --- a/libs/langchain/langchain/chains/base.py +++ b/libs/langchain/langchain/chains/base.py @@ -261,8 +261,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC): """ if verbose is None: return _get_verbosity() - else: - return verbose + return verbose @property @abstractmethod @@ -474,8 +473,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC): self.memory.save_context(inputs, outputs) if return_only_outputs: return outputs - else: - return {**inputs, **outputs} + return {**inputs, **outputs} async def aprep_outputs( self, @@ -500,8 +498,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC): await self.memory.asave_context(inputs, outputs) if return_only_outputs: return outputs - else: - return {**inputs, **outputs} + return {**inputs, **outputs} def prep_inputs(self, inputs: Union[dict[str, Any], Any]) -> dict[str, str]: """Prepare chain inputs, including adding inputs from memory. @@ -628,12 +625,11 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC): " but none were provided." ) raise ValueError(msg) - else: - msg = ( - f"`run` supported with either positional arguments or keyword arguments" - f" but not both. Got args: {args} and kwargs: {kwargs}." - ) - raise ValueError(msg) + msg = ( + f"`run` supported with either positional arguments or keyword arguments" + f" but not both. Got args: {args} and kwargs: {kwargs}." + ) + raise ValueError(msg) @deprecated("0.1.0", alternative="ainvoke", removal="1.0") async def arun( @@ -687,7 +683,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC): f"one output key. Got {self.output_keys}." ) raise ValueError(msg) - elif args and not kwargs: + if args and not kwargs: if len(args) != 1: msg = "`run` supports only one positional argument." raise ValueError(msg) diff --git a/libs/langchain/langchain/chains/combine_documents/map_reduce.py b/libs/langchain/langchain/chains/combine_documents/map_reduce.py index 2c7c199ce0c..bd5201e826b 100644 --- a/libs/langchain/langchain/chains/combine_documents/map_reduce.py +++ b/libs/langchain/langchain/chains/combine_documents/map_reduce.py @@ -208,28 +208,25 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain): if isinstance(self.reduce_documents_chain, ReduceDocumentsChain): if self.reduce_documents_chain.collapse_documents_chain: return self.reduce_documents_chain.collapse_documents_chain - else: - return self.reduce_documents_chain.combine_documents_chain - else: - msg = ( - f"`reduce_documents_chain` is of type " - f"{type(self.reduce_documents_chain)} so it does not have " - f"this attribute." - ) - raise ValueError(msg) + return self.reduce_documents_chain.combine_documents_chain + msg = ( + f"`reduce_documents_chain` is of type " + f"{type(self.reduce_documents_chain)} so it does not have " + f"this attribute." + ) + raise ValueError(msg) @property def combine_document_chain(self) -> BaseCombineDocumentsChain: """Kept for backward compatibility.""" if isinstance(self.reduce_documents_chain, ReduceDocumentsChain): return self.reduce_documents_chain.combine_documents_chain - else: - msg = ( - f"`reduce_documents_chain` is of type " - f"{type(self.reduce_documents_chain)} so it does not have " - f"this attribute." - ) - raise ValueError(msg) + msg = ( + f"`reduce_documents_chain` is of type " + f"{type(self.reduce_documents_chain)} so it does not have " + f"this attribute." + ) + raise ValueError(msg) def combine_docs( self, diff --git a/libs/langchain/langchain/chains/combine_documents/reduce.py b/libs/langchain/langchain/chains/combine_documents/reduce.py index 0814c68a8dd..fab6b53eea1 100644 --- a/libs/langchain/langchain/chains/combine_documents/reduce.py +++ b/libs/langchain/langchain/chains/combine_documents/reduce.py @@ -225,8 +225,7 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain): def _collapse_chain(self) -> BaseCombineDocumentsChain: if self.collapse_documents_chain is not None: return self.collapse_documents_chain - else: - return self.combine_documents_chain + return self.combine_documents_chain def combine_docs( self, diff --git a/libs/langchain/langchain/chains/combine_documents/refine.py b/libs/langchain/langchain/chains/combine_documents/refine.py index f3bf793e76c..34ac3d65042 100644 --- a/libs/langchain/langchain/chains/combine_documents/refine.py +++ b/libs/langchain/langchain/chains/combine_documents/refine.py @@ -222,8 +222,7 @@ class RefineDocumentsChain(BaseCombineDocumentsChain): base_inputs: dict = { self.document_variable_name: self.document_prompt.format(**document_info) } - inputs = {**base_inputs, **kwargs} - return inputs + return {**base_inputs, **kwargs} @property def _chain_type(self) -> str: diff --git a/libs/langchain/langchain/chains/constitutional_ai/base.py b/libs/langchain/langchain/chains/constitutional_ai/base.py index 9b7a1f6166d..cdcbb89d2d8 100644 --- a/libs/langchain/langchain/chains/constitutional_ai/base.py +++ b/libs/langchain/langchain/chains/constitutional_ai/base.py @@ -201,8 +201,7 @@ class ConstitutionalChain(Chain): ) -> list[ConstitutionalPrinciple]: if names is None: return list(PRINCIPLES.values()) - else: - return [PRINCIPLES[name] for name in names] + return [PRINCIPLES[name] for name in names] @classmethod def from_llm( diff --git a/libs/langchain/langchain/chains/elasticsearch_database/base.py b/libs/langchain/langchain/chains/elasticsearch_database/base.py index e27ffa9d349..80ecc381be7 100644 --- a/libs/langchain/langchain/chains/elasticsearch_database/base.py +++ b/libs/langchain/langchain/chains/elasticsearch_database/base.py @@ -80,8 +80,7 @@ class ElasticsearchDatabaseChain(Chain): """ if not self.return_intermediate_steps: return [self.output_key] - else: - return [self.output_key, INTERMEDIATE_STEPS_KEY] + return [self.output_key, INTERMEDIATE_STEPS_KEY] def _list_indices(self) -> list[str]: all_indices = [ diff --git a/libs/langchain/langchain/chains/hyde/base.py b/libs/langchain/langchain/chains/hyde/base.py index b4f1a87bc3a..3ee532ac583 100644 --- a/libs/langchain/langchain/chains/hyde/base.py +++ b/libs/langchain/langchain/chains/hyde/base.py @@ -47,8 +47,7 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings): """Output keys for Hyde's LLM chain.""" if isinstance(self.llm_chain, LLMChain): return self.llm_chain.output_keys - else: - return ["text"] + return ["text"] def embed_documents(self, texts: list[str]) -> list[list[float]]: """Call the base embeddings.""" diff --git a/libs/langchain/langchain/chains/llm.py b/libs/langchain/langchain/chains/llm.py index 9489614224a..e02ce49e7e1 100644 --- a/libs/langchain/langchain/chains/llm.py +++ b/libs/langchain/langchain/chains/llm.py @@ -116,8 +116,7 @@ class LLMChain(Chain): """ if self.return_final_only: return [self.output_key] - else: - return [self.output_key, "full_generation"] + return [self.output_key, "full_generation"] def _call( self, @@ -142,17 +141,16 @@ class LLMChain(Chain): callbacks=callbacks, **self.llm_kwargs, ) - else: - results = self.llm.bind(stop=stop, **self.llm_kwargs).batch( - cast(list, prompts), {"callbacks": callbacks} - ) - generations: list[list[Generation]] = [] - for res in results: - if isinstance(res, BaseMessage): - generations.append([ChatGeneration(message=res)]) - else: - generations.append([Generation(text=res)]) - return LLMResult(generations=generations) + results = self.llm.bind(stop=stop, **self.llm_kwargs).batch( + cast(list, prompts), {"callbacks": callbacks} + ) + generations: list[list[Generation]] = [] + for res in results: + if isinstance(res, BaseMessage): + generations.append([ChatGeneration(message=res)]) + else: + generations.append([Generation(text=res)]) + return LLMResult(generations=generations) async def agenerate( self, @@ -169,17 +167,16 @@ class LLMChain(Chain): callbacks=callbacks, **self.llm_kwargs, ) - else: - results = await self.llm.bind(stop=stop, **self.llm_kwargs).abatch( - cast(list, prompts), {"callbacks": callbacks} - ) - generations: list[list[Generation]] = [] - for res in results: - if isinstance(res, BaseMessage): - generations.append([ChatGeneration(message=res)]) - else: - generations.append([Generation(text=res)]) - return LLMResult(generations=generations) + results = await self.llm.bind(stop=stop, **self.llm_kwargs).abatch( + cast(list, prompts), {"callbacks": callbacks} + ) + generations: list[list[Generation]] = [] + for res in results: + if isinstance(res, BaseMessage): + generations.append([ChatGeneration(message=res)]) + else: + generations.append([Generation(text=res)]) + return LLMResult(generations=generations) def prep_prompts( self, @@ -344,8 +341,7 @@ class LLMChain(Chain): result = self.predict(callbacks=callbacks, **kwargs) if self.prompt.output_parser is not None: return self.prompt.output_parser.parse(result) - else: - return result + return result async def apredict_and_parse( self, callbacks: Callbacks = None, **kwargs: Any @@ -358,8 +354,7 @@ class LLMChain(Chain): result = await self.apredict(callbacks=callbacks, **kwargs) if self.prompt.output_parser is not None: return self.prompt.output_parser.parse(result) - else: - return result + return result def apply_and_parse( self, input_list: list[dict[str, Any]], callbacks: Callbacks = None @@ -380,8 +375,7 @@ class LLMChain(Chain): self.prompt.output_parser.parse(res[self.output_key]) for res in generation ] - else: - return generation + return generation async def aapply_and_parse( self, input_list: list[dict[str, Any]], callbacks: Callbacks = None @@ -411,15 +405,14 @@ class LLMChain(Chain): def _get_language_model(llm_like: Runnable) -> BaseLanguageModel: if isinstance(llm_like, BaseLanguageModel): return llm_like - elif isinstance(llm_like, RunnableBinding): + if isinstance(llm_like, RunnableBinding): return _get_language_model(llm_like.bound) - elif isinstance(llm_like, RunnableWithFallbacks): + if isinstance(llm_like, RunnableWithFallbacks): return _get_language_model(llm_like.runnable) - elif isinstance(llm_like, (RunnableBranch, DynamicRunnable)): + if isinstance(llm_like, (RunnableBranch, DynamicRunnable)): return _get_language_model(llm_like.default) - else: - msg = ( - f"Unable to extract BaseLanguageModel from llm_like object of type " - f"{type(llm_like)}" - ) - raise ValueError(msg) + msg = ( + f"Unable to extract BaseLanguageModel from llm_like object of type " + f"{type(llm_like)}" + ) + raise ValueError(msg) diff --git a/libs/langchain/langchain/chains/llm_checker/base.py b/libs/langchain/langchain/chains/llm_checker/base.py index 56d8ba7b748..365fc5195cf 100644 --- a/libs/langchain/langchain/chains/llm_checker/base.py +++ b/libs/langchain/langchain/chains/llm_checker/base.py @@ -55,13 +55,12 @@ def _load_question_to_checked_assertions_chain( check_assertions_chain, revised_answer_chain, ] - question_to_checked_assertions_chain = SequentialChain( + return SequentialChain( chains=chains, # type: ignore[arg-type] input_variables=["question"], output_variables=["revised_statement"], verbose=True, ) - return question_to_checked_assertions_chain @deprecated( diff --git a/libs/langchain/langchain/chains/llm_summarization_checker/base.py b/libs/langchain/langchain/chains/llm_summarization_checker/base.py index fd96eeac2e9..6fe584b5cbd 100644 --- a/libs/langchain/langchain/chains/llm_summarization_checker/base.py +++ b/libs/langchain/langchain/chains/llm_summarization_checker/base.py @@ -32,7 +32,7 @@ def _load_sequential_chain( are_all_true_prompt: PromptTemplate, verbose: bool = False, ) -> SequentialChain: - chain = SequentialChain( + return SequentialChain( chains=[ LLMChain( llm=llm, @@ -63,7 +63,6 @@ def _load_sequential_chain( output_variables=["all_true", "revised_summary"], verbose=verbose, ) - return chain @deprecated( diff --git a/libs/langchain/langchain/chains/loading.py b/libs/langchain/langchain/chains/loading.py index a16fd0eac54..932f64cfe31 100644 --- a/libs/langchain/langchain/chains/loading.py +++ b/libs/langchain/langchain/chains/loading.py @@ -311,8 +311,7 @@ def _load_llm_math_chain(config: dict, **kwargs: Any) -> LLMMathChain: prompt = load_prompt(config.pop("prompt_path")) if llm_chain: return LLMMathChain(llm_chain=llm_chain, prompt=prompt, **config) # type: ignore[arg-type] - else: - return LLMMathChain(llm=llm, prompt=prompt, **config) + return LLMMathChain(llm=llm, prompt=prompt, **config) def _load_map_rerank_documents_chain( @@ -609,8 +608,7 @@ def _load_llm_requests_chain(config: dict, **kwargs: Any) -> LLMRequestsChain: return LLMRequestsChain( llm_chain=llm_chain, requests_wrapper=requests_wrapper, **config ) - else: - return LLMRequestsChain(llm_chain=llm_chain, **config) + return LLMRequestsChain(llm_chain=llm_chain, **config) type_to_loader_dict = { diff --git a/libs/langchain/langchain/chains/moderation.py b/libs/langchain/langchain/chains/moderation.py index eadf98f9b06..552fde0bcb1 100644 --- a/libs/langchain/langchain/chains/moderation.py +++ b/libs/langchain/langchain/chains/moderation.py @@ -100,8 +100,7 @@ class OpenAIModerationChain(Chain): error_str = "Text was found that violates OpenAI's content policy." if self.error: raise ValueError(error_str) - else: - return error_str + return error_str return text def _call( diff --git a/libs/langchain/langchain/chains/openai_functions/base.py b/libs/langchain/langchain/chains/openai_functions/base.py index c6b37a01fec..a6c020621c5 100644 --- a/libs/langchain/langchain/chains/openai_functions/base.py +++ b/libs/langchain/langchain/chains/openai_functions/base.py @@ -132,7 +132,7 @@ def create_openai_fn_chain( } if len(openai_functions) == 1 and enforce_single_function_usage: llm_kwargs["function_call"] = {"name": openai_functions[0]["name"]} - llm_chain = LLMChain( + return LLMChain( llm=llm, prompt=prompt, output_parser=output_parser, @@ -140,7 +140,6 @@ def create_openai_fn_chain( output_key=output_key, **kwargs, ) - return llm_chain @deprecated( diff --git a/libs/langchain/langchain/chains/openai_functions/citation_fuzzy_match.py b/libs/langchain/langchain/chains/openai_functions/citation_fuzzy_match.py index 01e8a9c83bc..6200f9d9bf7 100644 --- a/libs/langchain/langchain/chains/openai_functions/citation_fuzzy_match.py +++ b/libs/langchain/langchain/chains/openai_functions/citation_fuzzy_match.py @@ -149,10 +149,9 @@ def create_citation_fuzzy_match_chain(llm: BaseLanguageModel) -> LLMChain: ] prompt = ChatPromptTemplate(messages=messages) # type: ignore[arg-type] - chain = LLMChain( + return LLMChain( llm=llm, prompt=prompt, llm_kwargs=llm_kwargs, output_parser=output_parser, ) - return chain diff --git a/libs/langchain/langchain/chains/openai_functions/extraction.py b/libs/langchain/langchain/chains/openai_functions/extraction.py index 96e3578b704..e703dba2a76 100644 --- a/libs/langchain/langchain/chains/openai_functions/extraction.py +++ b/libs/langchain/langchain/chains/openai_functions/extraction.py @@ -103,7 +103,7 @@ def create_extraction_chain( extraction_prompt = prompt or ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE) output_parser = JsonKeyOutputFunctionsParser(key_name="info") llm_kwargs = get_llm_kwargs(function) - chain = LLMChain( + return LLMChain( llm=llm, prompt=extraction_prompt, llm_kwargs=llm_kwargs, @@ -111,7 +111,6 @@ def create_extraction_chain( tags=tags, verbose=verbose, ) - return chain @deprecated( @@ -187,11 +186,10 @@ def create_extraction_chain_pydantic( pydantic_schema=PydanticSchema, attr_name="info" ) llm_kwargs = get_llm_kwargs(function) - chain = LLMChain( + return LLMChain( llm=llm, prompt=extraction_prompt, llm_kwargs=llm_kwargs, output_parser=output_parser, verbose=verbose, ) - return chain diff --git a/libs/langchain/langchain/chains/openai_functions/qa_with_structure.py b/libs/langchain/langchain/chains/openai_functions/qa_with_structure.py index 4520a5b15e6..46ecb86c80c 100644 --- a/libs/langchain/langchain/chains/openai_functions/qa_with_structure.py +++ b/libs/langchain/langchain/chains/openai_functions/qa_with_structure.py @@ -100,14 +100,13 @@ def create_qa_with_structure_chain( ] prompt = prompt or ChatPromptTemplate(messages=messages) # type: ignore[arg-type] - chain = LLMChain( + return LLMChain( llm=llm, prompt=prompt, llm_kwargs=llm_kwargs, output_parser=_output_parser, verbose=verbose, ) - return chain @deprecated( diff --git a/libs/langchain/langchain/chains/openai_functions/tagging.py b/libs/langchain/langchain/chains/openai_functions/tagging.py index 234226996cc..74a115502ae 100644 --- a/libs/langchain/langchain/chains/openai_functions/tagging.py +++ b/libs/langchain/langchain/chains/openai_functions/tagging.py @@ -91,14 +91,13 @@ def create_tagging_chain( prompt = prompt or ChatPromptTemplate.from_template(_TAGGING_TEMPLATE) output_parser = JsonOutputFunctionsParser() llm_kwargs = get_llm_kwargs(function) - chain = LLMChain( + return LLMChain( llm=llm, prompt=prompt, llm_kwargs=llm_kwargs, output_parser=output_parser, **kwargs, ) - return chain @deprecated( @@ -164,11 +163,10 @@ def create_tagging_chain_pydantic( prompt = prompt or ChatPromptTemplate.from_template(_TAGGING_TEMPLATE) output_parser = PydanticOutputFunctionsParser(pydantic_schema=pydantic_schema) llm_kwargs = get_llm_kwargs(function) - chain = LLMChain( + return LLMChain( llm=llm, prompt=prompt, llm_kwargs=llm_kwargs, output_parser=output_parser, **kwargs, ) - return chain diff --git a/libs/langchain/langchain/chains/openai_tools/extraction.py b/libs/langchain/langchain/chains/openai_tools/extraction.py index 14ad2c73cf9..d76d373f3b3 100644 --- a/libs/langchain/langchain/chains/openai_tools/extraction.py +++ b/libs/langchain/langchain/chains/openai_tools/extraction.py @@ -76,5 +76,4 @@ def create_extraction_chain_pydantic( functions = [convert_pydantic_to_openai_function(p) for p in pydantic_schemas] tools = [{"type": "function", "function": d} for d in functions] model = llm.bind(tools=tools) - chain = prompt | model | PydanticToolsParser(tools=pydantic_schemas) - return chain + return prompt | model | PydanticToolsParser(tools=pydantic_schemas) diff --git a/libs/langchain/langchain/chains/query_constructor/base.py b/libs/langchain/langchain/chains/query_constructor/base.py index 5d3e57048d0..32718cfd2a7 100644 --- a/libs/langchain/langchain/chains/query_constructor/base.py +++ b/libs/langchain/langchain/chains/query_constructor/base.py @@ -91,13 +91,12 @@ class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]): filter_directive = cast( Optional[FilterDirective], get_parser().parse(raw_filter) ) - fixed = fix_filter_directive( + return fix_filter_directive( filter_directive, allowed_comparators=allowed_comparators, allowed_operators=allowed_operators, allowed_attributes=allowed_attributes, ) - return fixed else: ast_parse = get_parser( @@ -131,13 +130,13 @@ def fix_filter_directive( ) or not filter: return filter - elif isinstance(filter, Comparison): + if isinstance(filter, Comparison): if allowed_comparators and filter.comparator not in allowed_comparators: return None if allowed_attributes and filter.attribute not in allowed_attributes: return None return filter - elif isinstance(filter, Operation): + if isinstance(filter, Operation): if allowed_operators and filter.operator not in allowed_operators: return None args = [ @@ -155,15 +154,13 @@ def fix_filter_directive( ] if not args: return None - elif len(args) == 1 and filter.operator in (Operator.AND, Operator.OR): + if len(args) == 1 and filter.operator in (Operator.AND, Operator.OR): return args[0] - else: - return Operation( - operator=filter.operator, - arguments=args, - ) - else: - return filter + return Operation( + operator=filter.operator, + arguments=args, + ) + return filter def _format_attribute_info(info: Sequence[Union[AttributeInfo, dict]]) -> str: diff --git a/libs/langchain/langchain/chains/query_constructor/parser.py b/libs/langchain/langchain/chains/query_constructor/parser.py index f6002518de7..826835f7334 100644 --- a/libs/langchain/langchain/chains/query_constructor/parser.py +++ b/libs/langchain/langchain/chains/query_constructor/parser.py @@ -101,10 +101,9 @@ class QueryTransformer(Transformer): ) raise ValueError(msg) return Comparison(comparator=func, attribute=args[0], value=args[1]) - elif len(args) == 1 and func in (Operator.AND, Operator.OR): + if len(args) == 1 and func in (Operator.AND, Operator.OR): return args[0] - else: - return Operation(operator=func, arguments=args) + return Operation(operator=func, arguments=args) def _match_func_name(self, func_name: str) -> Union[Operator, Comparator]: if func_name in set(Comparator): @@ -118,7 +117,7 @@ class QueryTransformer(Transformer): ) raise ValueError(msg) return Comparator(func_name) - elif func_name in set(Operator): + if func_name in set(Operator): if ( self.allowed_operators is not None and func_name not in self.allowed_operators @@ -129,12 +128,11 @@ class QueryTransformer(Transformer): ) raise ValueError(msg) return Operator(func_name) - else: - msg = ( - f"Received unrecognized function {func_name}. Valid functions are " - f"{list(Operator) + list(Comparator)}" - ) - raise ValueError(msg) + msg = ( + f"Received unrecognized function {func_name}. Valid functions are " + f"{list(Operator) + list(Comparator)}" + ) + raise ValueError(msg) def args(self, *items: Any) -> tuple: return items diff --git a/libs/langchain/langchain/chains/retrieval.py b/libs/langchain/langchain/chains/retrieval.py index b27036ac597..97098a2a917 100644 --- a/libs/langchain/langchain/chains/retrieval.py +++ b/libs/langchain/langchain/chains/retrieval.py @@ -60,10 +60,8 @@ def create_retrieval_chain( else: retrieval_docs = (lambda x: x["input"]) | retriever - retrieval_chain = ( + return ( RunnablePassthrough.assign( context=retrieval_docs.with_config(run_name="retrieve_documents"), ).assign(answer=combine_docs_chain) ).with_config(run_name="retrieval_chain") - - return retrieval_chain diff --git a/libs/langchain/langchain/chains/retrieval_qa/base.py b/libs/langchain/langchain/chains/retrieval_qa/base.py index 491891e44fc..fbce2eac11c 100644 --- a/libs/langchain/langchain/chains/retrieval_qa/base.py +++ b/libs/langchain/langchain/chains/retrieval_qa/base.py @@ -157,8 +157,7 @@ class BaseRetrievalQA(Chain): if self.return_source_documents: return {self.output_key: answer, "source_documents": docs} - else: - return {self.output_key: answer} + return {self.output_key: answer} @abstractmethod async def _aget_docs( @@ -200,8 +199,7 @@ class BaseRetrievalQA(Chain): if self.return_source_documents: return {self.output_key: answer, "source_documents": docs} - else: - return {self.output_key: answer} + return {self.output_key: answer} @deprecated( diff --git a/libs/langchain/langchain/chains/router/base.py b/libs/langchain/langchain/chains/router/base.py index 46e058b7693..2091cf5e5ee 100644 --- a/libs/langchain/langchain/chains/router/base.py +++ b/libs/langchain/langchain/chains/router/base.py @@ -97,15 +97,14 @@ class MultiRouteChain(Chain): ) if not route.destination: return self.default_chain(route.next_inputs, callbacks=callbacks) - elif route.destination in self.destination_chains: + if route.destination in self.destination_chains: return self.destination_chains[route.destination]( route.next_inputs, callbacks=callbacks ) - elif self.silent_errors: + if self.silent_errors: return self.default_chain(route.next_inputs, callbacks=callbacks) - else: - msg = f"Received invalid destination chain name '{route.destination}'" - raise ValueError(msg) + msg = f"Received invalid destination chain name '{route.destination}'" + raise ValueError(msg) async def _acall( self, @@ -123,14 +122,13 @@ class MultiRouteChain(Chain): return await self.default_chain.acall( route.next_inputs, callbacks=callbacks ) - elif route.destination in self.destination_chains: + if route.destination in self.destination_chains: return await self.destination_chains[route.destination].acall( route.next_inputs, callbacks=callbacks ) - elif self.silent_errors: + if self.silent_errors: return await self.default_chain.acall( route.next_inputs, callbacks=callbacks ) - else: - msg = f"Received invalid destination chain name '{route.destination}'" - raise ValueError(msg) + msg = f"Received invalid destination chain name '{route.destination}'" + raise ValueError(msg) diff --git a/libs/langchain/langchain/chains/router/llm_router.py b/libs/langchain/langchain/chains/router/llm_router.py index 2cf6a442516..2b74952d75b 100644 --- a/libs/langchain/langchain/chains/router/llm_router.py +++ b/libs/langchain/langchain/chains/router/llm_router.py @@ -136,11 +136,10 @@ class LLMRouterChain(RouterChain): callbacks = _run_manager.get_child() prediction = self.llm_chain.predict(callbacks=callbacks, **inputs) - output = cast( + return cast( dict[str, Any], self.llm_chain.prompt.output_parser.parse(prediction), ) - return output async def _acall( self, @@ -149,11 +148,10 @@ class LLMRouterChain(RouterChain): ) -> dict[str, Any]: _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() callbacks = _run_manager.get_child() - output = cast( + return cast( dict[str, Any], await self.llm_chain.apredict_and_parse(callbacks=callbacks, **inputs), ) - return output @classmethod def from_llm( diff --git a/libs/langchain/langchain/chains/sql_database/query.py b/libs/langchain/langchain/chains/sql_database/query.py index ce7018ffa02..09e8e12319b 100644 --- a/libs/langchain/langchain/chains/sql_database/query.py +++ b/libs/langchain/langchain/chains/sql_database/query.py @@ -141,8 +141,7 @@ def create_sql_query_chain( f"{db.dialect}" ) raise ValueError(msg) - else: - table_info_kwargs["get_col_comments"] = True + table_info_kwargs["get_col_comments"] = True inputs = { "input": lambda x: x["question"] + "\nSQLQuery: ", diff --git a/libs/langchain/langchain/chains/structured_output/base.py b/libs/langchain/langchain/chains/structured_output/base.py index c4c431d6d15..8e64a268f92 100644 --- a/libs/langchain/langchain/chains/structured_output/base.py +++ b/libs/langchain/langchain/chains/structured_output/base.py @@ -143,8 +143,7 @@ def create_openai_fn_runnable( output_parser = output_parser or get_openai_output_parser(functions) if prompt: return prompt | llm.bind(**llm_kwargs_) | output_parser - else: - return llm.bind(**llm_kwargs_) | output_parser + return llm.bind(**llm_kwargs_) | output_parser @deprecated( @@ -413,7 +412,7 @@ def create_structured_output_runnable( first_tool_only=return_single, ) - elif mode == "openai-functions": + if mode == "openai-functions": return _create_openai_functions_structured_output_runnable( output_schema, llm, @@ -422,7 +421,7 @@ def create_structured_output_runnable( enforce_single_function_usage=force_function_usage, **kwargs, # llm-specific kwargs ) - elif mode == "openai-json": + if mode == "openai-json": if force_function_usage: msg = ( "enforce_single_function_usage is not supported for mode='openai-json'." @@ -431,12 +430,11 @@ def create_structured_output_runnable( return _create_openai_json_runnable( output_schema, llm, prompt=prompt, output_parser=output_parser, **kwargs ) - else: - msg = ( - f"Invalid mode {mode}. Expected one of 'openai-tools', 'openai-functions', " - f"'openai-json'." - ) - raise ValueError(msg) + msg = ( + f"Invalid mode {mode}. Expected one of 'openai-tools', 'openai-functions', " + f"'openai-json'." + ) + raise ValueError(msg) def _create_openai_tools_runnable( @@ -460,8 +458,7 @@ def _create_openai_tools_runnable( ) if prompt: return prompt | llm.bind(**llm_kwargs) | output_parser - else: - return llm.bind(**llm_kwargs) | output_parser + return llm.bind(**llm_kwargs) | output_parser def _get_openai_tool_output_parser( @@ -535,8 +532,7 @@ def _create_openai_json_runnable( prompt = prompt.partial(output_schema=json.dumps(schema_as_dict, indent=2)) return prompt | llm | output_parser - else: - return llm | output_parser + return llm | output_parser def _create_openai_functions_structured_output_runnable( diff --git a/libs/langchain/langchain/chains/transform.py b/libs/langchain/langchain/chains/transform.py index fae0b5a7bfd..004d4a43efc 100644 --- a/libs/langchain/langchain/chains/transform.py +++ b/libs/langchain/langchain/chains/transform.py @@ -77,9 +77,8 @@ class TransformChain(Chain): ) -> dict[str, Any]: if self.atransform_cb is not None: return await self.atransform_cb(inputs) - else: - self._log_once( - "TransformChain's atransform is not provided, falling" - " back to synchronous transform" - ) - return self.transform_cb(inputs) + self._log_once( + "TransformChain's atransform is not provided, falling" + " back to synchronous transform" + ) + return self.transform_cb(inputs) diff --git a/libs/langchain/langchain/chat_models/base.py b/libs/langchain/langchain/chat_models/base.py index ef27bcbf47e..dee5c5f7ff7 100644 --- a/libs/langchain/langchain/chat_models/base.py +++ b/libs/langchain/langchain/chat_models/base.py @@ -322,16 +322,15 @@ def init_chat_model( return _init_chat_model_helper( cast(str, model), model_provider=model_provider, **kwargs ) - else: - if model: - kwargs["model"] = model - if model_provider: - kwargs["model_provider"] = model_provider - return _ConfigurableModel( - default_config=kwargs, - config_prefix=config_prefix, - configurable_fields=configurable_fields, - ) + if model: + kwargs["model"] = model + if model_provider: + kwargs["model_provider"] = model_provider + return _ConfigurableModel( + default_config=kwargs, + config_prefix=config_prefix, + configurable_fields=configurable_fields, + ) def _init_chat_model_helper( @@ -343,42 +342,42 @@ def _init_chat_model_helper( from langchain_openai import ChatOpenAI return ChatOpenAI(model=model, **kwargs) - elif model_provider == "anthropic": + if model_provider == "anthropic": _check_pkg("langchain_anthropic") from langchain_anthropic import ChatAnthropic return ChatAnthropic(model=model, **kwargs) # type: ignore[call-arg,unused-ignore] - elif model_provider == "azure_openai": + if model_provider == "azure_openai": _check_pkg("langchain_openai") from langchain_openai import AzureChatOpenAI return AzureChatOpenAI(model=model, **kwargs) - elif model_provider == "azure_ai": + if model_provider == "azure_ai": _check_pkg("langchain_azure_ai") from langchain_azure_ai.chat_models import AzureAIChatCompletionsModel return AzureAIChatCompletionsModel(model=model, **kwargs) - elif model_provider == "cohere": + if model_provider == "cohere": _check_pkg("langchain_cohere") from langchain_cohere import ChatCohere return ChatCohere(model=model, **kwargs) - elif model_provider == "google_vertexai": + if model_provider == "google_vertexai": _check_pkg("langchain_google_vertexai") from langchain_google_vertexai import ChatVertexAI return ChatVertexAI(model=model, **kwargs) - elif model_provider == "google_genai": + if model_provider == "google_genai": _check_pkg("langchain_google_genai") from langchain_google_genai import ChatGoogleGenerativeAI return ChatGoogleGenerativeAI(model=model, **kwargs) - elif model_provider == "fireworks": + if model_provider == "fireworks": _check_pkg("langchain_fireworks") from langchain_fireworks import ChatFireworks return ChatFireworks(model=model, **kwargs) - elif model_provider == "ollama": + if model_provider == "ollama": try: _check_pkg("langchain_ollama") from langchain_ollama import ChatOllama @@ -393,74 +392,72 @@ def _init_chat_model_helper( _check_pkg("langchain_ollama") return ChatOllama(model=model, **kwargs) - elif model_provider == "together": + if model_provider == "together": _check_pkg("langchain_together") from langchain_together import ChatTogether return ChatTogether(model=model, **kwargs) - elif model_provider == "mistralai": + if model_provider == "mistralai": _check_pkg("langchain_mistralai") from langchain_mistralai import ChatMistralAI return ChatMistralAI(model=model, **kwargs) # type: ignore[call-arg,unused-ignore] - elif model_provider == "huggingface": + if model_provider == "huggingface": _check_pkg("langchain_huggingface") from langchain_huggingface import ChatHuggingFace return ChatHuggingFace(model_id=model, **kwargs) - elif model_provider == "groq": + if model_provider == "groq": _check_pkg("langchain_groq") from langchain_groq import ChatGroq return ChatGroq(model=model, **kwargs) - elif model_provider == "bedrock": + if model_provider == "bedrock": _check_pkg("langchain_aws") from langchain_aws import ChatBedrock # TODO: update to use model= once ChatBedrock supports return ChatBedrock(model_id=model, **kwargs) - elif model_provider == "bedrock_converse": + if model_provider == "bedrock_converse": _check_pkg("langchain_aws") from langchain_aws import ChatBedrockConverse return ChatBedrockConverse(model=model, **kwargs) - elif model_provider == "google_anthropic_vertex": + if model_provider == "google_anthropic_vertex": _check_pkg("langchain_google_vertexai") from langchain_google_vertexai.model_garden import ChatAnthropicVertex return ChatAnthropicVertex(model=model, **kwargs) - elif model_provider == "deepseek": + if model_provider == "deepseek": _check_pkg("langchain_deepseek", pkg_kebab="langchain-deepseek") from langchain_deepseek import ChatDeepSeek return ChatDeepSeek(model=model, **kwargs) - elif model_provider == "nvidia": + if model_provider == "nvidia": _check_pkg("langchain_nvidia_ai_endpoints") from langchain_nvidia_ai_endpoints import ChatNVIDIA return ChatNVIDIA(model=model, **kwargs) - elif model_provider == "ibm": + if model_provider == "ibm": _check_pkg("langchain_ibm") from langchain_ibm import ChatWatsonx return ChatWatsonx(model_id=model, **kwargs) - elif model_provider == "xai": + if model_provider == "xai": _check_pkg("langchain_xai") from langchain_xai import ChatXAI return ChatXAI(model=model, **kwargs) - elif model_provider == "perplexity": + if model_provider == "perplexity": _check_pkg("langchain_perplexity") from langchain_perplexity import ChatPerplexity return ChatPerplexity(model=model, **kwargs) - else: - supported = ", ".join(_SUPPORTED_PROVIDERS) - msg = ( - f"Unsupported {model_provider=}.\n\nSupported model providers are: " - f"{supported}" - ) - raise ValueError(msg) + supported = ", ".join(_SUPPORTED_PROVIDERS) + msg = ( + f"Unsupported {model_provider=}.\n\nSupported model providers are: {supported}" + ) + raise ValueError(msg) _SUPPORTED_PROVIDERS = { @@ -490,26 +487,25 @@ _SUPPORTED_PROVIDERS = { def _attempt_infer_model_provider(model_name: str) -> Optional[str]: if any(model_name.startswith(pre) for pre in ("gpt-3", "gpt-4", "o1", "o3")): return "openai" - elif model_name.startswith("claude"): + if model_name.startswith("claude"): return "anthropic" - elif model_name.startswith("command"): + if model_name.startswith("command"): return "cohere" - elif model_name.startswith("accounts/fireworks"): + if model_name.startswith("accounts/fireworks"): return "fireworks" - elif model_name.startswith("gemini"): + if model_name.startswith("gemini"): return "google_vertexai" - elif model_name.startswith("amazon."): + if model_name.startswith("amazon."): return "bedrock" - elif model_name.startswith("mistral"): + if model_name.startswith("mistral"): return "mistralai" - elif model_name.startswith("deepseek"): + if model_name.startswith("deepseek"): return "deepseek" - elif model_name.startswith("grok"): + if model_name.startswith("grok"): return "xai" - elif model_name.startswith("sonar"): + if model_name.startswith("sonar"): return "perplexity" - else: - return None + return None def _parse_model(model: str, model_provider: Optional[str]) -> tuple[str, str]: @@ -595,14 +591,13 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]): ) return queue - elif self._default_config and (model := self._model()) and hasattr(model, name): + if self._default_config and (model := self._model()) and hasattr(model, name): return getattr(model, name) - else: - msg = f"{name} is not a BaseChatModel attribute" - if self._default_config: - msg += " and is not implemented on the default model" - msg += "." - raise AttributeError(msg) + msg = f"{name} is not a BaseChatModel attribute" + if self._default_config: + msg += " and is not implemented on the default model" + msg += "." + raise AttributeError(msg) def _model(self, config: Optional[RunnableConfig] = None) -> Runnable: params = {**self._default_config, **self._model_params(config)} @@ -728,10 +723,9 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]): ) # If multiple configs default to Runnable.batch which uses executor to invoke # in parallel. - else: - return super().batch( - inputs, config=config, return_exceptions=return_exceptions, **kwargs - ) + return super().batch( + inputs, config=config, return_exceptions=return_exceptions, **kwargs + ) async def abatch( self, @@ -751,10 +745,9 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]): ) # If multiple configs default to Runnable.batch which uses executor to invoke # in parallel. - else: - return await super().abatch( - inputs, config=config, return_exceptions=return_exceptions, **kwargs - ) + return await super().abatch( + inputs, config=config, return_exceptions=return_exceptions, **kwargs + ) def batch_as_completed( self, diff --git a/libs/langchain/langchain/embeddings/base.py b/libs/langchain/langchain/embeddings/base.py index 308a0b50247..eea04a10f94 100644 --- a/libs/langchain/langchain/embeddings/base.py +++ b/libs/langchain/langchain/embeddings/base.py @@ -192,41 +192,40 @@ def init_embeddings( from langchain_openai import OpenAIEmbeddings return OpenAIEmbeddings(model=model_name, **kwargs) - elif provider == "azure_openai": + if provider == "azure_openai": from langchain_openai import AzureOpenAIEmbeddings return AzureOpenAIEmbeddings(model=model_name, **kwargs) - elif provider == "google_vertexai": + if provider == "google_vertexai": from langchain_google_vertexai import VertexAIEmbeddings return VertexAIEmbeddings(model=model_name, **kwargs) - elif provider == "bedrock": + if provider == "bedrock": from langchain_aws import BedrockEmbeddings return BedrockEmbeddings(model_id=model_name, **kwargs) - elif provider == "cohere": + if provider == "cohere": from langchain_cohere import CohereEmbeddings return CohereEmbeddings(model=model_name, **kwargs) - elif provider == "mistralai": + if provider == "mistralai": from langchain_mistralai import MistralAIEmbeddings return MistralAIEmbeddings(model=model_name, **kwargs) - elif provider == "huggingface": + if provider == "huggingface": from langchain_huggingface import HuggingFaceEmbeddings return HuggingFaceEmbeddings(model_name=model_name, **kwargs) - elif provider == "ollama": + if provider == "ollama": from langchain_ollama import OllamaEmbeddings return OllamaEmbeddings(model=model_name, **kwargs) - else: - msg = ( - f"Provider '{provider}' is not supported.\n" - f"Supported providers and their required packages:\n" - f"{_get_provider_list()}" - ) - raise ValueError(msg) + msg = ( + f"Provider '{provider}' is not supported.\n" + f"Supported providers and their required packages:\n" + f"{_get_provider_list()}" + ) + raise ValueError(msg) __all__ = [ diff --git a/libs/langchain/langchain/evaluation/comparison/eval_chain.py b/libs/langchain/langchain/evaluation/comparison/eval_chain.py index c50d22cb3bf..3802d915a5d 100644 --- a/libs/langchain/langchain/evaluation/comparison/eval_chain.py +++ b/libs/langchain/langchain/evaluation/comparison/eval_chain.py @@ -70,7 +70,7 @@ def resolve_pairwise_criteria( Criteria.DEPTH, ] return {k.value: _SUPPORTED_CRITERIA[k] for k in _default_criteria} - elif isinstance(criteria, Criteria): + if isinstance(criteria, Criteria): criteria_ = {criteria.value: _SUPPORTED_CRITERIA[criteria]} elif isinstance(criteria, str): if criteria in _SUPPORTED_CRITERIA: diff --git a/libs/langchain/langchain/evaluation/embedding_distance/base.py b/libs/langchain/langchain/evaluation/embedding_distance/base.py index c072168194d..ca374ba5648 100644 --- a/libs/langchain/langchain/evaluation/embedding_distance/base.py +++ b/libs/langchain/langchain/evaluation/embedding_distance/base.py @@ -186,9 +186,8 @@ class _EmbeddingDistanceChainMixin(Chain): } if metric in metrics: return metrics[metric] - else: - msg = f"Invalid metric: {metric}" - raise ValueError(msg) + msg = f"Invalid metric: {metric}" + raise ValueError(msg) @staticmethod def _cosine_distance(a: Any, b: Any) -> Any: diff --git a/libs/langchain/langchain/evaluation/loading.py b/libs/langchain/langchain/evaluation/loading.py index a328cada57b..25c0f67c208 100644 --- a/libs/langchain/langchain/evaluation/loading.py +++ b/libs/langchain/langchain/evaluation/loading.py @@ -162,8 +162,7 @@ def load_evaluator( ) raise ValueError(msg) from e return evaluator_cls.from_llm(llm=llm, **kwargs) - else: - return evaluator_cls(**kwargs) + return evaluator_cls(**kwargs) def load_evaluators( diff --git a/libs/langchain/langchain/evaluation/parsing/json_schema.py b/libs/langchain/langchain/evaluation/parsing/json_schema.py index d06f2806fdf..7d26d2343ab 100644 --- a/libs/langchain/langchain/evaluation/parsing/json_schema.py +++ b/libs/langchain/langchain/evaluation/parsing/json_schema.py @@ -70,7 +70,7 @@ class JsonSchemaEvaluator(StringEvaluator): def _parse_json(self, node: Any) -> Union[dict, list, None, float, bool, int, str]: if isinstance(node, str): return parse_json_markdown(node) - elif hasattr(node, "schema") and callable(getattr(node, "schema")): + if hasattr(node, "schema") and callable(getattr(node, "schema")): # Pydantic model return getattr(node, "schema")() return node diff --git a/libs/langchain/langchain/evaluation/qa/eval_chain.py b/libs/langchain/langchain/evaluation/qa/eval_chain.py index d0e85b28c2c..fcf8571fd24 100644 --- a/libs/langchain/langchain/evaluation/qa/eval_chain.py +++ b/libs/langchain/langchain/evaluation/qa/eval_chain.py @@ -24,7 +24,7 @@ def _get_score(text: str) -> Optional[tuple[str, int]]: if match: if match.group(1).upper() == "CORRECT": return "CORRECT", 1 - elif match.group(1).upper() == "INCORRECT": + if match.group(1).upper() == "INCORRECT": return "INCORRECT", 0 try: first_word = ( @@ -32,7 +32,7 @@ def _get_score(text: str) -> Optional[tuple[str, int]]: ) if first_word.upper() == "CORRECT": return "CORRECT", 1 - elif first_word.upper() == "INCORRECT": + if first_word.upper() == "INCORRECT": return "INCORRECT", 0 last_word = ( text.strip() @@ -41,7 +41,7 @@ def _get_score(text: str) -> Optional[tuple[str, int]]: ) if last_word.upper() == "CORRECT": return "CORRECT", 1 - elif last_word.upper() == "INCORRECT": + if last_word.upper() == "INCORRECT": return "INCORRECT", 0 except IndexError: pass diff --git a/libs/langchain/langchain/evaluation/schema.py b/libs/langchain/langchain/evaluation/schema.py index f9508716ad6..f83d8dbeedc 100644 --- a/libs/langchain/langchain/evaluation/schema.py +++ b/libs/langchain/langchain/evaluation/schema.py @@ -123,12 +123,12 @@ class _EvalArgsMixin: if self.requires_input and input is None: msg = f"{self.__class__.__name__} requires an input string." raise ValueError(msg) - elif input is not None and not self.requires_input: + if input is not None and not self.requires_input: warn(self._skip_input_warning) if self.requires_reference and reference is None: msg = f"{self.__class__.__name__} requires a reference string." raise ValueError(msg) - elif reference is not None and not self.requires_reference: + if reference is not None and not self.requires_reference: warn(self._skip_reference_warning) diff --git a/libs/langchain/langchain/evaluation/scoring/eval_chain.py b/libs/langchain/langchain/evaluation/scoring/eval_chain.py index afcc573b195..aa3d307fb59 100644 --- a/libs/langchain/langchain/evaluation/scoring/eval_chain.py +++ b/libs/langchain/langchain/evaluation/scoring/eval_chain.py @@ -70,7 +70,7 @@ def resolve_criteria( Criteria.DEPTH, ] return {k.value: _SUPPORTED_CRITERIA[k] for k in _default_criteria} - elif isinstance(criteria, Criteria): + if isinstance(criteria, Criteria): criteria_ = {criteria.value: _SUPPORTED_CRITERIA[criteria]} elif isinstance(criteria, str): if criteria in _SUPPORTED_CRITERIA: diff --git a/libs/langchain/langchain/evaluation/string_distance/base.py b/libs/langchain/langchain/evaluation/string_distance/base.py index a386b2653aa..db6ae3155a1 100644 --- a/libs/langchain/langchain/evaluation/string_distance/base.py +++ b/libs/langchain/langchain/evaluation/string_distance/base.py @@ -138,8 +138,7 @@ class _RapidFuzzChainMixin(Chain): module = module_map[distance] if normalize_score: return module.normalized_distance - else: - return module.distance + return module.distance @property def metric(self) -> Callable: diff --git a/libs/langchain/langchain/hub.py b/libs/langchain/langchain/hub.py index 9af458d5a82..40035385107 100644 --- a/libs/langchain/langchain/hub.py +++ b/libs/langchain/langchain/hub.py @@ -21,10 +21,9 @@ def _get_client( ls_client = LangSmithClient(api_url, api_key=api_key) if hasattr(ls_client, "push_prompt") and hasattr(ls_client, "pull_prompt"): return ls_client - else: - from langchainhub import Client as LangChainHubClient + from langchainhub import Client as LangChainHubClient - return LangChainHubClient(api_url, api_key=api_key) + return LangChainHubClient(api_url, api_key=api_key) except ImportError: try: from langchainhub import Client as LangChainHubClient @@ -82,14 +81,13 @@ def push( # Then it's langchainhub manifest_json = dumps(object) - message = client.push( + return client.push( repo_full_name, manifest_json, parent_commit_hash=parent_commit_hash, new_repo_is_public=new_repo_is_public, new_repo_description=new_repo_description, ) - return message def pull( @@ -113,8 +111,7 @@ def pull( # Then it's langsmith if hasattr(client, "pull_prompt"): - response = client.pull_prompt(owner_repo_commit, include_model=include_model) - return response + return client.pull_prompt(owner_repo_commit, include_model=include_model) # Then it's langchainhub if hasattr(client, "pull_repo"): diff --git a/libs/langchain/langchain/llms/__init__.py b/libs/langchain/langchain/llms/__init__.py index 1666e48b4a6..30337d8f360 100644 --- a/libs/langchain/langchain/llms/__init__.py +++ b/libs/langchain/langchain/llms/__init__.py @@ -561,8 +561,7 @@ def __getattr__(name: str) -> Any: k: v() for k, v in get_type_to_cls_dict().items() } return type_to_cls_dict - else: - return getattr(llms, name) + return getattr(llms, name) __all__ = [ diff --git a/libs/langchain/langchain/memory/entity.py b/libs/langchain/langchain/memory/entity.py index 8c199ccac38..c52eb5e6b85 100644 --- a/libs/langchain/langchain/memory/entity.py +++ b/libs/langchain/langchain/memory/entity.py @@ -154,6 +154,7 @@ class UpstashRedisEntityStore(BaseEntityStore): logger.debug( f"Redis MEM set '{self.full_key_prefix}:{key}': '{value}' EX {self.ttl}" ) + return None def delete(self, key: str) -> None: self.redis_client.delete(f"{self.full_key_prefix}:{key}") @@ -255,6 +256,7 @@ class RedisEntityStore(BaseEntityStore): logger.debug( f"REDIS MEM set '{self.full_key_prefix}:{key}': '{value}' EX {self.ttl}" ) + return None def delete(self, key: str) -> None: self.redis_client.delete(f"{self.full_key_prefix}:{key}") @@ -362,6 +364,7 @@ class SQLiteEntityStore(BaseEntityStore): f'"{self.full_table_name}" (key, value) VALUES (?, ?)' ) self._execute_query(query, (key, value)) + return None def delete(self, key: str) -> None: """Deletes a key-value pair, safely quoting the table name.""" diff --git a/libs/langchain/langchain/output_parsers/boolean.py b/libs/langchain/langchain/output_parsers/boolean.py index 740ed9c8979..1210f58ee12 100644 --- a/libs/langchain/langchain/output_parsers/boolean.py +++ b/libs/langchain/langchain/output_parsers/boolean.py @@ -34,7 +34,7 @@ class BooleanOutputParser(BaseOutputParser[bool]): ) raise ValueError(msg) return True - elif self.false_val.upper() in truthy: + if self.false_val.upper() in truthy: if self.true_val.upper() in truthy: msg = ( f"Ambiguous response. Both {self.true_val} and {self.false_val} " diff --git a/libs/langchain/langchain/output_parsers/fix.py b/libs/langchain/langchain/output_parsers/fix.py index a3f30ca00fa..e8f26cb6220 100644 --- a/libs/langchain/langchain/output_parsers/fix.py +++ b/libs/langchain/langchain/output_parsers/fix.py @@ -71,31 +71,30 @@ class OutputFixingParser(BaseOutputParser[T]): except OutputParserException as e: if retries == self.max_retries: raise e + retries += 1 + if self.legacy and hasattr(self.retry_chain, "run"): + completion = self.retry_chain.run( + instructions=self.parser.get_format_instructions(), + completion=completion, + error=repr(e), + ) else: - retries += 1 - if self.legacy and hasattr(self.retry_chain, "run"): - completion = self.retry_chain.run( - instructions=self.parser.get_format_instructions(), - completion=completion, - error=repr(e), + try: + completion = self.retry_chain.invoke( + { + "instructions": self.parser.get_format_instructions(), # noqa: E501 + "completion": completion, + "error": repr(e), + } + ) + except (NotImplementedError, AttributeError): + # Case: self.parser does not have get_format_instructions + completion = self.retry_chain.invoke( + { + "completion": completion, + "error": repr(e), + } ) - else: - try: - completion = self.retry_chain.invoke( - { - "instructions": self.parser.get_format_instructions(), # noqa: E501 - "completion": completion, - "error": repr(e), - } - ) - except (NotImplementedError, AttributeError): - # Case: self.parser does not have get_format_instructions - completion = self.retry_chain.invoke( - { - "completion": completion, - "error": repr(e), - } - ) msg = "Failed to parse" raise OutputParserException(msg) @@ -109,31 +108,30 @@ class OutputFixingParser(BaseOutputParser[T]): except OutputParserException as e: if retries == self.max_retries: raise e + retries += 1 + if self.legacy and hasattr(self.retry_chain, "arun"): + completion = await self.retry_chain.arun( + instructions=self.parser.get_format_instructions(), + completion=completion, + error=repr(e), + ) else: - retries += 1 - if self.legacy and hasattr(self.retry_chain, "arun"): - completion = await self.retry_chain.arun( - instructions=self.parser.get_format_instructions(), - completion=completion, - error=repr(e), + try: + completion = await self.retry_chain.ainvoke( + { + "instructions": self.parser.get_format_instructions(), # noqa: E501 + "completion": completion, + "error": repr(e), + } + ) + except (NotImplementedError, AttributeError): + # Case: self.parser does not have get_format_instructions + completion = await self.retry_chain.ainvoke( + { + "completion": completion, + "error": repr(e), + } ) - else: - try: - completion = await self.retry_chain.ainvoke( - { - "instructions": self.parser.get_format_instructions(), # noqa: E501 - "completion": completion, - "error": repr(e), - } - ) - except (NotImplementedError, AttributeError): - # Case: self.parser does not have get_format_instructions - completion = await self.retry_chain.ainvoke( - { - "completion": completion, - "error": repr(e), - } - ) msg = "Failed to parse" raise OutputParserException(msg) diff --git a/libs/langchain/langchain/output_parsers/pandas_dataframe.py b/libs/langchain/langchain/output_parsers/pandas_dataframe.py index 3c0bc96c191..e1e94c338be 100644 --- a/libs/langchain/langchain/output_parsers/pandas_dataframe.py +++ b/libs/langchain/langchain/output_parsers/pandas_dataframe.py @@ -64,7 +64,7 @@ class PandasDataFrameOutputParser(BaseOutputParser[dict[str, Any]]): msg = f"Invalid array format in '{original_request_params}'. \ Please check the format instructions." raise OutputParserException(msg) - elif ( + if ( isinstance(parsed_array[0], int) and parsed_array[-1] > self.dataframe.index.max() ): diff --git a/libs/langchain/langchain/output_parsers/regex.py b/libs/langchain/langchain/output_parsers/regex.py index dad22bbd221..b3b33909a86 100644 --- a/libs/langchain/langchain/output_parsers/regex.py +++ b/libs/langchain/langchain/output_parsers/regex.py @@ -30,12 +30,10 @@ class RegexParser(BaseOutputParser[dict[str, str]]): match = re.search(self.regex, text) if match: return {key: match.group(i + 1) for i, key in enumerate(self.output_keys)} - else: - if self.default_output_key is None: - msg = f"Could not parse output: {text}" - raise ValueError(msg) - else: - return { - key: text if key == self.default_output_key else "" - for key in self.output_keys - } + if self.default_output_key is None: + msg = f"Could not parse output: {text}" + raise ValueError(msg) + return { + key: text if key == self.default_output_key else "" + for key in self.output_keys + } diff --git a/libs/langchain/langchain/output_parsers/regex_dict.py b/libs/langchain/langchain/output_parsers/regex_dict.py index 20f01e6528a..b61e5149ac5 100644 --- a/libs/langchain/langchain/output_parsers/regex_dict.py +++ b/libs/langchain/langchain/output_parsers/regex_dict.py @@ -33,14 +33,11 @@ class RegexDictParser(BaseOutputParser[dict[str, str]]): {expected_format} on text {text}" ) raise ValueError(msg) - elif len(matches) > 1: + if len(matches) > 1: msg = f"Multiple matches found for output key: {output_key} with \ expected format {expected_format} on text {text}" raise ValueError(msg) - elif ( - self.no_update_value is not None and matches[0] == self.no_update_value - ): + if self.no_update_value is not None and matches[0] == self.no_update_value: continue - else: - result[output_key] = matches[0] + result[output_key] = matches[0] return result diff --git a/libs/langchain/langchain/output_parsers/retry.py b/libs/langchain/langchain/output_parsers/retry.py index dfc1000b513..f9eb15dc94f 100644 --- a/libs/langchain/langchain/output_parsers/retry.py +++ b/libs/langchain/langchain/output_parsers/retry.py @@ -107,20 +107,19 @@ class RetryOutputParser(BaseOutputParser[T]): except OutputParserException as e: if retries == self.max_retries: raise e + retries += 1 + if self.legacy and hasattr(self.retry_chain, "run"): + completion = self.retry_chain.run( + prompt=prompt_value.to_string(), + completion=completion, + ) else: - retries += 1 - if self.legacy and hasattr(self.retry_chain, "run"): - completion = self.retry_chain.run( - prompt=prompt_value.to_string(), - completion=completion, - ) - else: - completion = self.retry_chain.invoke( - { - "prompt": prompt_value.to_string(), - "completion": completion, - } - ) + completion = self.retry_chain.invoke( + { + "prompt": prompt_value.to_string(), + "completion": completion, + } + ) msg = "Failed to parse" raise OutputParserException(msg) @@ -143,21 +142,20 @@ class RetryOutputParser(BaseOutputParser[T]): except OutputParserException as e: if retries == self.max_retries: raise e + retries += 1 + if self.legacy and hasattr(self.retry_chain, "arun"): + completion = await self.retry_chain.arun( + prompt=prompt_value.to_string(), + completion=completion, + error=repr(e), + ) else: - retries += 1 - if self.legacy and hasattr(self.retry_chain, "arun"): - completion = await self.retry_chain.arun( - prompt=prompt_value.to_string(), - completion=completion, - error=repr(e), - ) - else: - completion = await self.retry_chain.ainvoke( - { - "prompt": prompt_value.to_string(), - "completion": completion, - } - ) + completion = await self.retry_chain.ainvoke( + { + "prompt": prompt_value.to_string(), + "completion": completion, + } + ) msg = "Failed to parse" raise OutputParserException(msg) @@ -234,22 +232,21 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]): except OutputParserException as e: if retries == self.max_retries: raise e + retries += 1 + if self.legacy and hasattr(self.retry_chain, "run"): + completion = self.retry_chain.run( + prompt=prompt_value.to_string(), + completion=completion, + error=repr(e), + ) else: - retries += 1 - if self.legacy and hasattr(self.retry_chain, "run"): - completion = self.retry_chain.run( - prompt=prompt_value.to_string(), - completion=completion, - error=repr(e), - ) - else: - completion = self.retry_chain.invoke( - { - "completion": completion, - "prompt": prompt_value.to_string(), - "error": repr(e), - } - ) + completion = self.retry_chain.invoke( + { + "completion": completion, + "prompt": prompt_value.to_string(), + "error": repr(e), + } + ) msg = "Failed to parse" raise OutputParserException(msg) @@ -263,22 +260,21 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]): except OutputParserException as e: if retries == self.max_retries: raise e + retries += 1 + if self.legacy and hasattr(self.retry_chain, "arun"): + completion = await self.retry_chain.arun( + prompt=prompt_value.to_string(), + completion=completion, + error=repr(e), + ) else: - retries += 1 - if self.legacy and hasattr(self.retry_chain, "arun"): - completion = await self.retry_chain.arun( - prompt=prompt_value.to_string(), - completion=completion, - error=repr(e), - ) - else: - completion = await self.retry_chain.ainvoke( - { - "prompt": prompt_value.to_string(), - "completion": completion, - "error": repr(e), - } - ) + completion = await self.retry_chain.ainvoke( + { + "prompt": prompt_value.to_string(), + "completion": completion, + "error": repr(e), + } + ) msg = "Failed to parse" raise OutputParserException(msg) diff --git a/libs/langchain/langchain/output_parsers/structured.py b/libs/langchain/langchain/output_parsers/structured.py index 181acd7b026..a1da7f4ff3e 100644 --- a/libs/langchain/langchain/output_parsers/structured.py +++ b/libs/langchain/langchain/output_parsers/structured.py @@ -89,8 +89,7 @@ class StructuredOutputParser(BaseOutputParser[dict[str, Any]]): ) if only_json: return STRUCTURED_FORMAT_SIMPLE_INSTRUCTIONS.format(format=schema_str) - else: - return STRUCTURED_FORMAT_INSTRUCTIONS.format(format=schema_str) + return STRUCTURED_FORMAT_INSTRUCTIONS.format(format=schema_str) def parse(self, text: str) -> dict[str, Any]: expected_keys = [rs.name for rs in self.response_schemas] diff --git a/libs/langchain/langchain/output_parsers/yaml.py b/libs/langchain/langchain/output_parsers/yaml.py index 8a6b4234a45..4d86599ee71 100644 --- a/libs/langchain/langchain/output_parsers/yaml.py +++ b/libs/langchain/langchain/output_parsers/yaml.py @@ -33,8 +33,7 @@ class YamlOutputParser(BaseOutputParser[T]): json_object = yaml.safe_load(yaml_str) if hasattr(self.pydantic_object, "model_validate"): return self.pydantic_object.model_validate(json_object) - else: - return self.pydantic_object.parse_obj(json_object) + return self.pydantic_object.parse_obj(json_object) except (yaml.YAMLError, ValidationError) as e: name = self.pydantic_object.__name__ diff --git a/libs/langchain/langchain/retrievers/contextual_compression.py b/libs/langchain/langchain/retrievers/contextual_compression.py index 6a4134160e9..313cc8880c9 100644 --- a/libs/langchain/langchain/retrievers/contextual_compression.py +++ b/libs/langchain/langchain/retrievers/contextual_compression.py @@ -45,8 +45,7 @@ class ContextualCompressionRetriever(BaseRetriever): docs, query, callbacks=run_manager.get_child() ) return list(compressed_docs) - else: - return [] + return [] async def _aget_relevant_documents( self, @@ -71,5 +70,4 @@ class ContextualCompressionRetriever(BaseRetriever): docs, query, callbacks=run_manager.get_child() ) return list(compressed_docs) - else: - return [] + return [] diff --git a/libs/langchain/langchain/retrievers/ensemble.py b/libs/langchain/langchain/retrievers/ensemble.py index 1bb17e220ca..b6c2e7d16b9 100644 --- a/libs/langchain/langchain/retrievers/ensemble.py +++ b/libs/langchain/langchain/retrievers/ensemble.py @@ -174,9 +174,7 @@ class EnsembleRetriever(BaseRetriever): """ # Get fused result of the retrievers. - fused_documents = self.rank_fusion(query, run_manager) - - return fused_documents + return self.rank_fusion(query, run_manager) async def _aget_relevant_documents( self, @@ -195,9 +193,7 @@ class EnsembleRetriever(BaseRetriever): """ # Get fused result of the retrievers. - fused_documents = await self.arank_fusion(query, run_manager) - - return fused_documents + return await self.arank_fusion(query, run_manager) def rank_fusion( self, @@ -236,9 +232,7 @@ class EnsembleRetriever(BaseRetriever): ] # apply rank fusion - fused_documents = self.weighted_reciprocal_rank(retriever_docs) - - return fused_documents + return self.weighted_reciprocal_rank(retriever_docs) async def arank_fusion( self, @@ -280,9 +274,7 @@ class EnsembleRetriever(BaseRetriever): ] # apply rank fusion - fused_documents = self.weighted_reciprocal_rank(retriever_docs) - - return fused_documents + return self.weighted_reciprocal_rank(retriever_docs) def weighted_reciprocal_rank( self, doc_lists: list[list[Document]] @@ -318,7 +310,7 @@ class EnsembleRetriever(BaseRetriever): # Docs are deduplicated by their contents then sorted by their scores all_docs = chain.from_iterable(doc_lists) - sorted_docs = sorted( + return sorted( unique_by_key( all_docs, lambda doc: ( @@ -332,4 +324,3 @@ class EnsembleRetriever(BaseRetriever): doc.page_content if self.id_key is None else doc.metadata[self.id_key] ], ) - return sorted_docs diff --git a/libs/langchain/langchain/retrievers/merger_retriever.py b/libs/langchain/langchain/retrievers/merger_retriever.py index 5a192ef8e4c..ce7663c7c85 100644 --- a/libs/langchain/langchain/retrievers/merger_retriever.py +++ b/libs/langchain/langchain/retrievers/merger_retriever.py @@ -31,9 +31,7 @@ class MergerRetriever(BaseRetriever): """ # Merge the results of the retrievers. - merged_documents = self.merge_documents(query, run_manager) - - return merged_documents + return self.merge_documents(query, run_manager) async def _aget_relevant_documents( self, @@ -52,9 +50,7 @@ class MergerRetriever(BaseRetriever): """ # Merge the results of the retrievers. - merged_documents = await self.amerge_documents(query, run_manager) - - return merged_documents + return await self.amerge_documents(query, run_manager) def merge_documents( self, query: str, run_manager: CallbackManagerForRetrieverRun diff --git a/libs/langchain/langchain/retrievers/re_phraser.py b/libs/langchain/langchain/retrievers/re_phraser.py index 9f8fc6432cd..9de82e2e2ca 100644 --- a/libs/langchain/langchain/retrievers/re_phraser.py +++ b/libs/langchain/langchain/retrievers/re_phraser.py @@ -74,10 +74,9 @@ class RePhraseQueryRetriever(BaseRetriever): query, {"callbacks": run_manager.get_child()} ) logger.info(f"Re-phrased question: {re_phrased_question}") - docs = self.retriever.invoke( + return self.retriever.invoke( re_phrased_question, config={"callbacks": run_manager.get_child()} ) - return docs async def _aget_relevant_documents( self, diff --git a/libs/langchain/langchain/retrievers/self_query/base.py b/libs/langchain/langchain/retrievers/self_query/base.py index d5ac1cb38d6..257b4e80ff2 100644 --- a/libs/langchain/langchain/retrievers/self_query/base.py +++ b/libs/langchain/langchain/retrievers/self_query/base.py @@ -119,117 +119,116 @@ def _get_builtin_translator(vectorstore: VectorStore) -> Visitor: } if isinstance(vectorstore, DatabricksVectorSearch): return DatabricksVectorSearchTranslator() - elif isinstance(vectorstore, MyScale): + if isinstance(vectorstore, MyScale): return MyScaleTranslator(metadata_key=vectorstore.metadata_column) - elif isinstance(vectorstore, Redis): + if isinstance(vectorstore, Redis): return RedisTranslator.from_vectorstore(vectorstore) - elif isinstance(vectorstore, TencentVectorDB): + if isinstance(vectorstore, TencentVectorDB): fields = [ field.name for field in (vectorstore.meta_fields or []) if field.index ] return TencentVectorDBTranslator(fields) - elif vectorstore.__class__ in BUILTIN_TRANSLATORS: + if vectorstore.__class__ in BUILTIN_TRANSLATORS: return BUILTIN_TRANSLATORS[vectorstore.__class__]() + try: + from langchain_astradb.vectorstores import AstraDBVectorStore + except ImportError: + pass else: - try: - from langchain_astradb.vectorstores import AstraDBVectorStore - except ImportError: - pass - else: - if isinstance(vectorstore, AstraDBVectorStore): - return AstraDBTranslator() + if isinstance(vectorstore, AstraDBVectorStore): + return AstraDBTranslator() - try: - from langchain_elasticsearch.vectorstores import ElasticsearchStore - except ImportError: - pass - else: - if isinstance(vectorstore, ElasticsearchStore): - return ElasticsearchTranslator() + try: + from langchain_elasticsearch.vectorstores import ElasticsearchStore + except ImportError: + pass + else: + if isinstance(vectorstore, ElasticsearchStore): + return ElasticsearchTranslator() - try: - from langchain_pinecone import PineconeVectorStore - except ImportError: - pass - else: - if isinstance(vectorstore, PineconeVectorStore): - return PineconeTranslator() + try: + from langchain_pinecone import PineconeVectorStore + except ImportError: + pass + else: + if isinstance(vectorstore, PineconeVectorStore): + return PineconeTranslator() - try: - from langchain_milvus import Milvus - except ImportError: - pass - else: - if isinstance(vectorstore, Milvus): - return MilvusTranslator() + try: + from langchain_milvus import Milvus + except ImportError: + pass + else: + if isinstance(vectorstore, Milvus): + return MilvusTranslator() - try: - from langchain_mongodb import MongoDBAtlasVectorSearch - except ImportError: - pass - else: - if isinstance(vectorstore, MongoDBAtlasVectorSearch): - return MongoDBAtlasTranslator() + try: + from langchain_mongodb import MongoDBAtlasVectorSearch + except ImportError: + pass + else: + if isinstance(vectorstore, MongoDBAtlasVectorSearch): + return MongoDBAtlasTranslator() - try: - from langchain_neo4j import Neo4jVector - except ImportError: - pass - else: - if isinstance(vectorstore, Neo4jVector): - return Neo4jTranslator() + try: + from langchain_neo4j import Neo4jVector + except ImportError: + pass + else: + if isinstance(vectorstore, Neo4jVector): + return Neo4jTranslator() - try: - # Trying langchain_chroma import if exists - from langchain_chroma import Chroma - except ImportError: - pass - else: - if isinstance(vectorstore, Chroma): - return ChromaTranslator() + try: + # Trying langchain_chroma import if exists + from langchain_chroma import Chroma + except ImportError: + pass + else: + if isinstance(vectorstore, Chroma): + return ChromaTranslator() - try: - from langchain_postgres import PGVector - from langchain_postgres import PGVectorTranslator as NewPGVectorTranslator - except ImportError: - pass - else: - if isinstance(vectorstore, PGVector): - return NewPGVectorTranslator() + try: + from langchain_postgres import PGVector + from langchain_postgres import PGVectorTranslator as NewPGVectorTranslator + except ImportError: + pass + else: + if isinstance(vectorstore, PGVector): + return NewPGVectorTranslator() - try: - from langchain_qdrant import QdrantVectorStore - except ImportError: - pass - else: - if isinstance(vectorstore, QdrantVectorStore): - return QdrantTranslator(metadata_key=vectorstore.metadata_payload_key) + try: + from langchain_qdrant import QdrantVectorStore + except ImportError: + pass + else: + if isinstance(vectorstore, QdrantVectorStore): + return QdrantTranslator(metadata_key=vectorstore.metadata_payload_key) - try: - # Added in langchain-community==0.2.11 - from langchain_community.query_constructors.hanavector import HanaTranslator - from langchain_community.vectorstores import HanaDB - except ImportError: - pass - else: - if isinstance(vectorstore, HanaDB): - return HanaTranslator() + try: + # Added in langchain-community==0.2.11 + from langchain_community.query_constructors.hanavector import HanaTranslator + from langchain_community.vectorstores import HanaDB + except ImportError: + pass + else: + if isinstance(vectorstore, HanaDB): + return HanaTranslator() - try: - # Trying langchain_weaviate (weaviate v4) import if exists - from langchain_weaviate.vectorstores import WeaviateVectorStore + try: + # Trying langchain_weaviate (weaviate v4) import if exists + from langchain_weaviate.vectorstores import WeaviateVectorStore - except ImportError: - pass - else: - if isinstance(vectorstore, WeaviateVectorStore): - return WeaviateTranslator() + except ImportError: + pass + else: + if isinstance(vectorstore, WeaviateVectorStore): + return WeaviateTranslator() - msg = ( - f"Self query retriever with Vector Store type {vectorstore.__class__}" - f" not supported." - ) - raise ValueError(msg) + msg = ( + f"Self query retriever with Vector Store type {vectorstore.__class__}" + f" not supported." + ) + raise ValueError(msg) class SelfQueryRetriever(BaseRetriever): @@ -289,14 +288,12 @@ class SelfQueryRetriever(BaseRetriever): def _get_docs_with_query( self, query: str, search_kwargs: dict[str, Any] ) -> list[Document]: - docs = self.vectorstore.search(query, self.search_type, **search_kwargs) - return docs + return self.vectorstore.search(query, self.search_type, **search_kwargs) async def _aget_docs_with_query( self, query: str, search_kwargs: dict[str, Any] ) -> list[Document]: - docs = await self.vectorstore.asearch(query, self.search_type, **search_kwargs) - return docs + return await self.vectorstore.asearch(query, self.search_type, **search_kwargs) def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun @@ -315,8 +312,7 @@ class SelfQueryRetriever(BaseRetriever): if self.verbose: logger.info(f"Generated Query: {structured_query}") new_query, search_kwargs = self._prepare_query(query, structured_query) - docs = self._get_docs_with_query(new_query, search_kwargs) - return docs + return self._get_docs_with_query(new_query, search_kwargs) async def _aget_relevant_documents( self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun @@ -335,8 +331,7 @@ class SelfQueryRetriever(BaseRetriever): if self.verbose: logger.info(f"Generated Query: {structured_query}") new_query, search_kwargs = self._prepare_query(query, structured_query) - docs = await self._aget_docs_with_query(new_query, search_kwargs) - return docs + return await self._aget_docs_with_query(new_query, search_kwargs) @classmethod def from_llm( diff --git a/libs/langchain/langchain/smith/evaluation/runner_utils.py b/libs/langchain/langchain/smith/evaluation/runner_utils.py index 0d828c9b5a6..10d8b796fa0 100644 --- a/libs/langchain/langchain/smith/evaluation/runner_utils.py +++ b/libs/langchain/langchain/smith/evaluation/runner_utils.py @@ -191,13 +191,13 @@ def _wrap_in_chain_factory( ) raise ValueError(msg) return lambda: chain - elif isinstance(llm_or_chain_factory, BaseLanguageModel): + if isinstance(llm_or_chain_factory, BaseLanguageModel): return llm_or_chain_factory - elif isinstance(llm_or_chain_factory, Runnable): + if isinstance(llm_or_chain_factory, Runnable): # Memory may exist here, but it's not elegant to check all those cases. lcf = llm_or_chain_factory return lambda: lcf - elif callable(llm_or_chain_factory): + if callable(llm_or_chain_factory): if is_traceable_function(llm_or_chain_factory): runnable_ = as_runnable(cast(Callable, llm_or_chain_factory)) return lambda: runnable_ @@ -215,15 +215,14 @@ def _wrap_in_chain_factory( # It's not uncommon to do an LLM constructor instead of raw LLM, # so we'll unpack it for the user. return _model - elif is_traceable_function(cast(Callable, _model)): + if is_traceable_function(cast(Callable, _model)): runnable_ = as_runnable(cast(Callable, _model)) return lambda: runnable_ - elif not isinstance(_model, Runnable): + if not isinstance(_model, Runnable): # This is unlikely to happen - a constructor for a model function return lambda: RunnableLambda(constructor) - else: - # Typical correct case - return constructor + # Typical correct case + return constructor return llm_or_chain_factory @@ -272,9 +271,8 @@ def _get_prompt(inputs: dict[str, Any]) -> str: raise InputFormatError(msg) if len(prompts) == 1: return prompts[0] - else: - msg = f"LLM Run expects single prompt input. Got {len(prompts)} prompts." - raise InputFormatError(msg) + msg = f"LLM Run expects single prompt input. Got {len(prompts)} prompts." + raise InputFormatError(msg) class ChatModelInput(TypedDict): @@ -321,12 +319,11 @@ def _get_messages(inputs: dict[str, Any]) -> dict: ) raise InputFormatError(msg) return input_copy - else: - msg = ( - f"Chat Run expects single List[dict] or List[List[dict]] 'messages'" - f" input. Got {inputs}" - ) - raise InputFormatError(msg) + msg = ( + f"Chat Run expects single List[dict] or List[List[dict]] 'messages'" + f" input. Got {inputs}" + ) + raise InputFormatError(msg) ## Shared data validation utilities @@ -707,31 +704,29 @@ async def _arun_llm( callbacks=callbacks, tags=tags or [], metadata=metadata or {} ), ) - else: - msg = ( - "Input mapper returned invalid format" - f" {prompt_or_messages}" - "\nExpected a single string or list of chat messages." - ) - raise InputFormatError(msg) + msg = ( + "Input mapper returned invalid format" + f" {prompt_or_messages}" + "\nExpected a single string or list of chat messages." + ) + raise InputFormatError(msg) - else: - try: - prompt = _get_prompt(inputs) - llm_output: Union[str, BaseMessage] = await llm.ainvoke( - prompt, - config=RunnableConfig( - callbacks=callbacks, tags=tags or [], metadata=metadata or {} - ), - ) - except InputFormatError: - llm_inputs = _get_messages(inputs) - llm_output = await llm.ainvoke( - **llm_inputs, - config=RunnableConfig( - callbacks=callbacks, tags=tags or [], metadata=metadata or {} - ), - ) + try: + prompt = _get_prompt(inputs) + llm_output: Union[str, BaseMessage] = await llm.ainvoke( + prompt, + config=RunnableConfig( + callbacks=callbacks, tags=tags or [], metadata=metadata or {} + ), + ) + except InputFormatError: + llm_inputs = _get_messages(inputs) + llm_output = await llm.ainvoke( + **llm_inputs, + config=RunnableConfig( + callbacks=callbacks, tags=tags or [], metadata=metadata or {} + ), + ) return llm_output diff --git a/libs/langchain/langchain/smith/evaluation/string_run_evaluator.py b/libs/langchain/langchain/smith/evaluation/string_run_evaluator.py index 720f5271285..d08d029165e 100644 --- a/libs/langchain/langchain/smith/evaluation/string_run_evaluator.py +++ b/libs/langchain/langchain/smith/evaluation/string_run_evaluator.py @@ -28,8 +28,7 @@ def _get_messages_from_run_dict(messages: list[dict]) -> list[BaseMessage]: first_message = messages[0] if "lc" in first_message: return [load(dumpd(message)) for message in messages] - else: - return messages_from_dict(messages) + return messages_from_dict(messages) class StringRunMapper(Serializable): @@ -106,25 +105,23 @@ class LLMStringRunMapper(StringRunMapper): if run.run_type != "llm": msg = "LLM RunMapper only supports LLM runs." raise ValueError(msg) - elif not run.outputs: + if not run.outputs: if run.error: msg = f"Cannot evaluate errored LLM run {run.id}: {run.error}" raise ValueError(msg) - else: - msg = f"Run {run.id} has no outputs. Cannot evaluate this run." - raise ValueError(msg) - else: - try: - inputs = self.serialize_inputs(run.inputs) - except Exception as e: - msg = f"Could not parse LM input from run inputs {run.inputs}" - raise ValueError(msg) from e - try: - output_ = self.serialize_outputs(run.outputs) - except Exception as e: - msg = f"Could not parse LM prediction from run outputs {run.outputs}" - raise ValueError(msg) from e - return {"input": inputs, "prediction": output_} + msg = f"Run {run.id} has no outputs. Cannot evaluate this run." + raise ValueError(msg) + try: + inputs = self.serialize_inputs(run.inputs) + except Exception as e: + msg = f"Could not parse LM input from run inputs {run.inputs}" + raise ValueError(msg) from e + try: + output_ = self.serialize_outputs(run.outputs) + except Exception as e: + msg = f"Could not parse LM prediction from run outputs {run.outputs}" + raise ValueError(msg) from e + return {"input": inputs, "prediction": output_} class ChainStringRunMapper(StringRunMapper): @@ -142,14 +139,13 @@ class ChainStringRunMapper(StringRunMapper): def _get_key(self, source: dict, key: Optional[str], which: str) -> str: if key is not None: return source[key] - elif len(source) == 1: + if len(source) == 1: return next(iter(source.values())) - else: - msg = ( - f"Could not map run {which} with multiple keys: " - f"{source}\nPlease manually specify a {which}_key" - ) - raise ValueError(msg) + msg = ( + f"Could not map run {which} with multiple keys: " + f"{source}\nPlease manually specify a {which}_key" + ) + raise ValueError(msg) def map(self, run: Run) -> dict[str, str]: """Maps the Run to a dictionary.""" @@ -168,7 +164,7 @@ class ChainStringRunMapper(StringRunMapper): f" '{self.input_key}'." ) raise ValueError(msg) - elif self.prediction_key is not None and self.prediction_key not in run.outputs: + if self.prediction_key is not None and self.prediction_key not in run.outputs: available_keys = ", ".join(run.outputs.keys()) msg = ( f"Run with ID {run.id} doesn't have the expected prediction key" @@ -178,13 +174,12 @@ class ChainStringRunMapper(StringRunMapper): ) raise ValueError(msg) - else: - input_ = self._get_key(run.inputs, self.input_key, "input") - prediction = self._get_key(run.outputs, self.prediction_key, "prediction") - return { - "input": input_, - "prediction": prediction, - } + input_ = self._get_key(run.inputs, self.input_key, "input") + prediction = self._get_key(run.outputs, self.prediction_key, "prediction") + return { + "input": input_, + "prediction": prediction, + } class ToolStringRunMapper(StringRunMapper): @@ -224,8 +219,7 @@ class StringExampleMapper(Serializable): " specify a reference_key." ) raise ValueError(msg) - else: - output = list(example.outputs.values())[0] + output = list(example.outputs.values())[0] elif self.reference_key not in example.outputs: msg = ( f"Example {example.id} does not have reference key" diff --git a/libs/langchain/langchain/tools/__init__.py b/libs/langchain/langchain/tools/__init__.py index 1126065407f..55ea272212c 100644 --- a/libs/langchain/langchain/tools/__init__.py +++ b/libs/langchain/langchain/tools/__init__.py @@ -65,24 +65,23 @@ def _import_python_tool_PythonREPLTool() -> Any: def __getattr__(name: str) -> Any: if name == "PythonAstREPLTool": return _import_python_tool_PythonAstREPLTool() - elif name == "PythonREPLTool": + if name == "PythonREPLTool": return _import_python_tool_PythonREPLTool() - else: - from langchain_community import tools + from langchain_community import tools - # If not in interactive env, raise warning. - if not is_interactive_env(): - warnings.warn( - "Importing tools from langchain is deprecated. Importing from " - "langchain will no longer be supported as of langchain==0.2.0. " - "Please import from langchain-community instead:\n\n" - f"`from langchain_community.tools import {name}`.\n\n" - "To install langchain-community run " - "`pip install -U langchain-community`.", - category=LangChainDeprecationWarning, - ) + # If not in interactive env, raise warning. + if not is_interactive_env(): + warnings.warn( + "Importing tools from langchain is deprecated. Importing from " + "langchain will no longer be supported as of langchain==0.2.0. " + "Please import from langchain-community instead:\n\n" + f"`from langchain_community.tools import {name}`.\n\n" + "To install langchain-community run " + "`pip install -U langchain-community`.", + category=LangChainDeprecationWarning, + ) - return getattr(tools, name) + return getattr(tools, name) __all__ = [ diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index aff79ed3f81..9b1ca739161 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -143,7 +143,7 @@ ignore-regex = ".*(Stati Uniti|Tense=Pres).*" ignore-words-list = "momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogyny,unsecure,damon,crate,aadd,symbl,precesses,accademia,nin" [tool.ruff.lint] -select = ["A", "C4", "E", "F", "I", "EM", "PGH003", "PIE", "SIM", "T201", "D", "UP", "S", "W"] +select = ["A", "C4", "D", "E", "EM", "F", "I", "PGH003", "PIE", "RET", "S", "SIM", "T201", "UP", "W"] pydocstyle.convention = "google" pyupgrade.keep-runtime-typing = true diff --git a/libs/langchain/tests/mock_servers/robot/server.py b/libs/langchain/tests/mock_servers/robot/server.py index 3ef8d2c54e2..959efe6c9af 100644 --- a/libs/langchain/tests/mock_servers/robot/server.py +++ b/libs/langchain/tests/mock_servers/robot/server.py @@ -139,8 +139,7 @@ async def get_state( async def ask_for_passphrase(said_please: bool) -> dict[str, Any]: if said_please: return {"passphrase": f"The passphrase is {PASS_PHRASE}"} - else: - return {"passphrase": "I won't share the passphrase without saying 'please'."} + return {"passphrase": "I won't share the passphrase without saying 'please'."} @app.delete( @@ -153,12 +152,11 @@ async def recycle(password: SecretPassPhrase) -> dict[str, Any]: if password.pw == PASS_PHRASE: _ROBOT_STATE["destruct"] = True return {"status": "Self-destruct initiated", "state": _ROBOT_STATE} - else: - _ROBOT_STATE["destruct"] = False - raise HTTPException( - status_code=400, - detail="Pass phrase required. You should have thought to ask for it.", - ) + _ROBOT_STATE["destruct"] = False + raise HTTPException( + status_code=400, + detail="Pass phrase required. You should have thought to ask for it.", + ) @app.post( diff --git a/libs/langchain/tests/unit_tests/agents/test_agent.py b/libs/langchain/tests/unit_tests/agents/test_agent.py index 4d9f168317c..53ada9bd702 100644 --- a/libs/langchain/tests/unit_tests/agents/test_agent.py +++ b/libs/langchain/tests/unit_tests/agents/test_agent.py @@ -100,14 +100,13 @@ def _get_agent(**kwargs: Any) -> AgentExecutor: ), ] - agent = initialize_agent( + return initialize_agent( tools, fake_llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True, **kwargs, ) - return agent def test_agent_bad_action() -> None: diff --git a/libs/langchain/tests/unit_tests/agents/test_agent_async.py b/libs/langchain/tests/unit_tests/agents/test_agent_async.py index 42c2b39f327..62c1039ed7b 100644 --- a/libs/langchain/tests/unit_tests/agents/test_agent_async.py +++ b/libs/langchain/tests/unit_tests/agents/test_agent_async.py @@ -71,14 +71,13 @@ def _get_agent(**kwargs: Any) -> AgentExecutor: ), ] - agent = initialize_agent( + return initialize_agent( tools, fake_llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True, **kwargs, ) - return agent async def test_agent_bad_action() -> None: diff --git a/libs/langchain/tests/unit_tests/agents/test_chat.py b/libs/langchain/tests/unit_tests/agents/test_chat.py index dbc703dd8bd..28b91bb3816 100644 --- a/libs/langchain/tests/unit_tests/agents/test_chat.py +++ b/libs/langchain/tests/unit_tests/agents/test_chat.py @@ -11,8 +11,7 @@ def get_action_and_input(text: str) -> tuple[str, str]: output = output_parser.parse(text) if isinstance(output, AgentAction): return output.tool, str(output.tool_input) - else: - return "Final Answer", output.return_values["output"] + return "Final Answer", output.return_values["output"] def test_parse_with_language() -> None: diff --git a/libs/langchain/tests/unit_tests/agents/test_mrkl.py b/libs/langchain/tests/unit_tests/agents/test_mrkl.py index cb4af994acb..a464b1eea92 100644 --- a/libs/langchain/tests/unit_tests/agents/test_mrkl.py +++ b/libs/langchain/tests/unit_tests/agents/test_mrkl.py @@ -16,8 +16,7 @@ def get_action_and_input(text: str) -> tuple[str, str]: output = MRKLOutputParser().parse(text) if isinstance(output, AgentAction): return output.tool, str(output.tool_input) - else: - return "Final Answer", output.return_values["output"] + return "Final Answer", output.return_values["output"] def test_get_action_and_input() -> None: diff --git a/libs/langchain/tests/unit_tests/agents/test_structured_chat.py b/libs/langchain/tests/unit_tests/agents/test_structured_chat.py index 2c583bc978a..61bb6cdee4c 100644 --- a/libs/langchain/tests/unit_tests/agents/test_structured_chat.py +++ b/libs/langchain/tests/unit_tests/agents/test_structured_chat.py @@ -21,11 +21,10 @@ def get_action_and_input(text: str) -> tuple[str, str]: output = output_parser.parse(text) if isinstance(output, AgentAction): return output.tool, str(output.tool_input) - elif isinstance(output, AgentFinish): + if isinstance(output, AgentFinish): return output.return_values["output"], output.log - else: - msg = "Unexpected output type" - raise ValueError(msg) + msg = "Unexpected output type" + raise ValueError(msg) def test_parse_with_language() -> None: diff --git a/libs/langchain/tests/unit_tests/chains/test_base.py b/libs/langchain/tests/unit_tests/chains/test_base.py index 0059528ed9f..231df806f9e 100644 --- a/libs/langchain/tests/unit_tests/chains/test_base.py +++ b/libs/langchain/tests/unit_tests/chains/test_base.py @@ -58,8 +58,7 @@ class FakeChain(Chain): ) -> dict[str, str]: if self.be_correct: return {"bar": "baz"} - else: - return {"baz": "bar"} + return {"baz": "bar"} def test_bad_inputs() -> None: diff --git a/libs/langchain/tests/unit_tests/evaluation/agents/test_eval_chain.py b/libs/langchain/tests/unit_tests/evaluation/agents/test_eval_chain.py index 85f3c869bd7..7395036f8fe 100644 --- a/libs/langchain/tests/unit_tests/evaluation/agents/test_eval_chain.py +++ b/libs/langchain/tests/unit_tests/evaluation/agents/test_eval_chain.py @@ -54,9 +54,8 @@ class _FakeTrajectoryChatModel(FakeChatModel): response = self.queries[list(self.queries.keys())[self.response_index]] self.response_index = self.response_index + 1 return response - else: - prompt = messages[0].content - return self.queries[prompt] + prompt = messages[0].content + return self.queries[prompt] def test_trajectory_output_parser_parse() -> None: diff --git a/libs/langchain/tests/unit_tests/llms/fake_llm.py b/libs/langchain/tests/unit_tests/llms/fake_llm.py index 6b4ec98b71e..a558d0faa69 100644 --- a/libs/langchain/tests/unit_tests/llms/fake_llm.py +++ b/libs/langchain/tests/unit_tests/llms/fake_llm.py @@ -45,8 +45,7 @@ class FakeLLM(LLM): return self.queries[prompt] if stop is None: return "foo" - else: - return "bar" + return "bar" @property def _identifying_params(self) -> dict[str, Any]: diff --git a/libs/langchain/tests/unit_tests/retrievers/sequential_retriever.py b/libs/langchain/tests/unit_tests/retrievers/sequential_retriever.py index fb1722cdc85..6452caff9f4 100644 --- a/libs/langchain/tests/unit_tests/retrievers/sequential_retriever.py +++ b/libs/langchain/tests/unit_tests/retrievers/sequential_retriever.py @@ -14,9 +14,8 @@ class SequentialRetriever(BaseRetriever): ) -> list[Document]: if self.response_index >= len(self.sequential_responses): return [] - else: - self.response_index += 1 - return self.sequential_responses[self.response_index - 1] + self.response_index += 1 + return self.sequential_responses[self.response_index - 1] async def _aget_relevant_documents( # type: ignore[override] self,