langchain: Add ruff rule RET (#31875)

All auto-fixes
See https://docs.astral.sh/ruff/rules/#flake8-return-ret

---------

Co-authored-by: Mason Daugherty <github@mdrxy.com>
This commit is contained in:
Christophe Bornet 2025-07-07 17:33:18 +02:00 committed by GitHub
parent fceebbb387
commit 56bbfd9723
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
91 changed files with 663 additions and 835 deletions

View File

@ -48,25 +48,25 @@ def __getattr__(name: str) -> Any:
_warn_on_import(name, replacement="langchain.agents.MRKLChain") _warn_on_import(name, replacement="langchain.agents.MRKLChain")
return MRKLChain return MRKLChain
elif name == "ReActChain": if name == "ReActChain":
from langchain.agents import ReActChain from langchain.agents import ReActChain
_warn_on_import(name, replacement="langchain.agents.ReActChain") _warn_on_import(name, replacement="langchain.agents.ReActChain")
return ReActChain return ReActChain
elif name == "SelfAskWithSearchChain": if name == "SelfAskWithSearchChain":
from langchain.agents import SelfAskWithSearchChain from langchain.agents import SelfAskWithSearchChain
_warn_on_import(name, replacement="langchain.agents.SelfAskWithSearchChain") _warn_on_import(name, replacement="langchain.agents.SelfAskWithSearchChain")
return SelfAskWithSearchChain return SelfAskWithSearchChain
elif name == "ConversationChain": if name == "ConversationChain":
from langchain.chains import ConversationChain from langchain.chains import ConversationChain
_warn_on_import(name, replacement="langchain.chains.ConversationChain") _warn_on_import(name, replacement="langchain.chains.ConversationChain")
return ConversationChain return ConversationChain
elif name == "LLMBashChain": if name == "LLMBashChain":
msg = ( msg = (
"This module has been moved to langchain-experimental. " "This module has been moved to langchain-experimental. "
"For more details: " "For more details: "
@ -77,97 +77,97 @@ def __getattr__(name: str) -> Any:
) )
raise ImportError(msg) raise ImportError(msg)
elif name == "LLMChain": if name == "LLMChain":
from langchain.chains import LLMChain from langchain.chains import LLMChain
_warn_on_import(name, replacement="langchain.chains.LLMChain") _warn_on_import(name, replacement="langchain.chains.LLMChain")
return LLMChain return LLMChain
elif name == "LLMCheckerChain": if name == "LLMCheckerChain":
from langchain.chains import LLMCheckerChain from langchain.chains import LLMCheckerChain
_warn_on_import(name, replacement="langchain.chains.LLMCheckerChain") _warn_on_import(name, replacement="langchain.chains.LLMCheckerChain")
return LLMCheckerChain return LLMCheckerChain
elif name == "LLMMathChain": if name == "LLMMathChain":
from langchain.chains import LLMMathChain from langchain.chains import LLMMathChain
_warn_on_import(name, replacement="langchain.chains.LLMMathChain") _warn_on_import(name, replacement="langchain.chains.LLMMathChain")
return LLMMathChain return LLMMathChain
elif name == "QAWithSourcesChain": if name == "QAWithSourcesChain":
from langchain.chains import QAWithSourcesChain from langchain.chains import QAWithSourcesChain
_warn_on_import(name, replacement="langchain.chains.QAWithSourcesChain") _warn_on_import(name, replacement="langchain.chains.QAWithSourcesChain")
return QAWithSourcesChain return QAWithSourcesChain
elif name == "VectorDBQA": if name == "VectorDBQA":
from langchain.chains import VectorDBQA from langchain.chains import VectorDBQA
_warn_on_import(name, replacement="langchain.chains.VectorDBQA") _warn_on_import(name, replacement="langchain.chains.VectorDBQA")
return VectorDBQA return VectorDBQA
elif name == "VectorDBQAWithSourcesChain": if name == "VectorDBQAWithSourcesChain":
from langchain.chains import VectorDBQAWithSourcesChain from langchain.chains import VectorDBQAWithSourcesChain
_warn_on_import(name, replacement="langchain.chains.VectorDBQAWithSourcesChain") _warn_on_import(name, replacement="langchain.chains.VectorDBQAWithSourcesChain")
return VectorDBQAWithSourcesChain return VectorDBQAWithSourcesChain
elif name == "InMemoryDocstore": if name == "InMemoryDocstore":
from langchain_community.docstore import InMemoryDocstore from langchain_community.docstore import InMemoryDocstore
_warn_on_import(name, replacement="langchain.docstore.InMemoryDocstore") _warn_on_import(name, replacement="langchain.docstore.InMemoryDocstore")
return InMemoryDocstore return InMemoryDocstore
elif name == "Wikipedia": if name == "Wikipedia":
from langchain_community.docstore import Wikipedia from langchain_community.docstore import Wikipedia
_warn_on_import(name, replacement="langchain.docstore.Wikipedia") _warn_on_import(name, replacement="langchain.docstore.Wikipedia")
return Wikipedia return Wikipedia
elif name == "Anthropic": if name == "Anthropic":
from langchain_community.llms import Anthropic from langchain_community.llms import Anthropic
_warn_on_import(name, replacement="langchain_community.llms.Anthropic") _warn_on_import(name, replacement="langchain_community.llms.Anthropic")
return Anthropic return Anthropic
elif name == "Banana": if name == "Banana":
from langchain_community.llms import Banana from langchain_community.llms import Banana
_warn_on_import(name, replacement="langchain_community.llms.Banana") _warn_on_import(name, replacement="langchain_community.llms.Banana")
return Banana return Banana
elif name == "CerebriumAI": if name == "CerebriumAI":
from langchain_community.llms import CerebriumAI from langchain_community.llms import CerebriumAI
_warn_on_import(name, replacement="langchain_community.llms.CerebriumAI") _warn_on_import(name, replacement="langchain_community.llms.CerebriumAI")
return CerebriumAI return CerebriumAI
elif name == "Cohere": if name == "Cohere":
from langchain_community.llms import Cohere from langchain_community.llms import Cohere
_warn_on_import(name, replacement="langchain_community.llms.Cohere") _warn_on_import(name, replacement="langchain_community.llms.Cohere")
return Cohere return Cohere
elif name == "ForefrontAI": if name == "ForefrontAI":
from langchain_community.llms import ForefrontAI from langchain_community.llms import ForefrontAI
_warn_on_import(name, replacement="langchain_community.llms.ForefrontAI") _warn_on_import(name, replacement="langchain_community.llms.ForefrontAI")
return ForefrontAI return ForefrontAI
elif name == "GooseAI": if name == "GooseAI":
from langchain_community.llms import GooseAI from langchain_community.llms import GooseAI
_warn_on_import(name, replacement="langchain_community.llms.GooseAI") _warn_on_import(name, replacement="langchain_community.llms.GooseAI")
return GooseAI return GooseAI
elif name == "HuggingFaceHub": if name == "HuggingFaceHub":
from langchain_community.llms import HuggingFaceHub from langchain_community.llms import HuggingFaceHub
_warn_on_import(name, replacement="langchain_community.llms.HuggingFaceHub") _warn_on_import(name, replacement="langchain_community.llms.HuggingFaceHub")
return HuggingFaceHub return HuggingFaceHub
elif name == "HuggingFaceTextGenInference": if name == "HuggingFaceTextGenInference":
from langchain_community.llms import HuggingFaceTextGenInference from langchain_community.llms import HuggingFaceTextGenInference
_warn_on_import( _warn_on_import(
@ -175,55 +175,55 @@ def __getattr__(name: str) -> Any:
) )
return HuggingFaceTextGenInference return HuggingFaceTextGenInference
elif name == "LlamaCpp": if name == "LlamaCpp":
from langchain_community.llms import LlamaCpp from langchain_community.llms import LlamaCpp
_warn_on_import(name, replacement="langchain_community.llms.LlamaCpp") _warn_on_import(name, replacement="langchain_community.llms.LlamaCpp")
return LlamaCpp return LlamaCpp
elif name == "Modal": if name == "Modal":
from langchain_community.llms import Modal from langchain_community.llms import Modal
_warn_on_import(name, replacement="langchain_community.llms.Modal") _warn_on_import(name, replacement="langchain_community.llms.Modal")
return Modal return Modal
elif name == "OpenAI": if name == "OpenAI":
from langchain_community.llms import OpenAI from langchain_community.llms import OpenAI
_warn_on_import(name, replacement="langchain_community.llms.OpenAI") _warn_on_import(name, replacement="langchain_community.llms.OpenAI")
return OpenAI return OpenAI
elif name == "Petals": if name == "Petals":
from langchain_community.llms import Petals from langchain_community.llms import Petals
_warn_on_import(name, replacement="langchain_community.llms.Petals") _warn_on_import(name, replacement="langchain_community.llms.Petals")
return Petals return Petals
elif name == "PipelineAI": if name == "PipelineAI":
from langchain_community.llms import PipelineAI from langchain_community.llms import PipelineAI
_warn_on_import(name, replacement="langchain_community.llms.PipelineAI") _warn_on_import(name, replacement="langchain_community.llms.PipelineAI")
return PipelineAI return PipelineAI
elif name == "SagemakerEndpoint": if name == "SagemakerEndpoint":
from langchain_community.llms import SagemakerEndpoint from langchain_community.llms import SagemakerEndpoint
_warn_on_import(name, replacement="langchain_community.llms.SagemakerEndpoint") _warn_on_import(name, replacement="langchain_community.llms.SagemakerEndpoint")
return SagemakerEndpoint return SagemakerEndpoint
elif name == "StochasticAI": if name == "StochasticAI":
from langchain_community.llms import StochasticAI from langchain_community.llms import StochasticAI
_warn_on_import(name, replacement="langchain_community.llms.StochasticAI") _warn_on_import(name, replacement="langchain_community.llms.StochasticAI")
return StochasticAI return StochasticAI
elif name == "Writer": if name == "Writer":
from langchain_community.llms import Writer from langchain_community.llms import Writer
_warn_on_import(name, replacement="langchain_community.llms.Writer") _warn_on_import(name, replacement="langchain_community.llms.Writer")
return Writer return Writer
elif name == "HuggingFacePipeline": if name == "HuggingFacePipeline":
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
_warn_on_import( _warn_on_import(
@ -232,7 +232,7 @@ def __getattr__(name: str) -> Any:
) )
return HuggingFacePipeline return HuggingFacePipeline
elif name == "FewShotPromptTemplate": if name == "FewShotPromptTemplate":
from langchain_core.prompts import FewShotPromptTemplate from langchain_core.prompts import FewShotPromptTemplate
_warn_on_import( _warn_on_import(
@ -240,7 +240,7 @@ def __getattr__(name: str) -> Any:
) )
return FewShotPromptTemplate return FewShotPromptTemplate
elif name == "Prompt": if name == "Prompt":
from langchain_core.prompts import PromptTemplate from langchain_core.prompts import PromptTemplate
_warn_on_import(name, replacement="langchain_core.prompts.PromptTemplate") _warn_on_import(name, replacement="langchain_core.prompts.PromptTemplate")
@ -248,19 +248,19 @@ def __getattr__(name: str) -> Any:
# it's renamed as prompt template anyways # it's renamed as prompt template anyways
# this is just for backwards compat # this is just for backwards compat
return PromptTemplate return PromptTemplate
elif name == "PromptTemplate": if name == "PromptTemplate":
from langchain_core.prompts import PromptTemplate from langchain_core.prompts import PromptTemplate
_warn_on_import(name, replacement="langchain_core.prompts.PromptTemplate") _warn_on_import(name, replacement="langchain_core.prompts.PromptTemplate")
return PromptTemplate return PromptTemplate
elif name == "BasePromptTemplate": if name == "BasePromptTemplate":
from langchain_core.prompts import BasePromptTemplate from langchain_core.prompts import BasePromptTemplate
_warn_on_import(name, replacement="langchain_core.prompts.BasePromptTemplate") _warn_on_import(name, replacement="langchain_core.prompts.BasePromptTemplate")
return BasePromptTemplate return BasePromptTemplate
elif name == "ArxivAPIWrapper": if name == "ArxivAPIWrapper":
from langchain_community.utilities import ArxivAPIWrapper from langchain_community.utilities import ArxivAPIWrapper
_warn_on_import( _warn_on_import(
@ -268,7 +268,7 @@ def __getattr__(name: str) -> Any:
) )
return ArxivAPIWrapper return ArxivAPIWrapper
elif name == "GoldenQueryAPIWrapper": if name == "GoldenQueryAPIWrapper":
from langchain_community.utilities import GoldenQueryAPIWrapper from langchain_community.utilities import GoldenQueryAPIWrapper
_warn_on_import( _warn_on_import(
@ -276,7 +276,7 @@ def __getattr__(name: str) -> Any:
) )
return GoldenQueryAPIWrapper return GoldenQueryAPIWrapper
elif name == "GoogleSearchAPIWrapper": if name == "GoogleSearchAPIWrapper":
from langchain_community.utilities import GoogleSearchAPIWrapper from langchain_community.utilities import GoogleSearchAPIWrapper
_warn_on_import( _warn_on_import(
@ -284,7 +284,7 @@ def __getattr__(name: str) -> Any:
) )
return GoogleSearchAPIWrapper return GoogleSearchAPIWrapper
elif name == "GoogleSerperAPIWrapper": if name == "GoogleSerperAPIWrapper":
from langchain_community.utilities import GoogleSerperAPIWrapper from langchain_community.utilities import GoogleSerperAPIWrapper
_warn_on_import( _warn_on_import(
@ -292,7 +292,7 @@ def __getattr__(name: str) -> Any:
) )
return GoogleSerperAPIWrapper return GoogleSerperAPIWrapper
elif name == "PowerBIDataset": if name == "PowerBIDataset":
from langchain_community.utilities import PowerBIDataset from langchain_community.utilities import PowerBIDataset
_warn_on_import( _warn_on_import(
@ -300,7 +300,7 @@ def __getattr__(name: str) -> Any:
) )
return PowerBIDataset return PowerBIDataset
elif name == "SearxSearchWrapper": if name == "SearxSearchWrapper":
from langchain_community.utilities import SearxSearchWrapper from langchain_community.utilities import SearxSearchWrapper
_warn_on_import( _warn_on_import(
@ -308,7 +308,7 @@ def __getattr__(name: str) -> Any:
) )
return SearxSearchWrapper return SearxSearchWrapper
elif name == "WikipediaAPIWrapper": if name == "WikipediaAPIWrapper":
from langchain_community.utilities import WikipediaAPIWrapper from langchain_community.utilities import WikipediaAPIWrapper
_warn_on_import( _warn_on_import(
@ -316,7 +316,7 @@ def __getattr__(name: str) -> Any:
) )
return WikipediaAPIWrapper return WikipediaAPIWrapper
elif name == "WolframAlphaAPIWrapper": if name == "WolframAlphaAPIWrapper":
from langchain_community.utilities import WolframAlphaAPIWrapper from langchain_community.utilities import WolframAlphaAPIWrapper
_warn_on_import( _warn_on_import(
@ -324,19 +324,19 @@ def __getattr__(name: str) -> Any:
) )
return WolframAlphaAPIWrapper return WolframAlphaAPIWrapper
elif name == "SQLDatabase": if name == "SQLDatabase":
from langchain_community.utilities import SQLDatabase from langchain_community.utilities import SQLDatabase
_warn_on_import(name, replacement="langchain_community.utilities.SQLDatabase") _warn_on_import(name, replacement="langchain_community.utilities.SQLDatabase")
return SQLDatabase return SQLDatabase
elif name == "FAISS": if name == "FAISS":
from langchain_community.vectorstores import FAISS from langchain_community.vectorstores import FAISS
_warn_on_import(name, replacement="langchain_community.vectorstores.FAISS") _warn_on_import(name, replacement="langchain_community.vectorstores.FAISS")
return FAISS return FAISS
elif name == "ElasticVectorSearch": if name == "ElasticVectorSearch":
from langchain_community.vectorstores import ElasticVectorSearch from langchain_community.vectorstores import ElasticVectorSearch
_warn_on_import( _warn_on_import(
@ -345,7 +345,7 @@ def __getattr__(name: str) -> Any:
return ElasticVectorSearch return ElasticVectorSearch
# For backwards compatibility # For backwards compatibility
elif name == "SerpAPIChain" or name == "SerpAPIWrapper": if name == "SerpAPIChain" or name == "SerpAPIWrapper":
from langchain_community.utilities import SerpAPIWrapper from langchain_community.utilities import SerpAPIWrapper
_warn_on_import( _warn_on_import(
@ -353,7 +353,7 @@ def __getattr__(name: str) -> Any:
) )
return SerpAPIWrapper return SerpAPIWrapper
elif name == "verbose": if name == "verbose":
from langchain.globals import _verbose from langchain.globals import _verbose
_warn_on_import( _warn_on_import(
@ -364,7 +364,7 @@ def __getattr__(name: str) -> Any:
) )
return _verbose return _verbose
elif name == "debug": if name == "debug":
from langchain.globals import _debug from langchain.globals import _debug
_warn_on_import( _warn_on_import(
@ -375,7 +375,7 @@ def __getattr__(name: str) -> Any:
) )
return _debug return _debug
elif name == "llm_cache": if name == "llm_cache":
from langchain.globals import _llm_cache from langchain.globals import _llm_cache
_warn_on_import( _warn_on_import(
@ -386,9 +386,8 @@ def __getattr__(name: str) -> Any:
) )
return _llm_cache return _llm_cache
else: msg = f"Could not find: {name}"
msg = f"Could not find: {name}" raise AttributeError(msg)
raise AttributeError(msg)
__all__ = [ __all__ = [

View File

@ -137,9 +137,8 @@ class BaseSingleActionAgent(BaseModel):
return AgentFinish( return AgentFinish(
{"output": "Agent stopped due to iteration limit or time limit."}, "" {"output": "Agent stopped due to iteration limit or time limit."}, ""
) )
else: msg = f"Got unsupported early_stopping_method `{early_stopping_method}`"
msg = f"Got unsupported early_stopping_method `{early_stopping_method}`" raise ValueError(msg)
raise ValueError(msg)
@classmethod @classmethod
def from_llm_and_tools( def from_llm_and_tools(
@ -308,9 +307,8 @@ class BaseMultiActionAgent(BaseModel):
if early_stopping_method == "force": if early_stopping_method == "force":
# `force` just returns a constant string # `force` just returns a constant string
return AgentFinish({"output": "Agent stopped due to max iterations."}, "") return AgentFinish({"output": "Agent stopped due to max iterations."}, "")
else: msg = f"Got unsupported early_stopping_method `{early_stopping_method}`"
msg = f"Got unsupported early_stopping_method `{early_stopping_method}`" raise ValueError(msg)
raise ValueError(msg)
@property @property
def _agent_type(self) -> str: def _agent_type(self) -> str:
@ -815,8 +813,7 @@ class Agent(BaseSingleActionAgent):
""" """
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs) full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
full_output = await self.llm_chain.apredict(callbacks=callbacks, **full_inputs) full_output = await self.llm_chain.apredict(callbacks=callbacks, **full_inputs)
agent_output = await self.output_parser.aparse(full_output) return await self.output_parser.aparse(full_output)
return agent_output
def get_full_inputs( def get_full_inputs(
self, intermediate_steps: list[tuple[AgentAction, str]], **kwargs: Any self, intermediate_steps: list[tuple[AgentAction, str]], **kwargs: Any
@ -833,8 +830,7 @@ class Agent(BaseSingleActionAgent):
""" """
thoughts = self._construct_scratchpad(intermediate_steps) thoughts = self._construct_scratchpad(intermediate_steps)
new_inputs = {"agent_scratchpad": thoughts, "stop": self._stop} new_inputs = {"agent_scratchpad": thoughts, "stop": self._stop}
full_inputs = {**kwargs, **new_inputs} return {**kwargs, **new_inputs}
return full_inputs
@property @property
def input_keys(self) -> list[str]: def input_keys(self) -> list[str]:
@ -970,7 +966,7 @@ class Agent(BaseSingleActionAgent):
return AgentFinish( return AgentFinish(
{"output": "Agent stopped due to iteration limit or time limit."}, "" {"output": "Agent stopped due to iteration limit or time limit."}, ""
) )
elif early_stopping_method == "generate": if early_stopping_method == "generate":
# Generate does one final forward pass # Generate does one final forward pass
thoughts = "" thoughts = ""
for action, observation in intermediate_steps: for action, observation in intermediate_steps:
@ -990,16 +986,14 @@ class Agent(BaseSingleActionAgent):
if isinstance(parsed_output, AgentFinish): if isinstance(parsed_output, AgentFinish):
# If we can extract, we send the correct stuff # If we can extract, we send the correct stuff
return parsed_output return parsed_output
else: # If we can extract, but the tool is not the final tool,
# If we can extract, but the tool is not the final tool, # we just return the full output
# we just return the full output return AgentFinish({"output": full_output}, full_output)
return AgentFinish({"output": full_output}, full_output) msg = (
else: "early_stopping_method should be one of `force` or `generate`, "
msg = ( f"got {early_stopping_method}"
"early_stopping_method should be one of `force` or `generate`, " )
f"got {early_stopping_method}" raise ValueError(msg)
)
raise ValueError(msg)
def tool_run_logging_kwargs(self) -> builtins.dict: def tool_run_logging_kwargs(self) -> builtins.dict:
"""Return logging kwargs for tool run.""" """Return logging kwargs for tool run."""
@ -1179,8 +1173,7 @@ class AgentExecutor(Chain):
""" """
if isinstance(self.agent, Runnable): if isinstance(self.agent, Runnable):
return cast(RunnableAgentType, self.agent) return cast(RunnableAgentType, self.agent)
else: return self.agent
return self.agent
def save(self, file_path: Union[Path, str]) -> None: def save(self, file_path: Union[Path, str]) -> None:
"""Raise error - saving not supported for Agent Executors. """Raise error - saving not supported for Agent Executors.
@ -1249,8 +1242,7 @@ class AgentExecutor(Chain):
""" """
if self.return_intermediate_steps: if self.return_intermediate_steps:
return self._action_agent.return_values + ["intermediate_steps"] return self._action_agent.return_values + ["intermediate_steps"]
else: return self._action_agent.return_values
return self._action_agent.return_values
def lookup_tool(self, name: str) -> BaseTool: def lookup_tool(self, name: str) -> BaseTool:
"""Lookup tool by name. """Lookup tool by name.
@ -1304,10 +1296,7 @@ class AgentExecutor(Chain):
msg = "Expected a single AgentFinish output, but got multiple values." msg = "Expected a single AgentFinish output, but got multiple values."
raise ValueError(msg) raise ValueError(msg)
return values[-1] return values[-1]
else: return [(a.action, a.observation) for a in values if isinstance(a, AgentStep)]
return [
(a.action, a.observation) for a in values if isinstance(a, AgentStep)
]
def _take_next_step( def _take_next_step(
self, self,
@ -1727,10 +1716,9 @@ class AgentExecutor(Chain):
and self.trim_intermediate_steps > 0 and self.trim_intermediate_steps > 0
): ):
return intermediate_steps[-self.trim_intermediate_steps :] return intermediate_steps[-self.trim_intermediate_steps :]
elif callable(self.trim_intermediate_steps): if callable(self.trim_intermediate_steps):
return self.trim_intermediate_steps(intermediate_steps) return self.trim_intermediate_steps(intermediate_steps)
else: return intermediate_steps
return intermediate_steps
@override @override
def stream( def stream(

View File

@ -61,8 +61,7 @@ class ChatAgent(Agent):
f"(but I haven't seen any of it! I only see what " f"(but I haven't seen any of it! I only see what "
f"you return as final answer):\n{agent_scratchpad}" f"you return as final answer):\n{agent_scratchpad}"
) )
else: return agent_scratchpad
return agent_scratchpad
@classmethod @classmethod
def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser: def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser:

View File

@ -39,15 +39,13 @@ class ConvoOutputParser(AgentOutputParser):
# If the action indicates a final answer, return an AgentFinish # If the action indicates a final answer, return an AgentFinish
if action == "Final Answer": if action == "Final Answer":
return AgentFinish({"output": action_input}, text) return AgentFinish({"output": action_input}, text)
else: # Otherwise, return an AgentAction with the specified action and
# Otherwise, return an AgentAction with the specified action and # input
# input return AgentAction(action, action_input, text)
return AgentAction(action, action_input, text) # If the necessary keys aren't present in the response, raise an
else: # exception
# If the necessary keys aren't present in the response, raise an msg = f"Missing 'action' or 'action_input' in LLM output: {text}"
# exception raise OutputParserException(msg)
msg = f"Missing 'action' or 'action_input' in LLM output: {text}"
raise OutputParserException(msg)
except Exception as e: except Exception as e:
# If any other exception is raised during parsing, also raise an # If any other exception is raised during parsing, also raise an
# OutputParserException # OutputParserException

View File

@ -23,8 +23,7 @@ def _convert_agent_action_to_messages(
return list(agent_action.message_log) + [ return list(agent_action.message_log) + [
_create_function_message(agent_action, observation) _create_function_message(agent_action, observation)
] ]
else: return [AIMessage(content=agent_action.log)]
return [AIMessage(content=agent_action.log)]
def _create_function_message( def _create_function_message(

View File

@ -182,7 +182,7 @@ def create_json_chat_agent(
else: else:
llm_to_use = llm llm_to_use = llm
agent = ( return (
RunnablePassthrough.assign( RunnablePassthrough.assign(
agent_scratchpad=lambda x: format_log_to_messages( agent_scratchpad=lambda x: format_log_to_messages(
x["intermediate_steps"], template_tool_response=template_tool_response x["intermediate_steps"], template_tool_response=template_tool_response
@ -192,4 +192,3 @@ def create_json_chat_agent(
| llm_to_use | llm_to_use
| JSONAgentOutputParser() | JSONAgentOutputParser()
) )
return agent

View File

@ -55,9 +55,8 @@ class MRKLOutputParser(AgentOutputParser):
return AgentFinish( return AgentFinish(
{"output": text[start_index:end_index].strip()}, text[:end_index] {"output": text[start_index:end_index].strip()}, text[:end_index]
) )
else: msg = f"{FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE}: {text}"
msg = f"{FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE}: {text}" raise OutputParserException(msg)
raise OutputParserException(msg)
if action_match: if action_match:
action = action_match.group(1).strip() action = action_match.group(1).strip()
@ -69,7 +68,7 @@ class MRKLOutputParser(AgentOutputParser):
return AgentAction(action, tool_input, text) return AgentAction(action, tool_input, text)
elif includes_answer: if includes_answer:
return AgentFinish( return AgentFinish(
{"output": text.split(FINAL_ANSWER_ACTION)[-1].strip()}, text {"output": text.split(FINAL_ANSWER_ACTION)[-1].strip()}, text
) )
@ -82,7 +81,7 @@ class MRKLOutputParser(AgentOutputParser):
llm_output=text, llm_output=text,
send_to_llm=True, send_to_llm=True,
) )
elif not re.search( if not re.search(
r"[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)", text, re.DOTALL r"[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)", text, re.DOTALL
): ):
msg = f"Could not parse LLM output: `{text}`" msg = f"Could not parse LLM output: `{text}`"
@ -92,9 +91,8 @@ class MRKLOutputParser(AgentOutputParser):
llm_output=text, llm_output=text,
send_to_llm=True, send_to_llm=True,
) )
else: msg = f"Could not parse LLM output: `{text}`"
msg = f"Could not parse LLM output: `{text}`" raise OutputParserException(msg)
raise OutputParserException(msg)
@property @property
def _type(self) -> str: def _type(self) -> str:

View File

@ -128,8 +128,7 @@ def _get_assistants_tool(
""" """
if _is_assistants_builtin_tool(tool): if _is_assistants_builtin_tool(tool):
return tool # type: ignore[return-value] return tool # type: ignore[return-value]
else: return convert_to_openai_tool(tool)
return convert_to_openai_tool(tool)
OutputType = Union[ OutputType = Union[
@ -510,12 +509,11 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
for action, output in intermediate_steps for action, output in intermediate_steps
if action.tool_call_id in required_tool_call_ids if action.tool_call_id in required_tool_call_ids
] ]
submit_tool_outputs = { return {
"tool_outputs": tool_outputs, "tool_outputs": tool_outputs,
"run_id": last_action.run_id, "run_id": last_action.run_id,
"thread_id": last_action.thread_id, "thread_id": last_action.thread_id,
} }
return submit_tool_outputs
def _create_run(self, input_dict: dict) -> Any: def _create_run(self, input_dict: dict) -> Any:
params = { params = {
@ -558,12 +556,11 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
"run_metadata", "run_metadata",
) )
} }
run = self.client.beta.threads.create_and_run( return self.client.beta.threads.create_and_run(
assistant_id=self.assistant_id, assistant_id=self.assistant_id,
thread=thread, thread=thread,
**params, **params,
) )
return run
def _get_response(self, run: Any) -> Any: def _get_response(self, run: Any) -> Any:
# TODO: Pagination # TODO: Pagination
@ -612,7 +609,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
run_id=run.id, run_id=run.id,
thread_id=run.thread_id, thread_id=run.thread_id,
) )
elif run.status == "requires_action": if run.status == "requires_action":
if not self.as_agent: if not self.as_agent:
return run.required_action.submit_tool_outputs.tool_calls return run.required_action.submit_tool_outputs.tool_calls
actions = [] actions = []
@ -639,10 +636,9 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
) )
) )
return actions return actions
else: run_info = json.dumps(run.dict(), indent=2)
run_info = json.dumps(run.dict(), indent=2) msg = f"Unexpected run status: {run.status}. Full run info:\n\n{run_info})"
msg = f"Unexpected run status: {run.status}. Full run info:\n\n{run_info})" raise ValueError(msg)
raise ValueError(msg)
def _wait_for_run(self, run_id: str, thread_id: str) -> Any: def _wait_for_run(self, run_id: str, thread_id: str) -> Any:
in_progress = True in_progress = True
@ -668,12 +664,11 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
for action, output in intermediate_steps for action, output in intermediate_steps
if action.tool_call_id in required_tool_call_ids if action.tool_call_id in required_tool_call_ids
] ]
submit_tool_outputs = { return {
"tool_outputs": tool_outputs, "tool_outputs": tool_outputs,
"run_id": last_action.run_id, "run_id": last_action.run_id,
"thread_id": last_action.thread_id, "thread_id": last_action.thread_id,
} }
return submit_tool_outputs
async def _acreate_run(self, input_dict: dict) -> Any: async def _acreate_run(self, input_dict: dict) -> Any:
params = { params = {
@ -716,12 +711,11 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
"run_metadata", "run_metadata",
) )
} }
run = await self.async_client.beta.threads.create_and_run( return await self.async_client.beta.threads.create_and_run(
assistant_id=self.assistant_id, assistant_id=self.assistant_id,
thread=thread, thread=thread,
**params, **params,
) )
return run
async def _aget_response(self, run: Any) -> Any: async def _aget_response(self, run: Any) -> Any:
# TODO: Pagination # TODO: Pagination
@ -766,7 +760,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
run_id=run.id, run_id=run.id,
thread_id=run.thread_id, thread_id=run.thread_id,
) )
elif run.status == "requires_action": if run.status == "requires_action":
if not self.as_agent: if not self.as_agent:
return run.required_action.submit_tool_outputs.tool_calls return run.required_action.submit_tool_outputs.tool_calls
actions = [] actions = []
@ -793,10 +787,9 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
) )
) )
return actions return actions
else: run_info = json.dumps(run.dict(), indent=2)
run_info = json.dumps(run.dict(), indent=2) msg = f"Unexpected run status: {run.status}. Full run info:\n\n{run_info})"
msg = f"Unexpected run status: {run.status}. Full run info:\n\n{run_info})" raise ValueError(msg)
raise ValueError(msg)
async def _await_for_run(self, run_id: str, thread_id: str) -> Any: async def _await_for_run(self, run_id: str, thread_id: str) -> Any:
in_progress = True in_progress = True

