langchain: Add ruff rule RET (#31875)

All auto-fixes
See https://docs.astral.sh/ruff/rules/#flake8-return-ret

---------

Co-authored-by: Mason Daugherty <github@mdrxy.com>
This commit is contained in:
Christophe Bornet 2025-07-07 17:33:18 +02:00 committed by GitHub
parent fceebbb387
commit 56bbfd9723
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
91 changed files with 663 additions and 835 deletions

View File

@ -48,25 +48,25 @@ def __getattr__(name: str) -> Any:
_warn_on_import(name, replacement="langchain.agents.MRKLChain")
return MRKLChain
elif name == "ReActChain":
if name == "ReActChain":
from langchain.agents import ReActChain
_warn_on_import(name, replacement="langchain.agents.ReActChain")
return ReActChain
elif name == "SelfAskWithSearchChain":
if name == "SelfAskWithSearchChain":
from langchain.agents import SelfAskWithSearchChain
_warn_on_import(name, replacement="langchain.agents.SelfAskWithSearchChain")
return SelfAskWithSearchChain
elif name == "ConversationChain":
if name == "ConversationChain":
from langchain.chains import ConversationChain
_warn_on_import(name, replacement="langchain.chains.ConversationChain")
return ConversationChain
elif name == "LLMBashChain":
if name == "LLMBashChain":
msg = (
"This module has been moved to langchain-experimental. "
"For more details: "
@ -77,97 +77,97 @@ def __getattr__(name: str) -> Any:
)
raise ImportError(msg)
elif name == "LLMChain":
if name == "LLMChain":
from langchain.chains import LLMChain
_warn_on_import(name, replacement="langchain.chains.LLMChain")
return LLMChain
elif name == "LLMCheckerChain":
if name == "LLMCheckerChain":
from langchain.chains import LLMCheckerChain
_warn_on_import(name, replacement="langchain.chains.LLMCheckerChain")
return LLMCheckerChain
elif name == "LLMMathChain":
if name == "LLMMathChain":
from langchain.chains import LLMMathChain
_warn_on_import(name, replacement="langchain.chains.LLMMathChain")
return LLMMathChain
elif name == "QAWithSourcesChain":
if name == "QAWithSourcesChain":
from langchain.chains import QAWithSourcesChain
_warn_on_import(name, replacement="langchain.chains.QAWithSourcesChain")
return QAWithSourcesChain
elif name == "VectorDBQA":
if name == "VectorDBQA":
from langchain.chains import VectorDBQA
_warn_on_import(name, replacement="langchain.chains.VectorDBQA")
return VectorDBQA
elif name == "VectorDBQAWithSourcesChain":
if name == "VectorDBQAWithSourcesChain":
from langchain.chains import VectorDBQAWithSourcesChain
_warn_on_import(name, replacement="langchain.chains.VectorDBQAWithSourcesChain")
return VectorDBQAWithSourcesChain
elif name == "InMemoryDocstore":
if name == "InMemoryDocstore":
from langchain_community.docstore import InMemoryDocstore
_warn_on_import(name, replacement="langchain.docstore.InMemoryDocstore")
return InMemoryDocstore
elif name == "Wikipedia":
if name == "Wikipedia":
from langchain_community.docstore import Wikipedia
_warn_on_import(name, replacement="langchain.docstore.Wikipedia")
return Wikipedia
elif name == "Anthropic":
if name == "Anthropic":
from langchain_community.llms import Anthropic
_warn_on_import(name, replacement="langchain_community.llms.Anthropic")
return Anthropic
elif name == "Banana":
if name == "Banana":
from langchain_community.llms import Banana
_warn_on_import(name, replacement="langchain_community.llms.Banana")
return Banana
elif name == "CerebriumAI":
if name == "CerebriumAI":
from langchain_community.llms import CerebriumAI
_warn_on_import(name, replacement="langchain_community.llms.CerebriumAI")
return CerebriumAI
elif name == "Cohere":
if name == "Cohere":
from langchain_community.llms import Cohere
_warn_on_import(name, replacement="langchain_community.llms.Cohere")
return Cohere
elif name == "ForefrontAI":
if name == "ForefrontAI":
from langchain_community.llms import ForefrontAI
_warn_on_import(name, replacement="langchain_community.llms.ForefrontAI")
return ForefrontAI
elif name == "GooseAI":
if name == "GooseAI":
from langchain_community.llms import GooseAI
_warn_on_import(name, replacement="langchain_community.llms.GooseAI")
return GooseAI
elif name == "HuggingFaceHub":
if name == "HuggingFaceHub":
from langchain_community.llms import HuggingFaceHub
_warn_on_import(name, replacement="langchain_community.llms.HuggingFaceHub")
return HuggingFaceHub
elif name == "HuggingFaceTextGenInference":
if name == "HuggingFaceTextGenInference":
from langchain_community.llms import HuggingFaceTextGenInference
_warn_on_import(
@ -175,55 +175,55 @@ def __getattr__(name: str) -> Any:
)
return HuggingFaceTextGenInference
elif name == "LlamaCpp":
if name == "LlamaCpp":
from langchain_community.llms import LlamaCpp
_warn_on_import(name, replacement="langchain_community.llms.LlamaCpp")
return LlamaCpp
elif name == "Modal":
if name == "Modal":
from langchain_community.llms import Modal
_warn_on_import(name, replacement="langchain_community.llms.Modal")
return Modal
elif name == "OpenAI":
if name == "OpenAI":
from langchain_community.llms import OpenAI
_warn_on_import(name, replacement="langchain_community.llms.OpenAI")
return OpenAI
elif name == "Petals":
if name == "Petals":
from langchain_community.llms import Petals
_warn_on_import(name, replacement="langchain_community.llms.Petals")
return Petals
elif name == "PipelineAI":
if name == "PipelineAI":
from langchain_community.llms import PipelineAI
_warn_on_import(name, replacement="langchain_community.llms.PipelineAI")
return PipelineAI
elif name == "SagemakerEndpoint":
if name == "SagemakerEndpoint":
from langchain_community.llms import SagemakerEndpoint
_warn_on_import(name, replacement="langchain_community.llms.SagemakerEndpoint")
return SagemakerEndpoint
elif name == "StochasticAI":
if name == "StochasticAI":
from langchain_community.llms import StochasticAI
_warn_on_import(name, replacement="langchain_community.llms.StochasticAI")
return StochasticAI
elif name == "Writer":
if name == "Writer":
from langchain_community.llms import Writer
_warn_on_import(name, replacement="langchain_community.llms.Writer")
return Writer
elif name == "HuggingFacePipeline":
if name == "HuggingFacePipeline":
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
_warn_on_import(
@ -232,7 +232,7 @@ def __getattr__(name: str) -> Any:
)
return HuggingFacePipeline
elif name == "FewShotPromptTemplate":
if name == "FewShotPromptTemplate":
from langchain_core.prompts import FewShotPromptTemplate
_warn_on_import(
@ -240,7 +240,7 @@ def __getattr__(name: str) -> Any:
)
return FewShotPromptTemplate
elif name == "Prompt":
if name == "Prompt":
from langchain_core.prompts import PromptTemplate
_warn_on_import(name, replacement="langchain_core.prompts.PromptTemplate")
@ -248,19 +248,19 @@ def __getattr__(name: str) -> Any:
# it's renamed as prompt template anyways
# this is just for backwards compat
return PromptTemplate
elif name == "PromptTemplate":
if name == "PromptTemplate":
from langchain_core.prompts import PromptTemplate
_warn_on_import(name, replacement="langchain_core.prompts.PromptTemplate")
return PromptTemplate
elif name == "BasePromptTemplate":
if name == "BasePromptTemplate":
from langchain_core.prompts import BasePromptTemplate
_warn_on_import(name, replacement="langchain_core.prompts.BasePromptTemplate")
return BasePromptTemplate
elif name == "ArxivAPIWrapper":
if name == "ArxivAPIWrapper":
from langchain_community.utilities import ArxivAPIWrapper
_warn_on_import(
@ -268,7 +268,7 @@ def __getattr__(name: str) -> Any:
)
return ArxivAPIWrapper
elif name == "GoldenQueryAPIWrapper":
if name == "GoldenQueryAPIWrapper":
from langchain_community.utilities import GoldenQueryAPIWrapper
_warn_on_import(
@ -276,7 +276,7 @@ def __getattr__(name: str) -> Any:
)
return GoldenQueryAPIWrapper
elif name == "GoogleSearchAPIWrapper":
if name == "GoogleSearchAPIWrapper":
from langchain_community.utilities import GoogleSearchAPIWrapper
_warn_on_import(
@ -284,7 +284,7 @@ def __getattr__(name: str) -> Any:
)
return GoogleSearchAPIWrapper
elif name == "GoogleSerperAPIWrapper":
if name == "GoogleSerperAPIWrapper":
from langchain_community.utilities import GoogleSerperAPIWrapper
_warn_on_import(
@ -292,7 +292,7 @@ def __getattr__(name: str) -> Any:
)
return GoogleSerperAPIWrapper
elif name == "PowerBIDataset":
if name == "PowerBIDataset":
from langchain_community.utilities import PowerBIDataset
_warn_on_import(
@ -300,7 +300,7 @@ def __getattr__(name: str) -> Any:
)
return PowerBIDataset
elif name == "SearxSearchWrapper":
if name == "SearxSearchWrapper":
from langchain_community.utilities import SearxSearchWrapper
_warn_on_import(
@ -308,7 +308,7 @@ def __getattr__(name: str) -> Any:
)
return SearxSearchWrapper
elif name == "WikipediaAPIWrapper":
if name == "WikipediaAPIWrapper":
from langchain_community.utilities import WikipediaAPIWrapper
_warn_on_import(
@ -316,7 +316,7 @@ def __getattr__(name: str) -> Any:
)
return WikipediaAPIWrapper
elif name == "WolframAlphaAPIWrapper":
if name == "WolframAlphaAPIWrapper":
from langchain_community.utilities import WolframAlphaAPIWrapper
_warn_on_import(
@ -324,19 +324,19 @@ def __getattr__(name: str) -> Any:
)
return WolframAlphaAPIWrapper
elif name == "SQLDatabase":
if name == "SQLDatabase":
from langchain_community.utilities import SQLDatabase
_warn_on_import(name, replacement="langchain_community.utilities.SQLDatabase")
return SQLDatabase
elif name == "FAISS":
if name == "FAISS":
from langchain_community.vectorstores import FAISS
_warn_on_import(name, replacement="langchain_community.vectorstores.FAISS")
return FAISS
elif name == "ElasticVectorSearch":
if name == "ElasticVectorSearch":
from langchain_community.vectorstores import ElasticVectorSearch
_warn_on_import(
@ -345,7 +345,7 @@ def __getattr__(name: str) -> Any:
return ElasticVectorSearch
# For backwards compatibility
elif name == "SerpAPIChain" or name == "SerpAPIWrapper":
if name == "SerpAPIChain" or name == "SerpAPIWrapper":
from langchain_community.utilities import SerpAPIWrapper
_warn_on_import(
@ -353,7 +353,7 @@ def __getattr__(name: str) -> Any:
)
return SerpAPIWrapper
elif name == "verbose":
if name == "verbose":
from langchain.globals import _verbose
_warn_on_import(
@ -364,7 +364,7 @@ def __getattr__(name: str) -> Any:
)
return _verbose
elif name == "debug":
if name == "debug":
from langchain.globals import _debug
_warn_on_import(
@ -375,7 +375,7 @@ def __getattr__(name: str) -> Any:
)
return _debug
elif name == "llm_cache":
if name == "llm_cache":
from langchain.globals import _llm_cache
_warn_on_import(
@ -386,7 +386,6 @@ def __getattr__(name: str) -> Any:
)
return _llm_cache
else:
msg = f"Could not find: {name}"
raise AttributeError(msg)

View File

@ -137,7 +137,6 @@ class BaseSingleActionAgent(BaseModel):
return AgentFinish(
{"output": "Agent stopped due to iteration limit or time limit."}, ""
)
else:
msg = f"Got unsupported early_stopping_method `{early_stopping_method}`"
raise ValueError(msg)
@ -308,7 +307,6 @@ class BaseMultiActionAgent(BaseModel):
if early_stopping_method == "force":
# `force` just returns a constant string
return AgentFinish({"output": "Agent stopped due to max iterations."}, "")
else:
msg = f"Got unsupported early_stopping_method `{early_stopping_method}`"
raise ValueError(msg)
@ -815,8 +813,7 @@ class Agent(BaseSingleActionAgent):
"""
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
full_output = await self.llm_chain.apredict(callbacks=callbacks, **full_inputs)
agent_output = await self.output_parser.aparse(full_output)
return agent_output
return await self.output_parser.aparse(full_output)
def get_full_inputs(
self, intermediate_steps: list[tuple[AgentAction, str]], **kwargs: Any
@ -833,8 +830,7 @@ class Agent(BaseSingleActionAgent):
"""
thoughts = self._construct_scratchpad(intermediate_steps)
new_inputs = {"agent_scratchpad": thoughts, "stop": self._stop}
full_inputs = {**kwargs, **new_inputs}
return full_inputs
return {**kwargs, **new_inputs}
@property
def input_keys(self) -> list[str]:
@ -970,7 +966,7 @@ class Agent(BaseSingleActionAgent):
return AgentFinish(
{"output": "Agent stopped due to iteration limit or time limit."}, ""
)
elif early_stopping_method == "generate":
if early_stopping_method == "generate":
# Generate does one final forward pass
thoughts = ""
for action, observation in intermediate_steps:
@ -990,11 +986,9 @@ class Agent(BaseSingleActionAgent):
if isinstance(parsed_output, AgentFinish):
# If we can extract, we send the correct stuff
return parsed_output
else:
# If we can extract, but the tool is not the final tool,
# we just return the full output
return AgentFinish({"output": full_output}, full_output)
else:
msg = (
"early_stopping_method should be one of `force` or `generate`, "
f"got {early_stopping_method}"
@ -1179,7 +1173,6 @@ class AgentExecutor(Chain):
"""
if isinstance(self.agent, Runnable):
return cast(RunnableAgentType, self.agent)
else:
return self.agent
def save(self, file_path: Union[Path, str]) -> None:
@ -1249,7 +1242,6 @@ class AgentExecutor(Chain):
"""
if self.return_intermediate_steps:
return self._action_agent.return_values + ["intermediate_steps"]
else:
return self._action_agent.return_values
def lookup_tool(self, name: str) -> BaseTool:
@ -1304,10 +1296,7 @@ class AgentExecutor(Chain):
msg = "Expected a single AgentFinish output, but got multiple values."
raise ValueError(msg)
return values[-1]
else:
return [
(a.action, a.observation) for a in values if isinstance(a, AgentStep)
]
return [(a.action, a.observation) for a in values if isinstance(a, AgentStep)]
def _take_next_step(
self,
@ -1727,9 +1716,8 @@ class AgentExecutor(Chain):
and self.trim_intermediate_steps > 0
):
return intermediate_steps[-self.trim_intermediate_steps :]
elif callable(self.trim_intermediate_steps):
if callable(self.trim_intermediate_steps):
return self.trim_intermediate_steps(intermediate_steps)
else:
return intermediate_steps
@override

View File

@ -61,7 +61,6 @@ class ChatAgent(Agent):
f"(but I haven't seen any of it! I only see what "
f"you return as final answer):\n{agent_scratchpad}"
)
else:
return agent_scratchpad
@classmethod

View File

@ -39,11 +39,9 @@ class ConvoOutputParser(AgentOutputParser):
# If the action indicates a final answer, return an AgentFinish
if action == "Final Answer":
return AgentFinish({"output": action_input}, text)
else:
# Otherwise, return an AgentAction with the specified action and
# input
return AgentAction(action, action_input, text)
else:
# If the necessary keys aren't present in the response, raise an
# exception
msg = f"Missing 'action' or 'action_input' in LLM output: {text}"

View File

@ -23,7 +23,6 @@ def _convert_agent_action_to_messages(
return list(agent_action.message_log) + [
_create_function_message(agent_action, observation)
]
else:
return [AIMessage(content=agent_action.log)]

View File

@ -182,7 +182,7 @@ def create_json_chat_agent(
else:
llm_to_use = llm
agent = (
return (
RunnablePassthrough.assign(
agent_scratchpad=lambda x: format_log_to_messages(
x["intermediate_steps"], template_tool_response=template_tool_response
@ -192,4 +192,3 @@ def create_json_chat_agent(
| llm_to_use
| JSONAgentOutputParser()
)
return agent

View File

@ -55,7 +55,6 @@ class MRKLOutputParser(AgentOutputParser):
return AgentFinish(
{"output": text[start_index:end_index].strip()}, text[:end_index]
)
else:
msg = f"{FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE}: {text}"
raise OutputParserException(msg)
@ -69,7 +68,7 @@ class MRKLOutputParser(AgentOutputParser):
return AgentAction(action, tool_input, text)
elif includes_answer:
if includes_answer:
return AgentFinish(
{"output": text.split(FINAL_ANSWER_ACTION)[-1].strip()}, text
)
@ -82,7 +81,7 @@ class MRKLOutputParser(AgentOutputParser):
llm_output=text,
send_to_llm=True,
)
elif not re.search(
if not re.search(
r"[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)", text, re.DOTALL
):
msg = f"Could not parse LLM output: `{text}`"
@ -92,7 +91,6 @@ class MRKLOutputParser(AgentOutputParser):
llm_output=text,
send_to_llm=True,
)
else:
msg = f"Could not parse LLM output: `{text}`"
raise OutputParserException(msg)

View File

@ -128,7 +128,6 @@ def _get_assistants_tool(
"""
if _is_assistants_builtin_tool(tool):
return tool # type: ignore[return-value]
else:
return convert_to_openai_tool(tool)
@ -510,12 +509,11 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
for action, output in intermediate_steps
if action.tool_call_id in required_tool_call_ids
]
submit_tool_outputs = {
return {
"tool_outputs": tool_outputs,
"run_id": last_action.run_id,
"thread_id": last_action.thread_id,
}
return submit_tool_outputs
def _create_run(self, input_dict: dict) -> Any:
params = {
@ -558,12 +556,11 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
"run_metadata",
)
}
run = self.client.beta.threads.create_and_run(
return self.client.beta.threads.create_and_run(
assistant_id=self.assistant_id,
thread=thread,
**params,
)
return run
def _get_response(self, run: Any) -> Any:
# TODO: Pagination
@ -612,7 +609,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
run_id=run.id,
thread_id=run.thread_id,
)
elif run.status == "requires_action":
if run.status == "requires_action":
if not self.as_agent:
return run.required_action.submit_tool_outputs.tool_calls
actions = []
@ -639,7 +636,6 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
)
)
return actions
else:
run_info = json.dumps(run.dict(), indent=2)
msg = f"Unexpected run status: {run.status}. Full run info:\n\n{run_info})"
raise ValueError(msg)
@ -668,12 +664,11 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
for action, output in intermediate_steps
if action.tool_call_id in required_tool_call_ids
]
submit_tool_outputs = {
return {
"tool_outputs": tool_outputs,
"run_id": last_action.run_id,
"thread_id": last_action.thread_id,
}
return submit_tool_outputs
async def _acreate_run(self, input_dict: dict) -> Any:
params = {
@ -716,12 +711,11 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
"run_metadata",
)
}
run = await self.async_client.beta.threads.create_and_run(
return await self.async_client.beta.threads.create_and_run(
assistant_id=self.assistant_id,
thread=thread,
**params,
)
return run
async def _aget_response(self, run: Any) -> Any:
# TODO: Pagination
@ -766,7 +760,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
run_id=run.id,
thread_id=run.thread_id,
)
elif run.status == "requires_action":
if run.status == "requires_action":
if not self.as_agent:
return run.required_action.submit_tool_outputs.tool_calls
actions = []
@ -793,7 +787,6 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
)
)
return actions
else:
run_info = json.dumps(run.dict(), indent=2)
msg = f"Unexpected run status: {run.status}. Full run info:\n\n{run_info})"
raise ValueError(msg)

View File

@ -132,8 +132,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
messages,
callbacks=callbacks,
)
agent_decision = self.output_parser._parse_ai_message(predicted_message)
return agent_decision
return self.output_parser._parse_ai_message(predicted_message)
async def aplan(
self,
@ -164,8 +163,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
predicted_message = await self.llm.apredict_messages(
messages, functions=self.functions, callbacks=callbacks
)
agent_decision = self.output_parser._parse_ai_message(predicted_message)
return agent_decision
return self.output_parser._parse_ai_message(predicted_message)
def return_stopped_response(
self,
@ -192,17 +190,15 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
return AgentFinish(
{"output": "Agent stopped due to iteration limit or time limit."}, ""
)
elif early_stopping_method == "generate":
if early_stopping_method == "generate":
# Generate does one final forward pass
agent_decision = self.plan(
intermediate_steps, with_functions=False, **kwargs
)
if isinstance(agent_decision, AgentFinish):
return agent_decision
else:
msg = f"got AgentAction with no functions provided: {agent_decision}"
raise ValueError(msg)
else:
msg = (
"early_stopping_method should be one of `force` or `generate`, "
f"got {early_stopping_method}"
@ -358,7 +354,7 @@ def create_openai_functions_agent(
)
raise ValueError(msg)
llm_with_tools = llm.bind(functions=[convert_to_openai_function(t) for t in tools])
agent = (
return (
RunnablePassthrough.assign(
agent_scratchpad=lambda x: format_to_openai_function_messages(
x["intermediate_steps"]
@ -368,4 +364,3 @@ def create_openai_functions_agent(
| llm_with_tools
| OpenAIFunctionsAgentOutputParser()
)
return agent

View File

@ -224,8 +224,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
predicted_message = self.llm.predict_messages(
messages, functions=self.functions, callbacks=callbacks
)
agent_decision = _parse_ai_message(predicted_message)
return agent_decision
return _parse_ai_message(predicted_message)
async def aplan(
self,
@ -254,8 +253,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
predicted_message = await self.llm.apredict_messages(
messages, functions=self.functions, callbacks=callbacks
)
agent_decision = _parse_ai_message(predicted_message)
return agent_decision
return _parse_ai_message(predicted_message)
@classmethod
def create_prompt(

View File

@ -96,7 +96,7 @@ def create_openai_tools_agent(
tools=[convert_to_openai_tool(tool, strict=strict) for tool in tools]
)
agent = (
return (
RunnablePassthrough.assign(
agent_scratchpad=lambda x: format_to_openai_tool_messages(
x["intermediate_steps"]
@ -106,4 +106,3 @@ def create_openai_tools_agent(
| llm_with_tools
| OpenAIToolsAgentOutputParser()
)
return agent

View File

@ -49,7 +49,6 @@ class JSONAgentOutputParser(AgentOutputParser):
response = response[0]
if response["action"] == "Final Answer":
return AgentFinish({"output": response["action_input"]}, text)
else:
action_input = response.get("action_input", {})
if action_input is None:
action_input = {}

View File

@ -65,7 +65,7 @@ class ReActSingleInputOutputParser(AgentOutputParser):
return AgentAction(action, tool_input, text)
elif includes_answer:
if includes_answer:
return AgentFinish(
{"output": text.split(FINAL_ANSWER_ACTION)[-1].strip()}, text
)
@ -78,7 +78,7 @@ class ReActSingleInputOutputParser(AgentOutputParser):
llm_output=text,
send_to_llm=True,
)
elif not re.search(
if not re.search(
r"[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)", text, re.DOTALL
):
msg = f"Could not parse LLM output: `{text}`"
@ -88,7 +88,6 @@ class ReActSingleInputOutputParser(AgentOutputParser):
llm_output=text,
send_to_llm=True,
)
else:
msg = f"Could not parse LLM output: `{text}`"
raise OutputParserException(msg)

View File

@ -36,12 +36,11 @@ class XMLAgentOutputParser(AgentOutputParser):
if "</tool_input>" in _tool_input:
_tool_input = _tool_input.split("</tool_input>")[0]
return AgentAction(tool=_tool, tool_input=_tool_input, log=text)
elif "<final_answer>" in text:
if "<final_answer>" in text:
_, answer = text.split("<final_answer>")
if "</final_answer>" in answer:
answer = answer.split("</final_answer>")[0]
return AgentFinish(return_values={"output": answer}, log=text)
else:
raise ValueError
def get_format_instructions(self) -> str:

View File

@ -134,7 +134,7 @@ def create_react_agent(
else:
llm_with_stop = llm
output_parser = output_parser or ReActSingleInputOutputParser()
agent = (
return (
RunnablePassthrough.assign(
agent_scratchpad=lambda x: format_log_to_str(x["intermediate_steps"]),
)
@ -142,4 +142,3 @@ def create_react_agent(
| llm_with_stop
| output_parser
)
return agent

View File

@ -96,7 +96,6 @@ class DocstoreExplorer:
if isinstance(result, Document):
self.document = result
return self._summary
else:
self.document = None
return result
@ -113,9 +112,8 @@ class DocstoreExplorer:
lookups = [p for p in self._paragraphs if self.lookup_str in p.lower()]
if len(lookups) == 0:
return "No Results"
elif self.lookup_index >= len(lookups):
if self.lookup_index >= len(lookups):
return "No More Results"
else:
result_prefix = f"(Result {self.lookup_index + 1}/{len(lookups)})"
return f"{result_prefix} {lookups[self.lookup_index]}"

View File

@ -26,7 +26,6 @@ class ReActOutputParser(AgentOutputParser):
action, action_input = re_matches.group(1), re_matches.group(2)
if action == "Finish":
return AgentFinish({"output": action_input}, text)
else:
return AgentAction(action, action_input, text)
@property

View File

@ -195,7 +195,7 @@ def create_self_ask_with_search_agent(
raise ValueError(msg)
llm_with_stop = llm.bind(stop=["\nIntermediate answer:"])
agent = (
return (
RunnablePassthrough.assign(
agent_scratchpad=lambda x: format_log_to_str(
x["intermediate_steps"],
@ -209,4 +209,3 @@ def create_self_ask_with_search_agent(
| llm_with_stop
| SelfAskOutputParser()
)
return agent

View File

@ -62,7 +62,6 @@ class StructuredChatAgent(Agent):
f"(but I haven't seen any of it! I only see what "
f"you return as final answer):\n{agent_scratchpad}"
)
else:
return agent_scratchpad
@classmethod
@ -292,7 +291,7 @@ def create_structured_chat_agent(
else:
llm_with_stop = llm
agent = (
return (
RunnablePassthrough.assign(
agent_scratchpad=lambda x: format_log_to_str(x["intermediate_steps"]),
)
@ -300,4 +299,3 @@ def create_structured_chat_agent(
| llm_with_stop
| JSONAgentOutputParser()
)
return agent

View File

@ -42,11 +42,9 @@ class StructuredChatOutputParser(AgentOutputParser):
response = response[0]
if response["action"] == "Final Answer":
return AgentFinish({"output": response["action_input"]}, text)
else:
return AgentAction(
response["action"], response.get("action_input", {}), text
)
else:
return AgentFinish({"output": text}, text)
except Exception as e:
msg = f"Could not parse LLM output: {text}"
@ -93,9 +91,8 @@ class StructuredChatOutputParserWithRetries(AgentOutputParser):
llm=llm, parser=base_parser
)
return cls(output_fixing_parser=output_fixing_parser)
elif base_parser is not None:
if base_parser is not None:
return cls(base_parser=base_parser)
else:
return cls()
@property

View File

@ -100,7 +100,7 @@ def create_tool_calling_agent(
)
llm_with_tools = llm.bind_tools(tools)
agent = (
return (
RunnablePassthrough.assign(
agent_scratchpad=lambda x: message_formatter(x["intermediate_steps"])
)
@ -108,4 +108,3 @@ def create_tool_calling_agent(
| llm_with_tools
| ToolsAgentOutputParser()
)
return agent

View File

@ -221,7 +221,7 @@ def create_xml_agent(
else:
llm_with_stop = llm
agent = (
return (
RunnablePassthrough.assign(
agent_scratchpad=lambda x: format_xml(x["intermediate_steps"]),
)
@ -229,4 +229,3 @@ def create_xml_agent(
| llm_with_stop
| XMLAgentOutputParser()
)
return agent

View File

@ -24,7 +24,6 @@ class AsyncFinalIteratorCallbackHandler(AsyncIteratorCallbackHandler):
def check_if_answer_reached(self) -> bool:
if self.strip_tokens:
return self.last_tokens_stripped == self.answer_prefix_tokens_stripped
else:
return self.last_tokens == self.answer_prefix_tokens
def __init__(

View File

@ -25,7 +25,6 @@ class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
def check_if_answer_reached(self) -> bool:
if self.strip_tokens:
return self.last_tokens_stripped == self.answer_prefix_tokens_stripped
else:
return self.last_tokens == self.answer_prefix_tokens
def __init__(

View File

@ -261,7 +261,6 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
"""
if verbose is None:
return _get_verbosity()
else:
return verbose
@property
@ -474,7 +473,6 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
self.memory.save_context(inputs, outputs)
if return_only_outputs:
return outputs
else:
return {**inputs, **outputs}
async def aprep_outputs(
@ -500,7 +498,6 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
await self.memory.asave_context(inputs, outputs)
if return_only_outputs:
return outputs
else:
return {**inputs, **outputs}
def prep_inputs(self, inputs: Union[dict[str, Any], Any]) -> dict[str, str]:
@ -628,7 +625,6 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
" but none were provided."
)
raise ValueError(msg)
else:
msg = (
f"`run` supported with either positional arguments or keyword arguments"
f" but not both. Got args: {args} and kwargs: {kwargs}."
@ -687,7 +683,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
f"one output key. Got {self.output_keys}."
)
raise ValueError(msg)
elif args and not kwargs:
if args and not kwargs:
if len(args) != 1:
msg = "`run` supports only one positional argument."
raise ValueError(msg)

View File

@ -208,9 +208,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
if isinstance(self.reduce_documents_chain, ReduceDocumentsChain):
if self.reduce_documents_chain.collapse_documents_chain:
return self.reduce_documents_chain.collapse_documents_chain
else:
return self.reduce_documents_chain.combine_documents_chain
else:
msg = (
f"`reduce_documents_chain` is of type "
f"{type(self.reduce_documents_chain)} so it does not have "
@ -223,7 +221,6 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
"""Kept for backward compatibility."""
if isinstance(self.reduce_documents_chain, ReduceDocumentsChain):
return self.reduce_documents_chain.combine_documents_chain
else:
msg = (
f"`reduce_documents_chain` is of type "
f"{type(self.reduce_documents_chain)} so it does not have "

View File

@ -225,7 +225,6 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
def _collapse_chain(self) -> BaseCombineDocumentsChain:
if self.collapse_documents_chain is not None:
return self.collapse_documents_chain
else:
return self.combine_documents_chain
def combine_docs(

View File

@ -222,8 +222,7 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
base_inputs: dict = {
self.document_variable_name: self.document_prompt.format(**document_info)
}
inputs = {**base_inputs, **kwargs}
return inputs
return {**base_inputs, **kwargs}
@property
def _chain_type(self) -> str:

View File

@ -201,7 +201,6 @@ class ConstitutionalChain(Chain):
) -> list[ConstitutionalPrinciple]:
if names is None:
return list(PRINCIPLES.values())
else:
return [PRINCIPLES[name] for name in names]
@classmethod

View File

@ -80,7 +80,6 @@ class ElasticsearchDatabaseChain(Chain):
"""
if not self.return_intermediate_steps:
return [self.output_key]
else:
return [self.output_key, INTERMEDIATE_STEPS_KEY]
def _list_indices(self) -> list[str]:

View File

@ -47,7 +47,6 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
"""Output keys for Hyde's LLM chain."""
if isinstance(self.llm_chain, LLMChain):
return self.llm_chain.output_keys
else:
return ["text"]
def embed_documents(self, texts: list[str]) -> list[list[float]]:

View File

@ -116,7 +116,6 @@ class LLMChain(Chain):
"""
if self.return_final_only:
return [self.output_key]
else:
return [self.output_key, "full_generation"]
def _call(
@ -142,7 +141,6 @@ class LLMChain(Chain):
callbacks=callbacks,
**self.llm_kwargs,
)
else:
results = self.llm.bind(stop=stop, **self.llm_kwargs).batch(
cast(list, prompts), {"callbacks": callbacks}
)
@ -169,7 +167,6 @@ class LLMChain(Chain):
callbacks=callbacks,
**self.llm_kwargs,
)
else:
results = await self.llm.bind(stop=stop, **self.llm_kwargs).abatch(
cast(list, prompts), {"callbacks": callbacks}
)
@ -344,7 +341,6 @@ class LLMChain(Chain):
result = self.predict(callbacks=callbacks, **kwargs)
if self.prompt.output_parser is not None:
return self.prompt.output_parser.parse(result)
else:
return result
async def apredict_and_parse(
@ -358,7 +354,6 @@ class LLMChain(Chain):
result = await self.apredict(callbacks=callbacks, **kwargs)
if self.prompt.output_parser is not None:
return self.prompt.output_parser.parse(result)
else:
return result
def apply_and_parse(
@ -380,7 +375,6 @@ class LLMChain(Chain):
self.prompt.output_parser.parse(res[self.output_key])
for res in generation
]
else:
return generation
async def aapply_and_parse(
@ -411,13 +405,12 @@ class LLMChain(Chain):
def _get_language_model(llm_like: Runnable) -> BaseLanguageModel:
if isinstance(llm_like, BaseLanguageModel):
return llm_like
elif isinstance(llm_like, RunnableBinding):
if isinstance(llm_like, RunnableBinding):
return _get_language_model(llm_like.bound)
elif isinstance(llm_like, RunnableWithFallbacks):
if isinstance(llm_like, RunnableWithFallbacks):
return _get_language_model(llm_like.runnable)
elif isinstance(llm_like, (RunnableBranch, DynamicRunnable)):
if isinstance(llm_like, (RunnableBranch, DynamicRunnable)):
return _get_language_model(llm_like.default)
else:
msg = (
f"Unable to extract BaseLanguageModel from llm_like object of type "
f"{type(llm_like)}"

View File

@ -55,13 +55,12 @@ def _load_question_to_checked_assertions_chain(
check_assertions_chain,
revised_answer_chain,
]
question_to_checked_assertions_chain = SequentialChain(
return SequentialChain(
chains=chains, # type: ignore[arg-type]
input_variables=["question"],
output_variables=["revised_statement"],
verbose=True,
)
return question_to_checked_assertions_chain
@deprecated(

View File

@ -32,7 +32,7 @@ def _load_sequential_chain(
are_all_true_prompt: PromptTemplate,
verbose: bool = False,
) -> SequentialChain:
chain = SequentialChain(
return SequentialChain(
chains=[
LLMChain(
llm=llm,
@ -63,7 +63,6 @@ def _load_sequential_chain(
output_variables=["all_true", "revised_summary"],
verbose=verbose,
)
return chain
@deprecated(

View File

@ -311,7 +311,6 @@ def _load_llm_math_chain(config: dict, **kwargs: Any) -> LLMMathChain:
prompt = load_prompt(config.pop("prompt_path"))
if llm_chain:
return LLMMathChain(llm_chain=llm_chain, prompt=prompt, **config) # type: ignore[arg-type]
else:
return LLMMathChain(llm=llm, prompt=prompt, **config)
@ -609,7 +608,6 @@ def _load_llm_requests_chain(config: dict, **kwargs: Any) -> LLMRequestsChain:
return LLMRequestsChain(
llm_chain=llm_chain, requests_wrapper=requests_wrapper, **config
)
else:
return LLMRequestsChain(llm_chain=llm_chain, **config)

View File

@ -100,7 +100,6 @@ class OpenAIModerationChain(Chain):
error_str = "Text was found that violates OpenAI's content policy."
if self.error:
raise ValueError(error_str)
else:
return error_str
return text

View File

@ -132,7 +132,7 @@ def create_openai_fn_chain(
}
if len(openai_functions) == 1 and enforce_single_function_usage:
llm_kwargs["function_call"] = {"name": openai_functions[0]["name"]}
llm_chain = LLMChain(
return LLMChain(
llm=llm,
prompt=prompt,
output_parser=output_parser,
@ -140,7 +140,6 @@ def create_openai_fn_chain(
output_key=output_key,
**kwargs,
)
return llm_chain
@deprecated(

View File

@ -149,10 +149,9 @@ def create_citation_fuzzy_match_chain(llm: BaseLanguageModel) -> LLMChain:
]
prompt = ChatPromptTemplate(messages=messages) # type: ignore[arg-type]
chain = LLMChain(
return LLMChain(
llm=llm,
prompt=prompt,
llm_kwargs=llm_kwargs,
output_parser=output_parser,
)
return chain

View File

@ -103,7 +103,7 @@ def create_extraction_chain(
extraction_prompt = prompt or ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE)
output_parser = JsonKeyOutputFunctionsParser(key_name="info")
llm_kwargs = get_llm_kwargs(function)
chain = LLMChain(
return LLMChain(
llm=llm,
prompt=extraction_prompt,
llm_kwargs=llm_kwargs,
@ -111,7 +111,6 @@ def create_extraction_chain(
tags=tags,
verbose=verbose,
)
return chain
@deprecated(
@ -187,11 +186,10 @@ def create_extraction_chain_pydantic(
pydantic_schema=PydanticSchema, attr_name="info"
)
llm_kwargs = get_llm_kwargs(function)
chain = LLMChain(
return LLMChain(
llm=llm,
prompt=extraction_prompt,
llm_kwargs=llm_kwargs,
output_parser=output_parser,
verbose=verbose,
)
return chain

View File

@ -100,14 +100,13 @@ def create_qa_with_structure_chain(
]
prompt = prompt or ChatPromptTemplate(messages=messages) # type: ignore[arg-type]
chain = LLMChain(
return LLMChain(
llm=llm,
prompt=prompt,
llm_kwargs=llm_kwargs,
output_parser=_output_parser,
verbose=verbose,
)
return chain
@deprecated(

View File

@ -91,14 +91,13 @@ def create_tagging_chain(
prompt = prompt or ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
output_parser = JsonOutputFunctionsParser()
llm_kwargs = get_llm_kwargs(function)
chain = LLMChain(
return LLMChain(
llm=llm,
prompt=prompt,
llm_kwargs=llm_kwargs,
output_parser=output_parser,
**kwargs,
)
return chain
@deprecated(
@ -164,11 +163,10 @@ def create_tagging_chain_pydantic(
prompt = prompt or ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
output_parser = PydanticOutputFunctionsParser(pydantic_schema=pydantic_schema)
llm_kwargs = get_llm_kwargs(function)
chain = LLMChain(
return LLMChain(
llm=llm,
prompt=prompt,
llm_kwargs=llm_kwargs,
output_parser=output_parser,
**kwargs,
)
return chain

View File

@ -76,5 +76,4 @@ def create_extraction_chain_pydantic(
functions = [convert_pydantic_to_openai_function(p) for p in pydantic_schemas]
tools = [{"type": "function", "function": d} for d in functions]
model = llm.bind(tools=tools)
chain = prompt | model | PydanticToolsParser(tools=pydantic_schemas)
return chain
return prompt | model | PydanticToolsParser(tools=pydantic_schemas)

View File

@ -91,13 +91,12 @@ class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
filter_directive = cast(
Optional[FilterDirective], get_parser().parse(raw_filter)
)
fixed = fix_filter_directive(
return fix_filter_directive(
filter_directive,
allowed_comparators=allowed_comparators,
allowed_operators=allowed_operators,
allowed_attributes=allowed_attributes,
)
return fixed
else:
ast_parse = get_parser(
@ -131,13 +130,13 @@ def fix_filter_directive(
) or not filter:
return filter
elif isinstance(filter, Comparison):
if isinstance(filter, Comparison):
if allowed_comparators and filter.comparator not in allowed_comparators:
return None
if allowed_attributes and filter.attribute not in allowed_attributes:
return None
return filter
elif isinstance(filter, Operation):
if isinstance(filter, Operation):
if allowed_operators and filter.operator not in allowed_operators:
return None
args = [
@ -155,14 +154,12 @@ def fix_filter_directive(
]
if not args:
return None
elif len(args) == 1 and filter.operator in (Operator.AND, Operator.OR):
if len(args) == 1 and filter.operator in (Operator.AND, Operator.OR):
return args[0]
else:
return Operation(
operator=filter.operator,
arguments=args,
)
else:
return filter

View File

@ -101,9 +101,8 @@ class QueryTransformer(Transformer):
)
raise ValueError(msg)
return Comparison(comparator=func, attribute=args[0], value=args[1])
elif len(args) == 1 and func in (Operator.AND, Operator.OR):
if len(args) == 1 and func in (Operator.AND, Operator.OR):
return args[0]
else:
return Operation(operator=func, arguments=args)
def _match_func_name(self, func_name: str) -> Union[Operator, Comparator]:
@ -118,7 +117,7 @@ class QueryTransformer(Transformer):
)
raise ValueError(msg)
return Comparator(func_name)
elif func_name in set(Operator):
if func_name in set(Operator):
if (
self.allowed_operators is not None
and func_name not in self.allowed_operators
@ -129,7 +128,6 @@ class QueryTransformer(Transformer):
)
raise ValueError(msg)
return Operator(func_name)
else:
msg = (
f"Received unrecognized function {func_name}. Valid functions are "
f"{list(Operator) + list(Comparator)}"

View File

@ -60,10 +60,8 @@ def create_retrieval_chain(
else:
retrieval_docs = (lambda x: x["input"]) | retriever
retrieval_chain = (
return (
RunnablePassthrough.assign(
context=retrieval_docs.with_config(run_name="retrieve_documents"),
).assign(answer=combine_docs_chain)
).with_config(run_name="retrieval_chain")
return retrieval_chain

View File

@ -157,7 +157,6 @@ class BaseRetrievalQA(Chain):
if self.return_source_documents:
return {self.output_key: answer, "source_documents": docs}
else:
return {self.output_key: answer}
@abstractmethod
@ -200,7 +199,6 @@ class BaseRetrievalQA(Chain):
if self.return_source_documents:
return {self.output_key: answer, "source_documents": docs}
else:
return {self.output_key: answer}

View File

@ -97,13 +97,12 @@ class MultiRouteChain(Chain):
)
if not route.destination:
return self.default_chain(route.next_inputs, callbacks=callbacks)
elif route.destination in self.destination_chains:
if route.destination in self.destination_chains:
return self.destination_chains[route.destination](
route.next_inputs, callbacks=callbacks
)
elif self.silent_errors:
if self.silent_errors:
return self.default_chain(route.next_inputs, callbacks=callbacks)
else:
msg = f"Received invalid destination chain name '{route.destination}'"
raise ValueError(msg)
@ -123,14 +122,13 @@ class MultiRouteChain(Chain):
return await self.default_chain.acall(
route.next_inputs, callbacks=callbacks
)
elif route.destination in self.destination_chains:
if route.destination in self.destination_chains:
return await self.destination_chains[route.destination].acall(
route.next_inputs, callbacks=callbacks
)
elif self.silent_errors:
if self.silent_errors:
return await self.default_chain.acall(
route.next_inputs, callbacks=callbacks
)
else:
msg = f"Received invalid destination chain name '{route.destination}'"
raise ValueError(msg)

View File

@ -136,11 +136,10 @@ class LLMRouterChain(RouterChain):
callbacks = _run_manager.get_child()
prediction = self.llm_chain.predict(callbacks=callbacks, **inputs)
output = cast(
return cast(
dict[str, Any],
self.llm_chain.prompt.output_parser.parse(prediction),
)
return output
async def _acall(
self,
@ -149,11 +148,10 @@ class LLMRouterChain(RouterChain):
) -> dict[str, Any]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
output = cast(
return cast(
dict[str, Any],
await self.llm_chain.apredict_and_parse(callbacks=callbacks, **inputs),
)
return output
@classmethod
def from_llm(

View File

@ -141,7 +141,6 @@ def create_sql_query_chain(
f"{db.dialect}"
)
raise ValueError(msg)
else:
table_info_kwargs["get_col_comments"] = True
inputs = {

View File

@ -143,7 +143,6 @@ def create_openai_fn_runnable(
output_parser = output_parser or get_openai_output_parser(functions)
if prompt:
return prompt | llm.bind(**llm_kwargs_) | output_parser
else:
return llm.bind(**llm_kwargs_) | output_parser
@ -413,7 +412,7 @@ def create_structured_output_runnable(
first_tool_only=return_single,
)
elif mode == "openai-functions":
if mode == "openai-functions":
return _create_openai_functions_structured_output_runnable(
output_schema,
llm,
@ -422,7 +421,7 @@ def create_structured_output_runnable(
enforce_single_function_usage=force_function_usage,
**kwargs, # llm-specific kwargs
)
elif mode == "openai-json":
if mode == "openai-json":
if force_function_usage:
msg = (
"enforce_single_function_usage is not supported for mode='openai-json'."
@ -431,7 +430,6 @@ def create_structured_output_runnable(
return _create_openai_json_runnable(
output_schema, llm, prompt=prompt, output_parser=output_parser, **kwargs
)
else:
msg = (
f"Invalid mode {mode}. Expected one of 'openai-tools', 'openai-functions', "
f"'openai-json'."
@ -460,7 +458,6 @@ def _create_openai_tools_runnable(
)
if prompt:
return prompt | llm.bind(**llm_kwargs) | output_parser
else:
return llm.bind(**llm_kwargs) | output_parser
@ -535,7 +532,6 @@ def _create_openai_json_runnable(
prompt = prompt.partial(output_schema=json.dumps(schema_as_dict, indent=2))
return prompt | llm | output_parser
else:
return llm | output_parser

View File

@ -77,7 +77,6 @@ class TransformChain(Chain):
) -> dict[str, Any]:
if self.atransform_cb is not None:
return await self.atransform_cb(inputs)
else:
self._log_once(
"TransformChain's atransform is not provided, falling"
" back to synchronous transform"

View File

@ -322,7 +322,6 @@ def init_chat_model(
return _init_chat_model_helper(
cast(str, model), model_provider=model_provider, **kwargs
)
else:
if model:
kwargs["model"] = model
if model_provider:
@ -343,42 +342,42 @@ def _init_chat_model_helper(
from langchain_openai import ChatOpenAI
return ChatOpenAI(model=model, **kwargs)
elif model_provider == "anthropic":
if model_provider == "anthropic":
_check_pkg("langchain_anthropic")
from langchain_anthropic import ChatAnthropic
return ChatAnthropic(model=model, **kwargs) # type: ignore[call-arg,unused-ignore]
elif model_provider == "azure_openai":
if model_provider == "azure_openai":
_check_pkg("langchain_openai")
from langchain_openai import AzureChatOpenAI
return AzureChatOpenAI(model=model, **kwargs)
elif model_provider == "azure_ai":
if model_provider == "azure_ai":
_check_pkg("langchain_azure_ai")
from langchain_azure_ai.chat_models import AzureAIChatCompletionsModel
return AzureAIChatCompletionsModel(model=model, **kwargs)
elif model_provider == "cohere":
if model_provider == "cohere":
_check_pkg("langchain_cohere")
from langchain_cohere import ChatCohere
return ChatCohere(model=model, **kwargs)
elif model_provider == "google_vertexai":
if model_provider == "google_vertexai":
_check_pkg("langchain_google_vertexai")
from langchain_google_vertexai import ChatVertexAI
return ChatVertexAI(model=model, **kwargs)
elif model_provider == "google_genai":
if model_provider == "google_genai":
_check_pkg("langchain_google_genai")
from langchain_google_genai import ChatGoogleGenerativeAI
return ChatGoogleGenerativeAI(model=model, **kwargs)
elif model_provider == "fireworks":
if model_provider == "fireworks":
_check_pkg("langchain_fireworks")
from langchain_fireworks import ChatFireworks
return ChatFireworks(model=model, **kwargs)
elif model_provider == "ollama":
if model_provider == "ollama":
try:
_check_pkg("langchain_ollama")
from langchain_ollama import ChatOllama
@ -393,72 +392,70 @@ def _init_chat_model_helper(
_check_pkg("langchain_ollama")
return ChatOllama(model=model, **kwargs)
elif model_provider == "together":
if model_provider == "together":
_check_pkg("langchain_together")
from langchain_together import ChatTogether
return ChatTogether(model=model, **kwargs)
elif model_provider == "mistralai":
if model_provider == "mistralai":
_check_pkg("langchain_mistralai")
from langchain_mistralai import ChatMistralAI
return ChatMistralAI(model=model, **kwargs) # type: ignore[call-arg,unused-ignore]
elif model_provider == "huggingface":
if model_provider == "huggingface":
_check_pkg("langchain_huggingface")
from langchain_huggingface import ChatHuggingFace
return ChatHuggingFace(model_id=model, **kwargs)
elif model_provider == "groq":
if model_provider == "groq":
_check_pkg("langchain_groq")
from langchain_groq import ChatGroq
return ChatGroq(model=model, **kwargs)
elif model_provider == "bedrock":
if model_provider == "bedrock":
_check_pkg("langchain_aws")
from langchain_aws import ChatBedrock
# TODO: update to use model= once ChatBedrock supports
return ChatBedrock(model_id=model, **kwargs)
elif model_provider == "bedrock_converse":
if model_provider == "bedrock_converse":
_check_pkg("langchain_aws")
from langchain_aws import ChatBedrockConverse
return ChatBedrockConverse(model=model, **kwargs)
elif model_provider == "google_anthropic_vertex":
if model_provider == "google_anthropic_vertex":
_check_pkg("langchain_google_vertexai")
from langchain_google_vertexai.model_garden import ChatAnthropicVertex
return ChatAnthropicVertex(model=model, **kwargs)
elif model_provider == "deepseek":
if model_provider == "deepseek":
_check_pkg("langchain_deepseek", pkg_kebab="langchain-deepseek")
from langchain_deepseek import ChatDeepSeek
return ChatDeepSeek(model=model, **kwargs)
elif model_provider == "nvidia":
if model_provider == "nvidia":
_check_pkg("langchain_nvidia_ai_endpoints")
from langchain_nvidia_ai_endpoints import ChatNVIDIA
return ChatNVIDIA(model=model, **kwargs)
elif model_provider == "ibm":
if model_provider == "ibm":
_check_pkg("langchain_ibm")
from langchain_ibm import ChatWatsonx
return ChatWatsonx(model_id=model, **kwargs)
elif model_provider == "xai":
if model_provider == "xai":
_check_pkg("langchain_xai")
from langchain_xai import ChatXAI
return ChatXAI(model=model, **kwargs)
elif model_provider == "perplexity":
if model_provider == "perplexity":
_check_pkg("langchain_perplexity")
from langchain_perplexity import ChatPerplexity
return ChatPerplexity(model=model, **kwargs)
else:
supported = ", ".join(_SUPPORTED_PROVIDERS)
msg = (
f"Unsupported {model_provider=}.\n\nSupported model providers are: "
f"{supported}"
f"Unsupported {model_provider=}.\n\nSupported model providers are: {supported}"
)
raise ValueError(msg)
@ -490,25 +487,24 @@ _SUPPORTED_PROVIDERS = {
def _attempt_infer_model_provider(model_name: str) -> Optional[str]:
if any(model_name.startswith(pre) for pre in ("gpt-3", "gpt-4", "o1", "o3")):
return "openai"
elif model_name.startswith("claude"):
if model_name.startswith("claude"):
return "anthropic"
elif model_name.startswith("command"):
if model_name.startswith("command"):
return "cohere"
elif model_name.startswith("accounts/fireworks"):
if model_name.startswith("accounts/fireworks"):
return "fireworks"
elif model_name.startswith("gemini"):
if model_name.startswith("gemini"):
return "google_vertexai"
elif model_name.startswith("amazon."):
if model_name.startswith("amazon."):
return "bedrock"
elif model_name.startswith("mistral"):
if model_name.startswith("mistral"):
return "mistralai"
elif model_name.startswith("deepseek"):
if model_name.startswith("deepseek"):
return "deepseek"
elif model_name.startswith("grok"):
if model_name.startswith("grok"):
return "xai"
elif model_name.startswith("sonar"):
if model_name.startswith("sonar"):
return "perplexity"
else:
return None
@ -595,9 +591,8 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
)
return queue
elif self._default_config and (model := self._model()) and hasattr(model, name):
if self._default_config and (model := self._model()) and hasattr(model, name):
return getattr(model, name)
else:
msg = f"{name} is not a BaseChatModel attribute"
if self._default_config:
msg += " and is not implemented on the default model"
@ -728,7 +723,6 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
)
# If multiple configs default to Runnable.batch which uses executor to invoke
# in parallel.
else:
return super().batch(
inputs, config=config, return_exceptions=return_exceptions, **kwargs
)
@ -751,7 +745,6 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
)
# If multiple configs default to Runnable.batch which uses executor to invoke
# in parallel.
else:
return await super().abatch(
inputs, config=config, return_exceptions=return_exceptions, **kwargs
)

View File

@ -192,35 +192,34 @@ def init_embeddings(
from langchain_openai import OpenAIEmbeddings
return OpenAIEmbeddings(model=model_name, **kwargs)
elif provider == "azure_openai":
if provider == "azure_openai":
from langchain_openai import AzureOpenAIEmbeddings
return AzureOpenAIEmbeddings(model=model_name, **kwargs)
elif provider == "google_vertexai":
if provider == "google_vertexai":
from langchain_google_vertexai import VertexAIEmbeddings
return VertexAIEmbeddings(model=model_name, **kwargs)
elif provider == "bedrock":
if provider == "bedrock":
from langchain_aws import BedrockEmbeddings
return BedrockEmbeddings(model_id=model_name, **kwargs)
elif provider == "cohere":
if provider == "cohere":
from langchain_cohere import CohereEmbeddings
return CohereEmbeddings(model=model_name, **kwargs)
elif provider == "mistralai":
if provider == "mistralai":
from langchain_mistralai import MistralAIEmbeddings
return MistralAIEmbeddings(model=model_name, **kwargs)
elif provider == "huggingface":
if provider == "huggingface":
from langchain_huggingface import HuggingFaceEmbeddings
return HuggingFaceEmbeddings(model_name=model_name, **kwargs)
elif provider == "ollama":
if provider == "ollama":
from langchain_ollama import OllamaEmbeddings
return OllamaEmbeddings(model=model_name, **kwargs)
else:
msg = (
f"Provider '{provider}' is not supported.\n"
f"Supported providers and their required packages:\n"

View File

@ -70,7 +70,7 @@ def resolve_pairwise_criteria(
Criteria.DEPTH,
]
return {k.value: _SUPPORTED_CRITERIA[k] for k in _default_criteria}
elif isinstance(criteria, Criteria):
if isinstance(criteria, Criteria):
criteria_ = {criteria.value: _SUPPORTED_CRITERIA[criteria]}
elif isinstance(criteria, str):
if criteria in _SUPPORTED_CRITERIA:

View File

@ -186,7 +186,6 @@ class _EmbeddingDistanceChainMixin(Chain):
}
if metric in metrics:
return metrics[metric]
else:
msg = f"Invalid metric: {metric}"
raise ValueError(msg)

View File

@ -162,7 +162,6 @@ def load_evaluator(
)
raise ValueError(msg) from e
return evaluator_cls.from_llm(llm=llm, **kwargs)
else:
return evaluator_cls(**kwargs)

View File

@ -70,7 +70,7 @@ class JsonSchemaEvaluator(StringEvaluator):
def _parse_json(self, node: Any) -> Union[dict, list, None, float, bool, int, str]:
if isinstance(node, str):
return parse_json_markdown(node)
elif hasattr(node, "schema") and callable(getattr(node, "schema")):
if hasattr(node, "schema") and callable(getattr(node, "schema")):
# Pydantic model
return getattr(node, "schema")()
return node

View File

@ -24,7 +24,7 @@ def _get_score(text: str) -> Optional[tuple[str, int]]:
if match:
if match.group(1).upper() == "CORRECT":
return "CORRECT", 1
elif match.group(1).upper() == "INCORRECT":
if match.group(1).upper() == "INCORRECT":
return "INCORRECT", 0
try:
first_word = (
@ -32,7 +32,7 @@ def _get_score(text: str) -> Optional[tuple[str, int]]:
)
if first_word.upper() == "CORRECT":
return "CORRECT", 1
elif first_word.upper() == "INCORRECT":
if first_word.upper() == "INCORRECT":
return "INCORRECT", 0
last_word = (
text.strip()
@ -41,7 +41,7 @@ def _get_score(text: str) -> Optional[tuple[str, int]]:
)
if last_word.upper() == "CORRECT":
return "CORRECT", 1
elif last_word.upper() == "INCORRECT":
if last_word.upper() == "INCORRECT":
return "INCORRECT", 0
except IndexError:
pass

View File

@ -123,12 +123,12 @@ class _EvalArgsMixin:
if self.requires_input and input is None:
msg = f"{self.__class__.__name__} requires an input string."
raise ValueError(msg)
elif input is not None and not self.requires_input:
if input is not None and not self.requires_input:
warn(self._skip_input_warning)
if self.requires_reference and reference is None:
msg = f"{self.__class__.__name__} requires a reference string."
raise ValueError(msg)
elif reference is not None and not self.requires_reference:
if reference is not None and not self.requires_reference:
warn(self._skip_reference_warning)

View File

@ -70,7 +70,7 @@ def resolve_criteria(
Criteria.DEPTH,
]
return {k.value: _SUPPORTED_CRITERIA[k] for k in _default_criteria}
elif isinstance(criteria, Criteria):
if isinstance(criteria, Criteria):
criteria_ = {criteria.value: _SUPPORTED_CRITERIA[criteria]}
elif isinstance(criteria, str):
if criteria in _SUPPORTED_CRITERIA:

View File

@ -138,7 +138,6 @@ class _RapidFuzzChainMixin(Chain):
module = module_map[distance]
if normalize_score:
return module.normalized_distance
else:
return module.distance
@property

View File

@ -21,7 +21,6 @@ def _get_client(
ls_client = LangSmithClient(api_url, api_key=api_key)
if hasattr(ls_client, "push_prompt") and hasattr(ls_client, "pull_prompt"):
return ls_client
else:
from langchainhub import Client as LangChainHubClient
return LangChainHubClient(api_url, api_key=api_key)
@ -82,14 +81,13 @@ def push(
# Then it's langchainhub
manifest_json = dumps(object)
message = client.push(
return client.push(
repo_full_name,
manifest_json,
parent_commit_hash=parent_commit_hash,
new_repo_is_public=new_repo_is_public,
new_repo_description=new_repo_description,
)
return message
def pull(
@ -113,8 +111,7 @@ def pull(
# Then it's langsmith
if hasattr(client, "pull_prompt"):
response = client.pull_prompt(owner_repo_commit, include_model=include_model)
return response
return client.pull_prompt(owner_repo_commit, include_model=include_model)
# Then it's langchainhub
if hasattr(client, "pull_repo"):

View File

@ -561,7 +561,6 @@ def __getattr__(name: str) -> Any:
k: v() for k, v in get_type_to_cls_dict().items()
}
return type_to_cls_dict
else:
return getattr(llms, name)

View File

@ -154,6 +154,7 @@ class UpstashRedisEntityStore(BaseEntityStore):
logger.debug(
f"Redis MEM set '{self.full_key_prefix}:{key}': '{value}' EX {self.ttl}"
)
return None
def delete(self, key: str) -> None:
self.redis_client.delete(f"{self.full_key_prefix}:{key}")
@ -255,6 +256,7 @@ class RedisEntityStore(BaseEntityStore):
logger.debug(
f"REDIS MEM set '{self.full_key_prefix}:{key}': '{value}' EX {self.ttl}"
)
return None
def delete(self, key: str) -> None:
self.redis_client.delete(f"{self.full_key_prefix}:{key}")
@ -362,6 +364,7 @@ class SQLiteEntityStore(BaseEntityStore):
f'"{self.full_table_name}" (key, value) VALUES (?, ?)'
)
self._execute_query(query, (key, value))
return None
def delete(self, key: str) -> None:
"""Deletes a key-value pair, safely quoting the table name."""

View File

@ -34,7 +34,7 @@ class BooleanOutputParser(BaseOutputParser[bool]):
)
raise ValueError(msg)
return True
elif self.false_val.upper() in truthy:
if self.false_val.upper() in truthy:
if self.true_val.upper() in truthy:
msg = (
f"Ambiguous response. Both {self.true_val} and {self.false_val} "

View File

@ -71,7 +71,6 @@ class OutputFixingParser(BaseOutputParser[T]):
except OutputParserException as e:
if retries == self.max_retries:
raise e
else:
retries += 1
if self.legacy and hasattr(self.retry_chain, "run"):
completion = self.retry_chain.run(
@ -109,7 +108,6 @@ class OutputFixingParser(BaseOutputParser[T]):
except OutputParserException as e:
if retries == self.max_retries:
raise e
else:
retries += 1
if self.legacy and hasattr(self.retry_chain, "arun"):
completion = await self.retry_chain.arun(

View File

@ -64,7 +64,7 @@ class PandasDataFrameOutputParser(BaseOutputParser[dict[str, Any]]):
msg = f"Invalid array format in '{original_request_params}'. \
Please check the format instructions."
raise OutputParserException(msg)
elif (
if (
isinstance(parsed_array[0], int)
and parsed_array[-1] > self.dataframe.index.max()
):

View File

@ -30,11 +30,9 @@ class RegexParser(BaseOutputParser[dict[str, str]]):
match = re.search(self.regex, text)
if match:
return {key: match.group(i + 1) for i, key in enumerate(self.output_keys)}
else:
if self.default_output_key is None:
msg = f"Could not parse output: {text}"
raise ValueError(msg)
else:
return {
key: text if key == self.default_output_key else ""
for key in self.output_keys

View File

@ -33,14 +33,11 @@ class RegexDictParser(BaseOutputParser[dict[str, str]]):
{expected_format} on text {text}"
)
raise ValueError(msg)
elif len(matches) > 1:
if len(matches) > 1:
msg = f"Multiple matches found for output key: {output_key} with \
expected format {expected_format} on text {text}"
raise ValueError(msg)
elif (
self.no_update_value is not None and matches[0] == self.no_update_value
):
if self.no_update_value is not None and matches[0] == self.no_update_value:
continue
else:
result[output_key] = matches[0]
return result

View File

@ -107,7 +107,6 @@ class RetryOutputParser(BaseOutputParser[T]):
except OutputParserException as e:
if retries == self.max_retries:
raise e
else:
retries += 1
if self.legacy and hasattr(self.retry_chain, "run"):
completion = self.retry_chain.run(
@ -143,7 +142,6 @@ class RetryOutputParser(BaseOutputParser[T]):
except OutputParserException as e:
if retries == self.max_retries:
raise e
else:
retries += 1
if self.legacy and hasattr(self.retry_chain, "arun"):
completion = await self.retry_chain.arun(
@ -234,7 +232,6 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
except OutputParserException as e:
if retries == self.max_retries:
raise e
else:
retries += 1
if self.legacy and hasattr(self.retry_chain, "run"):
completion = self.retry_chain.run(
@ -263,7 +260,6 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
except OutputParserException as e:
if retries == self.max_retries:
raise e
else:
retries += 1
if self.legacy and hasattr(self.retry_chain, "arun"):
completion = await self.retry_chain.arun(

View File

@ -89,7 +89,6 @@ class StructuredOutputParser(BaseOutputParser[dict[str, Any]]):
)
if only_json:
return STRUCTURED_FORMAT_SIMPLE_INSTRUCTIONS.format(format=schema_str)
else:
return STRUCTURED_FORMAT_INSTRUCTIONS.format(format=schema_str)
def parse(self, text: str) -> dict[str, Any]:

View File

@ -33,7 +33,6 @@ class YamlOutputParser(BaseOutputParser[T]):
json_object = yaml.safe_load(yaml_str)
if hasattr(self.pydantic_object, "model_validate"):
return self.pydantic_object.model_validate(json_object)
else:
return self.pydantic_object.parse_obj(json_object)
except (yaml.YAMLError, ValidationError) as e:

View File

@ -45,7 +45,6 @@ class ContextualCompressionRetriever(BaseRetriever):
docs, query, callbacks=run_manager.get_child()
)
return list(compressed_docs)
else:
return []
async def _aget_relevant_documents(
@ -71,5 +70,4 @@ class ContextualCompressionRetriever(BaseRetriever):
docs, query, callbacks=run_manager.get_child()
)
return list(compressed_docs)
else:
return []

View File

@ -174,9 +174,7 @@ class EnsembleRetriever(BaseRetriever):
"""
# Get fused result of the retrievers.
fused_documents = self.rank_fusion(query, run_manager)
return fused_documents
return self.rank_fusion(query, run_manager)
async def _aget_relevant_documents(
self,
@ -195,9 +193,7 @@ class EnsembleRetriever(BaseRetriever):
"""
# Get fused result of the retrievers.
fused_documents = await self.arank_fusion(query, run_manager)
return fused_documents
return await self.arank_fusion(query, run_manager)
def rank_fusion(
self,
@ -236,9 +232,7 @@ class EnsembleRetriever(BaseRetriever):
]
# apply rank fusion
fused_documents = self.weighted_reciprocal_rank(retriever_docs)
return fused_documents
return self.weighted_reciprocal_rank(retriever_docs)
async def arank_fusion(
self,
@ -280,9 +274,7 @@ class EnsembleRetriever(BaseRetriever):
]
# apply rank fusion
fused_documents = self.weighted_reciprocal_rank(retriever_docs)
return fused_documents
return self.weighted_reciprocal_rank(retriever_docs)
def weighted_reciprocal_rank(
self, doc_lists: list[list[Document]]
@ -318,7 +310,7 @@ class EnsembleRetriever(BaseRetriever):
# Docs are deduplicated by their contents then sorted by their scores
all_docs = chain.from_iterable(doc_lists)
sorted_docs = sorted(
return sorted(
unique_by_key(
all_docs,
lambda doc: (
@ -332,4 +324,3 @@ class EnsembleRetriever(BaseRetriever):
doc.page_content if self.id_key is None else doc.metadata[self.id_key]
],
)
return sorted_docs

View File

@ -31,9 +31,7 @@ class MergerRetriever(BaseRetriever):
"""
# Merge the results of the retrievers.
merged_documents = self.merge_documents(query, run_manager)
return merged_documents
return self.merge_documents(query, run_manager)
async def _aget_relevant_documents(
self,
@ -52,9 +50,7 @@ class MergerRetriever(BaseRetriever):
"""
# Merge the results of the retrievers.
merged_documents = await self.amerge_documents(query, run_manager)
return merged_documents
return await self.amerge_documents(query, run_manager)
def merge_documents(
self, query: str, run_manager: CallbackManagerForRetrieverRun

View File

@ -74,10 +74,9 @@ class RePhraseQueryRetriever(BaseRetriever):
query, {"callbacks": run_manager.get_child()}
)
logger.info(f"Re-phrased question: {re_phrased_question}")
docs = self.retriever.invoke(
return self.retriever.invoke(
re_phrased_question, config={"callbacks": run_manager.get_child()}
)
return docs
async def _aget_relevant_documents(
self,

View File

@ -119,18 +119,17 @@ def _get_builtin_translator(vectorstore: VectorStore) -> Visitor:
}
if isinstance(vectorstore, DatabricksVectorSearch):
return DatabricksVectorSearchTranslator()
elif isinstance(vectorstore, MyScale):
if isinstance(vectorstore, MyScale):
return MyScaleTranslator(metadata_key=vectorstore.metadata_column)
elif isinstance(vectorstore, Redis):
if isinstance(vectorstore, Redis):
return RedisTranslator.from_vectorstore(vectorstore)
elif isinstance(vectorstore, TencentVectorDB):
if isinstance(vectorstore, TencentVectorDB):
fields = [
field.name for field in (vectorstore.meta_fields or []) if field.index
]
return TencentVectorDBTranslator(fields)
elif vectorstore.__class__ in BUILTIN_TRANSLATORS:
if vectorstore.__class__ in BUILTIN_TRANSLATORS:
return BUILTIN_TRANSLATORS[vectorstore.__class__]()
else:
try:
from langchain_astradb.vectorstores import AstraDBVectorStore
except ImportError:
@ -289,14 +288,12 @@ class SelfQueryRetriever(BaseRetriever):
def _get_docs_with_query(
self, query: str, search_kwargs: dict[str, Any]
) -> list[Document]:
docs = self.vectorstore.search(query, self.search_type, **search_kwargs)
return docs
return self.vectorstore.search(query, self.search_type, **search_kwargs)
async def _aget_docs_with_query(
self, query: str, search_kwargs: dict[str, Any]
) -> list[Document]:
docs = await self.vectorstore.asearch(query, self.search_type, **search_kwargs)
return docs
return await self.vectorstore.asearch(query, self.search_type, **search_kwargs)
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
@ -315,8 +312,7 @@ class SelfQueryRetriever(BaseRetriever):
if self.verbose:
logger.info(f"Generated Query: {structured_query}")
new_query, search_kwargs = self._prepare_query(query, structured_query)
docs = self._get_docs_with_query(new_query, search_kwargs)
return docs
return self._get_docs_with_query(new_query, search_kwargs)
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
@ -335,8 +331,7 @@ class SelfQueryRetriever(BaseRetriever):
if self.verbose:
logger.info(f"Generated Query: {structured_query}")
new_query, search_kwargs = self._prepare_query(query, structured_query)
docs = await self._aget_docs_with_query(new_query, search_kwargs)
return docs
return await self._aget_docs_with_query(new_query, search_kwargs)
@classmethod
def from_llm(

View File

@ -191,13 +191,13 @@ def _wrap_in_chain_factory(
)
raise ValueError(msg)
return lambda: chain
elif isinstance(llm_or_chain_factory, BaseLanguageModel):
if isinstance(llm_or_chain_factory, BaseLanguageModel):
return llm_or_chain_factory
elif isinstance(llm_or_chain_factory, Runnable):
if isinstance(llm_or_chain_factory, Runnable):
# Memory may exist here, but it's not elegant to check all those cases.
lcf = llm_or_chain_factory
return lambda: lcf
elif callable(llm_or_chain_factory):
if callable(llm_or_chain_factory):
if is_traceable_function(llm_or_chain_factory):
runnable_ = as_runnable(cast(Callable, llm_or_chain_factory))
return lambda: runnable_
@ -215,13 +215,12 @@ def _wrap_in_chain_factory(
# It's not uncommon to do an LLM constructor instead of raw LLM,
# so we'll unpack it for the user.
return _model
elif is_traceable_function(cast(Callable, _model)):
if is_traceable_function(cast(Callable, _model)):
runnable_ = as_runnable(cast(Callable, _model))
return lambda: runnable_
elif not isinstance(_model, Runnable):
if not isinstance(_model, Runnable):
# This is unlikely to happen - a constructor for a model function
return lambda: RunnableLambda(constructor)
else:
# Typical correct case
return constructor
return llm_or_chain_factory
@ -272,7 +271,6 @@ def _get_prompt(inputs: dict[str, Any]) -> str:
raise InputFormatError(msg)
if len(prompts) == 1:
return prompts[0]
else:
msg = f"LLM Run expects single prompt input. Got {len(prompts)} prompts."
raise InputFormatError(msg)
@ -321,7 +319,6 @@ def _get_messages(inputs: dict[str, Any]) -> dict:
)
raise InputFormatError(msg)
return input_copy
else:
msg = (
f"Chat Run expects single List[dict] or List[List[dict]] 'messages'"
f" input. Got {inputs}"
@ -707,7 +704,6 @@ async def _arun_llm(
callbacks=callbacks, tags=tags or [], metadata=metadata or {}
),
)
else:
msg = (
"Input mapper returned invalid format"
f" {prompt_or_messages}"
@ -715,7 +711,6 @@ async def _arun_llm(
)
raise InputFormatError(msg)
else:
try:
prompt = _get_prompt(inputs)
llm_output: Union[str, BaseMessage] = await llm.ainvoke(

View File

@ -28,7 +28,6 @@ def _get_messages_from_run_dict(messages: list[dict]) -> list[BaseMessage]:
first_message = messages[0]
if "lc" in first_message:
return [load(dumpd(message)) for message in messages]
else:
return messages_from_dict(messages)
@ -106,14 +105,12 @@ class LLMStringRunMapper(StringRunMapper):
if run.run_type != "llm":
msg = "LLM RunMapper only supports LLM runs."
raise ValueError(msg)
elif not run.outputs:
if not run.outputs:
if run.error:
msg = f"Cannot evaluate errored LLM run {run.id}: {run.error}"
raise ValueError(msg)
else:
msg = f"Run {run.id} has no outputs. Cannot evaluate this run."
raise ValueError(msg)
else:
try:
inputs = self.serialize_inputs(run.inputs)
except Exception as e:
@ -142,9 +139,8 @@ class ChainStringRunMapper(StringRunMapper):
def _get_key(self, source: dict, key: Optional[str], which: str) -> str:
if key is not None:
return source[key]
elif len(source) == 1:
if len(source) == 1:
return next(iter(source.values()))
else:
msg = (
f"Could not map run {which} with multiple keys: "
f"{source}\nPlease manually specify a {which}_key"
@ -168,7 +164,7 @@ class ChainStringRunMapper(StringRunMapper):
f" '{self.input_key}'."
)
raise ValueError(msg)
elif self.prediction_key is not None and self.prediction_key not in run.outputs:
if self.prediction_key is not None and self.prediction_key not in run.outputs:
available_keys = ", ".join(run.outputs.keys())
msg = (
f"Run with ID {run.id} doesn't have the expected prediction key"
@ -178,7 +174,6 @@ class ChainStringRunMapper(StringRunMapper):
)
raise ValueError(msg)
else:
input_ = self._get_key(run.inputs, self.input_key, "input")
prediction = self._get_key(run.outputs, self.prediction_key, "prediction")
return {
@ -224,7 +219,6 @@ class StringExampleMapper(Serializable):
" specify a reference_key."
)
raise ValueError(msg)
else:
output = list(example.outputs.values())[0]
elif self.reference_key not in example.outputs:
msg = (

View File

@ -65,9 +65,8 @@ def _import_python_tool_PythonREPLTool() -> Any:
def __getattr__(name: str) -> Any:
if name == "PythonAstREPLTool":
return _import_python_tool_PythonAstREPLTool()
elif name == "PythonREPLTool":
if name == "PythonREPLTool":
return _import_python_tool_PythonREPLTool()
else:
from langchain_community import tools
# If not in interactive env, raise warning.

View File

@ -143,7 +143,7 @@ ignore-regex = ".*(Stati Uniti|Tense=Pres).*"
ignore-words-list = "momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogyny,unsecure,damon,crate,aadd,symbl,precesses,accademia,nin"
[tool.ruff.lint]
select = ["A", "C4", "E", "F", "I", "EM", "PGH003", "PIE", "SIM", "T201", "D", "UP", "S", "W"]
select = ["A", "C4", "D", "E", "EM", "F", "I", "PGH003", "PIE", "RET", "S", "SIM", "T201", "UP", "W"]
pydocstyle.convention = "google"
pyupgrade.keep-runtime-typing = true

View File

@ -139,7 +139,6 @@ async def get_state(
async def ask_for_passphrase(said_please: bool) -> dict[str, Any]:
if said_please:
return {"passphrase": f"The passphrase is {PASS_PHRASE}"}
else:
return {"passphrase": "I won't share the passphrase without saying 'please'."}
@ -153,7 +152,6 @@ async def recycle(password: SecretPassPhrase) -> dict[str, Any]:
if password.pw == PASS_PHRASE:
_ROBOT_STATE["destruct"] = True
return {"status": "Self-destruct initiated", "state": _ROBOT_STATE}
else:
_ROBOT_STATE["destruct"] = False
raise HTTPException(
status_code=400,

View File

@ -100,14 +100,13 @@ def _get_agent(**kwargs: Any) -> AgentExecutor:
),
]
agent = initialize_agent(
return initialize_agent(
tools,
fake_llm,
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
verbose=True,
**kwargs,
)
return agent
def test_agent_bad_action() -> None:

View File

@ -71,14 +71,13 @@ def _get_agent(**kwargs: Any) -> AgentExecutor:
),
]
agent = initialize_agent(
return initialize_agent(
tools,
fake_llm,
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
verbose=True,
**kwargs,
)
return agent
async def test_agent_bad_action() -> None:

View File

@ -11,7 +11,6 @@ def get_action_and_input(text: str) -> tuple[str, str]:
output = output_parser.parse(text)
if isinstance(output, AgentAction):
return output.tool, str(output.tool_input)
else:
return "Final Answer", output.return_values["output"]

View File

@ -16,7 +16,6 @@ def get_action_and_input(text: str) -> tuple[str, str]:
output = MRKLOutputParser().parse(text)
if isinstance(output, AgentAction):
return output.tool, str(output.tool_input)
else:
return "Final Answer", output.return_values["output"]

View File

@ -21,9 +21,8 @@ def get_action_and_input(text: str) -> tuple[str, str]:
output = output_parser.parse(text)
if isinstance(output, AgentAction):
return output.tool, str(output.tool_input)
elif isinstance(output, AgentFinish):
if isinstance(output, AgentFinish):
return output.return_values["output"], output.log
else:
msg = "Unexpected output type"
raise ValueError(msg)

View File

@ -58,7 +58,6 @@ class FakeChain(Chain):
) -> dict[str, str]:
if self.be_correct:
return {"bar": "baz"}
else:
return {"baz": "bar"}

View File

@ -54,7 +54,6 @@ class _FakeTrajectoryChatModel(FakeChatModel):
response = self.queries[list(self.queries.keys())[self.response_index]]
self.response_index = self.response_index + 1
return response
else:
prompt = messages[0].content
return self.queries[prompt]

View File

@ -45,7 +45,6 @@ class FakeLLM(LLM):
return self.queries[prompt]
if stop is None:
return "foo"
else:
return "bar"
@property

View File

@ -14,7 +14,6 @@ class SequentialRetriever(BaseRetriever):
) -> list[Document]:
if self.response_index >= len(self.sequential_responses):
return []
else:
self.response_index += 1
return self.sequential_responses[self.response_index - 1]