View File

@ -132,8 +132,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
messages, messages,
callbacks=callbacks, callbacks=callbacks,
) )
agent_decision = self.output_parser._parse_ai_message(predicted_message) return self.output_parser._parse_ai_message(predicted_message)
return agent_decision
async def aplan( async def aplan(
self, self,
@ -164,8 +163,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
predicted_message = await self.llm.apredict_messages( predicted_message = await self.llm.apredict_messages(
messages, functions=self.functions, callbacks=callbacks messages, functions=self.functions, callbacks=callbacks
) )
agent_decision = self.output_parser._parse_ai_message(predicted_message) return self.output_parser._parse_ai_message(predicted_message)
return agent_decision
def return_stopped_response( def return_stopped_response(
self, self,
@ -192,22 +190,20 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
return AgentFinish( return AgentFinish(
{"output": "Agent stopped due to iteration limit or time limit."}, "" {"output": "Agent stopped due to iteration limit or time limit."}, ""
) )
elif early_stopping_method == "generate": if early_stopping_method == "generate":
# Generate does one final forward pass # Generate does one final forward pass
agent_decision = self.plan( agent_decision = self.plan(
intermediate_steps, with_functions=False, **kwargs intermediate_steps, with_functions=False, **kwargs
) )
if isinstance(agent_decision, AgentFinish): if isinstance(agent_decision, AgentFinish):
return agent_decision return agent_decision
else: msg = f"got AgentAction with no functions provided: {agent_decision}"
msg = f"got AgentAction with no functions provided: {agent_decision}"
raise ValueError(msg)
else:
msg = (
"early_stopping_method should be one of `force` or `generate`, "
f"got {early_stopping_method}"
)
raise ValueError(msg) raise ValueError(msg)
msg = (
"early_stopping_method should be one of `force` or `generate`, "
f"got {early_stopping_method}"
)
raise ValueError(msg)
@classmethod @classmethod
def create_prompt( def create_prompt(
@ -358,7 +354,7 @@ def create_openai_functions_agent(
) )
raise ValueError(msg) raise ValueError(msg)
llm_with_tools = llm.bind(functions=[convert_to_openai_function(t) for t in tools]) llm_with_tools = llm.bind(functions=[convert_to_openai_function(t) for t in tools])
agent = ( return (
RunnablePassthrough.assign( RunnablePassthrough.assign(
agent_scratchpad=lambda x: format_to_openai_function_messages( agent_scratchpad=lambda x: format_to_openai_function_messages(
x["intermediate_steps"] x["intermediate_steps"]
@ -368,4 +364,3 @@ def create_openai_functions_agent(
| llm_with_tools | llm_with_tools
| OpenAIFunctionsAgentOutputParser() | OpenAIFunctionsAgentOutputParser()
) )
return agent

View File

@ -224,8 +224,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
predicted_message = self.llm.predict_messages( predicted_message = self.llm.predict_messages(
messages, functions=self.functions, callbacks=callbacks messages, functions=self.functions, callbacks=callbacks
) )
agent_decision = _parse_ai_message(predicted_message) return _parse_ai_message(predicted_message)
return agent_decision
async def aplan( async def aplan(
self, self,
@ -254,8 +253,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
predicted_message = await self.llm.apredict_messages( predicted_message = await self.llm.apredict_messages(
messages, functions=self.functions, callbacks=callbacks messages, functions=self.functions, callbacks=callbacks
) )
agent_decision = _parse_ai_message(predicted_message) return _parse_ai_message(predicted_message)
return agent_decision
@classmethod @classmethod
def create_prompt( def create_prompt(

View File

@ -96,7 +96,7 @@ def create_openai_tools_agent(
tools=[convert_to_openai_tool(tool, strict=strict) for tool in tools] tools=[convert_to_openai_tool(tool, strict=strict) for tool in tools]
) )
agent = ( return (
RunnablePassthrough.assign( RunnablePassthrough.assign(
agent_scratchpad=lambda x: format_to_openai_tool_messages( agent_scratchpad=lambda x: format_to_openai_tool_messages(
x["intermediate_steps"] x["intermediate_steps"]
@ -106,4 +106,3 @@ def create_openai_tools_agent(
| llm_with_tools | llm_with_tools
| OpenAIToolsAgentOutputParser() | OpenAIToolsAgentOutputParser()
) )
return agent

View File

@ -49,11 +49,10 @@ class JSONAgentOutputParser(AgentOutputParser):
response = response[0] response = response[0]
if response["action"] == "Final Answer": if response["action"] == "Final Answer":
return AgentFinish({"output": response["action_input"]}, text) return AgentFinish({"output": response["action_input"]}, text)
else: action_input = response.get("action_input", {})
action_input = response.get("action_input", {}) if action_input is None:
if action_input is None: action_input = {}
action_input = {} return AgentAction(response["action"], action_input, text)
return AgentAction(response["action"], action_input, text)
except Exception as e: except Exception as e:
msg = f"Could not parse LLM output: {text}" msg = f"Could not parse LLM output: {text}"
raise OutputParserException(msg) from e raise OutputParserException(msg) from e

View File

@ -65,7 +65,7 @@ class ReActSingleInputOutputParser(AgentOutputParser):
return AgentAction(action, tool_input, text) return AgentAction(action, tool_input, text)
elif includes_answer: if includes_answer:
return AgentFinish( return AgentFinish(
{"output": text.split(FINAL_ANSWER_ACTION)[-1].strip()}, text {"output": text.split(FINAL_ANSWER_ACTION)[-1].strip()}, text
) )
@ -78,7 +78,7 @@ class ReActSingleInputOutputParser(AgentOutputParser):
llm_output=text, llm_output=text,
send_to_llm=True, send_to_llm=True,
) )
elif not re.search( if not re.search(
r"[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)", text, re.DOTALL r"[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)", text, re.DOTALL
): ):
msg = f"Could not parse LLM output: `{text}`" msg = f"Could not parse LLM output: `{text}`"
@ -88,9 +88,8 @@ class ReActSingleInputOutputParser(AgentOutputParser):
llm_output=text, llm_output=text,
send_to_llm=True, send_to_llm=True,
) )
else: msg = f"Could not parse LLM output: `{text}`"
msg = f"Could not parse LLM output: `{text}`" raise OutputParserException(msg)
raise OutputParserException(msg)
@property @property
def _type(self) -> str: def _type(self) -> str:

View File

@ -36,13 +36,12 @@ class XMLAgentOutputParser(AgentOutputParser):
if "</tool_input>" in _tool_input: if "</tool_input>" in _tool_input:
_tool_input = _tool_input.split("</tool_input>")[0] _tool_input = _tool_input.split("</tool_input>")[0]
return AgentAction(tool=_tool, tool_input=_tool_input, log=text) return AgentAction(tool=_tool, tool_input=_tool_input, log=text)
elif "<final_answer>" in text: if "<final_answer>" in text:
_, answer = text.split("<final_answer>") _, answer = text.split("<final_answer>")
if "</final_answer>" in answer: if "</final_answer>" in answer:
answer = answer.split("</final_answer>")[0] answer = answer.split("</final_answer>")[0]
return AgentFinish(return_values={"output": answer}, log=text) return AgentFinish(return_values={"output": answer}, log=text)
else: raise ValueError
raise ValueError
def get_format_instructions(self) -> str: def get_format_instructions(self) -> str:
raise NotImplementedError raise NotImplementedError

View File

@ -134,7 +134,7 @@ def create_react_agent(
else: else:
llm_with_stop = llm llm_with_stop = llm
output_parser = output_parser or ReActSingleInputOutputParser() output_parser = output_parser or ReActSingleInputOutputParser()
agent = ( return (
RunnablePassthrough.assign( RunnablePassthrough.assign(
agent_scratchpad=lambda x: format_log_to_str(x["intermediate_steps"]), agent_scratchpad=lambda x: format_log_to_str(x["intermediate_steps"]),
) )
@ -142,4 +142,3 @@ def create_react_agent(
| llm_with_stop | llm_with_stop
| output_parser | output_parser
) )
return agent

View File

@ -96,9 +96,8 @@ class DocstoreExplorer:
if isinstance(result, Document): if isinstance(result, Document):
self.document = result self.document = result
return self._summary return self._summary
else: self.document = None
self.document = None return result
return result
def lookup(self, term: str) -> str: def lookup(self, term: str) -> str:
"""Lookup a term in document (if saved).""" """Lookup a term in document (if saved)."""
@ -113,11 +112,10 @@ class DocstoreExplorer:
lookups = [p for p in self._paragraphs if self.lookup_str in p.lower()] lookups = [p for p in self._paragraphs if self.lookup_str in p.lower()]
if len(lookups) == 0: if len(lookups) == 0:
return "No Results" return "No Results"
elif self.lookup_index >= len(lookups): if self.lookup_index >= len(lookups):
return "No More Results" return "No More Results"
else: result_prefix = f"(Result {self.lookup_index + 1}/{len(lookups)})"
result_prefix = f"(Result {self.lookup_index + 1}/{len(lookups)})" return f"{result_prefix} {lookups[self.lookup_index]}"
return f"{result_prefix} {lookups[self.lookup_index]}"
@property @property
def _summary(self) -> str: def _summary(self) -> str:

View File

@ -26,8 +26,7 @@ class ReActOutputParser(AgentOutputParser):
action, action_input = re_matches.group(1), re_matches.group(2) action, action_input = re_matches.group(1), re_matches.group(2)
if action == "Finish": if action == "Finish":
return AgentFinish({"output": action_input}, text) return AgentFinish({"output": action_input}, text)
else: return AgentAction(action, action_input, text)
return AgentAction(action, action_input, text)
@property @property
def _type(self) -> str: def _type(self) -> str:

View File

@ -195,7 +195,7 @@ def create_self_ask_with_search_agent(
raise ValueError(msg) raise ValueError(msg)
llm_with_stop = llm.bind(stop=["\nIntermediate answer:"]) llm_with_stop = llm.bind(stop=["\nIntermediate answer:"])
agent = ( return (
RunnablePassthrough.assign( RunnablePassthrough.assign(
agent_scratchpad=lambda x: format_log_to_str( agent_scratchpad=lambda x: format_log_to_str(
x["intermediate_steps"], x["intermediate_steps"],
@ -209,4 +209,3 @@ def create_self_ask_with_search_agent(
| llm_with_stop | llm_with_stop
| SelfAskOutputParser() | SelfAskOutputParser()
) )
return agent

View File

@ -62,8 +62,7 @@ class StructuredChatAgent(Agent):
f"(but I haven't seen any of it! I only see what " f"(but I haven't seen any of it! I only see what "
f"you return as final answer):\n{agent_scratchpad}" f"you return as final answer):\n{agent_scratchpad}"
) )
else: return agent_scratchpad
return agent_scratchpad
@classmethod @classmethod
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None: def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
@ -292,7 +291,7 @@ def create_structured_chat_agent(
else: else:
llm_with_stop = llm llm_with_stop = llm
agent = ( return (
RunnablePassthrough.assign( RunnablePassthrough.assign(
agent_scratchpad=lambda x: format_log_to_str(x["intermediate_steps"]), agent_scratchpad=lambda x: format_log_to_str(x["intermediate_steps"]),
) )
@ -300,4 +299,3 @@ def create_structured_chat_agent(
| llm_with_stop | llm_with_stop
| JSONAgentOutputParser() | JSONAgentOutputParser()
) )
return agent

View File

@ -42,12 +42,10 @@ class StructuredChatOutputParser(AgentOutputParser):
response = response[0] response = response[0]
if response["action"] == "Final Answer": if response["action"] == "Final Answer":
return AgentFinish({"output": response["action_input"]}, text) return AgentFinish({"output": response["action_input"]}, text)
else: return AgentAction(
return AgentAction( response["action"], response.get("action_input", {}), text
response["action"], response.get("action_input", {}), text )
) return AgentFinish({"output": text}, text)
else:
return AgentFinish({"output": text}, text)
except Exception as e: except Exception as e:
msg = f"Could not parse LLM output: {text}" msg = f"Could not parse LLM output: {text}"
raise OutputParserException(msg) from e raise OutputParserException(msg) from e
@ -93,10 +91,9 @@ class StructuredChatOutputParserWithRetries(AgentOutputParser):
llm=llm, parser=base_parser llm=llm, parser=base_parser
) )
return cls(output_fixing_parser=output_fixing_parser) return cls(output_fixing_parser=output_fixing_parser)
elif base_parser is not None: if base_parser is not None:
return cls(base_parser=base_parser) return cls(base_parser=base_parser)
else: return cls()
return cls()
@property @property
def _type(self) -> str: def _type(self) -> str:

View File

@ -100,7 +100,7 @@ def create_tool_calling_agent(
) )
llm_with_tools = llm.bind_tools(tools) llm_with_tools = llm.bind_tools(tools)
agent = ( return (
RunnablePassthrough.assign( RunnablePassthrough.assign(
agent_scratchpad=lambda x: message_formatter(x["intermediate_steps"]) agent_scratchpad=lambda x: message_formatter(x["intermediate_steps"])
) )
@ -108,4 +108,3 @@ def create_tool_calling_agent(
| llm_with_tools | llm_with_tools
| ToolsAgentOutputParser() | ToolsAgentOutputParser()
) )
return agent

View File

@ -221,7 +221,7 @@ def create_xml_agent(
else: else:
llm_with_stop = llm llm_with_stop = llm
agent = ( return (
RunnablePassthrough.assign( RunnablePassthrough.assign(
agent_scratchpad=lambda x: format_xml(x["intermediate_steps"]), agent_scratchpad=lambda x: format_xml(x["intermediate_steps"]),
) )
@ -229,4 +229,3 @@ def create_xml_agent(
| llm_with_stop | llm_with_stop
| XMLAgentOutputParser() | XMLAgentOutputParser()
) )
return agent

View File

@ -24,8 +24,7 @@ class AsyncFinalIteratorCallbackHandler(AsyncIteratorCallbackHandler):
def check_if_answer_reached(self) -> bool: def check_if_answer_reached(self) -> bool:
if self.strip_tokens: if self.strip_tokens:
return self.last_tokens_stripped == self.answer_prefix_tokens_stripped return self.last_tokens_stripped == self.answer_prefix_tokens_stripped
else: return self.last_tokens == self.answer_prefix_tokens
return self.last_tokens == self.answer_prefix_tokens
def __init__( def __init__(
self, self,

View File

@ -25,8 +25,7 @@ class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
def check_if_answer_reached(self) -> bool: def check_if_answer_reached(self) -> bool:
if self.strip_tokens: if self.strip_tokens:
return self.last_tokens_stripped == self.answer_prefix_tokens_stripped return self.last_tokens_stripped == self.answer_prefix_tokens_stripped
else: return self.last_tokens == self.answer_prefix_tokens
return self.last_tokens == self.answer_prefix_tokens
def __init__( def __init__(
self, self,

View File

@ -261,8 +261,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
""" """
if verbose is None: if verbose is None:
return _get_verbosity() return _get_verbosity()
else: return verbose
return verbose
@property @property
@abstractmethod @abstractmethod
@ -474,8 +473,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
self.memory.save_context(inputs, outputs) self.memory.save_context(inputs, outputs)
if return_only_outputs: if return_only_outputs:
return outputs return outputs
else: return {**inputs, **outputs}
return {**inputs, **outputs}
async def aprep_outputs( async def aprep_outputs(
self, self,
@ -500,8 +498,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
await self.memory.asave_context(inputs, outputs) await self.memory.asave_context(inputs, outputs)
if return_only_outputs: if return_only_outputs:
return outputs return outputs
else: return {**inputs, **outputs}
return {**inputs, **outputs}
def prep_inputs(self, inputs: Union[dict[str, Any], Any]) -> dict[str, str]: def prep_inputs(self, inputs: Union[dict[str, Any], Any]) -> dict[str, str]:
"""Prepare chain inputs, including adding inputs from memory. """Prepare chain inputs, including adding inputs from memory.
@ -628,12 +625,11 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
" but none were provided." " but none were provided."
) )
raise ValueError(msg) raise ValueError(msg)
else: msg = (
msg = ( f"`run` supported with either positional arguments or keyword arguments"
f"`run` supported with either positional arguments or keyword arguments" f" but not both. Got args: {args} and kwargs: {kwargs}."
f" but not both. Got args: {args} and kwargs: {kwargs}." )
) raise ValueError(msg)
raise ValueError(msg)
@deprecated("0.1.0", alternative="ainvoke", removal="1.0") @deprecated("0.1.0", alternative="ainvoke", removal="1.0")
async def arun( async def arun(
@ -687,7 +683,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
f"one output key. Got {self.output_keys}." f"one output key. Got {self.output_keys}."
) )
raise ValueError(msg) raise ValueError(msg)
elif args and not kwargs: if args and not kwargs:
if len(args) != 1: if len(args) != 1:
msg = "`run` supports only one positional argument." msg = "`run` supports only one positional argument."
raise ValueError(msg) raise ValueError(msg)

View File

@ -208,28 +208,25 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
if isinstance(self.reduce_documents_chain, ReduceDocumentsChain): if isinstance(self.reduce_documents_chain, ReduceDocumentsChain):
if self.reduce_documents_chain.collapse_documents_chain: if self.reduce_documents_chain.collapse_documents_chain:
return self.reduce_documents_chain.collapse_documents_chain return self.reduce_documents_chain.collapse_documents_chain
else: return self.reduce_documents_chain.combine_documents_chain
return self.reduce_documents_chain.combine_documents_chain msg = (
else: f"`reduce_documents_chain` is of type "
msg = ( f"{type(self.reduce_documents_chain)} so it does not have "
f"`reduce_documents_chain` is of type " f"this attribute."
f"{type(self.reduce_documents_chain)} so it does not have " )
f"this attribute." raise ValueError(msg)
)
raise ValueError(msg)
@property @property
def combine_document_chain(self) -> BaseCombineDocumentsChain: def combine_document_chain(self) -> BaseCombineDocumentsChain:
"""Kept for backward compatibility.""" """Kept for backward compatibility."""
if isinstance(self.reduce_documents_chain, ReduceDocumentsChain): if isinstance(self.reduce_documents_chain, ReduceDocumentsChain):
return self.reduce_documents_chain.combine_documents_chain return self.reduce_documents_chain.combine_documents_chain
else: msg = (
msg = ( f"`reduce_documents_chain` is of type "
f"`reduce_documents_chain` is of type " f"{type(self.reduce_documents_chain)} so it does not have "
f"{type(self.reduce_documents_chain)} so it does not have " f"this attribute."
f"this attribute." )
) raise ValueError(msg)
raise ValueError(msg)
def combine_docs( def combine_docs(
self, self,

View File

@ -225,8 +225,7 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
def _collapse_chain(self) -> BaseCombineDocumentsChain: def _collapse_chain(self) -> BaseCombineDocumentsChain:
if self.collapse_documents_chain is not None: if self.collapse_documents_chain is not None:
return self.collapse_documents_chain return self.collapse_documents_chain
else: return self.combine_documents_chain
return self.combine_documents_chain
def combine_docs( def combine_docs(
self, self,

View File

@ -222,8 +222,7 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
base_inputs: dict = { base_inputs: dict = {
self.document_variable_name: self.document_prompt.format(**document_info) self.document_variable_name: self.document_prompt.format(**document_info)
} }
inputs = {**base_inputs, **kwargs} return {**base_inputs, **kwargs}
return inputs
@property @property
def _chain_type(self) -> str: def _chain_type(self) -> str:

View File

@ -201,8 +201,7 @@ class ConstitutionalChain(Chain):
) -> list[ConstitutionalPrinciple]: ) -> list[ConstitutionalPrinciple]:
if names is None: if names is None:
return list(PRINCIPLES.values()) return list(PRINCIPLES.values())
else: return [PRINCIPLES[name] for name in names]
return [PRINCIPLES[name] for name in names]
@classmethod @classmethod
def from_llm( def from_llm(

View File

@ -80,8 +80,7 @@ class ElasticsearchDatabaseChain(Chain):
""" """
if not self.return_intermediate_steps: if not self.return_intermediate_steps:
return [self.output_key] return [self.output_key]
else: return [self.output_key, INTERMEDIATE_STEPS_KEY]
return [self.output_key, INTERMEDIATE_STEPS_KEY]
def _list_indices(self) -> list[str]: def _list_indices(self) -> list[str]:
all_indices = [ all_indices = [

View File

@ -47,8 +47,7 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
"""Output keys for Hyde's LLM chain.""" """Output keys for Hyde's LLM chain."""
if isinstance(self.llm_chain, LLMChain): if isinstance(self.llm_chain, LLMChain):
return self.llm_chain.output_keys return self.llm_chain.output_keys
else: return ["text"]
return ["text"]
def embed_documents(self, texts: list[str]) -> list[list[float]]: def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Call the base embeddings.""" """Call the base embeddings."""

View File

@ -116,8 +116,7 @@ class LLMChain(Chain):
""" """
if self.return_final_only: if self.return_final_only:
return [self.output_key] return [self.output_key]
else: return [self.output_key, "full_generation"]
return [self.output_key, "full_generation"]
def _call( def _call(
self, self,
@ -142,17 +141,16 @@ class LLMChain(Chain):
callbacks=callbacks, callbacks=callbacks,
**self.llm_kwargs, **self.llm_kwargs,
) )
else: results = self.llm.bind(stop=stop, **self.llm_kwargs).batch(
results = self.llm.bind(stop=stop, **self.llm_kwargs).batch( cast(list, prompts), {"callbacks": callbacks}
cast(list, prompts), {"callbacks": callbacks} )
) generations: list[list[Generation]] = []
generations: list[list[Generation]] = [] for res in results:
for res in results: if isinstance(res, BaseMessage):
if isinstance(res, BaseMessage): generations.append([ChatGeneration(message=res)])
generations.append([ChatGeneration(message=res)]) else:
else: generations.append([Generation(text=res)])
generations.append([Generation(text=res)]) return LLMResult(generations=generations)
return LLMResult(generations=generations)
async def agenerate( async def agenerate(
self, self,
@ -169,17 +167,16 @@ class LLMChain(Chain):
callbacks=callbacks, callbacks=callbacks,
**self.llm_kwargs, **self.llm_kwargs,
) )
else: results = await self.llm.bind(stop=stop, **self.llm_kwargs).abatch(
results = await self.llm.bind(stop=stop, **self.llm_kwargs).abatch( cast(list, prompts), {"callbacks": callbacks}
cast(list, prompts), {"callbacks": callbacks} )
) generations: list[list[Generation]] = []
generations: list[list[Generation]] = [] for res in results:
for res in results: if isinstance(res, BaseMessage):
if isinstance(res, BaseMessage): generations.append([ChatGeneration(message=res)])
generations.append([ChatGeneration(message=res)]) else:
else: generations.append([Generation(text=res)])
generations.append([Generation(text=res)]) return LLMResult(generations=generations)
return LLMResult(generations=generations)
def prep_prompts( def prep_prompts(
self, self,
@ -344,8 +341,7 @@ class LLMChain(Chain):
result = self.predict(callbacks=callbacks, **kwargs) result = self.predict(callbacks=callbacks, **kwargs)
if self.prompt.output_parser is not None: if self.prompt.output_parser is not None:
return self.prompt.output_parser.parse(result) return self.prompt.output_parser.parse(result)
else: return result
return result
async def apredict_and_parse( async def apredict_and_parse(
self, callbacks: Callbacks = None, **kwargs: Any self, callbacks: Callbacks = None, **kwargs: Any
@ -358,8 +354,7 @@ class LLMChain(Chain):
result = await self.apredict(callbacks=callbacks, **kwargs) result = await self.apredict(callbacks=callbacks, **kwargs)
if self.prompt.output_parser is not None: if self.prompt.output_parser is not None:
return self.prompt.output_parser.parse(result) return self.prompt.output_parser.parse(result)
else: return result
return result
def apply_and_parse( def apply_and_parse(
self, input_list: list[dict[str, Any]], callbacks: Callbacks = None self, input_list: list[dict[str, Any]], callbacks: Callbacks = None
@ -380,8 +375,7 @@ class LLMChain(Chain):
self.prompt.output_parser.parse(res[self.output_key]) self.prompt.output_parser.parse(res[self.output_key])
for res in generation for res in generation
] ]
else: return generation
return generation
async def aapply_and_parse( async def aapply_and_parse(
self, input_list: list[dict[str, Any]], callbacks: Callbacks = None self, input_list: list[dict[str, Any]], callbacks: Callbacks = None
@ -411,15 +405,14 @@ class LLMChain(Chain):
def _get_language_model(llm_like: Runnable) -> BaseLanguageModel: def _get_language_model(llm_like: Runnable) -> BaseLanguageModel:
if isinstance(llm_like, BaseLanguageModel): if isinstance(llm_like, BaseLanguageModel):
return llm_like return llm_like
elif isinstance(llm_like, RunnableBinding): if isinstance(llm_like, RunnableBinding):
return _get_language_model(llm_like.bound) return _get_language_model(llm_like.bound)
elif isinstance(llm_like, RunnableWithFallbacks): if isinstance(llm_like, RunnableWithFallbacks):
return _get_language_model(llm_like.runnable) return _get_language_model(llm_like.runnable)
elif isinstance(llm_like, (RunnableBranch, DynamicRunnable)): if isinstance(llm_like, (RunnableBranch, DynamicRunnable)):
return _get_language_model(llm_like.default) return _get_language_model(llm_like.default)
else: msg = (
msg = ( f"Unable to extract BaseLanguageModel from llm_like object of type "
f"Unable to extract BaseLanguageModel from llm_like object of type " f"{type(llm_like)}"
f"{type(llm_like)}" )
) raise ValueError(msg)
raise ValueError(msg)

View File

@ -55,13 +55,12 @@ def _load_question_to_checked_assertions_chain(
check_assertions_chain, check_assertions_chain,
revised_answer_chain, revised_answer_chain,
] ]
question_to_checked_assertions_chain = SequentialChain( return SequentialChain(
chains=chains, # type: ignore[arg-type] chains=chains, # type: ignore[arg-type]
input_variables=["question"], input_variables=["question"],
output_variables=["revised_statement"], output_variables=["revised_statement"],
verbose=True, verbose=True,
) )
return question_to_checked_assertions_chain
@deprecated( @deprecated(

View File

@ -32,7 +32,7 @@ def _load_sequential_chain(
are_all_true_prompt: PromptTemplate, are_all_true_prompt: PromptTemplate,
verbose: bool = False, verbose: bool = False,
) -> SequentialChain: ) -> SequentialChain:
chain = SequentialChain( return SequentialChain(
chains=[ chains=[
LLMChain( LLMChain(
llm=llm, llm=llm,
@ -63,7 +63,6 @@ def _load_sequential_chain(
output_variables=["all_true", "revised_summary"], output_variables=["all_true", "revised_summary"],
verbose=verbose, verbose=verbose,
) )
return chain
@deprecated( @deprecated(

View File

@ -311,8 +311,7 @@ def _load_llm_math_chain(config: dict, **kwargs: Any) -> LLMMathChain:
prompt = load_prompt(config.pop("prompt_path")) prompt = load_prompt(config.pop("prompt_path"))
if llm_chain: if llm_chain:
return LLMMathChain(llm_chain=llm_chain, prompt=prompt, **config) # type: ignore[arg-type] return LLMMathChain(llm_chain=llm_chain, prompt=prompt, **config) # type: ignore[arg-type]
else: return LLMMathChain(llm=llm, prompt=prompt, **config)
return LLMMathChain(llm=llm, prompt=prompt, **config)
def _load_map_rerank_documents_chain( def _load_map_rerank_documents_chain(
@ -609,8 +608,7 @@ def _load_llm_requests_chain(config: dict, **kwargs: Any) -> LLMRequestsChain:
return LLMRequestsChain( return LLMRequestsChain(
llm_chain=llm_chain, requests_wrapper=requests_wrapper, **config llm_chain=llm_chain, requests_wrapper=requests_wrapper, **config
) )
else: return LLMRequestsChain(llm_chain=llm_chain, **config)
return LLMRequestsChain(llm_chain=llm_chain, **config)
type_to_loader_dict = { type_to_loader_dict = {

View File

@ -100,8 +100,7 @@ class OpenAIModerationChain(Chain):
error_str = "Text was found that violates OpenAI's content policy." error_str = "Text was found that violates OpenAI's content policy."
if self.error: if self.error:
raise ValueError(error_str) raise ValueError(error_str)
else: return error_str
return error_str
return text return text
def _call( def _call(

View File

@ -132,7 +132,7 @@ def create_openai_fn_chain(
} }
if len(openai_functions) == 1 and enforce_single_function_usage: if len(openai_functions) == 1 and enforce_single_function_usage:
llm_kwargs["function_call"] = {"name": openai_functions[0]["name"]} llm_kwargs["function_call"] = {"name": openai_functions[0]["name"]}
llm_chain = LLMChain( return LLMChain(
llm=llm, llm=llm,
prompt=prompt, prompt=prompt,
output_parser=output_parser, output_parser=output_parser,
@ -140,7 +140,6 @@ def create_openai_fn_chain(
output_key=output_key, output_key=output_key,
**kwargs, **kwargs,
) )
return llm_chain
@deprecated( @deprecated(

View File

@ -149,10 +149,9 @@ def create_citation_fuzzy_match_chain(llm: BaseLanguageModel) -> LLMChain:
] ]
prompt = ChatPromptTemplate(messages=messages) # type: ignore[arg-type] prompt = ChatPromptTemplate(messages=messages) # type: ignore[arg-type]
chain = LLMChain( return LLMChain(
llm=llm, llm=llm,
prompt=prompt, prompt=prompt,
llm_kwargs=llm_kwargs, llm_kwargs=llm_kwargs,
output_parser=output_parser, output_parser=output_parser,
) )
return chain

View File

@ -103,7 +103,7 @@ def create_extraction_chain(
extraction_prompt = prompt or ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE) extraction_prompt = prompt or ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE)
output_parser = JsonKeyOutputFunctionsParser(key_name="info") output_parser = JsonKeyOutputFunctionsParser(key_name="info")
llm_kwargs = get_llm_kwargs(function) llm_kwargs = get_llm_kwargs(function)
chain = LLMChain( return LLMChain(
llm=llm, llm=llm,
prompt=extraction_prompt, prompt=extraction_prompt,
llm_kwargs=llm_kwargs, llm_kwargs=llm_kwargs,
@ -111,7 +111,6 @@ def create_extraction_chain(
tags=tags, tags=tags,
verbose=verbose, verbose=verbose,
) )
return chain
@deprecated( @deprecated(
@ -187,11 +186,10 @@ def create_extraction_chain_pydantic(
pydantic_schema=PydanticSchema, attr_name="info" pydantic_schema=PydanticSchema, attr_name="info"
) )
llm_kwargs = get_llm_kwargs(function) llm_kwargs = get_llm_kwargs(function)
chain = LLMChain( return LLMChain(
llm=llm, llm=llm,
prompt=extraction_prompt, prompt=extraction_prompt,
llm_kwargs=llm_kwargs, llm_kwargs=llm_kwargs,
output_parser=output_parser, output_parser=output_parser,
verbose=verbose, verbose=verbose,
) )
return chain

View File

@ -100,14 +100,13 @@ def create_qa_with_structure_chain(
] ]
prompt = prompt or ChatPromptTemplate(messages=messages) # type: ignore[arg-type] prompt = prompt or ChatPromptTemplate(messages=messages) # type: ignore[arg-type]
chain = LLMChain( return LLMChain(
llm=llm, llm=llm,
prompt=prompt, prompt=prompt,
llm_kwargs=llm_kwargs, llm_kwargs=llm_kwargs,
output_parser=_output_parser, output_parser=_output_parser,
verbose=verbose, verbose=verbose,
) )
return chain
@deprecated( @deprecated(

View File

@ -91,14 +91,13 @@ def create_tagging_chain(
prompt = prompt or ChatPromptTemplate.from_template(_TAGGING_TEMPLATE) prompt = prompt or ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
output_parser = JsonOutputFunctionsParser() output_parser = JsonOutputFunctionsParser()
llm_kwargs = get_llm_kwargs(function) llm_kwargs = get_llm_kwargs(function)
chain = LLMChain( return LLMChain(
llm=llm, llm=llm,
prompt=prompt, prompt=prompt,
llm_kwargs=llm_kwargs, llm_kwargs=llm_kwargs,
output_parser=output_parser, output_parser=output_parser,
**kwargs, **kwargs,
) )
return chain
@deprecated( @deprecated(
@ -164,11 +163,10 @@ def create_tagging_chain_pydantic(
prompt = prompt or ChatPromptTemplate.from_template(_TAGGING_TEMPLATE) prompt = prompt or ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
output_parser = PydanticOutputFunctionsParser(pydantic_schema=pydantic_schema) output_parser = PydanticOutputFunctionsParser(pydantic_schema=pydantic_schema)
llm_kwargs = get_llm_kwargs(function) llm_kwargs = get_llm_kwargs(function)
chain = LLMChain( return LLMChain(
llm=llm, llm=llm,
prompt=prompt, prompt=prompt,
llm_kwargs=llm_kwargs, llm_kwargs=llm_kwargs,
output_parser=output_parser, output_parser=output_parser,
**kwargs, **kwargs,
) )
return chain

View File

@ -76,5 +76,4 @@ def create_extraction_chain_pydantic(
functions = [convert_pydantic_to_openai_function(p) for p in pydantic_schemas] functions = [convert_pydantic_to_openai_function(p) for p in pydantic_schemas]
tools = [{"type": "function", "function": d} for d in functions] tools = [{"type": "function", "function": d} for d in functions]
model = llm.bind(tools=tools) model = llm.bind(tools=tools)
chain = prompt | model | PydanticToolsParser(tools=pydantic_schemas) return prompt | model | PydanticToolsParser(tools=pydantic_schemas)
return chain

View File

@ -91,13 +91,12 @@ class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
filter_directive = cast( filter_directive = cast(
Optional[FilterDirective], get_parser().parse(raw_filter) Optional[FilterDirective], get_parser().parse(raw_filter)
) )
fixed = fix_filter_directive( return fix_filter_directive(
filter_directive, filter_directive,
allowed_comparators=allowed_comparators, allowed_comparators=allowed_comparators,
allowed_operators=allowed_operators, allowed_operators=allowed_operators,
allowed_attributes=allowed_attributes, allowed_attributes=allowed_attributes,
) )
return fixed
else: else:
ast_parse = get_parser( ast_parse = get_parser(
@ -131,13 +130,13 @@ def fix_filter_directive(
) or not filter: ) or not filter:
return filter return filter
elif isinstance(filter, Comparison): if isinstance(filter, Comparison):
if allowed_comparators and filter.comparator not in allowed_comparators: if allowed_comparators and filter.comparator not in allowed_comparators:
return None return None
if allowed_attributes and filter.attribute not in allowed_attributes: if allowed_attributes and filter.attribute not in allowed_attributes:
return None return None
return filter return filter
elif isinstance(filter, Operation): if isinstance(filter, Operation):
if allowed_operators and filter.operator not in allowed_operators: if allowed_operators and filter.operator not in allowed_operators:
return None return None
args = [ args = [
@ -155,15 +154,13 @@ def fix_filter_directive(
] ]
if not args: if not args:
return None return None
elif len(args) == 1 and filter.operator in (Operator.AND, Operator.OR): if len(args) == 1 and filter.operator in (Operator.AND, Operator.OR):
return args[0] return args[0]
else: return Operation(
return Operation( operator=filter.operator,
operator=filter.operator, arguments=args,
arguments=args, )
) return filter
else:
return filter
def _format_attribute_info(info: Sequence[Union[AttributeInfo, dict]]) -> str: def _format_attribute_info(info: Sequence[Union[AttributeInfo, dict]]) -> str:

View File

@ -101,10 +101,9 @@ class QueryTransformer(Transformer):
) )
raise ValueError(msg) raise ValueError(msg)
return Comparison(comparator=func, attribute=args[0], value=args[1]) return Comparison(comparator=func, attribute=args[0], value=args[1])
elif len(args) == 1 and func in (Operator.AND, Operator.OR): if len(args) == 1 and func in (Operator.AND, Operator.OR):
return args[0] return args[0]
else: return Operation(operator=func, arguments=args)
return Operation(operator=func, arguments=args)
def _match_func_name(self, func_name: str) -> Union[Operator, Comparator]: def _match_func_name(self, func_name: str) -> Union[Operator, Comparator]:
if func_name in set(Comparator): if func_name in set(Comparator):
@ -118,7 +117,7 @@ class QueryTransformer(Transformer):
) )
raise ValueError(msg) raise ValueError(msg)
return Comparator(func_name) return Comparator(func_name)
elif func_name in set(Operator): if func_name in set(Operator):
if ( if (
self.allowed_operators is not None self.allowed_operators is not None
and func_name not in self.allowed_operators and func_name not in self.allowed_operators
@ -129,12 +128,11 @@ class QueryTransformer(Transformer):
) )
raise ValueError(msg) raise ValueError(msg)
return Operator(func_name) return Operator(func_name)
else: msg = (
msg = ( f"Received unrecognized function {func_name}. Valid functions are "
f"Received unrecognized function {func_name}. Valid functions are " f"{list(Operator) + list(Comparator)}"
f"{list(Operator) + list(Comparator)}" )
) raise ValueError(msg)
raise ValueError(msg)
def args(self, *items: Any) -> tuple: def args(self, *items: Any) -> tuple:
return items return items

View File

@ -60,10 +60,8 @@ def create_retrieval_chain(
else: else:
retrieval_docs = (lambda x: x["input"]) | retriever retrieval_docs = (lambda x: x["input"]) | retriever
retrieval_chain = ( return (
RunnablePassthrough.assign( RunnablePassthrough.assign(
context=retrieval_docs.with_config(run_name="retrieve_documents"), context=retrieval_docs.with_config(run_name="retrieve_documents"),
).assign(answer=combine_docs_chain) ).assign(answer=combine_docs_chain)
).with_config(run_name="retrieval_chain") ).with_config(run_name="retrieval_chain")
return retrieval_chain

View File

@ -157,8 +157,7 @@ class BaseRetrievalQA(Chain):
if self.return_source_documents: if self.return_source_documents:
return {self.output_key: answer, "source_documents": docs} return {self.output_key: answer, "source_documents": docs}
else: return {self.output_key: answer}
return {self.output_key: answer}
@abstractmethod @abstractmethod
async def _aget_docs( async def _aget_docs(
@ -200,8 +199,7 @@ class BaseRetrievalQA(Chain):
if self.return_source_documents: if self.return_source_documents:
return {self.output_key: answer, "source_documents": docs} return {self.output_key: answer, "source_documents": docs}
else: return {self.output_key: answer}
return {self.output_key: answer}
@deprecated( @deprecated(

View File

@ -97,15 +97,14 @@ class MultiRouteChain(Chain):
) )
if not route.destination: if not route.destination:
return self.default_chain(route.next_inputs, callbacks=callbacks) return self.default_chain(route.next_inputs, callbacks=callbacks)
elif route.destination in self.destination_chains: if route.destination in self.destination_chains:
return self.destination_chains[route.destination]( return self.destination_chains[route.destination](
route.next_inputs, callbacks=callbacks route.next_inputs, callbacks=callbacks
) )
elif self.silent_errors: if self.silent_errors:
return self.default_chain(route.next_inputs, callbacks=callbacks) return self.default_chain(route.next_inputs, callbacks=callbacks)
else: msg = f"Received invalid destination chain name '{route.destination}'"
msg = f"Received invalid destination chain name '{route.destination}'" raise ValueError(msg)
raise ValueError(msg)
async def _acall( async def _acall(
self, self,
@ -123,14 +122,13 @@ class MultiRouteChain(Chain):
return await self.default_chain.acall( return await self.default_chain.acall(
route.next_inputs, callbacks=callbacks route.next_inputs, callbacks=callbacks
) )
elif route.destination in self.destination_chains: if route.destination in self.destination_chains:
return await self.destination_chains[route.destination].acall( return await self.destination_chains[route.destination].acall(
route.next_inputs, callbacks=callbacks route.next_inputs, callbacks=callbacks
) )
elif self.silent_errors: if self.silent_errors:
return await self.default_chain.acall( return await self.default_chain.acall(
route.next_inputs, callbacks=callbacks route.next_inputs, callbacks=callbacks
) )
else: msg = f"Received invalid destination chain name '{route.destination}'"
msg = f"Received invalid destination chain name '{route.destination}'" raise ValueError(msg)
raise ValueError(msg)

View File

@ -136,11 +136,10 @@ class LLMRouterChain(RouterChain):
callbacks = _run_manager.get_child() callbacks = _run_manager.get_child()
prediction = self.llm_chain.predict(callbacks=callbacks, **inputs) prediction = self.llm_chain.predict(callbacks=callbacks, **inputs)
output = cast( return cast(
dict[str, Any], dict[str, Any],
self.llm_chain.prompt.output_parser.parse(prediction), self.llm_chain.prompt.output_parser.parse(prediction),
) )
return output
async def _acall( async def _acall(
self, self,
@ -149,11 +148,10 @@ class LLMRouterChain(RouterChain):
) -> dict[str, Any]: ) -> dict[str, Any]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child() callbacks = _run_manager.get_child()
output = cast( return cast(
dict[str, Any], dict[str, Any],
await self.llm_chain.apredict_and_parse(callbacks=callbacks, **inputs), await self.llm_chain.apredict_and_parse(callbacks=callbacks, **inputs),
) )
return output
@classmethod @classmethod
def from_llm( def from_llm(

View File

@ -141,8 +141,7 @@ def create_sql_query_chain(
f"{db.dialect}" f"{db.dialect}"
) )
raise ValueError(msg) raise ValueError(msg)
else: table_info_kwargs["get_col_comments"] = True
table_info_kwargs["get_col_comments"] = True
inputs = { inputs = {
"input": lambda x: x["question"] + "\nSQLQuery: ", "input": lambda x: x["question"] + "\nSQLQuery: ",

View File

@ -143,8 +143,7 @@ def create_openai_fn_runnable(
output_parser = output_parser or get_openai_output_parser(functions) output_parser = output_parser or get_openai_output_parser(functions)
if prompt: if prompt:
return prompt | llm.bind(**llm_kwargs_) | output_parser return prompt | llm.bind(**llm_kwargs_) | output_parser
else: return llm.bind(**llm_kwargs_) | output_parser
return llm.bind(**llm_kwargs_) | output_parser
@deprecated( @deprecated(
@ -413,7 +412,7 @@ def create_structured_output_runnable(
first_tool_only=return_single, first_tool_only=return_single,
) )
elif mode == "openai-functions": if mode == "openai-functions":
return _create_openai_functions_structured_output_runnable( return _create_openai_functions_structured_output_runnable(
output_schema, output_schema,
llm, llm,
@ -422,7 +421,7 @@ def create_structured_output_runnable(
enforce_single_function_usage=force_function_usage, enforce_single_function_usage=force_function_usage,
**kwargs, # llm-specific kwargs **kwargs, # llm-specific kwargs
) )
elif mode == "openai-json": if mode == "openai-json":
if force_function_usage: if force_function_usage:
msg = ( msg = (
"enforce_single_function_usage is not supported for mode='openai-json'." "enforce_single_function_usage is not supported for mode='openai-json'."
@ -431,12 +430,11 @@ def create_structured_output_runnable(
return _create_openai_json_runnable( return _create_openai_json_runnable(
output_schema, llm, prompt=prompt, output_parser=output_parser, **kwargs output_schema, llm, prompt=prompt, output_parser=output_parser, **kwargs
) )
else: msg = (
msg = ( f"Invalid mode {mode}. Expected one of 'openai-tools', 'openai-functions', "
f"Invalid mode {mode}. Expected one of 'openai-tools', 'openai-functions', " f"'openai-json'."
f"'openai-json'." )
) raise ValueError(msg)
raise ValueError(msg)
def _create_openai_tools_runnable( def _create_openai_tools_runnable(
@ -460,8 +458,7 @@ def _create_openai_tools_runnable(
) )
if prompt: if prompt:
return prompt | llm.bind(**llm_kwargs) | output_parser return prompt | llm.bind(**llm_kwargs) | output_parser
else: return llm.bind(**llm_kwargs) | output_parser
return llm.bind(**llm_kwargs) | output_parser
def _get_openai_tool_output_parser( def _get_openai_tool_output_parser(
@ -535,8 +532,7 @@ def _create_openai_json_runnable(
prompt = prompt.partial(output_schema=json.dumps(schema_as_dict, indent=2)) prompt = prompt.partial(output_schema=json.dumps(schema_as_dict, indent=2))
return prompt | llm | output_parser return prompt | llm | output_parser
else: return llm | output_parser
return llm | output_parser
def _create_openai_functions_structured_output_runnable( def _create_openai_functions_structured_output_runnable(

View File

@ -77,9 +77,8 @@ class TransformChain(Chain):
) -> dict[str, Any]: ) -> dict[str, Any]:
if self.atransform_cb is not None: if self.atransform_cb is not None:
return await self.atransform_cb(inputs) return await self.atransform_cb(inputs)
else: self._log_once(
self._log_once( "TransformChain's atransform is not provided, falling"
"TransformChain's atransform is not provided, falling" " back to synchronous transform"
" back to synchronous transform" )
) return self.transform_cb(inputs)
return self.transform_cb(inputs)

View File

@ -322,16 +322,15 @@ def init_chat_model(
return _init_chat_model_helper( return _init_chat_model_helper(
cast(str, model), model_provider=model_provider, **kwargs cast(str, model), model_provider=model_provider, **kwargs
) )
else: if model:
if model: kwargs["model"] = model
kwargs["model"] = model if model_provider:
if model_provider: kwargs["model_provider"] = model_provider
kwargs["model_provider"] = model_provider return _ConfigurableModel(
return _ConfigurableModel( default_config=kwargs,
default_config=kwargs, config_prefix=config_prefix,
config_prefix=config_prefix, configurable_fields=configurable_fields,
configurable_fields=configurable_fields, )
)
def _init_chat_model_helper( def _init_chat_model_helper(
@ -343,42 +342,42 @@ def _init_chat_model_helper(
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
return ChatOpenAI(model=model, **kwargs) return ChatOpenAI(model=model, **kwargs)
elif model_provider == "anthropic": if model_provider == "anthropic":
_check_pkg("langchain_anthropic") _check_pkg("langchain_anthropic")
from langchain_anthropic import ChatAnthropic from langchain_anthropic import ChatAnthropic
return ChatAnthropic(model=model, **kwargs) # type: ignore[call-arg,unused-ignore] return ChatAnthropic(model=model, **kwargs) # type: ignore[call-arg,unused-ignore]
elif model_provider == "azure_openai": if model_provider == "azure_openai":
_check_pkg("langchain_openai") _check_pkg("langchain_openai")
from langchain_openai import AzureChatOpenAI from langchain_openai import AzureChatOpenAI
return AzureChatOpenAI(model=model, **kwargs) return AzureChatOpenAI(model=model, **kwargs)
elif model_provider == "azure_ai": if model_provider == "azure_ai":
_check_pkg("langchain_azure_ai") _check_pkg("langchain_azure_ai")
from langchain_azure_ai.chat_models import AzureAIChatCompletionsModel from langchain_azure_ai.chat_models import AzureAIChatCompletionsModel
return AzureAIChatCompletionsModel(model=model, **kwargs) return AzureAIChatCompletionsModel(model=model, **kwargs)
elif model_provider == "cohere": if model_provider == "cohere":
_check_pkg("langchain_cohere") _check_pkg("langchain_cohere")
from langchain_cohere import ChatCohere from langchain_cohere import ChatCohere
return ChatCohere(model=model, **kwargs) return ChatCohere(model=model, **kwargs)
elif model_provider == "google_vertexai": if model_provider == "google_vertexai":
_check_pkg("langchain_google_vertexai") _check_pkg("langchain_google_vertexai")
from langchain_google_vertexai import ChatVertexAI from langchain_google_vertexai import ChatVertexAI
return ChatVertexAI(model=model, **kwargs) return ChatVertexAI(model=model, **kwargs)
elif model_provider == "google_genai": if model_provider == "google_genai":
_check_pkg("langchain_google_genai") _check_pkg("langchain_google_genai")
from langchain_google_genai import ChatGoogleGenerativeAI from langchain_google_genai import ChatGoogleGenerativeAI
return ChatGoogleGenerativeAI(model=model, **kwargs) return ChatGoogleGenerativeAI(model=model, **kwargs)
elif model_provider == "fireworks": if model_provider == "fireworks":
_check_pkg("langchain_fireworks") _check_pkg("langchain_fireworks")
from langchain_fireworks import ChatFireworks from langchain_fireworks import ChatFireworks
return ChatFireworks(model=model, **kwargs) return ChatFireworks(model=model, **kwargs)
elif model_provider == "ollama": if model_provider == "ollama":
try: try:
_check_pkg("langchain_ollama") _check_pkg("langchain_ollama")
from langchain_ollama import ChatOllama from langchain_ollama import ChatOllama
@ -393,74 +392,72 @@ def _init_chat_model_helper(
_check_pkg("langchain_ollama") _check_pkg("langchain_ollama")
return ChatOllama(model=model, **kwargs) return ChatOllama(model=model, **kwargs)
elif model_provider == "together": if model_provider == "together":
_check_pkg("langchain_together") _check_pkg("langchain_together")
from langchain_together import ChatTogether from langchain_together import ChatTogether
return ChatTogether(model=model, **kwargs) return ChatTogether(model=model, **kwargs)
elif model_provider == "mistralai": if model_provider == "mistralai":
_check_pkg("langchain_mistralai") _check_pkg("langchain_mistralai")
from langchain_mistralai import ChatMistralAI from langchain_mistralai import ChatMistralAI
return ChatMistralAI(model=model, **kwargs) # type: ignore[call-arg,unused-ignore] return ChatMistralAI(model=model, **kwargs) # type: ignore[call-arg,unused-ignore]
elif model_provider == "huggingface": if model_provider == "huggingface":
_check_pkg("langchain_huggingface") _check_pkg("langchain_huggingface")
from langchain_huggingface import ChatHuggingFace from langchain_huggingface import ChatHuggingFace
return ChatHuggingFace(model_id=model, **kwargs) return ChatHuggingFace(model_id=model, **kwargs)
elif model_provider == "groq": if model_provider == "groq":
_check_pkg("langchain_groq") _check_pkg("langchain_groq")
from langchain_groq import ChatGroq from langchain_groq import ChatGroq
return ChatGroq(model=model, **kwargs) return ChatGroq(model=model, **kwargs)
elif model_provider == "bedrock": if model_provider == "bedrock":
_check_pkg("langchain_aws") _check_pkg("langchain_aws")
from langchain_aws import ChatBedrock from langchain_aws import ChatBedrock
# TODO: update to use model= once ChatBedrock supports # TODO: update to use model= once ChatBedrock supports
return ChatBedrock(model_id=model, **kwargs) return ChatBedrock(model_id=model, **kwargs)
elif model_provider == "bedrock_converse": if model_provider == "bedrock_converse":
_check_pkg("langchain_aws") _check_pkg("langchain_aws")
from langchain_aws import ChatBedrockConverse from langchain_aws import ChatBedrockConverse
return ChatBedrockConverse(model=model, **kwargs) return ChatBedrockConverse(model=model, **kwargs)
elif model_provider == "google_anthropic_vertex": if model_provider == "google_anthropic_vertex":
_check_pkg("langchain_google_vertexai") _check_pkg("langchain_google_vertexai")
from langchain_google_vertexai.model_garden import ChatAnthropicVertex from langchain_google_vertexai.model_garden import ChatAnthropicVertex
return ChatAnthropicVertex(model=model, **kwargs) return ChatAnthropicVertex(model=model, **kwargs)
elif model_provider == "deepseek": if model_provider == "deepseek":
_check_pkg("langchain_deepseek", pkg_kebab="langchain-deepseek") _check_pkg("langchain_deepseek", pkg_kebab="langchain-deepseek")
from langchain_deepseek import ChatDeepSeek from langchain_deepseek import ChatDeepSeek
return ChatDeepSeek(model=model, **kwargs) return ChatDeepSeek(model=model, **kwargs)
elif model_provider == "nvidia": if model_provider == "nvidia":
_check_pkg("langchain_nvidia_ai_endpoints") _check_pkg("langchain_nvidia_ai_endpoints")
from langchain_nvidia_ai_endpoints import ChatNVIDIA from langchain_nvidia_ai_endpoints import ChatNVIDIA
return ChatNVIDIA(model=model, **kwargs) return ChatNVIDIA(model=model, **kwargs)
elif model_provider == "ibm": if model_provider == "ibm":
_check_pkg("langchain_ibm") _check_pkg("langchain_ibm")
from langchain_ibm import ChatWatsonx from langchain_ibm import ChatWatsonx
return ChatWatsonx(model_id=model, **kwargs) return ChatWatsonx(model_id=model, **kwargs)
elif model_provider == "xai": if model_provider == "xai":
_check_pkg("langchain_xai") _check_pkg("langchain_xai")
from langchain_xai import ChatXAI from langchain_xai import ChatXAI
return ChatXAI(model=model, **kwargs) return ChatXAI(model=model, **kwargs)
elif model_provider == "perplexity": if model_provider == "perplexity":
_check_pkg("langchain_perplexity") _check_pkg("langchain_perplexity")
from langchain_perplexity import ChatPerplexity from langchain_perplexity import ChatPerplexity
return ChatPerplexity(model=model, **kwargs) return ChatPerplexity(model=model, **kwargs)
else: supported = ", ".join(_SUPPORTED_PROVIDERS)
supported = ", ".join(_SUPPORTED_PROVIDERS) msg = (
msg = ( f"Unsupported {model_provider=}.\n\nSupported model providers are: {supported}"
f"Unsupported {model_provider=}.\n\nSupported model providers are: " )
f"{supported}" raise ValueError(msg)
)
raise ValueError(msg)
_SUPPORTED_PROVIDERS = { _SUPPORTED_PROVIDERS = {
@ -490,26 +487,25 @@ _SUPPORTED_PROVIDERS = {
def _attempt_infer_model_provider(model_name: str) -> Optional[str]: def _attempt_infer_model_provider(model_name: str) -> Optional[str]:
if any(model_name.startswith(pre) for pre in ("gpt-3", "gpt-4", "o1", "o3")): if any(model_name.startswith(pre) for pre in ("gpt-3", "gpt-4", "o1", "o3")):
return "openai" return "openai"
elif model_name.startswith("claude"): if model_name.startswith("claude"):
return "anthropic" return "anthropic"
elif model_name.startswith("command"): if model_name.startswith("command"):
return "cohere" return "cohere"
elif model_name.startswith("accounts/fireworks"): if model_name.startswith("accounts/fireworks"):
return "fireworks" return "fireworks"
elif model_name.startswith("gemini"): if model_name.startswith("gemini"):
return "google_vertexai" return "google_vertexai"
elif model_name.startswith("amazon."): if model_name.startswith("amazon."):
return "bedrock" return "bedrock"
elif model_name.startswith("mistral"): if model_name.startswith("mistral"):
return "mistralai" return "mistralai"
elif model_name.startswith("deepseek"): if model_name.startswith("deepseek"):
return "deepseek" return "deepseek"
elif model_name.startswith("grok"): if model_name.startswith("grok"):
return "xai" return "xai"
elif model_name.startswith("sonar"): if model_name.startswith("sonar"):
return "perplexity" return "perplexity"
else: return None
return None
def _parse_model(model: str, model_provider: Optional[str]) -> tuple[str, str]: def _parse_model(model: str, model_provider: Optional[str]) -> tuple[str, str]:
@ -595,14 +591,13 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
) )
return queue return queue
elif self._default_config and (model := self._model()) and hasattr(model, name): if self._default_config and (model := self._model()) and hasattr(model, name):
return getattr(model, name) return getattr(model, name)
else: msg = f"{name} is not a BaseChatModel attribute"
msg = f"{name} is not a BaseChatModel attribute" if self._default_config:
if self._default_config: msg += " and is not implemented on the default model"
msg += " and is not implemented on the default model" msg += "."
msg += "." raise AttributeError(msg)
raise AttributeError(msg)
def _model(self, config: Optional[RunnableConfig] = None) -> Runnable: def _model(self, config: Optional[RunnableConfig] = None) -> Runnable:
params = {**self._default_config, **self._model_params(config)} params = {**self._default_config, **self._model_params(config)}
@ -728,10 +723,9 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
) )
# If multiple configs default to Runnable.batch which uses executor to invoke # If multiple configs default to Runnable.batch which uses executor to invoke
# in parallel. # in parallel.
else: return super().batch(
return super().batch( inputs, config=config, return_exceptions=return_exceptions, **kwargs
inputs, config=config, return_exceptions=return_exceptions, **kwargs )
)
async def abatch( async def abatch(
self, self,
@ -751,10 +745,9 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
) )
# If multiple configs default to Runnable.batch which uses executor to invoke # If multiple configs default to Runnable.batch which uses executor to invoke
# in parallel. # in parallel.
else: return await super().abatch(
return await super().abatch( inputs, config=config, return_exceptions=return_exceptions, **kwargs
inputs, config=config, return_exceptions=return_exceptions, **kwargs )
)
def batch_as_completed( def batch_as_completed(
self, self,

View File

@ -192,41 +192,40 @@ def init_embeddings(
from langchain_openai import OpenAIEmbeddings from langchain_openai import OpenAIEmbeddings
return OpenAIEmbeddings(model=model_name, **kwargs) return OpenAIEmbeddings(model=model_name, **kwargs)
elif provider == "azure_openai": if provider == "azure_openai":
from langchain_openai import AzureOpenAIEmbeddings from langchain_openai import AzureOpenAIEmbeddings
return AzureOpenAIEmbeddings(model=model_name, **kwargs) return AzureOpenAIEmbeddings(model=model_name, **kwargs)
elif provider == "google_vertexai": if provider == "google_vertexai":
from langchain_google_vertexai import VertexAIEmbeddings from langchain_google_vertexai import VertexAIEmbeddings
return VertexAIEmbeddings(model=model_name, **kwargs) return VertexAIEmbeddings(model=model_name, **kwargs)
elif provider == "bedrock": if provider == "bedrock":
from langchain_aws import BedrockEmbeddings from langchain_aws import BedrockEmbeddings
return BedrockEmbeddings(model_id=model_name, **kwargs) return BedrockEmbeddings(model_id=model_name, **kwargs)
elif provider == "cohere": if provider == "cohere":
from langchain_cohere import CohereEmbeddings from langchain_cohere import CohereEmbeddings
return CohereEmbeddings(model=model_name, **kwargs) return CohereEmbeddings(model=model_name, **kwargs)
elif provider == "mistralai": if provider == "mistralai":
from langchain_mistralai import MistralAIEmbeddings from langchain_mistralai import MistralAIEmbeddings
return MistralAIEmbeddings(model=model_name, **kwargs) return MistralAIEmbeddings(model=model_name, **kwargs)
elif provider == "huggingface": if provider == "huggingface":
from langchain_huggingface import HuggingFaceEmbeddings from langchain_huggingface import HuggingFaceEmbeddings
return HuggingFaceEmbeddings(model_name=model_name, **kwargs) return HuggingFaceEmbeddings(model_name=model_name, **kwargs)
elif provider == "ollama": if provider == "ollama":
from langchain_ollama import OllamaEmbeddings from langchain_ollama import OllamaEmbeddings
return OllamaEmbeddings(model=model_name, **kwargs) return OllamaEmbeddings(model=model_name, **kwargs)
else: msg = (
msg = ( f"Provider '{provider}' is not supported.\n"
f"Provider '{provider}' is not supported.\n" f"Supported providers and their required packages:\n"
f"Supported providers and their required packages:\n" f"{_get_provider_list()}"
f"{_get_provider_list()}" )
) raise ValueError(msg)
raise ValueError(msg)
__all__ = [ __all__ = [

View File

@ -70,7 +70,7 @@ def resolve_pairwise_criteria(
Criteria.DEPTH, Criteria.DEPTH,
] ]
return {k.value: _SUPPORTED_CRITERIA[k] for k in _default_criteria} return {k.value: _SUPPORTED_CRITERIA[k] for k in _default_criteria}
elif isinstance(criteria, Criteria): if isinstance(criteria, Criteria):
criteria_ = {criteria.value: _SUPPORTED_CRITERIA[criteria]} criteria_ = {criteria.value: _SUPPORTED_CRITERIA[criteria]}
elif isinstance(criteria, str): elif isinstance(criteria, str):
if criteria in _SUPPORTED_CRITERIA: if criteria in _SUPPORTED_CRITERIA:

View File

@ -186,9 +186,8 @@ class _EmbeddingDistanceChainMixin(Chain):
} }
if metric in metrics: if metric in metrics:
return metrics[metric] return metrics[metric]
else: msg = f"Invalid metric: {metric}"
msg = f"Invalid metric: {metric}" raise ValueError(msg)
raise ValueError(msg)
@staticmethod @staticmethod
def _cosine_distance(a: Any, b: Any) -> Any: def _cosine_distance(a: Any, b: Any) -> Any:

View File

@ -162,8 +162,7 @@ def load_evaluator(
) )
raise ValueError(msg) from e raise ValueError(msg) from e
return evaluator_cls.from_llm(llm=llm, **kwargs) return evaluator_cls.from_llm(llm=llm, **kwargs)
else: return evaluator_cls(**kwargs)
return evaluator_cls(**kwargs)
def load_evaluators( def load_evaluators(

View File

@ -70,7 +70,7 @@ class JsonSchemaEvaluator(StringEvaluator):
def _parse_json(self, node: Any) -> Union[dict, list, None, float, bool, int, str]: def _parse_json(self, node: Any) -> Union[dict, list, None, float, bool, int, str]:
if isinstance(node, str): if isinstance(node, str):
return parse_json_markdown(node) return parse_json_markdown(node)
elif hasattr(node, "schema") and callable(getattr(node, "schema")): if hasattr(node, "schema") and callable(getattr(node, "schema")):
# Pydantic model # Pydantic model
return getattr(node, "schema")() return getattr(node, "schema")()
return node return node

View File

@ -24,7 +24,7 @@ def _get_score(text: str) -> Optional[tuple[str, int]]:
if match: if match:
if match.group(1).upper() == "CORRECT": if match.group(1).upper() == "CORRECT":
return "CORRECT", 1 return "CORRECT", 1
elif match.group(1).upper() == "INCORRECT": if match.group(1).upper() == "INCORRECT":
return "INCORRECT", 0 return "INCORRECT", 0
try: try:
first_word = ( first_word = (
@ -32,7 +32,7 @@ def _get_score(text: str) -> Optional[tuple[str, int]]:
) )
if first_word.upper() == "CORRECT": if first_word.upper() == "CORRECT":
return "CORRECT", 1 return "CORRECT", 1
elif first_word.upper() == "INCORRECT": if first_word.upper() == "INCORRECT":
return "INCORRECT", 0 return "INCORRECT", 0
last_word = ( last_word = (
text.strip() text.strip()
@ -41,7 +41,7 @@ def _get_score(text: str) -> Optional[tuple[str, int]]:
) )
if last_word.upper() == "CORRECT": if last_word.upper() == "CORRECT":
return "CORRECT", 1 return "CORRECT", 1
elif last_word.upper() == "INCORRECT": if last_word.upper() == "INCORRECT":
return "INCORRECT", 0 return "INCORRECT", 0
except IndexError: except IndexError:
pass pass

View File

@ -123,12 +123,12 @@ class _EvalArgsMixin:
if self.requires_input and input is None: if self.requires_input and input is None:
msg = f"{self.__class__.__name__} requires an input string." msg = f"{self.__class__.__name__} requires an input string."
raise ValueError(msg) raise ValueError(msg)
elif input is not None and not self.requires_input: if input is not None and not self.requires_input:
warn(self._skip_input_warning) warn(self._skip_input_warning)
if self.requires_reference and reference is None: if self.requires_reference and reference is None:
msg = f"{self.__class__.__name__} requires a reference string." msg = f"{self.__class__.__name__} requires a reference string."
raise ValueError(msg) raise ValueError(msg)
elif reference is not None and not self.requires_reference: if reference is not None and not self.requires_reference:
warn(self._skip_reference_warning) warn(self._skip_reference_warning)

View File

@ -70,7 +70,7 @@ def resolve_criteria(
Criteria.DEPTH, Criteria.DEPTH,
] ]
return {k.value: _SUPPORTED_CRITERIA[k] for k in _default_criteria} return {k.value: _SUPPORTED_CRITERIA[k] for k in _default_criteria}
elif isinstance(criteria, Criteria): if isinstance(criteria, Criteria):
criteria_ = {criteria.value: _SUPPORTED_CRITERIA[criteria]} criteria_ = {criteria.value: _SUPPORTED_CRITERIA[criteria]}
elif isinstance(criteria, str): elif isinstance(criteria, str):
if criteria in _SUPPORTED_CRITERIA: if criteria in _SUPPORTED_CRITERIA:

View File

@ -138,8 +138,7 @@ class _RapidFuzzChainMixin(Chain):
module = module_map[distance] module = module_map[distance]
if normalize_score: if normalize_score:
return module.normalized_distance return module.normalized_distance
else: return module.distance
return module.distance
@property @property
def metric(self) -> Callable: def metric(self) -> Callable:

View File

@ -21,10 +21,9 @@ def _get_client(
ls_client = LangSmithClient(api_url, api_key=api_key) ls_client = LangSmithClient(api_url, api_key=api_key)
if hasattr(ls_client, "push_prompt") and hasattr(ls_client, "pull_prompt"): if hasattr(ls_client, "push_prompt") and hasattr(ls_client, "pull_prompt"):
return ls_client return ls_client
else: from langchainhub import Client as LangChainHubClient
from langchainhub import Client as LangChainHubClient
return LangChainHubClient(api_url, api_key=api_key) return LangChainHubClient(api_url, api_key=api_key)
except ImportError: except ImportError:
try: try:
from langchainhub import Client as LangChainHubClient from langchainhub import Client as LangChainHubClient
@ -82,14 +81,13 @@ def push(
# Then it's langchainhub # Then it's langchainhub
manifest_json = dumps(object) manifest_json = dumps(object)
message = client.push( return client.push(
repo_full_name, repo_full_name,
manifest_json, manifest_json,
parent_commit_hash=parent_commit_hash, parent_commit_hash=parent_commit_hash,
new_repo_is_public=new_repo_is_public, new_repo_is_public=new_repo_is_public,
new_repo_description=new_repo_description, new_repo_description=new_repo_description,
) )
return message
def pull( def pull(
@ -113,8 +111,7 @@ def pull(
# Then it's langsmith # Then it's langsmith
if hasattr(client, "pull_prompt"): if hasattr(client, "pull_prompt"):
response = client.pull_prompt(owner_repo_commit, include_model=include_model) return client.pull_prompt(owner_repo_commit, include_model=include_model)
return response
# Then it's langchainhub # Then it's langchainhub
if hasattr(client, "pull_repo"): if hasattr(client, "pull_repo"):

View File

@ -561,8 +561,7 @@ def __getattr__(name: str) -> Any:
k: v() for k, v in get_type_to_cls_dict().items() k: v() for k, v in get_type_to_cls_dict().items()
} }
return type_to_cls_dict return type_to_cls_dict
else: return getattr(llms, name)
return getattr(llms, name)
__all__ = [ __all__ = [

View File

@ -154,6 +154,7 @@ class UpstashRedisEntityStore(BaseEntityStore):
logger.debug( logger.debug(
f"Redis MEM set '{self.full_key_prefix}:{key}': '{value}' EX {self.ttl}" f"Redis MEM set '{self.full_key_prefix}:{key}': '{value}' EX {self.ttl}"
) )
return None
def delete(self, key: str) -> None: def delete(self, key: str) -> None:
self.redis_client.delete(f"{self.full_key_prefix}:{key}") self.redis_client.delete(f"{self.full_key_prefix}:{key}")
@ -255,6 +256,7 @@ class RedisEntityStore(BaseEntityStore):
logger.debug( logger.debug(
f"REDIS MEM set '{self.full_key_prefix}:{key}': '{value}' EX {self.ttl}" f"REDIS MEM set '{self.full_key_prefix}:{key}': '{value}' EX {self.ttl}"
) )
return None
def delete(self, key: str) -> None: def delete(self, key: str) -> None:
self.redis_client.delete(f"{self.full_key_prefix}:{key}") self.redis_client.delete(f"{self.full_key_prefix}:{key}")
@ -362,6 +364,7 @@ class SQLiteEntityStore(BaseEntityStore):
f'"{self.full_table_name}" (key, value) VALUES (?, ?)' f'"{self.full_table_name}" (key, value) VALUES (?, ?)'
) )
self._execute_query(query, (key, value)) self._execute_query(query, (key, value))
return None
def delete(self, key: str) -> None: def delete(self, key: str) -> None:
"""Deletes a key-value pair, safely quoting the table name.""" """Deletes a key-value pair, safely quoting the table name."""

View File

@ -34,7 +34,7 @@ class BooleanOutputParser(BaseOutputParser[bool]):
) )
raise ValueError(msg) raise ValueError(msg)
return True return True
elif self.false_val.upper() in truthy: if self.false_val.upper() in truthy:
if self.true_val.upper() in truthy: if self.true_val.upper() in truthy:
msg = ( msg = (
f"Ambiguous response. Both {self.true_val} and {self.false_val} " f"Ambiguous response. Both {self.true_val} and {self.false_val} "

View File

@ -71,31 +71,30 @@ class OutputFixingParser(BaseOutputParser[T]):
except OutputParserException as e: except OutputParserException as e:
if retries == self.max_retries: if retries == self.max_retries:
raise e raise e
retries += 1
if self.legacy and hasattr(self.retry_chain, "run"):
completion = self.retry_chain.run(
instructions=self.parser.get_format_instructions(),
completion=completion,
error=repr(e),
)
else: else:
retries += 1 try:
if self.legacy and hasattr(self.retry_chain, "run"): completion = self.retry_chain.invoke(
completion = self.retry_chain.run( {
instructions=self.parser.get_format_instructions(), "instructions": self.parser.get_format_instructions(), # noqa: E501
completion=completion, "completion": completion,
error=repr(e), "error": repr(e),
}
)
except (NotImplementedError, AttributeError):
# Case: self.parser does not have get_format_instructions
completion = self.retry_chain.invoke(
{
"completion": completion,
"error": repr(e),
}
) )
else:
try:
completion = self.retry_chain.invoke(
{
"instructions": self.parser.get_format_instructions(), # noqa: E501
"completion": completion,
"error": repr(e),
}
)
except (NotImplementedError, AttributeError):
# Case: self.parser does not have get_format_instructions
completion = self.retry_chain.invoke(
{
"completion": completion,
"error": repr(e),
}
)
msg = "Failed to parse" msg = "Failed to parse"
raise OutputParserException(msg) raise OutputParserException(msg)
@ -109,31 +108,30 @@ class OutputFixingParser(BaseOutputParser[T]):
except OutputParserException as e: except OutputParserException as e:
if retries == self.max_retries: if retries == self.max_retries:
raise e raise e
retries += 1
if self.legacy and hasattr(self.retry_chain, "arun"):
completion = await self.retry_chain.arun(
instructions=self.parser.get_format_instructions(),
completion=completion,
error=repr(e),
)
else: else:
retries += 1 try:
if self.legacy and hasattr(self.retry_chain, "arun"): completion = await self.retry_chain.ainvoke(
completion = await self.retry_chain.arun( {
instructions=self.parser.get_format_instructions(), "instructions": self.parser.get_format_instructions(), # noqa: E501
completion=completion, "completion": completion,
error=repr(e), "error": repr(e),
}
)
except (NotImplementedError, AttributeError):
# Case: self.parser does not have get_format_instructions
completion = await self.retry_chain.ainvoke(
{
"completion": completion,
"error": repr(e),
}
) )
else:
try:
completion = await self.retry_chain.ainvoke(
{
"instructions": self.parser.get_format_instructions(), # noqa: E501
"completion": completion,
"error": repr(e),
}
)
except (NotImplementedError, AttributeError):
# Case: self.parser does not have get_format_instructions
completion = await self.retry_chain.ainvoke(
{
"completion": completion,
"error": repr(e),
}
)
msg = "Failed to parse" msg = "Failed to parse"
raise OutputParserException(msg) raise OutputParserException(msg)

View File

@ -64,7 +64,7 @@ class PandasDataFrameOutputParser(BaseOutputParser[dict[str, Any]]):
msg = f"Invalid array format in '{original_request_params}'. \ msg = f"Invalid array format in '{original_request_params}'. \
Please check the format instructions." Please check the format instructions."
raise OutputParserException(msg) raise OutputParserException(msg)
elif ( if (
isinstance(parsed_array[0], int) isinstance(parsed_array[0], int)
and parsed_array[-1] > self.dataframe.index.max() and parsed_array[-1] > self.dataframe.index.max()
): ):

View File

@ -30,12 +30,10 @@ class RegexParser(BaseOutputParser[dict[str, str]]):
match = re.search(self.regex, text) match = re.search(self.regex, text)
if match: if match:
return {key: match.group(i + 1) for i, key in enumerate(self.output_keys)} return {key: match.group(i + 1) for i, key in enumerate(self.output_keys)}
else: if self.default_output_key is None:
if self.default_output_key is None: msg = f"Could not parse output: {text}"
msg = f"Could not parse output: {text}" raise ValueError(msg)
raise ValueError(msg) return {
else: key: text if key == self.default_output_key else ""
return { for key in self.output_keys
key: text if key == self.default_output_key else "" }
for key in self.output_keys
}

View File

@ -33,14 +33,11 @@ class RegexDictParser(BaseOutputParser[dict[str, str]]):
{expected_format} on text {text}" {expected_format} on text {text}"
) )
raise ValueError(msg) raise ValueError(msg)
elif len(matches) > 1: if len(matches) > 1:
msg = f"Multiple matches found for output key: {output_key} with \ msg = f"Multiple matches found for output key: {output_key} with \
expected format {expected_format} on text {text}" expected format {expected_format} on text {text}"
raise ValueError(msg) raise ValueError(msg)
elif ( if self.no_update_value is not None and matches[0] == self.no_update_value:
self.no_update_value is not None and matches[0] == self.no_update_value
):
continue continue
else: result[output_key] = matches[0]
result[output_key] = matches[0]
return result return result

View File

@ -107,20 +107,19 @@ class RetryOutputParser(BaseOutputParser[T]):
except OutputParserException as e: except OutputParserException as e:
if retries == self.max_retries: if retries == self.max_retries:
raise e raise e
retries += 1
if self.legacy and hasattr(self.retry_chain, "run"):
completion = self.retry_chain.run(
prompt=prompt_value.to_string(),
completion=completion,
)
else: else:
retries += 1 completion = self.retry_chain.invoke(
if self.legacy and hasattr(self.retry_chain, "run"): {
completion = self.retry_chain.run( "prompt": prompt_value.to_string(),
prompt=prompt_value.to_string(), "completion": completion,
completion=completion, }
) )
else:
completion = self.retry_chain.invoke(
{
"prompt": prompt_value.to_string(),
"completion": completion,
}
)
msg = "Failed to parse" msg = "Failed to parse"
raise OutputParserException(msg) raise OutputParserException(msg)
@ -143,21 +142,20 @@ class RetryOutputParser(BaseOutputParser[T]):
except OutputParserException as e: except OutputParserException as e:
if retries == self.max_retries: if retries == self.max_retries:
raise e raise e
retries += 1
if self.legacy and hasattr(self.retry_chain, "arun"):
completion = await self.retry_chain.arun(
prompt=prompt_value.to_string(),
completion=completion,
error=repr(e),
)
else: else:
retries += 1 completion = await self.retry_chain.ainvoke(
if self.legacy and hasattr(self.retry_chain, "arun"): {
completion = await self.retry_chain.arun( "prompt": prompt_value.to_string(),
prompt=prompt_value.to_string(), "completion": completion,
completion=completion, }
error=repr(e), )
)
else:
completion = await self.retry_chain.ainvoke(
{
"prompt": prompt_value.to_string(),
"completion": completion,
}
)
msg = "Failed to parse" msg = "Failed to parse"
raise OutputParserException(msg) raise OutputParserException(msg)
@ -234,22 +232,21 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
except OutputParserException as e: except OutputParserException as e:
if retries == self.max_retries: if retries == self.max_retries:
raise e raise e
retries += 1
if self.legacy and hasattr(self.retry_chain, "run"):
completion = self.retry_chain.run(
prompt=prompt_value.to_string(),
completion=completion,
error=repr(e),
)
else: else:
retries += 1 completion = self.retry_chain.invoke(
if self.legacy and hasattr(self.retry_chain, "run"): {
completion = self.retry_chain.run( "completion": completion,
prompt=prompt_value.to_string(), "prompt": prompt_value.to_string(),
completion=completion, "error": repr(e),
error=repr(e), }
) )
else:
completion = self.retry_chain.invoke(
{
"completion": completion,
"prompt": prompt_value.to_string(),
"error": repr(e),
}
)
msg = "Failed to parse" msg = "Failed to parse"
raise OutputParserException(msg) raise OutputParserException(msg)
@ -263,22 +260,21 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
except OutputParserException as e: except OutputParserException as e:
if retries == self.max_retries: if retries == self.max_retries:
raise e raise e
retries += 1
if self.legacy and hasattr(self.retry_chain, "arun"):
completion = await self.retry_chain.arun(
prompt=prompt_value.to_string(),
completion=completion,
error=repr(e),
)
else: else:
retries += 1 completion = await self.retry_chain.ainvoke(
if self.legacy and hasattr(self.retry_chain, "arun"): {
completion = await self.retry_chain.arun( "prompt": prompt_value.to_string(),
prompt=prompt_value.to_string(), "completion": completion,
completion=completion, "error": repr(e),
error=repr(e), }
) )
else:
completion = await self.retry_chain.ainvoke(
{
"prompt": prompt_value.to_string(),
"completion": completion,
"error": repr(e),
}
)
msg = "Failed to parse" msg = "Failed to parse"
raise OutputParserException(msg) raise OutputParserException(msg)

View File

@ -89,8 +89,7 @@ class StructuredOutputParser(BaseOutputParser[dict[str, Any]]):
) )
if only_json: if only_json:
return STRUCTURED_FORMAT_SIMPLE_INSTRUCTIONS.format(format=schema_str) return STRUCTURED_FORMAT_SIMPLE_INSTRUCTIONS.format(format=schema_str)
else: return STRUCTURED_FORMAT_INSTRUCTIONS.format(format=schema_str)
return STRUCTURED_FORMAT_INSTRUCTIONS.format(format=schema_str)
def parse(self, text: str) -> dict[str, Any]: def parse(self, text: str) -> dict[str, Any]:
expected_keys = [rs.name for rs in self.response_schemas] expected_keys = [rs.name for rs in self.response_schemas]

View File

@ -33,8 +33,7 @@ class YamlOutputParser(BaseOutputParser[T]):
json_object = yaml.safe_load(yaml_str) json_object = yaml.safe_load(yaml_str)
if hasattr(self.pydantic_object, "model_validate"): if hasattr(self.pydantic_object, "model_validate"):
return self.pydantic_object.model_validate(json_object) return self.pydantic_object.model_validate(json_object)
else: return self.pydantic_object.parse_obj(json_object)
return self.pydantic_object.parse_obj(json_object)
except (yaml.YAMLError, ValidationError) as e: except (yaml.YAMLError, ValidationError) as e:
name = self.pydantic_object.__name__ name = self.pydantic_object.__name__

View File

@ -45,8 +45,7 @@ class ContextualCompressionRetriever(BaseRetriever):
docs, query, callbacks=run_manager.get_child() docs, query, callbacks=run_manager.get_child()
) )
return list(compressed_docs) return list(compressed_docs)
else: return []
return []
async def _aget_relevant_documents( async def _aget_relevant_documents(
self, self,
@ -71,5 +70,4 @@ class ContextualCompressionRetriever(BaseRetriever):
docs, query, callbacks=run_manager.get_child() docs, query, callbacks=run_manager.get_child()
) )
return list(compressed_docs) return list(compressed_docs)
else: return []
return []

View File

@ -174,9 +174,7 @@ class EnsembleRetriever(BaseRetriever):
""" """
# Get fused result of the retrievers. # Get fused result of the retrievers.
fused_documents = self.rank_fusion(query, run_manager) return self.rank_fusion(query, run_manager)
return fused_documents
async def _aget_relevant_documents( async def _aget_relevant_documents(
self, self,
@ -195,9 +193,7 @@ class EnsembleRetriever(BaseRetriever):
""" """
# Get fused result of the retrievers. # Get fused result of the retrievers.
fused_documents = await self.arank_fusion(query, run_manager) return await self.arank_fusion(query, run_manager)
return fused_documents
def rank_fusion( def rank_fusion(
self, self,
@ -236,9 +232,7 @@ class EnsembleRetriever(BaseRetriever):
] ]
# apply rank fusion # apply rank fusion
fused_documents = self.weighted_reciprocal_rank(retriever_docs) return self.weighted_reciprocal_rank(retriever_docs)
return fused_documents
async def arank_fusion( async def arank_fusion(
self, self,
@ -280,9 +274,7 @@ class EnsembleRetriever(BaseRetriever):
] ]
# apply rank fusion # apply rank fusion
fused_documents = self.weighted_reciprocal_rank(retriever_docs) return self.weighted_reciprocal_rank(retriever_docs)
return fused_documents
def weighted_reciprocal_rank( def weighted_reciprocal_rank(
self, doc_lists: list[list[Document]] self, doc_lists: list[list[Document]]
@ -318,7 +310,7 @@ class EnsembleRetriever(BaseRetriever):
# Docs are deduplicated by their contents then sorted by their scores # Docs are deduplicated by their contents then sorted by their scores
all_docs = chain.from_iterable(doc_lists) all_docs = chain.from_iterable(doc_lists)
sorted_docs = sorted( return sorted(
unique_by_key( unique_by_key(
all_docs, all_docs,
lambda doc: ( lambda doc: (
@ -332,4 +324,3 @@ class EnsembleRetriever(BaseRetriever):
doc.page_content if self.id_key is None else doc.metadata[self.id_key] doc.page_content if self.id_key is None else doc.metadata[self.id_key]
], ],
) )
return sorted_docs

View File

@ -31,9 +31,7 @@ class MergerRetriever(BaseRetriever):
""" """
# Merge the results of the retrievers. # Merge the results of the retrievers.
merged_documents = self.merge_documents(query, run_manager) return self.merge_documents(query, run_manager)
return merged_documents
async def _aget_relevant_documents( async def _aget_relevant_documents(
self, self,
@ -52,9 +50,7 @@ class MergerRetriever(BaseRetriever):
""" """
# Merge the results of the retrievers. # Merge the results of the retrievers.
merged_documents = await self.amerge_documents(query, run_manager) return await self.amerge_documents(query, run_manager)
return merged_documents
def merge_documents( def merge_documents(
self, query: str, run_manager: CallbackManagerForRetrieverRun self, query: str, run_manager: CallbackManagerForRetrieverRun

View File

@ -74,10 +74,9 @@ class RePhraseQueryRetriever(BaseRetriever):
query, {"callbacks": run_manager.get_child()} query, {"callbacks": run_manager.get_child()}
) )
logger.info(f"Re-phrased question: {re_phrased_question}") logger.info(f"Re-phrased question: {re_phrased_question}")
docs = self.retriever.invoke( return self.retriever.invoke(
re_phrased_question, config={"callbacks": run_manager.get_child()} re_phrased_question, config={"callbacks": run_manager.get_child()}
) )
return docs
async def _aget_relevant_documents( async def _aget_relevant_documents(
self, self,

View File

@ -119,117 +119,116 @@ def _get_builtin_translator(vectorstore: VectorStore) -> Visitor:
} }
if isinstance(vectorstore, DatabricksVectorSearch): if isinstance(vectorstore, DatabricksVectorSearch):
return DatabricksVectorSearchTranslator() return DatabricksVectorSearchTranslator()
elif isinstance(vectorstore, MyScale): if isinstance(vectorstore, MyScale):
return MyScaleTranslator(metadata_key=vectorstore.metadata_column) return MyScaleTranslator(metadata_key=vectorstore.metadata_column)
elif isinstance(vectorstore, Redis): if isinstance(vectorstore, Redis):
return RedisTranslator.from_vectorstore(vectorstore) return RedisTranslator.from_vectorstore(vectorstore)
elif isinstance(vectorstore, TencentVectorDB): if isinstance(vectorstore, TencentVectorDB):
fields = [ fields = [
field.name for field in (vectorstore.meta_fields or []) if field.index field.name for field in (vectorstore.meta_fields or []) if field.index
] ]
return TencentVectorDBTranslator(fields) return TencentVectorDBTranslator(fields)
elif vectorstore.__class__ in BUILTIN_TRANSLATORS: if vectorstore.__class__ in BUILTIN_TRANSLATORS:
return BUILTIN_TRANSLATORS[vectorstore.__class__]() return BUILTIN_TRANSLATORS[vectorstore.__class__]()
try:
from langchain_astradb.vectorstores import AstraDBVectorStore
except ImportError:
pass
else: else:
try: if isinstance(vectorstore, AstraDBVectorStore):
from langchain_astradb.vectorstores import AstraDBVectorStore return AstraDBTranslator()
except ImportError:
pass
else:
if isinstance(vectorstore, AstraDBVectorStore):
return AstraDBTranslator()
try: try:
from langchain_elasticsearch.vectorstores import ElasticsearchStore from langchain_elasticsearch.vectorstores import ElasticsearchStore
except ImportError: except ImportError:
pass pass
else: else:
if isinstance(vectorstore, ElasticsearchStore): if isinstance(vectorstore, ElasticsearchStore):
return ElasticsearchTranslator() return ElasticsearchTranslator()
try: try:
from langchain_pinecone import PineconeVectorStore from langchain_pinecone import PineconeVectorStore
except ImportError: except ImportError:
pass pass
else: else:
if isinstance(vectorstore, PineconeVectorStore): if isinstance(vectorstore, PineconeVectorStore):
return PineconeTranslator() return PineconeTranslator()
try: try:
from langchain_milvus import Milvus from langchain_milvus import Milvus
except ImportError: except ImportError:
pass pass
else: else:
if isinstance(vectorstore, Milvus): if isinstance(vectorstore, Milvus):
return MilvusTranslator() return MilvusTranslator()
try: try:
from langchain_mongodb import MongoDBAtlasVectorSearch from langchain_mongodb import MongoDBAtlasVectorSearch
except ImportError: except ImportError:
pass pass
else: else:
if isinstance(vectorstore, MongoDBAtlasVectorSearch): if isinstance(vectorstore, MongoDBAtlasVectorSearch):
return MongoDBAtlasTranslator() return MongoDBAtlasTranslator()
try: try:
from langchain_neo4j import Neo4jVector from langchain_neo4j import Neo4jVector
except ImportError: except ImportError:
pass pass
else: else:
if isinstance(vectorstore, Neo4jVector): if isinstance(vectorstore, Neo4jVector):
return Neo4jTranslator() return Neo4jTranslator()
try: try:
# Trying langchain_chroma import if exists # Trying langchain_chroma import if exists
from langchain_chroma import Chroma from langchain_chroma import Chroma
except ImportError: except ImportError:
pass pass
else: else:
if isinstance(vectorstore, Chroma): if isinstance(vectorstore, Chroma):
return ChromaTranslator() return ChromaTranslator()
try: try:
from langchain_postgres import PGVector from langchain_postgres import PGVector
from langchain_postgres import PGVectorTranslator as NewPGVectorTranslator from langchain_postgres import PGVectorTranslator as NewPGVectorTranslator
except ImportError: except ImportError:
pass pass
else: else:
if isinstance(vectorstore, PGVector): if isinstance(vectorstore, PGVector):
return NewPGVectorTranslator() return NewPGVectorTranslator()
try: try:
from langchain_qdrant import QdrantVectorStore from langchain_qdrant import QdrantVectorStore
except ImportError: except ImportError:
pass pass
else: else:
if isinstance(vectorstore, QdrantVectorStore): if isinstance(vectorstore, QdrantVectorStore):
return QdrantTranslator(metadata_key=vectorstore.metadata_payload_key) return QdrantTranslator(metadata_key=vectorstore.metadata_payload_key)
try: try:
# Added in langchain-community==0.2.11 # Added in langchain-community==0.2.11
from langchain_community.query_constructors.hanavector import HanaTranslator from langchain_community.query_constructors.hanavector import HanaTranslator
from langchain_community.vectorstores import HanaDB from langchain_community.vectorstores import HanaDB
except ImportError: except ImportError:
pass pass
else: else:
if isinstance(vectorstore, HanaDB): if isinstance(vectorstore, HanaDB):
return HanaTranslator() return HanaTranslator()
try: try:
# Trying langchain_weaviate (weaviate v4) import if exists # Trying langchain_weaviate (weaviate v4) import if exists
from langchain_weaviate.vectorstores import WeaviateVectorStore from langchain_weaviate.vectorstores import WeaviateVectorStore
except ImportError: except ImportError:
pass pass
else: else:
if isinstance(vectorstore, WeaviateVectorStore): if isinstance(vectorstore, WeaviateVectorStore):
return WeaviateTranslator() return WeaviateTranslator()
msg = ( msg = (
f"Self query retriever with Vector Store type {vectorstore.__class__}" f"Self query retriever with Vector Store type {vectorstore.__class__}"
f" not supported." f" not supported."
) )
raise ValueError(msg) raise ValueError(msg)
class SelfQueryRetriever(BaseRetriever): class SelfQueryRetriever(BaseRetriever):
@ -289,14 +288,12 @@ class SelfQueryRetriever(BaseRetriever):
def _get_docs_with_query( def _get_docs_with_query(
self, query: str, search_kwargs: dict[str, Any] self, query: str, search_kwargs: dict[str, Any]
) -> list[Document]: ) -> list[Document]:
docs = self.vectorstore.search(query, self.search_type, **search_kwargs) return self.vectorstore.search(query, self.search_type, **search_kwargs)
return docs
async def _aget_docs_with_query( async def _aget_docs_with_query(
self, query: str, search_kwargs: dict[str, Any] self, query: str, search_kwargs: dict[str, Any]
) -> list[Document]: ) -> list[Document]:
docs = await self.vectorstore.asearch(query, self.search_type, **search_kwargs) return await self.vectorstore.asearch(query, self.search_type, **search_kwargs)
return docs
def _get_relevant_documents( def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun self, query: str, *, run_manager: CallbackManagerForRetrieverRun
@ -315,8 +312,7 @@ class SelfQueryRetriever(BaseRetriever):
if self.verbose: if self.verbose:
logger.info(f"Generated Query: {structured_query}") logger.info(f"Generated Query: {structured_query}")
new_query, search_kwargs = self._prepare_query(query, structured_query) new_query, search_kwargs = self._prepare_query(query, structured_query)
docs = self._get_docs_with_query(new_query, search_kwargs) return self._get_docs_with_query(new_query, search_kwargs)
return docs
async def _aget_relevant_documents( async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
@ -335,8 +331,7 @@ class SelfQueryRetriever(BaseRetriever):
if self.verbose: if self.verbose:
logger.info(f"Generated Query: {structured_query}") logger.info(f"Generated Query: {structured_query}")
new_query, search_kwargs = self._prepare_query(query, structured_query) new_query, search_kwargs = self._prepare_query(query, structured_query)
docs = await self._aget_docs_with_query(new_query, search_kwargs) return await self._aget_docs_with_query(new_query, search_kwargs)
return docs
@classmethod @classmethod
def from_llm( def from_llm(

View File

@ -191,13 +191,13 @@ def _wrap_in_chain_factory(
) )
raise ValueError(msg) raise ValueError(msg)
return lambda: chain return lambda: chain
elif isinstance(llm_or_chain_factory, BaseLanguageModel): if isinstance(llm_or_chain_factory, BaseLanguageModel):
return llm_or_chain_factory return llm_or_chain_factory
elif isinstance(llm_or_chain_factory, Runnable): if isinstance(llm_or_chain_factory, Runnable):
# Memory may exist here, but it's not elegant to check all those cases. # Memory may exist here, but it's not elegant to check all those cases.
lcf = llm_or_chain_factory lcf = llm_or_chain_factory
return lambda: lcf return lambda: lcf
elif callable(llm_or_chain_factory): if callable(llm_or_chain_factory):
if is_traceable_function(llm_or_chain_factory): if is_traceable_function(llm_or_chain_factory):
runnable_ = as_runnable(cast(Callable, llm_or_chain_factory)) runnable_ = as_runnable(cast(Callable, llm_or_chain_factory))
return lambda: runnable_ return lambda: runnable_
@ -215,15 +215,14 @@ def _wrap_in_chain_factory(
# It's not uncommon to do an LLM constructor instead of raw LLM, # It's not uncommon to do an LLM constructor instead of raw LLM,
# so we'll unpack it for the user. # so we'll unpack it for the user.
return _model return _model
elif is_traceable_function(cast(Callable, _model)): if is_traceable_function(cast(Callable, _model)):
runnable_ = as_runnable(cast(Callable, _model)) runnable_ = as_runnable(cast(Callable, _model))
return lambda: runnable_ return lambda: runnable_
elif not isinstance(_model, Runnable): if not isinstance(_model, Runnable):
# This is unlikely to happen - a constructor for a model function # This is unlikely to happen - a constructor for a model function
return lambda: RunnableLambda(constructor) return lambda: RunnableLambda(constructor)
else: # Typical correct case
# Typical correct case return constructor
return constructor
return llm_or_chain_factory return llm_or_chain_factory
@ -272,9 +271,8 @@ def _get_prompt(inputs: dict[str, Any]) -> str:
raise InputFormatError(msg) raise InputFormatError(msg)
if len(prompts) == 1: if len(prompts) == 1:
return prompts[0] return prompts[0]
else: msg = f"LLM Run expects single prompt input. Got {len(prompts)} prompts."
msg = f"LLM Run expects single prompt input. Got {len(prompts)} prompts." raise InputFormatError(msg)
raise InputFormatError(msg)
class ChatModelInput(TypedDict): class ChatModelInput(TypedDict):
@ -321,12 +319,11 @@ def _get_messages(inputs: dict[str, Any]) -> dict:
) )
raise InputFormatError(msg) raise InputFormatError(msg)
return input_copy return input_copy
else: msg = (
msg = ( f"Chat Run expects single List[dict] or List[List[dict]] 'messages'"
f"Chat Run expects single List[dict] or List[List[dict]] 'messages'" f" input. Got {inputs}"
f" input. Got {inputs}" )
) raise InputFormatError(msg)
raise InputFormatError(msg)
## Shared data validation utilities ## Shared data validation utilities
@ -707,31 +704,29 @@ async def _arun_llm(
callbacks=callbacks, tags=tags or [], metadata=metadata or {} callbacks=callbacks, tags=tags or [], metadata=metadata or {}
), ),
) )
else: msg = (
msg = ( "Input mapper returned invalid format"
"Input mapper returned invalid format" f" {prompt_or_messages}"
f" {prompt_or_messages}" "\nExpected a single string or list of chat messages."
"\nExpected a single string or list of chat messages." )
) raise InputFormatError(msg)
raise InputFormatError(msg)
else: try:
try: prompt = _get_prompt(inputs)
prompt = _get_prompt(inputs) llm_output: Union[str, BaseMessage] = await llm.ainvoke(
llm_output: Union[str, BaseMessage] = await llm.ainvoke( prompt,
prompt, config=RunnableConfig(
config=RunnableConfig( callbacks=callbacks, tags=tags or [], metadata=metadata or {}
callbacks=callbacks, tags=tags or [], metadata=metadata or {} ),
), )
) except InputFormatError:
except InputFormatError: llm_inputs = _get_messages(inputs)
llm_inputs = _get_messages(inputs) llm_output = await llm.ainvoke(
llm_output = await llm.ainvoke( **llm_inputs,
**llm_inputs, config=RunnableConfig(
config=RunnableConfig( callbacks=callbacks, tags=tags or [], metadata=metadata or {}
callbacks=callbacks, tags=tags or [], metadata=metadata or {} ),
), )
)
return llm_output return llm_output

View File

@ -28,8 +28,7 @@ def _get_messages_from_run_dict(messages: list[dict]) -> list[BaseMessage]:
first_message = messages[0] first_message = messages[0]
if "lc" in first_message: if "lc" in first_message:
return [load(dumpd(message)) for message in messages] return [load(dumpd(message)) for message in messages]
else: return messages_from_dict(messages)
return messages_from_dict(messages)
class StringRunMapper(Serializable): class StringRunMapper(Serializable):
@ -106,25 +105,23 @@ class LLMStringRunMapper(StringRunMapper):
if run.run_type != "llm": if run.run_type != "llm":
msg = "LLM RunMapper only supports LLM runs." msg = "LLM RunMapper only supports LLM runs."
raise ValueError(msg) raise ValueError(msg)
elif not run.outputs: if not run.outputs:
if run.error: if run.error:
msg = f"Cannot evaluate errored LLM run {run.id}: {run.error}" msg = f"Cannot evaluate errored LLM run {run.id}: {run.error}"
raise ValueError(msg) raise ValueError(msg)
else: msg = f"Run {run.id} has no outputs. Cannot evaluate this run."
msg = f"Run {run.id} has no outputs. Cannot evaluate this run." raise ValueError(msg)
raise ValueError(msg) try:
else: inputs = self.serialize_inputs(run.inputs)
try: except Exception as e:
inputs = self.serialize_inputs(run.inputs) msg = f"Could not parse LM input from run inputs {run.inputs}"
except Exception as e: raise ValueError(msg) from e
msg = f"Could not parse LM input from run inputs {run.inputs}" try:
raise ValueError(msg) from e output_ = self.serialize_outputs(run.outputs)
try: except Exception as e:
output_ = self.serialize_outputs(run.outputs) msg = f"Could not parse LM prediction from run outputs {run.outputs}"
except Exception as e: raise ValueError(msg) from e
msg = f"Could not parse LM prediction from run outputs {run.outputs}" return {"input": inputs, "prediction": output_}
raise ValueError(msg) from e
return {"input": inputs, "prediction": output_}
class ChainStringRunMapper(StringRunMapper): class ChainStringRunMapper(StringRunMapper):
@ -142,14 +139,13 @@ class ChainStringRunMapper(StringRunMapper):
def _get_key(self, source: dict, key: Optional[str], which: str) -> str: def _get_key(self, source: dict, key: Optional[str], which: str) -> str:
if key is not None: if key is not None:
return source[key] return source[key]
elif len(source) == 1: if len(source) == 1:
return next(iter(source.values())) return next(iter(source.values()))
else: msg = (
msg = ( f"Could not map run {which} with multiple keys: "
f"Could not map run {which} with multiple keys: " f"{source}\nPlease manually specify a {which}_key"
f"{source}\nPlease manually specify a {which}_key" )
) raise ValueError(msg)
raise ValueError(msg)
def map(self, run: Run) -> dict[str, str]: def map(self, run: Run) -> dict[str, str]:
"""Maps the Run to a dictionary.""" """Maps the Run to a dictionary."""
@ -168,7 +164,7 @@ class ChainStringRunMapper(StringRunMapper):
f" '{self.input_key}'." f" '{self.input_key}'."
) )
raise ValueError(msg) raise ValueError(msg)
elif self.prediction_key is not None and self.prediction_key not in run.outputs: if self.prediction_key is not None and self.prediction_key not in run.outputs:
available_keys = ", ".join(run.outputs.keys()) available_keys = ", ".join(run.outputs.keys())
msg = ( msg = (
f"Run with ID {run.id} doesn't have the expected prediction key" f"Run with ID {run.id} doesn't have the expected prediction key"
@ -178,13 +174,12 @@ class ChainStringRunMapper(StringRunMapper):
) )
raise ValueError(msg) raise ValueError(msg)
else: input_ = self._get_key(run.inputs, self.input_key, "input")
input_ = self._get_key(run.inputs, self.input_key, "input") prediction = self._get_key(run.outputs, self.prediction_key, "prediction")
prediction = self._get_key(run.outputs, self.prediction_key, "prediction") return {
return { "input": input_,
"input": input_, "prediction": prediction,
"prediction": prediction, }
}
class ToolStringRunMapper(StringRunMapper): class ToolStringRunMapper(StringRunMapper):
@ -224,8 +219,7 @@ class StringExampleMapper(Serializable):
" specify a reference_key." " specify a reference_key."
) )
raise ValueError(msg) raise ValueError(msg)
else: output = list(example.outputs.values())[0]
output = list(example.outputs.values())[0]
elif self.reference_key not in example.outputs: elif self.reference_key not in example.outputs:
msg = ( msg = (
f"Example {example.id} does not have reference key" f"Example {example.id} does not have reference key"

View File

@ -65,24 +65,23 @@ def _import_python_tool_PythonREPLTool() -> Any:
def __getattr__(name: str) -> Any: def __getattr__(name: str) -> Any:
if name == "PythonAstREPLTool": if name == "PythonAstREPLTool":
return _import_python_tool_PythonAstREPLTool() return _import_python_tool_PythonAstREPLTool()
elif name == "PythonREPLTool": if name == "PythonREPLTool":
return _import_python_tool_PythonREPLTool() return _import_python_tool_PythonREPLTool()
else: from langchain_community import tools
from langchain_community import tools
# If not in interactive env, raise warning. # If not in interactive env, raise warning.
if not is_interactive_env(): if not is_interactive_env():
warnings.warn( warnings.warn(
"Importing tools from langchain is deprecated. Importing from " "Importing tools from langchain is deprecated. Importing from "
"langchain will no longer be supported as of langchain==0.2.0. " "langchain will no longer be supported as of langchain==0.2.0. "
"Please import from langchain-community instead:\n\n" "Please import from langchain-community instead:\n\n"
f"`from langchain_community.tools import {name}`.\n\n" f"`from langchain_community.tools import {name}`.\n\n"
"To install langchain-community run " "To install langchain-community run "
"`pip install -U langchain-community`.", "`pip install -U langchain-community`.",
category=LangChainDeprecationWarning, category=LangChainDeprecationWarning,
) )
return getattr(tools, name) return getattr(tools, name)
__all__ = [ __all__ = [

View File

@ -143,7 +143,7 @@ ignore-regex = ".*(Stati Uniti|Tense=Pres).*"
ignore-words-list = "momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogyny,unsecure,damon,crate,aadd,symbl,precesses,accademia,nin" ignore-words-list = "momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogyny,unsecure,damon,crate,aadd,symbl,precesses,accademia,nin"
[tool.ruff.lint] [tool.ruff.lint]
select = ["A", "C4", "E", "F", "I", "EM", "PGH003", "PIE", "SIM", "T201", "D", "UP", "S", "W"] select = ["A", "C4", "D", "E", "EM", "F", "I", "PGH003", "PIE", "RET", "S", "SIM", "T201", "UP", "W"]
pydocstyle.convention = "google" pydocstyle.convention = "google"
pyupgrade.keep-runtime-typing = true pyupgrade.keep-runtime-typing = true

View File

@ -139,8 +139,7 @@ async def get_state(
async def ask_for_passphrase(said_please: bool) -> dict[str, Any]: async def ask_for_passphrase(said_please: bool) -> dict[str, Any]:
if said_please: if said_please:
return {"passphrase": f"The passphrase is {PASS_PHRASE}"} return {"passphrase": f"The passphrase is {PASS_PHRASE}"}
else: return {"passphrase": "I won't share the passphrase without saying 'please'."}
return {"passphrase": "I won't share the passphrase without saying 'please'."}
@app.delete( @app.delete(
@ -153,12 +152,11 @@ async def recycle(password: SecretPassPhrase) -> dict[str, Any]:
if password.pw == PASS_PHRASE: if password.pw == PASS_PHRASE:
_ROBOT_STATE["destruct"] = True _ROBOT_STATE["destruct"] = True
return {"status": "Self-destruct initiated", "state": _ROBOT_STATE} return {"status": "Self-destruct initiated", "state": _ROBOT_STATE}
else: _ROBOT_STATE["destruct"] = False
_ROBOT_STATE["destruct"] = False raise HTTPException(
raise HTTPException( status_code=400,
status_code=400, detail="Pass phrase required. You should have thought to ask for it.",
detail="Pass phrase required. You should have thought to ask for it.", )
)
@app.post( @app.post(

View File

@ -100,14 +100,13 @@ def _get_agent(**kwargs: Any) -> AgentExecutor:
), ),
] ]
agent = initialize_agent( return initialize_agent(
tools, tools,
fake_llm, fake_llm,
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
verbose=True, verbose=True,
**kwargs, **kwargs,
) )
return agent
def test_agent_bad_action() -> None: def test_agent_bad_action() -> None:

View File

@ -71,14 +71,13 @@ def _get_agent(**kwargs: Any) -> AgentExecutor:
), ),
] ]
agent = initialize_agent( return initialize_agent(
tools, tools,
fake_llm, fake_llm,
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
verbose=True, verbose=True,
**kwargs, **kwargs,
) )
return agent
async def test_agent_bad_action() -> None: async def test_agent_bad_action() -> None:

View File

@ -11,8 +11,7 @@ def get_action_and_input(text: str) -> tuple[str, str]:
output = output_parser.parse(text) output = output_parser.parse(text)
if isinstance(output, AgentAction): if isinstance(output, AgentAction):
return output.tool, str(output.tool_input) return output.tool, str(output.tool_input)
else: return "Final Answer", output.return_values["output"]
return "Final Answer", output.return_values["output"]
def test_parse_with_language() -> None: def test_parse_with_language() -> None:

View File

@ -16,8 +16,7 @@ def get_action_and_input(text: str) -> tuple[str, str]:
output = MRKLOutputParser().parse(text) output = MRKLOutputParser().parse(text)
if isinstance(output, AgentAction): if isinstance(output, AgentAction):
return output.tool, str(output.tool_input) return output.tool, str(output.tool_input)
else: return "Final Answer", output.return_values["output"]
return "Final Answer", output.return_values["output"]
def test_get_action_and_input() -> None: def test_get_action_and_input() -> None:

View File

@ -21,11 +21,10 @@ def get_action_and_input(text: str) -> tuple[str, str]:
output = output_parser.parse(text) output = output_parser.parse(text)
if isinstance(output, AgentAction): if isinstance(output, AgentAction):
return output.tool, str(output.tool_input) return output.tool, str(output.tool_input)
elif isinstance(output, AgentFinish): if isinstance(output, AgentFinish):
return output.return_values["output"], output.log return output.return_values["output"], output.log
else: msg = "Unexpected output type"
msg = "Unexpected output type" raise ValueError(msg)
raise ValueError(msg)
def test_parse_with_language() -> None: def test_parse_with_language() -> None:

View File

@ -58,8 +58,7 @@ class FakeChain(Chain):
) -> dict[str, str]: ) -> dict[str, str]:
if self.be_correct: if self.be_correct:
return {"bar": "baz"} return {"bar": "baz"}
else: return {"baz": "bar"}
return {"baz": "bar"}
def test_bad_inputs() -> None: def test_bad_inputs() -> None:

View File

@ -54,9 +54,8 @@ class _FakeTrajectoryChatModel(FakeChatModel):
response = self.queries[list(self.queries.keys())[self.response_index]] response = self.queries[list(self.queries.keys())[self.response_index]]
self.response_index = self.response_index + 1 self.response_index = self.response_index + 1
return response return response
else: prompt = messages[0].content
prompt = messages[0].content return self.queries[prompt]
return self.queries[prompt]
def test_trajectory_output_parser_parse() -> None: def test_trajectory_output_parser_parse() -> None:

View File

@ -45,8 +45,7 @@ class FakeLLM(LLM):
return self.queries[prompt] return self.queries[prompt]
if stop is None: if stop is None:
return "foo" return "foo"
else: return "bar"
return "bar"
@property @property
def _identifying_params(self) -> dict[str, Any]: def _identifying_params(self) -> dict[str, Any]:

View File

@ -14,9 +14,8 @@ class SequentialRetriever(BaseRetriever):
) -> list[Document]: ) -> list[Document]:
if self.response_index >= len(self.sequential_responses): if self.response_index >= len(self.sequential_responses):
return [] return []
else: self.response_index += 1
self.response_index += 1 return self.sequential_responses[self.response_index - 1]
return self.sequential_responses[self.response_index - 1]
async def _aget_relevant_documents( # type: ignore[override] async def _aget_relevant_documents( # type: ignore[override]
self, self,