mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-12 15:59:56 +00:00
langchain: Add ruff rule RET (#31875)
All auto-fixes See https://docs.astral.sh/ruff/rules/#flake8-return-ret --------- Co-authored-by: Mason Daugherty <github@mdrxy.com>
This commit is contained in:
parent
fceebbb387
commit
56bbfd9723
@ -48,25 +48,25 @@ def __getattr__(name: str) -> Any:
|
||||
_warn_on_import(name, replacement="langchain.agents.MRKLChain")
|
||||
|
||||
return MRKLChain
|
||||
elif name == "ReActChain":
|
||||
if name == "ReActChain":
|
||||
from langchain.agents import ReActChain
|
||||
|
||||
_warn_on_import(name, replacement="langchain.agents.ReActChain")
|
||||
|
||||
return ReActChain
|
||||
elif name == "SelfAskWithSearchChain":
|
||||
if name == "SelfAskWithSearchChain":
|
||||
from langchain.agents import SelfAskWithSearchChain
|
||||
|
||||
_warn_on_import(name, replacement="langchain.agents.SelfAskWithSearchChain")
|
||||
|
||||
return SelfAskWithSearchChain
|
||||
elif name == "ConversationChain":
|
||||
if name == "ConversationChain":
|
||||
from langchain.chains import ConversationChain
|
||||
|
||||
_warn_on_import(name, replacement="langchain.chains.ConversationChain")
|
||||
|
||||
return ConversationChain
|
||||
elif name == "LLMBashChain":
|
||||
if name == "LLMBashChain":
|
||||
msg = (
|
||||
"This module has been moved to langchain-experimental. "
|
||||
"For more details: "
|
||||
@ -77,97 +77,97 @@ def __getattr__(name: str) -> Any:
|
||||
)
|
||||
raise ImportError(msg)
|
||||
|
||||
elif name == "LLMChain":
|
||||
if name == "LLMChain":
|
||||
from langchain.chains import LLMChain
|
||||
|
||||
_warn_on_import(name, replacement="langchain.chains.LLMChain")
|
||||
|
||||
return LLMChain
|
||||
elif name == "LLMCheckerChain":
|
||||
if name == "LLMCheckerChain":
|
||||
from langchain.chains import LLMCheckerChain
|
||||
|
||||
_warn_on_import(name, replacement="langchain.chains.LLMCheckerChain")
|
||||
|
||||
return LLMCheckerChain
|
||||
elif name == "LLMMathChain":
|
||||
if name == "LLMMathChain":
|
||||
from langchain.chains import LLMMathChain
|
||||
|
||||
_warn_on_import(name, replacement="langchain.chains.LLMMathChain")
|
||||
|
||||
return LLMMathChain
|
||||
elif name == "QAWithSourcesChain":
|
||||
if name == "QAWithSourcesChain":
|
||||
from langchain.chains import QAWithSourcesChain
|
||||
|
||||
_warn_on_import(name, replacement="langchain.chains.QAWithSourcesChain")
|
||||
|
||||
return QAWithSourcesChain
|
||||
elif name == "VectorDBQA":
|
||||
if name == "VectorDBQA":
|
||||
from langchain.chains import VectorDBQA
|
||||
|
||||
_warn_on_import(name, replacement="langchain.chains.VectorDBQA")
|
||||
|
||||
return VectorDBQA
|
||||
elif name == "VectorDBQAWithSourcesChain":
|
||||
if name == "VectorDBQAWithSourcesChain":
|
||||
from langchain.chains import VectorDBQAWithSourcesChain
|
||||
|
||||
_warn_on_import(name, replacement="langchain.chains.VectorDBQAWithSourcesChain")
|
||||
|
||||
return VectorDBQAWithSourcesChain
|
||||
elif name == "InMemoryDocstore":
|
||||
if name == "InMemoryDocstore":
|
||||
from langchain_community.docstore import InMemoryDocstore
|
||||
|
||||
_warn_on_import(name, replacement="langchain.docstore.InMemoryDocstore")
|
||||
|
||||
return InMemoryDocstore
|
||||
elif name == "Wikipedia":
|
||||
if name == "Wikipedia":
|
||||
from langchain_community.docstore import Wikipedia
|
||||
|
||||
_warn_on_import(name, replacement="langchain.docstore.Wikipedia")
|
||||
|
||||
return Wikipedia
|
||||
elif name == "Anthropic":
|
||||
if name == "Anthropic":
|
||||
from langchain_community.llms import Anthropic
|
||||
|
||||
_warn_on_import(name, replacement="langchain_community.llms.Anthropic")
|
||||
|
||||
return Anthropic
|
||||
elif name == "Banana":
|
||||
if name == "Banana":
|
||||
from langchain_community.llms import Banana
|
||||
|
||||
_warn_on_import(name, replacement="langchain_community.llms.Banana")
|
||||
|
||||
return Banana
|
||||
elif name == "CerebriumAI":
|
||||
if name == "CerebriumAI":
|
||||
from langchain_community.llms import CerebriumAI
|
||||
|
||||
_warn_on_import(name, replacement="langchain_community.llms.CerebriumAI")
|
||||
|
||||
return CerebriumAI
|
||||
elif name == "Cohere":
|
||||
if name == "Cohere":
|
||||
from langchain_community.llms import Cohere
|
||||
|
||||
_warn_on_import(name, replacement="langchain_community.llms.Cohere")
|
||||
|
||||
return Cohere
|
||||
elif name == "ForefrontAI":
|
||||
if name == "ForefrontAI":
|
||||
from langchain_community.llms import ForefrontAI
|
||||
|
||||
_warn_on_import(name, replacement="langchain_community.llms.ForefrontAI")
|
||||
|
||||
return ForefrontAI
|
||||
elif name == "GooseAI":
|
||||
if name == "GooseAI":
|
||||
from langchain_community.llms import GooseAI
|
||||
|
||||
_warn_on_import(name, replacement="langchain_community.llms.GooseAI")
|
||||
|
||||
return GooseAI
|
||||
elif name == "HuggingFaceHub":
|
||||
if name == "HuggingFaceHub":
|
||||
from langchain_community.llms import HuggingFaceHub
|
||||
|
||||
_warn_on_import(name, replacement="langchain_community.llms.HuggingFaceHub")
|
||||
|
||||
return HuggingFaceHub
|
||||
elif name == "HuggingFaceTextGenInference":
|
||||
if name == "HuggingFaceTextGenInference":
|
||||
from langchain_community.llms import HuggingFaceTextGenInference
|
||||
|
||||
_warn_on_import(
|
||||
@ -175,55 +175,55 @@ def __getattr__(name: str) -> Any:
|
||||
)
|
||||
|
||||
return HuggingFaceTextGenInference
|
||||
elif name == "LlamaCpp":
|
||||
if name == "LlamaCpp":
|
||||
from langchain_community.llms import LlamaCpp
|
||||
|
||||
_warn_on_import(name, replacement="langchain_community.llms.LlamaCpp")
|
||||
|
||||
return LlamaCpp
|
||||
elif name == "Modal":
|
||||
if name == "Modal":
|
||||
from langchain_community.llms import Modal
|
||||
|
||||
_warn_on_import(name, replacement="langchain_community.llms.Modal")
|
||||
|
||||
return Modal
|
||||
elif name == "OpenAI":
|
||||
if name == "OpenAI":
|
||||
from langchain_community.llms import OpenAI
|
||||
|
||||
_warn_on_import(name, replacement="langchain_community.llms.OpenAI")
|
||||
|
||||
return OpenAI
|
||||
elif name == "Petals":
|
||||
if name == "Petals":
|
||||
from langchain_community.llms import Petals
|
||||
|
||||
_warn_on_import(name, replacement="langchain_community.llms.Petals")
|
||||
|
||||
return Petals
|
||||
elif name == "PipelineAI":
|
||||
if name == "PipelineAI":
|
||||
from langchain_community.llms import PipelineAI
|
||||
|
||||
_warn_on_import(name, replacement="langchain_community.llms.PipelineAI")
|
||||
|
||||
return PipelineAI
|
||||
elif name == "SagemakerEndpoint":
|
||||
if name == "SagemakerEndpoint":
|
||||
from langchain_community.llms import SagemakerEndpoint
|
||||
|
||||
_warn_on_import(name, replacement="langchain_community.llms.SagemakerEndpoint")
|
||||
|
||||
return SagemakerEndpoint
|
||||
elif name == "StochasticAI":
|
||||
if name == "StochasticAI":
|
||||
from langchain_community.llms import StochasticAI
|
||||
|
||||
_warn_on_import(name, replacement="langchain_community.llms.StochasticAI")
|
||||
|
||||
return StochasticAI
|
||||
elif name == "Writer":
|
||||
if name == "Writer":
|
||||
from langchain_community.llms import Writer
|
||||
|
||||
_warn_on_import(name, replacement="langchain_community.llms.Writer")
|
||||
|
||||
return Writer
|
||||
elif name == "HuggingFacePipeline":
|
||||
if name == "HuggingFacePipeline":
|
||||
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
|
||||
|
||||
_warn_on_import(
|
||||
@ -232,7 +232,7 @@ def __getattr__(name: str) -> Any:
|
||||
)
|
||||
|
||||
return HuggingFacePipeline
|
||||
elif name == "FewShotPromptTemplate":
|
||||
if name == "FewShotPromptTemplate":
|
||||
from langchain_core.prompts import FewShotPromptTemplate
|
||||
|
||||
_warn_on_import(
|
||||
@ -240,7 +240,7 @@ def __getattr__(name: str) -> Any:
|
||||
)
|
||||
|
||||
return FewShotPromptTemplate
|
||||
elif name == "Prompt":
|
||||
if name == "Prompt":
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
|
||||
_warn_on_import(name, replacement="langchain_core.prompts.PromptTemplate")
|
||||
@ -248,19 +248,19 @@ def __getattr__(name: str) -> Any:
|
||||
# it's renamed as prompt template anyways
|
||||
# this is just for backwards compat
|
||||
return PromptTemplate
|
||||
elif name == "PromptTemplate":
|
||||
if name == "PromptTemplate":
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
|
||||
_warn_on_import(name, replacement="langchain_core.prompts.PromptTemplate")
|
||||
|
||||
return PromptTemplate
|
||||
elif name == "BasePromptTemplate":
|
||||
if name == "BasePromptTemplate":
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
|
||||
_warn_on_import(name, replacement="langchain_core.prompts.BasePromptTemplate")
|
||||
|
||||
return BasePromptTemplate
|
||||
elif name == "ArxivAPIWrapper":
|
||||
if name == "ArxivAPIWrapper":
|
||||
from langchain_community.utilities import ArxivAPIWrapper
|
||||
|
||||
_warn_on_import(
|
||||
@ -268,7 +268,7 @@ def __getattr__(name: str) -> Any:
|
||||
)
|
||||
|
||||
return ArxivAPIWrapper
|
||||
elif name == "GoldenQueryAPIWrapper":
|
||||
if name == "GoldenQueryAPIWrapper":
|
||||
from langchain_community.utilities import GoldenQueryAPIWrapper
|
||||
|
||||
_warn_on_import(
|
||||
@ -276,7 +276,7 @@ def __getattr__(name: str) -> Any:
|
||||
)
|
||||
|
||||
return GoldenQueryAPIWrapper
|
||||
elif name == "GoogleSearchAPIWrapper":
|
||||
if name == "GoogleSearchAPIWrapper":
|
||||
from langchain_community.utilities import GoogleSearchAPIWrapper
|
||||
|
||||
_warn_on_import(
|
||||
@ -284,7 +284,7 @@ def __getattr__(name: str) -> Any:
|
||||
)
|
||||
|
||||
return GoogleSearchAPIWrapper
|
||||
elif name == "GoogleSerperAPIWrapper":
|
||||
if name == "GoogleSerperAPIWrapper":
|
||||
from langchain_community.utilities import GoogleSerperAPIWrapper
|
||||
|
||||
_warn_on_import(
|
||||
@ -292,7 +292,7 @@ def __getattr__(name: str) -> Any:
|
||||
)
|
||||
|
||||
return GoogleSerperAPIWrapper
|
||||
elif name == "PowerBIDataset":
|
||||
if name == "PowerBIDataset":
|
||||
from langchain_community.utilities import PowerBIDataset
|
||||
|
||||
_warn_on_import(
|
||||
@ -300,7 +300,7 @@ def __getattr__(name: str) -> Any:
|
||||
)
|
||||
|
||||
return PowerBIDataset
|
||||
elif name == "SearxSearchWrapper":
|
||||
if name == "SearxSearchWrapper":
|
||||
from langchain_community.utilities import SearxSearchWrapper
|
||||
|
||||
_warn_on_import(
|
||||
@ -308,7 +308,7 @@ def __getattr__(name: str) -> Any:
|
||||
)
|
||||
|
||||
return SearxSearchWrapper
|
||||
elif name == "WikipediaAPIWrapper":
|
||||
if name == "WikipediaAPIWrapper":
|
||||
from langchain_community.utilities import WikipediaAPIWrapper
|
||||
|
||||
_warn_on_import(
|
||||
@ -316,7 +316,7 @@ def __getattr__(name: str) -> Any:
|
||||
)
|
||||
|
||||
return WikipediaAPIWrapper
|
||||
elif name == "WolframAlphaAPIWrapper":
|
||||
if name == "WolframAlphaAPIWrapper":
|
||||
from langchain_community.utilities import WolframAlphaAPIWrapper
|
||||
|
||||
_warn_on_import(
|
||||
@ -324,19 +324,19 @@ def __getattr__(name: str) -> Any:
|
||||
)
|
||||
|
||||
return WolframAlphaAPIWrapper
|
||||
elif name == "SQLDatabase":
|
||||
if name == "SQLDatabase":
|
||||
from langchain_community.utilities import SQLDatabase
|
||||
|
||||
_warn_on_import(name, replacement="langchain_community.utilities.SQLDatabase")
|
||||
|
||||
return SQLDatabase
|
||||
elif name == "FAISS":
|
||||
if name == "FAISS":
|
||||
from langchain_community.vectorstores import FAISS
|
||||
|
||||
_warn_on_import(name, replacement="langchain_community.vectorstores.FAISS")
|
||||
|
||||
return FAISS
|
||||
elif name == "ElasticVectorSearch":
|
||||
if name == "ElasticVectorSearch":
|
||||
from langchain_community.vectorstores import ElasticVectorSearch
|
||||
|
||||
_warn_on_import(
|
||||
@ -345,7 +345,7 @@ def __getattr__(name: str) -> Any:
|
||||
|
||||
return ElasticVectorSearch
|
||||
# For backwards compatibility
|
||||
elif name == "SerpAPIChain" or name == "SerpAPIWrapper":
|
||||
if name == "SerpAPIChain" or name == "SerpAPIWrapper":
|
||||
from langchain_community.utilities import SerpAPIWrapper
|
||||
|
||||
_warn_on_import(
|
||||
@ -353,7 +353,7 @@ def __getattr__(name: str) -> Any:
|
||||
)
|
||||
|
||||
return SerpAPIWrapper
|
||||
elif name == "verbose":
|
||||
if name == "verbose":
|
||||
from langchain.globals import _verbose
|
||||
|
||||
_warn_on_import(
|
||||
@ -364,7 +364,7 @@ def __getattr__(name: str) -> Any:
|
||||
)
|
||||
|
||||
return _verbose
|
||||
elif name == "debug":
|
||||
if name == "debug":
|
||||
from langchain.globals import _debug
|
||||
|
||||
_warn_on_import(
|
||||
@ -375,7 +375,7 @@ def __getattr__(name: str) -> Any:
|
||||
)
|
||||
|
||||
return _debug
|
||||
elif name == "llm_cache":
|
||||
if name == "llm_cache":
|
||||
from langchain.globals import _llm_cache
|
||||
|
||||
_warn_on_import(
|
||||
@ -386,7 +386,6 @@ def __getattr__(name: str) -> Any:
|
||||
)
|
||||
|
||||
return _llm_cache
|
||||
else:
|
||||
msg = f"Could not find: {name}"
|
||||
raise AttributeError(msg)
|
||||
|
||||
|
@ -137,7 +137,6 @@ class BaseSingleActionAgent(BaseModel):
|
||||
return AgentFinish(
|
||||
{"output": "Agent stopped due to iteration limit or time limit."}, ""
|
||||
)
|
||||
else:
|
||||
msg = f"Got unsupported early_stopping_method `{early_stopping_method}`"
|
||||
raise ValueError(msg)
|
||||
|
||||
@ -308,7 +307,6 @@ class BaseMultiActionAgent(BaseModel):
|
||||
if early_stopping_method == "force":
|
||||
# `force` just returns a constant string
|
||||
return AgentFinish({"output": "Agent stopped due to max iterations."}, "")
|
||||
else:
|
||||
msg = f"Got unsupported early_stopping_method `{early_stopping_method}`"
|
||||
raise ValueError(msg)
|
||||
|
||||
@ -815,8 +813,7 @@ class Agent(BaseSingleActionAgent):
|
||||
"""
|
||||
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
|
||||
full_output = await self.llm_chain.apredict(callbacks=callbacks, **full_inputs)
|
||||
agent_output = await self.output_parser.aparse(full_output)
|
||||
return agent_output
|
||||
return await self.output_parser.aparse(full_output)
|
||||
|
||||
def get_full_inputs(
|
||||
self, intermediate_steps: list[tuple[AgentAction, str]], **kwargs: Any
|
||||
@ -833,8 +830,7 @@ class Agent(BaseSingleActionAgent):
|
||||
"""
|
||||
thoughts = self._construct_scratchpad(intermediate_steps)
|
||||
new_inputs = {"agent_scratchpad": thoughts, "stop": self._stop}
|
||||
full_inputs = {**kwargs, **new_inputs}
|
||||
return full_inputs
|
||||
return {**kwargs, **new_inputs}
|
||||
|
||||
@property
|
||||
def input_keys(self) -> list[str]:
|
||||
@ -970,7 +966,7 @@ class Agent(BaseSingleActionAgent):
|
||||
return AgentFinish(
|
||||
{"output": "Agent stopped due to iteration limit or time limit."}, ""
|
||||
)
|
||||
elif early_stopping_method == "generate":
|
||||
if early_stopping_method == "generate":
|
||||
# Generate does one final forward pass
|
||||
thoughts = ""
|
||||
for action, observation in intermediate_steps:
|
||||
@ -990,11 +986,9 @@ class Agent(BaseSingleActionAgent):
|
||||
if isinstance(parsed_output, AgentFinish):
|
||||
# If we can extract, we send the correct stuff
|
||||
return parsed_output
|
||||
else:
|
||||
# If we can extract, but the tool is not the final tool,
|
||||
# we just return the full output
|
||||
return AgentFinish({"output": full_output}, full_output)
|
||||
else:
|
||||
msg = (
|
||||
"early_stopping_method should be one of `force` or `generate`, "
|
||||
f"got {early_stopping_method}"
|
||||
@ -1179,7 +1173,6 @@ class AgentExecutor(Chain):
|
||||
"""
|
||||
if isinstance(self.agent, Runnable):
|
||||
return cast(RunnableAgentType, self.agent)
|
||||
else:
|
||||
return self.agent
|
||||
|
||||
def save(self, file_path: Union[Path, str]) -> None:
|
||||
@ -1249,7 +1242,6 @@ class AgentExecutor(Chain):
|
||||
"""
|
||||
if self.return_intermediate_steps:
|
||||
return self._action_agent.return_values + ["intermediate_steps"]
|
||||
else:
|
||||
return self._action_agent.return_values
|
||||
|
||||
def lookup_tool(self, name: str) -> BaseTool:
|
||||
@ -1304,10 +1296,7 @@ class AgentExecutor(Chain):
|
||||
msg = "Expected a single AgentFinish output, but got multiple values."
|
||||
raise ValueError(msg)
|
||||
return values[-1]
|
||||
else:
|
||||
return [
|
||||
(a.action, a.observation) for a in values if isinstance(a, AgentStep)
|
||||
]
|
||||
return [(a.action, a.observation) for a in values if isinstance(a, AgentStep)]
|
||||
|
||||
def _take_next_step(
|
||||
self,
|
||||
@ -1727,9 +1716,8 @@ class AgentExecutor(Chain):
|
||||
and self.trim_intermediate_steps > 0
|
||||
):
|
||||
return intermediate_steps[-self.trim_intermediate_steps :]
|
||||
elif callable(self.trim_intermediate_steps):
|
||||
if callable(self.trim_intermediate_steps):
|
||||
return self.trim_intermediate_steps(intermediate_steps)
|
||||
else:
|
||||
return intermediate_steps
|
||||
|
||||
@override
|
||||
|
@ -61,7 +61,6 @@ class ChatAgent(Agent):
|
||||
f"(but I haven't seen any of it! I only see what "
|
||||
f"you return as final answer):\n{agent_scratchpad}"
|
||||
)
|
||||
else:
|
||||
return agent_scratchpad
|
||||
|
||||
@classmethod
|
||||
|
@ -39,11 +39,9 @@ class ConvoOutputParser(AgentOutputParser):
|
||||
# If the action indicates a final answer, return an AgentFinish
|
||||
if action == "Final Answer":
|
||||
return AgentFinish({"output": action_input}, text)
|
||||
else:
|
||||
# Otherwise, return an AgentAction with the specified action and
|
||||
# input
|
||||
return AgentAction(action, action_input, text)
|
||||
else:
|
||||
# If the necessary keys aren't present in the response, raise an
|
||||
# exception
|
||||
msg = f"Missing 'action' or 'action_input' in LLM output: {text}"
|
||||
|
@ -23,7 +23,6 @@ def _convert_agent_action_to_messages(
|
||||
return list(agent_action.message_log) + [
|
||||
_create_function_message(agent_action, observation)
|
||||
]
|
||||
else:
|
||||
return [AIMessage(content=agent_action.log)]
|
||||
|
||||
|
||||
|
@ -182,7 +182,7 @@ def create_json_chat_agent(
|
||||
else:
|
||||
llm_to_use = llm
|
||||
|
||||
agent = (
|
||||
return (
|
||||
RunnablePassthrough.assign(
|
||||
agent_scratchpad=lambda x: format_log_to_messages(
|
||||
x["intermediate_steps"], template_tool_response=template_tool_response
|
||||
@ -192,4 +192,3 @@ def create_json_chat_agent(
|
||||
| llm_to_use
|
||||
| JSONAgentOutputParser()
|
||||
)
|
||||
return agent
|
||||
|
@ -55,7 +55,6 @@ class MRKLOutputParser(AgentOutputParser):
|
||||
return AgentFinish(
|
||||
{"output": text[start_index:end_index].strip()}, text[:end_index]
|
||||
)
|
||||
else:
|
||||
msg = f"{FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE}: {text}"
|
||||
raise OutputParserException(msg)
|
||||
|
||||
@ -69,7 +68,7 @@ class MRKLOutputParser(AgentOutputParser):
|
||||
|
||||
return AgentAction(action, tool_input, text)
|
||||
|
||||
elif includes_answer:
|
||||
if includes_answer:
|
||||
return AgentFinish(
|
||||
{"output": text.split(FINAL_ANSWER_ACTION)[-1].strip()}, text
|
||||
)
|
||||
@ -82,7 +81,7 @@ class MRKLOutputParser(AgentOutputParser):
|
||||
llm_output=text,
|
||||
send_to_llm=True,
|
||||
)
|
||||
elif not re.search(
|
||||
if not re.search(
|
||||
r"[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)", text, re.DOTALL
|
||||
):
|
||||
msg = f"Could not parse LLM output: `{text}`"
|
||||
@ -92,7 +91,6 @@ class MRKLOutputParser(AgentOutputParser):
|
||||
llm_output=text,
|
||||
send_to_llm=True,
|
||||
)
|
||||
else:
|
||||
msg = f"Could not parse LLM output: `{text}`"
|
||||
raise OutputParserException(msg)
|
||||
|
||||
|
@ -128,7 +128,6 @@ def _get_assistants_tool(
|
||||
"""
|
||||
if _is_assistants_builtin_tool(tool):
|
||||
return tool # type: ignore[return-value]
|
||||
else:
|
||||
return convert_to_openai_tool(tool)
|
||||
|
||||
|
||||
@ -510,12 +509,11 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
|
||||
for action, output in intermediate_steps
|
||||
if action.tool_call_id in required_tool_call_ids
|
||||
]
|
||||
submit_tool_outputs = {
|
||||
return {
|
||||
"tool_outputs": tool_outputs,
|
||||
"run_id": last_action.run_id,
|
||||
"thread_id": last_action.thread_id,
|
||||
}
|
||||
return submit_tool_outputs
|
||||
|
||||
def _create_run(self, input_dict: dict) -> Any:
|
||||
params = {
|
||||
@ -558,12 +556,11 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
|
||||
"run_metadata",
|
||||
)
|
||||
}
|
||||
run = self.client.beta.threads.create_and_run(
|
||||
return self.client.beta.threads.create_and_run(
|
||||
assistant_id=self.assistant_id,
|
||||
thread=thread,
|
||||
**params,
|
||||
)
|
||||
return run
|
||||
|
||||
def _get_response(self, run: Any) -> Any:
|
||||
# TODO: Pagination
|
||||
@ -612,7 +609,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
|
||||
run_id=run.id,
|
||||
thread_id=run.thread_id,
|
||||
)
|
||||
elif run.status == "requires_action":
|
||||
if run.status == "requires_action":
|
||||
if not self.as_agent:
|
||||
return run.required_action.submit_tool_outputs.tool_calls
|
||||
actions = []
|
||||
@ -639,7 +636,6 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
|
||||
)
|
||||
)
|
||||
return actions
|
||||
else:
|
||||
run_info = json.dumps(run.dict(), indent=2)
|
||||
msg = f"Unexpected run status: {run.status}. Full run info:\n\n{run_info})"
|
||||
raise ValueError(msg)
|
||||
@ -668,12 +664,11 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
|
||||
for action, output in intermediate_steps
|
||||
if action.tool_call_id in required_tool_call_ids
|
||||
]
|
||||
submit_tool_outputs = {
|
||||
return {
|
||||
"tool_outputs": tool_outputs,
|
||||
"run_id": last_action.run_id,
|
||||
"thread_id": last_action.thread_id,
|
||||
}
|
||||
return submit_tool_outputs
|
||||
|
||||
async def _acreate_run(self, input_dict: dict) -> Any:
|
||||
params = {
|
||||
@ -716,12 +711,11 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
|
||||
"run_metadata",
|
||||
)
|
||||
}
|
||||
run = await self.async_client.beta.threads.create_and_run(
|
||||
return await self.async_client.beta.threads.create_and_run(
|
||||
assistant_id=self.assistant_id,
|
||||
thread=thread,
|
||||
**params,
|
||||
)
|
||||
return run
|
||||
|
||||
async def _aget_response(self, run: Any) -> Any:
|
||||
# TODO: Pagination
|
||||
@ -766,7 +760,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
|
||||
run_id=run.id,
|
||||
thread_id=run.thread_id,
|
||||
)
|
||||
elif run.status == "requires_action":
|
||||
if run.status == "requires_action":
|
||||
if not self.as_agent:
|
||||
return run.required_action.submit_tool_outputs.tool_calls
|
||||
actions = []
|
||||
@ -793,7 +787,6 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
|
||||
)
|
||||
)
|
||||
return actions
|
||||
else:
|
||||
run_info = json.dumps(run.dict(), indent=2)
|
||||
msg = f"Unexpected run status: {run.status}. Full run info:\n\n{run_info})"
|
||||
raise ValueError(msg)
|
||||
|
@ -132,8 +132,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
|
||||
messages,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
agent_decision = self.output_parser._parse_ai_message(predicted_message)
|
||||
return agent_decision
|
||||
return self.output_parser._parse_ai_message(predicted_message)
|
||||
|
||||
async def aplan(
|
||||
self,
|
||||
@ -164,8 +163,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
|
||||
predicted_message = await self.llm.apredict_messages(
|
||||
messages, functions=self.functions, callbacks=callbacks
|
||||
)
|
||||
agent_decision = self.output_parser._parse_ai_message(predicted_message)
|
||||
return agent_decision
|
||||
return self.output_parser._parse_ai_message(predicted_message)
|
||||
|
||||
def return_stopped_response(
|
||||
self,
|
||||
@ -192,17 +190,15 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
|
||||
return AgentFinish(
|
||||
{"output": "Agent stopped due to iteration limit or time limit."}, ""
|
||||
)
|
||||
elif early_stopping_method == "generate":
|
||||
if early_stopping_method == "generate":
|
||||
# Generate does one final forward pass
|
||||
agent_decision = self.plan(
|
||||
intermediate_steps, with_functions=False, **kwargs
|
||||
)
|
||||
if isinstance(agent_decision, AgentFinish):
|
||||
return agent_decision
|
||||
else:
|
||||
msg = f"got AgentAction with no functions provided: {agent_decision}"
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
msg = (
|
||||
"early_stopping_method should be one of `force` or `generate`, "
|
||||
f"got {early_stopping_method}"
|
||||
@ -358,7 +354,7 @@ def create_openai_functions_agent(
|
||||
)
|
||||
raise ValueError(msg)
|
||||
llm_with_tools = llm.bind(functions=[convert_to_openai_function(t) for t in tools])
|
||||
agent = (
|
||||
return (
|
||||
RunnablePassthrough.assign(
|
||||
agent_scratchpad=lambda x: format_to_openai_function_messages(
|
||||
x["intermediate_steps"]
|
||||
@ -368,4 +364,3 @@ def create_openai_functions_agent(
|
||||
| llm_with_tools
|
||||
| OpenAIFunctionsAgentOutputParser()
|
||||
)
|
||||
return agent
|
||||
|
@ -224,8 +224,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
|
||||
predicted_message = self.llm.predict_messages(
|
||||
messages, functions=self.functions, callbacks=callbacks
|
||||
)
|
||||
agent_decision = _parse_ai_message(predicted_message)
|
||||
return agent_decision
|
||||
return _parse_ai_message(predicted_message)
|
||||
|
||||
async def aplan(
|
||||
self,
|
||||
@ -254,8 +253,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
|
||||
predicted_message = await self.llm.apredict_messages(
|
||||
messages, functions=self.functions, callbacks=callbacks
|
||||
)
|
||||
agent_decision = _parse_ai_message(predicted_message)
|
||||
return agent_decision
|
||||
return _parse_ai_message(predicted_message)
|
||||
|
||||
@classmethod
|
||||
def create_prompt(
|
||||
|
@ -96,7 +96,7 @@ def create_openai_tools_agent(
|
||||
tools=[convert_to_openai_tool(tool, strict=strict) for tool in tools]
|
||||
)
|
||||
|
||||
agent = (
|
||||
return (
|
||||
RunnablePassthrough.assign(
|
||||
agent_scratchpad=lambda x: format_to_openai_tool_messages(
|
||||
x["intermediate_steps"]
|
||||
@ -106,4 +106,3 @@ def create_openai_tools_agent(
|
||||
| llm_with_tools
|
||||
| OpenAIToolsAgentOutputParser()
|
||||
)
|
||||
return agent
|
||||
|
@ -49,7 +49,6 @@ class JSONAgentOutputParser(AgentOutputParser):
|
||||
response = response[0]
|
||||
if response["action"] == "Final Answer":
|
||||
return AgentFinish({"output": response["action_input"]}, text)
|
||||
else:
|
||||
action_input = response.get("action_input", {})
|
||||
if action_input is None:
|
||||
action_input = {}
|
||||
|
@ -65,7 +65,7 @@ class ReActSingleInputOutputParser(AgentOutputParser):
|
||||
|
||||
return AgentAction(action, tool_input, text)
|
||||
|
||||
elif includes_answer:
|
||||
if includes_answer:
|
||||
return AgentFinish(
|
||||
{"output": text.split(FINAL_ANSWER_ACTION)[-1].strip()}, text
|
||||
)
|
||||
@ -78,7 +78,7 @@ class ReActSingleInputOutputParser(AgentOutputParser):
|
||||
llm_output=text,
|
||||
send_to_llm=True,
|
||||
)
|
||||
elif not re.search(
|
||||
if not re.search(
|
||||
r"[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)", text, re.DOTALL
|
||||
):
|
||||
msg = f"Could not parse LLM output: `{text}`"
|
||||
@ -88,7 +88,6 @@ class ReActSingleInputOutputParser(AgentOutputParser):
|
||||
llm_output=text,
|
||||
send_to_llm=True,
|
||||
)
|
||||
else:
|
||||
msg = f"Could not parse LLM output: `{text}`"
|
||||
raise OutputParserException(msg)
|
||||
|
||||
|
@ -36,12 +36,11 @@ class XMLAgentOutputParser(AgentOutputParser):
|
||||
if "</tool_input>" in _tool_input:
|
||||
_tool_input = _tool_input.split("</tool_input>")[0]
|
||||
return AgentAction(tool=_tool, tool_input=_tool_input, log=text)
|
||||
elif "<final_answer>" in text:
|
||||
if "<final_answer>" in text:
|
||||
_, answer = text.split("<final_answer>")
|
||||
if "</final_answer>" in answer:
|
||||
answer = answer.split("</final_answer>")[0]
|
||||
return AgentFinish(return_values={"output": answer}, log=text)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
|
@ -134,7 +134,7 @@ def create_react_agent(
|
||||
else:
|
||||
llm_with_stop = llm
|
||||
output_parser = output_parser or ReActSingleInputOutputParser()
|
||||
agent = (
|
||||
return (
|
||||
RunnablePassthrough.assign(
|
||||
agent_scratchpad=lambda x: format_log_to_str(x["intermediate_steps"]),
|
||||
)
|
||||
@ -142,4 +142,3 @@ def create_react_agent(
|
||||
| llm_with_stop
|
||||
| output_parser
|
||||
)
|
||||
return agent
|
||||
|
@ -96,7 +96,6 @@ class DocstoreExplorer:
|
||||
if isinstance(result, Document):
|
||||
self.document = result
|
||||
return self._summary
|
||||
else:
|
||||
self.document = None
|
||||
return result
|
||||
|
||||
@ -113,9 +112,8 @@ class DocstoreExplorer:
|
||||
lookups = [p for p in self._paragraphs if self.lookup_str in p.lower()]
|
||||
if len(lookups) == 0:
|
||||
return "No Results"
|
||||
elif self.lookup_index >= len(lookups):
|
||||
if self.lookup_index >= len(lookups):
|
||||
return "No More Results"
|
||||
else:
|
||||
result_prefix = f"(Result {self.lookup_index + 1}/{len(lookups)})"
|
||||
return f"{result_prefix} {lookups[self.lookup_index]}"
|
||||
|
||||
|
@ -26,7 +26,6 @@ class ReActOutputParser(AgentOutputParser):
|
||||
action, action_input = re_matches.group(1), re_matches.group(2)
|
||||
if action == "Finish":
|
||||
return AgentFinish({"output": action_input}, text)
|
||||
else:
|
||||
return AgentAction(action, action_input, text)
|
||||
|
||||
@property
|
||||
|
@ -195,7 +195,7 @@ def create_self_ask_with_search_agent(
|
||||
raise ValueError(msg)
|
||||
|
||||
llm_with_stop = llm.bind(stop=["\nIntermediate answer:"])
|
||||
agent = (
|
||||
return (
|
||||
RunnablePassthrough.assign(
|
||||
agent_scratchpad=lambda x: format_log_to_str(
|
||||
x["intermediate_steps"],
|
||||
@ -209,4 +209,3 @@ def create_self_ask_with_search_agent(
|
||||
| llm_with_stop
|
||||
| SelfAskOutputParser()
|
||||
)
|
||||
return agent
|
||||
|
@ -62,7 +62,6 @@ class StructuredChatAgent(Agent):
|
||||
f"(but I haven't seen any of it! I only see what "
|
||||
f"you return as final answer):\n{agent_scratchpad}"
|
||||
)
|
||||
else:
|
||||
return agent_scratchpad
|
||||
|
||||
@classmethod
|
||||
@ -292,7 +291,7 @@ def create_structured_chat_agent(
|
||||
else:
|
||||
llm_with_stop = llm
|
||||
|
||||
agent = (
|
||||
return (
|
||||
RunnablePassthrough.assign(
|
||||
agent_scratchpad=lambda x: format_log_to_str(x["intermediate_steps"]),
|
||||
)
|
||||
@ -300,4 +299,3 @@ def create_structured_chat_agent(
|
||||
| llm_with_stop
|
||||
| JSONAgentOutputParser()
|
||||
)
|
||||
return agent
|
||||
|
@ -42,11 +42,9 @@ class StructuredChatOutputParser(AgentOutputParser):
|
||||
response = response[0]
|
||||
if response["action"] == "Final Answer":
|
||||
return AgentFinish({"output": response["action_input"]}, text)
|
||||
else:
|
||||
return AgentAction(
|
||||
response["action"], response.get("action_input", {}), text
|
||||
)
|
||||
else:
|
||||
return AgentFinish({"output": text}, text)
|
||||
except Exception as e:
|
||||
msg = f"Could not parse LLM output: {text}"
|
||||
@ -93,9 +91,8 @@ class StructuredChatOutputParserWithRetries(AgentOutputParser):
|
||||
llm=llm, parser=base_parser
|
||||
)
|
||||
return cls(output_fixing_parser=output_fixing_parser)
|
||||
elif base_parser is not None:
|
||||
if base_parser is not None:
|
||||
return cls(base_parser=base_parser)
|
||||
else:
|
||||
return cls()
|
||||
|
||||
@property
|
||||
|
@ -100,7 +100,7 @@ def create_tool_calling_agent(
|
||||
)
|
||||
llm_with_tools = llm.bind_tools(tools)
|
||||
|
||||
agent = (
|
||||
return (
|
||||
RunnablePassthrough.assign(
|
||||
agent_scratchpad=lambda x: message_formatter(x["intermediate_steps"])
|
||||
)
|
||||
@ -108,4 +108,3 @@ def create_tool_calling_agent(
|
||||
| llm_with_tools
|
||||
| ToolsAgentOutputParser()
|
||||
)
|
||||
return agent
|
||||
|
@ -221,7 +221,7 @@ def create_xml_agent(
|
||||
else:
|
||||
llm_with_stop = llm
|
||||
|
||||
agent = (
|
||||
return (
|
||||
RunnablePassthrough.assign(
|
||||
agent_scratchpad=lambda x: format_xml(x["intermediate_steps"]),
|
||||
)
|
||||
@ -229,4 +229,3 @@ def create_xml_agent(
|
||||
| llm_with_stop
|
||||
| XMLAgentOutputParser()
|
||||
)
|
||||
return agent
|
||||
|
@ -24,7 +24,6 @@ class AsyncFinalIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
||||
def check_if_answer_reached(self) -> bool:
|
||||
if self.strip_tokens:
|
||||
return self.last_tokens_stripped == self.answer_prefix_tokens_stripped
|
||||
else:
|
||||
return self.last_tokens == self.answer_prefix_tokens
|
||||
|
||||
def __init__(
|
||||
|
@ -25,7 +25,6 @@ class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
|
||||
def check_if_answer_reached(self) -> bool:
|
||||
if self.strip_tokens:
|
||||
return self.last_tokens_stripped == self.answer_prefix_tokens_stripped
|
||||
else:
|
||||
return self.last_tokens == self.answer_prefix_tokens
|
||||
|
||||
def __init__(
|
||||
|
@ -261,7 +261,6 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
|
||||
"""
|
||||
if verbose is None:
|
||||
return _get_verbosity()
|
||||
else:
|
||||
return verbose
|
||||
|
||||
@property
|
||||
@ -474,7 +473,6 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
|
||||
self.memory.save_context(inputs, outputs)
|
||||
if return_only_outputs:
|
||||
return outputs
|
||||
else:
|
||||
return {**inputs, **outputs}
|
||||
|
||||
async def aprep_outputs(
|
||||
@ -500,7 +498,6 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
|
||||
await self.memory.asave_context(inputs, outputs)
|
||||
if return_only_outputs:
|
||||
return outputs
|
||||
else:
|
||||
return {**inputs, **outputs}
|
||||
|
||||
def prep_inputs(self, inputs: Union[dict[str, Any], Any]) -> dict[str, str]:
|
||||
@ -628,7 +625,6 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
|
||||
" but none were provided."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
msg = (
|
||||
f"`run` supported with either positional arguments or keyword arguments"
|
||||
f" but not both. Got args: {args} and kwargs: {kwargs}."
|
||||
@ -687,7 +683,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
|
||||
f"one output key. Got {self.output_keys}."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
elif args and not kwargs:
|
||||
if args and not kwargs:
|
||||
if len(args) != 1:
|
||||
msg = "`run` supports only one positional argument."
|
||||
raise ValueError(msg)
|
||||
|
@ -208,9 +208,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
if isinstance(self.reduce_documents_chain, ReduceDocumentsChain):
|
||||
if self.reduce_documents_chain.collapse_documents_chain:
|
||||
return self.reduce_documents_chain.collapse_documents_chain
|
||||
else:
|
||||
return self.reduce_documents_chain.combine_documents_chain
|
||||
else:
|
||||
msg = (
|
||||
f"`reduce_documents_chain` is of type "
|
||||
f"{type(self.reduce_documents_chain)} so it does not have "
|
||||
@ -223,7 +221,6 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
"""Kept for backward compatibility."""
|
||||
if isinstance(self.reduce_documents_chain, ReduceDocumentsChain):
|
||||
return self.reduce_documents_chain.combine_documents_chain
|
||||
else:
|
||||
msg = (
|
||||
f"`reduce_documents_chain` is of type "
|
||||
f"{type(self.reduce_documents_chain)} so it does not have "
|
||||
|
@ -225,7 +225,6 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
def _collapse_chain(self) -> BaseCombineDocumentsChain:
|
||||
if self.collapse_documents_chain is not None:
|
||||
return self.collapse_documents_chain
|
||||
else:
|
||||
return self.combine_documents_chain
|
||||
|
||||
def combine_docs(
|
||||
|
@ -222,8 +222,7 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
|
||||
base_inputs: dict = {
|
||||
self.document_variable_name: self.document_prompt.format(**document_info)
|
||||
}
|
||||
inputs = {**base_inputs, **kwargs}
|
||||
return inputs
|
||||
return {**base_inputs, **kwargs}
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
|
@ -201,7 +201,6 @@ class ConstitutionalChain(Chain):
|
||||
) -> list[ConstitutionalPrinciple]:
|
||||
if names is None:
|
||||
return list(PRINCIPLES.values())
|
||||
else:
|
||||
return [PRINCIPLES[name] for name in names]
|
||||
|
||||
@classmethod
|
||||
|
@ -80,7 +80,6 @@ class ElasticsearchDatabaseChain(Chain):
|
||||
"""
|
||||
if not self.return_intermediate_steps:
|
||||
return [self.output_key]
|
||||
else:
|
||||
return [self.output_key, INTERMEDIATE_STEPS_KEY]
|
||||
|
||||
def _list_indices(self) -> list[str]:
|
||||
|
@ -47,7 +47,6 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
|
||||
"""Output keys for Hyde's LLM chain."""
|
||||
if isinstance(self.llm_chain, LLMChain):
|
||||
return self.llm_chain.output_keys
|
||||
else:
|
||||
return ["text"]
|
||||
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
|
@ -116,7 +116,6 @@ class LLMChain(Chain):
|
||||
"""
|
||||
if self.return_final_only:
|
||||
return [self.output_key]
|
||||
else:
|
||||
return [self.output_key, "full_generation"]
|
||||
|
||||
def _call(
|
||||
@ -142,7 +141,6 @@ class LLMChain(Chain):
|
||||
callbacks=callbacks,
|
||||
**self.llm_kwargs,
|
||||
)
|
||||
else:
|
||||
results = self.llm.bind(stop=stop, **self.llm_kwargs).batch(
|
||||
cast(list, prompts), {"callbacks": callbacks}
|
||||
)
|
||||
@ -169,7 +167,6 @@ class LLMChain(Chain):
|
||||
callbacks=callbacks,
|
||||
**self.llm_kwargs,
|
||||
)
|
||||
else:
|
||||
results = await self.llm.bind(stop=stop, **self.llm_kwargs).abatch(
|
||||
cast(list, prompts), {"callbacks": callbacks}
|
||||
)
|
||||
@ -344,7 +341,6 @@ class LLMChain(Chain):
|
||||
result = self.predict(callbacks=callbacks, **kwargs)
|
||||
if self.prompt.output_parser is not None:
|
||||
return self.prompt.output_parser.parse(result)
|
||||
else:
|
||||
return result
|
||||
|
||||
async def apredict_and_parse(
|
||||
@ -358,7 +354,6 @@ class LLMChain(Chain):
|
||||
result = await self.apredict(callbacks=callbacks, **kwargs)
|
||||
if self.prompt.output_parser is not None:
|
||||
return self.prompt.output_parser.parse(result)
|
||||
else:
|
||||
return result
|
||||
|
||||
def apply_and_parse(
|
||||
@ -380,7 +375,6 @@ class LLMChain(Chain):
|
||||
self.prompt.output_parser.parse(res[self.output_key])
|
||||
for res in generation
|
||||
]
|
||||
else:
|
||||
return generation
|
||||
|
||||
async def aapply_and_parse(
|
||||
@ -411,13 +405,12 @@ class LLMChain(Chain):
|
||||
def _get_language_model(llm_like: Runnable) -> BaseLanguageModel:
|
||||
if isinstance(llm_like, BaseLanguageModel):
|
||||
return llm_like
|
||||
elif isinstance(llm_like, RunnableBinding):
|
||||
if isinstance(llm_like, RunnableBinding):
|
||||
return _get_language_model(llm_like.bound)
|
||||
elif isinstance(llm_like, RunnableWithFallbacks):
|
||||
if isinstance(llm_like, RunnableWithFallbacks):
|
||||
return _get_language_model(llm_like.runnable)
|
||||
elif isinstance(llm_like, (RunnableBranch, DynamicRunnable)):
|
||||
if isinstance(llm_like, (RunnableBranch, DynamicRunnable)):
|
||||
return _get_language_model(llm_like.default)
|
||||
else:
|
||||
msg = (
|
||||
f"Unable to extract BaseLanguageModel from llm_like object of type "
|
||||
f"{type(llm_like)}"
|
||||
|
@ -55,13 +55,12 @@ def _load_question_to_checked_assertions_chain(
|
||||
check_assertions_chain,
|
||||
revised_answer_chain,
|
||||
]
|
||||
question_to_checked_assertions_chain = SequentialChain(
|
||||
return SequentialChain(
|
||||
chains=chains, # type: ignore[arg-type]
|
||||
input_variables=["question"],
|
||||
output_variables=["revised_statement"],
|
||||
verbose=True,
|
||||
)
|
||||
return question_to_checked_assertions_chain
|
||||
|
||||
|
||||
@deprecated(
|
||||
|
@ -32,7 +32,7 @@ def _load_sequential_chain(
|
||||
are_all_true_prompt: PromptTemplate,
|
||||
verbose: bool = False,
|
||||
) -> SequentialChain:
|
||||
chain = SequentialChain(
|
||||
return SequentialChain(
|
||||
chains=[
|
||||
LLMChain(
|
||||
llm=llm,
|
||||
@ -63,7 +63,6 @@ def _load_sequential_chain(
|
||||
output_variables=["all_true", "revised_summary"],
|
||||
verbose=verbose,
|
||||
)
|
||||
return chain
|
||||
|
||||
|
||||
@deprecated(
|
||||
|
@ -311,7 +311,6 @@ def _load_llm_math_chain(config: dict, **kwargs: Any) -> LLMMathChain:
|
||||
prompt = load_prompt(config.pop("prompt_path"))
|
||||
if llm_chain:
|
||||
return LLMMathChain(llm_chain=llm_chain, prompt=prompt, **config) # type: ignore[arg-type]
|
||||
else:
|
||||
return LLMMathChain(llm=llm, prompt=prompt, **config)
|
||||
|
||||
|
||||
@ -609,7 +608,6 @@ def _load_llm_requests_chain(config: dict, **kwargs: Any) -> LLMRequestsChain:
|
||||
return LLMRequestsChain(
|
||||
llm_chain=llm_chain, requests_wrapper=requests_wrapper, **config
|
||||
)
|
||||
else:
|
||||
return LLMRequestsChain(llm_chain=llm_chain, **config)
|
||||
|
||||
|
||||
|
@ -100,7 +100,6 @@ class OpenAIModerationChain(Chain):
|
||||
error_str = "Text was found that violates OpenAI's content policy."
|
||||
if self.error:
|
||||
raise ValueError(error_str)
|
||||
else:
|
||||
return error_str
|
||||
return text
|
||||
|
||||
|
@ -132,7 +132,7 @@ def create_openai_fn_chain(
|
||||
}
|
||||
if len(openai_functions) == 1 and enforce_single_function_usage:
|
||||
llm_kwargs["function_call"] = {"name": openai_functions[0]["name"]}
|
||||
llm_chain = LLMChain(
|
||||
return LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
output_parser=output_parser,
|
||||
@ -140,7 +140,6 @@ def create_openai_fn_chain(
|
||||
output_key=output_key,
|
||||
**kwargs,
|
||||
)
|
||||
return llm_chain
|
||||
|
||||
|
||||
@deprecated(
|
||||
|
@ -149,10 +149,9 @@ def create_citation_fuzzy_match_chain(llm: BaseLanguageModel) -> LLMChain:
|
||||
]
|
||||
prompt = ChatPromptTemplate(messages=messages) # type: ignore[arg-type]
|
||||
|
||||
chain = LLMChain(
|
||||
return LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
llm_kwargs=llm_kwargs,
|
||||
output_parser=output_parser,
|
||||
)
|
||||
return chain
|
||||
|
@ -103,7 +103,7 @@ def create_extraction_chain(
|
||||
extraction_prompt = prompt or ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE)
|
||||
output_parser = JsonKeyOutputFunctionsParser(key_name="info")
|
||||
llm_kwargs = get_llm_kwargs(function)
|
||||
chain = LLMChain(
|
||||
return LLMChain(
|
||||
llm=llm,
|
||||
prompt=extraction_prompt,
|
||||
llm_kwargs=llm_kwargs,
|
||||
@ -111,7 +111,6 @@ def create_extraction_chain(
|
||||
tags=tags,
|
||||
verbose=verbose,
|
||||
)
|
||||
return chain
|
||||
|
||||
|
||||
@deprecated(
|
||||
@ -187,11 +186,10 @@ def create_extraction_chain_pydantic(
|
||||
pydantic_schema=PydanticSchema, attr_name="info"
|
||||
)
|
||||
llm_kwargs = get_llm_kwargs(function)
|
||||
chain = LLMChain(
|
||||
return LLMChain(
|
||||
llm=llm,
|
||||
prompt=extraction_prompt,
|
||||
llm_kwargs=llm_kwargs,
|
||||
output_parser=output_parser,
|
||||
verbose=verbose,
|
||||
)
|
||||
return chain
|
||||
|
@ -100,14 +100,13 @@ def create_qa_with_structure_chain(
|
||||
]
|
||||
prompt = prompt or ChatPromptTemplate(messages=messages) # type: ignore[arg-type]
|
||||
|
||||
chain = LLMChain(
|
||||
return LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
llm_kwargs=llm_kwargs,
|
||||
output_parser=_output_parser,
|
||||
verbose=verbose,
|
||||
)
|
||||
return chain
|
||||
|
||||
|
||||
@deprecated(
|
||||
|
@ -91,14 +91,13 @@ def create_tagging_chain(
|
||||
prompt = prompt or ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
|
||||
output_parser = JsonOutputFunctionsParser()
|
||||
llm_kwargs = get_llm_kwargs(function)
|
||||
chain = LLMChain(
|
||||
return LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
llm_kwargs=llm_kwargs,
|
||||
output_parser=output_parser,
|
||||
**kwargs,
|
||||
)
|
||||
return chain
|
||||
|
||||
|
||||
@deprecated(
|
||||
@ -164,11 +163,10 @@ def create_tagging_chain_pydantic(
|
||||
prompt = prompt or ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
|
||||
output_parser = PydanticOutputFunctionsParser(pydantic_schema=pydantic_schema)
|
||||
llm_kwargs = get_llm_kwargs(function)
|
||||
chain = LLMChain(
|
||||
return LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
llm_kwargs=llm_kwargs,
|
||||
output_parser=output_parser,
|
||||
**kwargs,
|
||||
)
|
||||
return chain
|
||||
|
@ -76,5 +76,4 @@ def create_extraction_chain_pydantic(
|
||||
functions = [convert_pydantic_to_openai_function(p) for p in pydantic_schemas]
|
||||
tools = [{"type": "function", "function": d} for d in functions]
|
||||
model = llm.bind(tools=tools)
|
||||
chain = prompt | model | PydanticToolsParser(tools=pydantic_schemas)
|
||||
return chain
|
||||
return prompt | model | PydanticToolsParser(tools=pydantic_schemas)
|
||||
|
@ -91,13 +91,12 @@ class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
|
||||
filter_directive = cast(
|
||||
Optional[FilterDirective], get_parser().parse(raw_filter)
|
||||
)
|
||||
fixed = fix_filter_directive(
|
||||
return fix_filter_directive(
|
||||
filter_directive,
|
||||
allowed_comparators=allowed_comparators,
|
||||
allowed_operators=allowed_operators,
|
||||
allowed_attributes=allowed_attributes,
|
||||
)
|
||||
return fixed
|
||||
|
||||
else:
|
||||
ast_parse = get_parser(
|
||||
@ -131,13 +130,13 @@ def fix_filter_directive(
|
||||
) or not filter:
|
||||
return filter
|
||||
|
||||
elif isinstance(filter, Comparison):
|
||||
if isinstance(filter, Comparison):
|
||||
if allowed_comparators and filter.comparator not in allowed_comparators:
|
||||
return None
|
||||
if allowed_attributes and filter.attribute not in allowed_attributes:
|
||||
return None
|
||||
return filter
|
||||
elif isinstance(filter, Operation):
|
||||
if isinstance(filter, Operation):
|
||||
if allowed_operators and filter.operator not in allowed_operators:
|
||||
return None
|
||||
args = [
|
||||
@ -155,14 +154,12 @@ def fix_filter_directive(
|
||||
]
|
||||
if not args:
|
||||
return None
|
||||
elif len(args) == 1 and filter.operator in (Operator.AND, Operator.OR):
|
||||
if len(args) == 1 and filter.operator in (Operator.AND, Operator.OR):
|
||||
return args[0]
|
||||
else:
|
||||
return Operation(
|
||||
operator=filter.operator,
|
||||
arguments=args,
|
||||
)
|
||||
else:
|
||||
return filter
|
||||
|
||||
|
||||
|
@ -101,9 +101,8 @@ class QueryTransformer(Transformer):
|
||||
)
|
||||
raise ValueError(msg)
|
||||
return Comparison(comparator=func, attribute=args[0], value=args[1])
|
||||
elif len(args) == 1 and func in (Operator.AND, Operator.OR):
|
||||
if len(args) == 1 and func in (Operator.AND, Operator.OR):
|
||||
return args[0]
|
||||
else:
|
||||
return Operation(operator=func, arguments=args)
|
||||
|
||||
def _match_func_name(self, func_name: str) -> Union[Operator, Comparator]:
|
||||
@ -118,7 +117,7 @@ class QueryTransformer(Transformer):
|
||||
)
|
||||
raise ValueError(msg)
|
||||
return Comparator(func_name)
|
||||
elif func_name in set(Operator):
|
||||
if func_name in set(Operator):
|
||||
if (
|
||||
self.allowed_operators is not None
|
||||
and func_name not in self.allowed_operators
|
||||
@ -129,7 +128,6 @@ class QueryTransformer(Transformer):
|
||||
)
|
||||
raise ValueError(msg)
|
||||
return Operator(func_name)
|
||||
else:
|
||||
msg = (
|
||||
f"Received unrecognized function {func_name}. Valid functions are "
|
||||
f"{list(Operator) + list(Comparator)}"
|
||||
|
@ -60,10 +60,8 @@ def create_retrieval_chain(
|
||||
else:
|
||||
retrieval_docs = (lambda x: x["input"]) | retriever
|
||||
|
||||
retrieval_chain = (
|
||||
return (
|
||||
RunnablePassthrough.assign(
|
||||
context=retrieval_docs.with_config(run_name="retrieve_documents"),
|
||||
).assign(answer=combine_docs_chain)
|
||||
).with_config(run_name="retrieval_chain")
|
||||
|
||||
return retrieval_chain
|
||||
|
@ -157,7 +157,6 @@ class BaseRetrievalQA(Chain):
|
||||
|
||||
if self.return_source_documents:
|
||||
return {self.output_key: answer, "source_documents": docs}
|
||||
else:
|
||||
return {self.output_key: answer}
|
||||
|
||||
@abstractmethod
|
||||
@ -200,7 +199,6 @@ class BaseRetrievalQA(Chain):
|
||||
|
||||
if self.return_source_documents:
|
||||
return {self.output_key: answer, "source_documents": docs}
|
||||
else:
|
||||
return {self.output_key: answer}
|
||||
|
||||
|
||||
|
@ -97,13 +97,12 @@ class MultiRouteChain(Chain):
|
||||
)
|
||||
if not route.destination:
|
||||
return self.default_chain(route.next_inputs, callbacks=callbacks)
|
||||
elif route.destination in self.destination_chains:
|
||||
if route.destination in self.destination_chains:
|
||||
return self.destination_chains[route.destination](
|
||||
route.next_inputs, callbacks=callbacks
|
||||
)
|
||||
elif self.silent_errors:
|
||||
if self.silent_errors:
|
||||
return self.default_chain(route.next_inputs, callbacks=callbacks)
|
||||
else:
|
||||
msg = f"Received invalid destination chain name '{route.destination}'"
|
||||
raise ValueError(msg)
|
||||
|
||||
@ -123,14 +122,13 @@ class MultiRouteChain(Chain):
|
||||
return await self.default_chain.acall(
|
||||
route.next_inputs, callbacks=callbacks
|
||||
)
|
||||
elif route.destination in self.destination_chains:
|
||||
if route.destination in self.destination_chains:
|
||||
return await self.destination_chains[route.destination].acall(
|
||||
route.next_inputs, callbacks=callbacks
|
||||
)
|
||||
elif self.silent_errors:
|
||||
if self.silent_errors:
|
||||
return await self.default_chain.acall(
|
||||
route.next_inputs, callbacks=callbacks
|
||||
)
|
||||
else:
|
||||
msg = f"Received invalid destination chain name '{route.destination}'"
|
||||
raise ValueError(msg)
|
||||
|
@ -136,11 +136,10 @@ class LLMRouterChain(RouterChain):
|
||||
callbacks = _run_manager.get_child()
|
||||
|
||||
prediction = self.llm_chain.predict(callbacks=callbacks, **inputs)
|
||||
output = cast(
|
||||
return cast(
|
||||
dict[str, Any],
|
||||
self.llm_chain.prompt.output_parser.parse(prediction),
|
||||
)
|
||||
return output
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
@ -149,11 +148,10 @@ class LLMRouterChain(RouterChain):
|
||||
) -> dict[str, Any]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
callbacks = _run_manager.get_child()
|
||||
output = cast(
|
||||
return cast(
|
||||
dict[str, Any],
|
||||
await self.llm_chain.apredict_and_parse(callbacks=callbacks, **inputs),
|
||||
)
|
||||
return output
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
|
@ -141,7 +141,6 @@ def create_sql_query_chain(
|
||||
f"{db.dialect}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
table_info_kwargs["get_col_comments"] = True
|
||||
|
||||
inputs = {
|
||||
|
@ -143,7 +143,6 @@ def create_openai_fn_runnable(
|
||||
output_parser = output_parser or get_openai_output_parser(functions)
|
||||
if prompt:
|
||||
return prompt | llm.bind(**llm_kwargs_) | output_parser
|
||||
else:
|
||||
return llm.bind(**llm_kwargs_) | output_parser
|
||||
|
||||
|
||||
@ -413,7 +412,7 @@ def create_structured_output_runnable(
|
||||
first_tool_only=return_single,
|
||||
)
|
||||
|
||||
elif mode == "openai-functions":
|
||||
if mode == "openai-functions":
|
||||
return _create_openai_functions_structured_output_runnable(
|
||||
output_schema,
|
||||
llm,
|
||||
@ -422,7 +421,7 @@ def create_structured_output_runnable(
|
||||
enforce_single_function_usage=force_function_usage,
|
||||
**kwargs, # llm-specific kwargs
|
||||
)
|
||||
elif mode == "openai-json":
|
||||
if mode == "openai-json":
|
||||
if force_function_usage:
|
||||
msg = (
|
||||
"enforce_single_function_usage is not supported for mode='openai-json'."
|
||||
@ -431,7 +430,6 @@ def create_structured_output_runnable(
|
||||
return _create_openai_json_runnable(
|
||||
output_schema, llm, prompt=prompt, output_parser=output_parser, **kwargs
|
||||
)
|
||||
else:
|
||||
msg = (
|
||||
f"Invalid mode {mode}. Expected one of 'openai-tools', 'openai-functions', "
|
||||
f"'openai-json'."
|
||||
@ -460,7 +458,6 @@ def _create_openai_tools_runnable(
|
||||
)
|
||||
if prompt:
|
||||
return prompt | llm.bind(**llm_kwargs) | output_parser
|
||||
else:
|
||||
return llm.bind(**llm_kwargs) | output_parser
|
||||
|
||||
|
||||
@ -535,7 +532,6 @@ def _create_openai_json_runnable(
|
||||
prompt = prompt.partial(output_schema=json.dumps(schema_as_dict, indent=2))
|
||||
|
||||
return prompt | llm | output_parser
|
||||
else:
|
||||
return llm | output_parser
|
||||
|
||||
|
||||
|
@ -77,7 +77,6 @@ class TransformChain(Chain):
|
||||
) -> dict[str, Any]:
|
||||
if self.atransform_cb is not None:
|
||||
return await self.atransform_cb(inputs)
|
||||
else:
|
||||
self._log_once(
|
||||
"TransformChain's atransform is not provided, falling"
|
||||
" back to synchronous transform"
|
||||
|
@ -322,7 +322,6 @@ def init_chat_model(
|
||||
return _init_chat_model_helper(
|
||||
cast(str, model), model_provider=model_provider, **kwargs
|
||||
)
|
||||
else:
|
||||
if model:
|
||||
kwargs["model"] = model
|
||||
if model_provider:
|
||||
@ -343,42 +342,42 @@ def _init_chat_model_helper(
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
return ChatOpenAI(model=model, **kwargs)
|
||||
elif model_provider == "anthropic":
|
||||
if model_provider == "anthropic":
|
||||
_check_pkg("langchain_anthropic")
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
|
||||
return ChatAnthropic(model=model, **kwargs) # type: ignore[call-arg,unused-ignore]
|
||||
elif model_provider == "azure_openai":
|
||||
if model_provider == "azure_openai":
|
||||
_check_pkg("langchain_openai")
|
||||
from langchain_openai import AzureChatOpenAI
|
||||
|
||||
return AzureChatOpenAI(model=model, **kwargs)
|
||||
elif model_provider == "azure_ai":
|
||||
if model_provider == "azure_ai":
|
||||
_check_pkg("langchain_azure_ai")
|
||||
from langchain_azure_ai.chat_models import AzureAIChatCompletionsModel
|
||||
|
||||
return AzureAIChatCompletionsModel(model=model, **kwargs)
|
||||
elif model_provider == "cohere":
|
||||
if model_provider == "cohere":
|
||||
_check_pkg("langchain_cohere")
|
||||
from langchain_cohere import ChatCohere
|
||||
|
||||
return ChatCohere(model=model, **kwargs)
|
||||
elif model_provider == "google_vertexai":
|
||||
if model_provider == "google_vertexai":
|
||||
_check_pkg("langchain_google_vertexai")
|
||||
from langchain_google_vertexai import ChatVertexAI
|
||||
|
||||
return ChatVertexAI(model=model, **kwargs)
|
||||
elif model_provider == "google_genai":
|
||||
if model_provider == "google_genai":
|
||||
_check_pkg("langchain_google_genai")
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
|
||||
return ChatGoogleGenerativeAI(model=model, **kwargs)
|
||||
elif model_provider == "fireworks":
|
||||
if model_provider == "fireworks":
|
||||
_check_pkg("langchain_fireworks")
|
||||
from langchain_fireworks import ChatFireworks
|
||||
|
||||
return ChatFireworks(model=model, **kwargs)
|
||||
elif model_provider == "ollama":
|
||||
if model_provider == "ollama":
|
||||
try:
|
||||
_check_pkg("langchain_ollama")
|
||||
from langchain_ollama import ChatOllama
|
||||
@ -393,72 +392,70 @@ def _init_chat_model_helper(
|
||||
_check_pkg("langchain_ollama")
|
||||
|
||||
return ChatOllama(model=model, **kwargs)
|
||||
elif model_provider == "together":
|
||||
if model_provider == "together":
|
||||
_check_pkg("langchain_together")
|
||||
from langchain_together import ChatTogether
|
||||
|
||||
return ChatTogether(model=model, **kwargs)
|
||||
elif model_provider == "mistralai":
|
||||
if model_provider == "mistralai":
|
||||
_check_pkg("langchain_mistralai")
|
||||
from langchain_mistralai import ChatMistralAI
|
||||
|
||||
return ChatMistralAI(model=model, **kwargs) # type: ignore[call-arg,unused-ignore]
|
||||
elif model_provider == "huggingface":
|
||||
if model_provider == "huggingface":
|
||||
_check_pkg("langchain_huggingface")
|
||||
from langchain_huggingface import ChatHuggingFace
|
||||
|
||||
return ChatHuggingFace(model_id=model, **kwargs)
|
||||
elif model_provider == "groq":
|
||||
if model_provider == "groq":
|
||||
_check_pkg("langchain_groq")
|
||||
from langchain_groq import ChatGroq
|
||||
|
||||
return ChatGroq(model=model, **kwargs)
|
||||
elif model_provider == "bedrock":
|
||||
if model_provider == "bedrock":
|
||||
_check_pkg("langchain_aws")
|
||||
from langchain_aws import ChatBedrock
|
||||
|
||||
# TODO: update to use model= once ChatBedrock supports
|
||||
return ChatBedrock(model_id=model, **kwargs)
|
||||
elif model_provider == "bedrock_converse":
|
||||
if model_provider == "bedrock_converse":
|
||||
_check_pkg("langchain_aws")
|
||||
from langchain_aws import ChatBedrockConverse
|
||||
|
||||
return ChatBedrockConverse(model=model, **kwargs)
|
||||
elif model_provider == "google_anthropic_vertex":
|
||||
if model_provider == "google_anthropic_vertex":
|
||||
_check_pkg("langchain_google_vertexai")
|
||||
from langchain_google_vertexai.model_garden import ChatAnthropicVertex
|
||||
|
||||
return ChatAnthropicVertex(model=model, **kwargs)
|
||||
elif model_provider == "deepseek":
|
||||
if model_provider == "deepseek":
|
||||
_check_pkg("langchain_deepseek", pkg_kebab="langchain-deepseek")
|
||||
from langchain_deepseek import ChatDeepSeek
|
||||
|
||||
return ChatDeepSeek(model=model, **kwargs)
|
||||
elif model_provider == "nvidia":
|
||||
if model_provider == "nvidia":
|
||||
_check_pkg("langchain_nvidia_ai_endpoints")
|
||||
from langchain_nvidia_ai_endpoints import ChatNVIDIA
|
||||
|
||||
return ChatNVIDIA(model=model, **kwargs)
|
||||
elif model_provider == "ibm":
|
||||
if model_provider == "ibm":
|
||||
_check_pkg("langchain_ibm")
|
||||
from langchain_ibm import ChatWatsonx
|
||||
|
||||
return ChatWatsonx(model_id=model, **kwargs)
|
||||
elif model_provider == "xai":
|
||||
if model_provider == "xai":
|
||||
_check_pkg("langchain_xai")
|
||||
from langchain_xai import ChatXAI
|
||||
|
||||
return ChatXAI(model=model, **kwargs)
|
||||
elif model_provider == "perplexity":
|
||||
if model_provider == "perplexity":
|
||||
_check_pkg("langchain_perplexity")
|
||||
from langchain_perplexity import ChatPerplexity
|
||||
|
||||
return ChatPerplexity(model=model, **kwargs)
|
||||
else:
|
||||
supported = ", ".join(_SUPPORTED_PROVIDERS)
|
||||
msg = (
|
||||
f"Unsupported {model_provider=}.\n\nSupported model providers are: "
|
||||
f"{supported}"
|
||||
f"Unsupported {model_provider=}.\n\nSupported model providers are: {supported}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
@ -490,25 +487,24 @@ _SUPPORTED_PROVIDERS = {
|
||||
def _attempt_infer_model_provider(model_name: str) -> Optional[str]:
|
||||
if any(model_name.startswith(pre) for pre in ("gpt-3", "gpt-4", "o1", "o3")):
|
||||
return "openai"
|
||||
elif model_name.startswith("claude"):
|
||||
if model_name.startswith("claude"):
|
||||
return "anthropic"
|
||||
elif model_name.startswith("command"):
|
||||
if model_name.startswith("command"):
|
||||
return "cohere"
|
||||
elif model_name.startswith("accounts/fireworks"):
|
||||
if model_name.startswith("accounts/fireworks"):
|
||||
return "fireworks"
|
||||
elif model_name.startswith("gemini"):
|
||||
if model_name.startswith("gemini"):
|
||||
return "google_vertexai"
|
||||
elif model_name.startswith("amazon."):
|
||||
if model_name.startswith("amazon."):
|
||||
return "bedrock"
|
||||
elif model_name.startswith("mistral"):
|
||||
if model_name.startswith("mistral"):
|
||||
return "mistralai"
|
||||
elif model_name.startswith("deepseek"):
|
||||
if model_name.startswith("deepseek"):
|
||||
return "deepseek"
|
||||
elif model_name.startswith("grok"):
|
||||
if model_name.startswith("grok"):
|
||||
return "xai"
|
||||
elif model_name.startswith("sonar"):
|
||||
if model_name.startswith("sonar"):
|
||||
return "perplexity"
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@ -595,9 +591,8 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
)
|
||||
|
||||
return queue
|
||||
elif self._default_config and (model := self._model()) and hasattr(model, name):
|
||||
if self._default_config and (model := self._model()) and hasattr(model, name):
|
||||
return getattr(model, name)
|
||||
else:
|
||||
msg = f"{name} is not a BaseChatModel attribute"
|
||||
if self._default_config:
|
||||
msg += " and is not implemented on the default model"
|
||||
@ -728,7 +723,6 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
)
|
||||
# If multiple configs default to Runnable.batch which uses executor to invoke
|
||||
# in parallel.
|
||||
else:
|
||||
return super().batch(
|
||||
inputs, config=config, return_exceptions=return_exceptions, **kwargs
|
||||
)
|
||||
@ -751,7 +745,6 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
)
|
||||
# If multiple configs default to Runnable.batch which uses executor to invoke
|
||||
# in parallel.
|
||||
else:
|
||||
return await super().abatch(
|
||||
inputs, config=config, return_exceptions=return_exceptions, **kwargs
|
||||
)
|
||||
|
@ -192,35 +192,34 @@ def init_embeddings(
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
|
||||
return OpenAIEmbeddings(model=model_name, **kwargs)
|
||||
elif provider == "azure_openai":
|
||||
if provider == "azure_openai":
|
||||
from langchain_openai import AzureOpenAIEmbeddings
|
||||
|
||||
return AzureOpenAIEmbeddings(model=model_name, **kwargs)
|
||||
elif provider == "google_vertexai":
|
||||
if provider == "google_vertexai":
|
||||
from langchain_google_vertexai import VertexAIEmbeddings
|
||||
|
||||
return VertexAIEmbeddings(model=model_name, **kwargs)
|
||||
elif provider == "bedrock":
|
||||
if provider == "bedrock":
|
||||
from langchain_aws import BedrockEmbeddings
|
||||
|
||||
return BedrockEmbeddings(model_id=model_name, **kwargs)
|
||||
elif provider == "cohere":
|
||||
if provider == "cohere":
|
||||
from langchain_cohere import CohereEmbeddings
|
||||
|
||||
return CohereEmbeddings(model=model_name, **kwargs)
|
||||
elif provider == "mistralai":
|
||||
if provider == "mistralai":
|
||||
from langchain_mistralai import MistralAIEmbeddings
|
||||
|
||||
return MistralAIEmbeddings(model=model_name, **kwargs)
|
||||
elif provider == "huggingface":
|
||||
if provider == "huggingface":
|
||||
from langchain_huggingface import HuggingFaceEmbeddings
|
||||
|
||||
return HuggingFaceEmbeddings(model_name=model_name, **kwargs)
|
||||
elif provider == "ollama":
|
||||
if provider == "ollama":
|
||||
from langchain_ollama import OllamaEmbeddings
|
||||
|
||||
return OllamaEmbeddings(model=model_name, **kwargs)
|
||||
else:
|
||||
msg = (
|
||||
f"Provider '{provider}' is not supported.\n"
|
||||
f"Supported providers and their required packages:\n"
|
||||
|
@ -70,7 +70,7 @@ def resolve_pairwise_criteria(
|
||||
Criteria.DEPTH,
|
||||
]
|
||||
return {k.value: _SUPPORTED_CRITERIA[k] for k in _default_criteria}
|
||||
elif isinstance(criteria, Criteria):
|
||||
if isinstance(criteria, Criteria):
|
||||
criteria_ = {criteria.value: _SUPPORTED_CRITERIA[criteria]}
|
||||
elif isinstance(criteria, str):
|
||||
if criteria in _SUPPORTED_CRITERIA:
|
||||
|
@ -186,7 +186,6 @@ class _EmbeddingDistanceChainMixin(Chain):
|
||||
}
|
||||
if metric in metrics:
|
||||
return metrics[metric]
|
||||
else:
|
||||
msg = f"Invalid metric: {metric}"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
@ -162,7 +162,6 @@ def load_evaluator(
|
||||
)
|
||||
raise ValueError(msg) from e
|
||||
return evaluator_cls.from_llm(llm=llm, **kwargs)
|
||||
else:
|
||||
return evaluator_cls(**kwargs)
|
||||
|
||||
|
||||
|
@ -70,7 +70,7 @@ class JsonSchemaEvaluator(StringEvaluator):
|
||||
def _parse_json(self, node: Any) -> Union[dict, list, None, float, bool, int, str]:
|
||||
if isinstance(node, str):
|
||||
return parse_json_markdown(node)
|
||||
elif hasattr(node, "schema") and callable(getattr(node, "schema")):
|
||||
if hasattr(node, "schema") and callable(getattr(node, "schema")):
|
||||
# Pydantic model
|
||||
return getattr(node, "schema")()
|
||||
return node
|
||||
|
@ -24,7 +24,7 @@ def _get_score(text: str) -> Optional[tuple[str, int]]:
|
||||
if match:
|
||||
if match.group(1).upper() == "CORRECT":
|
||||
return "CORRECT", 1
|
||||
elif match.group(1).upper() == "INCORRECT":
|
||||
if match.group(1).upper() == "INCORRECT":
|
||||
return "INCORRECT", 0
|
||||
try:
|
||||
first_word = (
|
||||
@ -32,7 +32,7 @@ def _get_score(text: str) -> Optional[tuple[str, int]]:
|
||||
)
|
||||
if first_word.upper() == "CORRECT":
|
||||
return "CORRECT", 1
|
||||
elif first_word.upper() == "INCORRECT":
|
||||
if first_word.upper() == "INCORRECT":
|
||||
return "INCORRECT", 0
|
||||
last_word = (
|
||||
text.strip()
|
||||
@ -41,7 +41,7 @@ def _get_score(text: str) -> Optional[tuple[str, int]]:
|
||||
)
|
||||
if last_word.upper() == "CORRECT":
|
||||
return "CORRECT", 1
|
||||
elif last_word.upper() == "INCORRECT":
|
||||
if last_word.upper() == "INCORRECT":
|
||||
return "INCORRECT", 0
|
||||
except IndexError:
|
||||
pass
|
||||
|
@ -123,12 +123,12 @@ class _EvalArgsMixin:
|
||||
if self.requires_input and input is None:
|
||||
msg = f"{self.__class__.__name__} requires an input string."
|
||||
raise ValueError(msg)
|
||||
elif input is not None and not self.requires_input:
|
||||
if input is not None and not self.requires_input:
|
||||
warn(self._skip_input_warning)
|
||||
if self.requires_reference and reference is None:
|
||||
msg = f"{self.__class__.__name__} requires a reference string."
|
||||
raise ValueError(msg)
|
||||
elif reference is not None and not self.requires_reference:
|
||||
if reference is not None and not self.requires_reference:
|
||||
warn(self._skip_reference_warning)
|
||||
|
||||
|
||||
|
@ -70,7 +70,7 @@ def resolve_criteria(
|
||||
Criteria.DEPTH,
|
||||
]
|
||||
return {k.value: _SUPPORTED_CRITERIA[k] for k in _default_criteria}
|
||||
elif isinstance(criteria, Criteria):
|
||||
if isinstance(criteria, Criteria):
|
||||
criteria_ = {criteria.value: _SUPPORTED_CRITERIA[criteria]}
|
||||
elif isinstance(criteria, str):
|
||||
if criteria in _SUPPORTED_CRITERIA:
|
||||
|
@ -138,7 +138,6 @@ class _RapidFuzzChainMixin(Chain):
|
||||
module = module_map[distance]
|
||||
if normalize_score:
|
||||
return module.normalized_distance
|
||||
else:
|
||||
return module.distance
|
||||
|
||||
@property
|
||||
|
@ -21,7 +21,6 @@ def _get_client(
|
||||
ls_client = LangSmithClient(api_url, api_key=api_key)
|
||||
if hasattr(ls_client, "push_prompt") and hasattr(ls_client, "pull_prompt"):
|
||||
return ls_client
|
||||
else:
|
||||
from langchainhub import Client as LangChainHubClient
|
||||
|
||||
return LangChainHubClient(api_url, api_key=api_key)
|
||||
@ -82,14 +81,13 @@ def push(
|
||||
|
||||
# Then it's langchainhub
|
||||
manifest_json = dumps(object)
|
||||
message = client.push(
|
||||
return client.push(
|
||||
repo_full_name,
|
||||
manifest_json,
|
||||
parent_commit_hash=parent_commit_hash,
|
||||
new_repo_is_public=new_repo_is_public,
|
||||
new_repo_description=new_repo_description,
|
||||
)
|
||||
return message
|
||||
|
||||
|
||||
def pull(
|
||||
@ -113,8 +111,7 @@ def pull(
|
||||
|
||||
# Then it's langsmith
|
||||
if hasattr(client, "pull_prompt"):
|
||||
response = client.pull_prompt(owner_repo_commit, include_model=include_model)
|
||||
return response
|
||||
return client.pull_prompt(owner_repo_commit, include_model=include_model)
|
||||
|
||||
# Then it's langchainhub
|
||||
if hasattr(client, "pull_repo"):
|
||||
|
@ -561,7 +561,6 @@ def __getattr__(name: str) -> Any:
|
||||
k: v() for k, v in get_type_to_cls_dict().items()
|
||||
}
|
||||
return type_to_cls_dict
|
||||
else:
|
||||
return getattr(llms, name)
|
||||
|
||||
|
||||
|
@ -154,6 +154,7 @@ class UpstashRedisEntityStore(BaseEntityStore):
|
||||
logger.debug(
|
||||
f"Redis MEM set '{self.full_key_prefix}:{key}': '{value}' EX {self.ttl}"
|
||||
)
|
||||
return None
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
self.redis_client.delete(f"{self.full_key_prefix}:{key}")
|
||||
@ -255,6 +256,7 @@ class RedisEntityStore(BaseEntityStore):
|
||||
logger.debug(
|
||||
f"REDIS MEM set '{self.full_key_prefix}:{key}': '{value}' EX {self.ttl}"
|
||||
)
|
||||
return None
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
self.redis_client.delete(f"{self.full_key_prefix}:{key}")
|
||||
@ -362,6 +364,7 @@ class SQLiteEntityStore(BaseEntityStore):
|
||||
f'"{self.full_table_name}" (key, value) VALUES (?, ?)'
|
||||
)
|
||||
self._execute_query(query, (key, value))
|
||||
return None
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
"""Deletes a key-value pair, safely quoting the table name."""
|
||||
|
@ -34,7 +34,7 @@ class BooleanOutputParser(BaseOutputParser[bool]):
|
||||
)
|
||||
raise ValueError(msg)
|
||||
return True
|
||||
elif self.false_val.upper() in truthy:
|
||||
if self.false_val.upper() in truthy:
|
||||
if self.true_val.upper() in truthy:
|
||||
msg = (
|
||||
f"Ambiguous response. Both {self.true_val} and {self.false_val} "
|
||||
|
@ -71,7 +71,6 @@ class OutputFixingParser(BaseOutputParser[T]):
|
||||
except OutputParserException as e:
|
||||
if retries == self.max_retries:
|
||||
raise e
|
||||
else:
|
||||
retries += 1
|
||||
if self.legacy and hasattr(self.retry_chain, "run"):
|
||||
completion = self.retry_chain.run(
|
||||
@ -109,7 +108,6 @@ class OutputFixingParser(BaseOutputParser[T]):
|
||||
except OutputParserException as e:
|
||||
if retries == self.max_retries:
|
||||
raise e
|
||||
else:
|
||||
retries += 1
|
||||
if self.legacy and hasattr(self.retry_chain, "arun"):
|
||||
completion = await self.retry_chain.arun(
|
||||
|
@ -64,7 +64,7 @@ class PandasDataFrameOutputParser(BaseOutputParser[dict[str, Any]]):
|
||||
msg = f"Invalid array format in '{original_request_params}'. \
|
||||
Please check the format instructions."
|
||||
raise OutputParserException(msg)
|
||||
elif (
|
||||
if (
|
||||
isinstance(parsed_array[0], int)
|
||||
and parsed_array[-1] > self.dataframe.index.max()
|
||||
):
|
||||
|
@ -30,11 +30,9 @@ class RegexParser(BaseOutputParser[dict[str, str]]):
|
||||
match = re.search(self.regex, text)
|
||||
if match:
|
||||
return {key: match.group(i + 1) for i, key in enumerate(self.output_keys)}
|
||||
else:
|
||||
if self.default_output_key is None:
|
||||
msg = f"Could not parse output: {text}"
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
return {
|
||||
key: text if key == self.default_output_key else ""
|
||||
for key in self.output_keys
|
||||
|
@ -33,14 +33,11 @@ class RegexDictParser(BaseOutputParser[dict[str, str]]):
|
||||
{expected_format} on text {text}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
elif len(matches) > 1:
|
||||
if len(matches) > 1:
|
||||
msg = f"Multiple matches found for output key: {output_key} with \
|
||||
expected format {expected_format} on text {text}"
|
||||
raise ValueError(msg)
|
||||
elif (
|
||||
self.no_update_value is not None and matches[0] == self.no_update_value
|
||||
):
|
||||
if self.no_update_value is not None and matches[0] == self.no_update_value:
|
||||
continue
|
||||
else:
|
||||
result[output_key] = matches[0]
|
||||
return result
|
||||
|
@ -107,7 +107,6 @@ class RetryOutputParser(BaseOutputParser[T]):
|
||||
except OutputParserException as e:
|
||||
if retries == self.max_retries:
|
||||
raise e
|
||||
else:
|
||||
retries += 1
|
||||
if self.legacy and hasattr(self.retry_chain, "run"):
|
||||
completion = self.retry_chain.run(
|
||||
@ -143,7 +142,6 @@ class RetryOutputParser(BaseOutputParser[T]):
|
||||
except OutputParserException as e:
|
||||
if retries == self.max_retries:
|
||||
raise e
|
||||
else:
|
||||
retries += 1
|
||||
if self.legacy and hasattr(self.retry_chain, "arun"):
|
||||
completion = await self.retry_chain.arun(
|
||||
@ -234,7 +232,6 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
|
||||
except OutputParserException as e:
|
||||
if retries == self.max_retries:
|
||||
raise e
|
||||
else:
|
||||
retries += 1
|
||||
if self.legacy and hasattr(self.retry_chain, "run"):
|
||||
completion = self.retry_chain.run(
|
||||
@ -263,7 +260,6 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
|
||||
except OutputParserException as e:
|
||||
if retries == self.max_retries:
|
||||
raise e
|
||||
else:
|
||||
retries += 1
|
||||
if self.legacy and hasattr(self.retry_chain, "arun"):
|
||||
completion = await self.retry_chain.arun(
|
||||
|
@ -89,7 +89,6 @@ class StructuredOutputParser(BaseOutputParser[dict[str, Any]]):
|
||||
)
|
||||
if only_json:
|
||||
return STRUCTURED_FORMAT_SIMPLE_INSTRUCTIONS.format(format=schema_str)
|
||||
else:
|
||||
return STRUCTURED_FORMAT_INSTRUCTIONS.format(format=schema_str)
|
||||
|
||||
def parse(self, text: str) -> dict[str, Any]:
|
||||
|
@ -33,7 +33,6 @@ class YamlOutputParser(BaseOutputParser[T]):
|
||||
json_object = yaml.safe_load(yaml_str)
|
||||
if hasattr(self.pydantic_object, "model_validate"):
|
||||
return self.pydantic_object.model_validate(json_object)
|
||||
else:
|
||||
return self.pydantic_object.parse_obj(json_object)
|
||||
|
||||
except (yaml.YAMLError, ValidationError) as e:
|
||||
|
@ -45,7 +45,6 @@ class ContextualCompressionRetriever(BaseRetriever):
|
||||
docs, query, callbacks=run_manager.get_child()
|
||||
)
|
||||
return list(compressed_docs)
|
||||
else:
|
||||
return []
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
@ -71,5 +70,4 @@ class ContextualCompressionRetriever(BaseRetriever):
|
||||
docs, query, callbacks=run_manager.get_child()
|
||||
)
|
||||
return list(compressed_docs)
|
||||
else:
|
||||
return []
|
||||
|
@ -174,9 +174,7 @@ class EnsembleRetriever(BaseRetriever):
|
||||
"""
|
||||
|
||||
# Get fused result of the retrievers.
|
||||
fused_documents = self.rank_fusion(query, run_manager)
|
||||
|
||||
return fused_documents
|
||||
return self.rank_fusion(query, run_manager)
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
@ -195,9 +193,7 @@ class EnsembleRetriever(BaseRetriever):
|
||||
"""
|
||||
|
||||
# Get fused result of the retrievers.
|
||||
fused_documents = await self.arank_fusion(query, run_manager)
|
||||
|
||||
return fused_documents
|
||||
return await self.arank_fusion(query, run_manager)
|
||||
|
||||
def rank_fusion(
|
||||
self,
|
||||
@ -236,9 +232,7 @@ class EnsembleRetriever(BaseRetriever):
|
||||
]
|
||||
|
||||
# apply rank fusion
|
||||
fused_documents = self.weighted_reciprocal_rank(retriever_docs)
|
||||
|
||||
return fused_documents
|
||||
return self.weighted_reciprocal_rank(retriever_docs)
|
||||
|
||||
async def arank_fusion(
|
||||
self,
|
||||
@ -280,9 +274,7 @@ class EnsembleRetriever(BaseRetriever):
|
||||
]
|
||||
|
||||
# apply rank fusion
|
||||
fused_documents = self.weighted_reciprocal_rank(retriever_docs)
|
||||
|
||||
return fused_documents
|
||||
return self.weighted_reciprocal_rank(retriever_docs)
|
||||
|
||||
def weighted_reciprocal_rank(
|
||||
self, doc_lists: list[list[Document]]
|
||||
@ -318,7 +310,7 @@ class EnsembleRetriever(BaseRetriever):
|
||||
|
||||
# Docs are deduplicated by their contents then sorted by their scores
|
||||
all_docs = chain.from_iterable(doc_lists)
|
||||
sorted_docs = sorted(
|
||||
return sorted(
|
||||
unique_by_key(
|
||||
all_docs,
|
||||
lambda doc: (
|
||||
@ -332,4 +324,3 @@ class EnsembleRetriever(BaseRetriever):
|
||||
doc.page_content if self.id_key is None else doc.metadata[self.id_key]
|
||||
],
|
||||
)
|
||||
return sorted_docs
|
||||
|
@ -31,9 +31,7 @@ class MergerRetriever(BaseRetriever):
|
||||
"""
|
||||
|
||||
# Merge the results of the retrievers.
|
||||
merged_documents = self.merge_documents(query, run_manager)
|
||||
|
||||
return merged_documents
|
||||
return self.merge_documents(query, run_manager)
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
@ -52,9 +50,7 @@ class MergerRetriever(BaseRetriever):
|
||||
"""
|
||||
|
||||
# Merge the results of the retrievers.
|
||||
merged_documents = await self.amerge_documents(query, run_manager)
|
||||
|
||||
return merged_documents
|
||||
return await self.amerge_documents(query, run_manager)
|
||||
|
||||
def merge_documents(
|
||||
self, query: str, run_manager: CallbackManagerForRetrieverRun
|
||||
|
@ -74,10 +74,9 @@ class RePhraseQueryRetriever(BaseRetriever):
|
||||
query, {"callbacks": run_manager.get_child()}
|
||||
)
|
||||
logger.info(f"Re-phrased question: {re_phrased_question}")
|
||||
docs = self.retriever.invoke(
|
||||
return self.retriever.invoke(
|
||||
re_phrased_question, config={"callbacks": run_manager.get_child()}
|
||||
)
|
||||
return docs
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
|
@ -119,18 +119,17 @@ def _get_builtin_translator(vectorstore: VectorStore) -> Visitor:
|
||||
}
|
||||
if isinstance(vectorstore, DatabricksVectorSearch):
|
||||
return DatabricksVectorSearchTranslator()
|
||||
elif isinstance(vectorstore, MyScale):
|
||||
if isinstance(vectorstore, MyScale):
|
||||
return MyScaleTranslator(metadata_key=vectorstore.metadata_column)
|
||||
elif isinstance(vectorstore, Redis):
|
||||
if isinstance(vectorstore, Redis):
|
||||
return RedisTranslator.from_vectorstore(vectorstore)
|
||||
elif isinstance(vectorstore, TencentVectorDB):
|
||||
if isinstance(vectorstore, TencentVectorDB):
|
||||
fields = [
|
||||
field.name for field in (vectorstore.meta_fields or []) if field.index
|
||||
]
|
||||
return TencentVectorDBTranslator(fields)
|
||||
elif vectorstore.__class__ in BUILTIN_TRANSLATORS:
|
||||
if vectorstore.__class__ in BUILTIN_TRANSLATORS:
|
||||
return BUILTIN_TRANSLATORS[vectorstore.__class__]()
|
||||
else:
|
||||
try:
|
||||
from langchain_astradb.vectorstores import AstraDBVectorStore
|
||||
except ImportError:
|
||||
@ -289,14 +288,12 @@ class SelfQueryRetriever(BaseRetriever):
|
||||
def _get_docs_with_query(
|
||||
self, query: str, search_kwargs: dict[str, Any]
|
||||
) -> list[Document]:
|
||||
docs = self.vectorstore.search(query, self.search_type, **search_kwargs)
|
||||
return docs
|
||||
return self.vectorstore.search(query, self.search_type, **search_kwargs)
|
||||
|
||||
async def _aget_docs_with_query(
|
||||
self, query: str, search_kwargs: dict[str, Any]
|
||||
) -> list[Document]:
|
||||
docs = await self.vectorstore.asearch(query, self.search_type, **search_kwargs)
|
||||
return docs
|
||||
return await self.vectorstore.asearch(query, self.search_type, **search_kwargs)
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
@ -315,8 +312,7 @@ class SelfQueryRetriever(BaseRetriever):
|
||||
if self.verbose:
|
||||
logger.info(f"Generated Query: {structured_query}")
|
||||
new_query, search_kwargs = self._prepare_query(query, structured_query)
|
||||
docs = self._get_docs_with_query(new_query, search_kwargs)
|
||||
return docs
|
||||
return self._get_docs_with_query(new_query, search_kwargs)
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
@ -335,8 +331,7 @@ class SelfQueryRetriever(BaseRetriever):
|
||||
if self.verbose:
|
||||
logger.info(f"Generated Query: {structured_query}")
|
||||
new_query, search_kwargs = self._prepare_query(query, structured_query)
|
||||
docs = await self._aget_docs_with_query(new_query, search_kwargs)
|
||||
return docs
|
||||
return await self._aget_docs_with_query(new_query, search_kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
|
@ -191,13 +191,13 @@ def _wrap_in_chain_factory(
|
||||
)
|
||||
raise ValueError(msg)
|
||||
return lambda: chain
|
||||
elif isinstance(llm_or_chain_factory, BaseLanguageModel):
|
||||
if isinstance(llm_or_chain_factory, BaseLanguageModel):
|
||||
return llm_or_chain_factory
|
||||
elif isinstance(llm_or_chain_factory, Runnable):
|
||||
if isinstance(llm_or_chain_factory, Runnable):
|
||||
# Memory may exist here, but it's not elegant to check all those cases.
|
||||
lcf = llm_or_chain_factory
|
||||
return lambda: lcf
|
||||
elif callable(llm_or_chain_factory):
|
||||
if callable(llm_or_chain_factory):
|
||||
if is_traceable_function(llm_or_chain_factory):
|
||||
runnable_ = as_runnable(cast(Callable, llm_or_chain_factory))
|
||||
return lambda: runnable_
|
||||
@ -215,13 +215,12 @@ def _wrap_in_chain_factory(
|
||||
# It's not uncommon to do an LLM constructor instead of raw LLM,
|
||||
# so we'll unpack it for the user.
|
||||
return _model
|
||||
elif is_traceable_function(cast(Callable, _model)):
|
||||
if is_traceable_function(cast(Callable, _model)):
|
||||
runnable_ = as_runnable(cast(Callable, _model))
|
||||
return lambda: runnable_
|
||||
elif not isinstance(_model, Runnable):
|
||||
if not isinstance(_model, Runnable):
|
||||
# This is unlikely to happen - a constructor for a model function
|
||||
return lambda: RunnableLambda(constructor)
|
||||
else:
|
||||
# Typical correct case
|
||||
return constructor
|
||||
return llm_or_chain_factory
|
||||
@ -272,7 +271,6 @@ def _get_prompt(inputs: dict[str, Any]) -> str:
|
||||
raise InputFormatError(msg)
|
||||
if len(prompts) == 1:
|
||||
return prompts[0]
|
||||
else:
|
||||
msg = f"LLM Run expects single prompt input. Got {len(prompts)} prompts."
|
||||
raise InputFormatError(msg)
|
||||
|
||||
@ -321,7 +319,6 @@ def _get_messages(inputs: dict[str, Any]) -> dict:
|
||||
)
|
||||
raise InputFormatError(msg)
|
||||
return input_copy
|
||||
else:
|
||||
msg = (
|
||||
f"Chat Run expects single List[dict] or List[List[dict]] 'messages'"
|
||||
f" input. Got {inputs}"
|
||||
@ -707,7 +704,6 @@ async def _arun_llm(
|
||||
callbacks=callbacks, tags=tags or [], metadata=metadata or {}
|
||||
),
|
||||
)
|
||||
else:
|
||||
msg = (
|
||||
"Input mapper returned invalid format"
|
||||
f" {prompt_or_messages}"
|
||||
@ -715,7 +711,6 @@ async def _arun_llm(
|
||||
)
|
||||
raise InputFormatError(msg)
|
||||
|
||||
else:
|
||||
try:
|
||||
prompt = _get_prompt(inputs)
|
||||
llm_output: Union[str, BaseMessage] = await llm.ainvoke(
|
||||
|
@ -28,7 +28,6 @@ def _get_messages_from_run_dict(messages: list[dict]) -> list[BaseMessage]:
|
||||
first_message = messages[0]
|
||||
if "lc" in first_message:
|
||||
return [load(dumpd(message)) for message in messages]
|
||||
else:
|
||||
return messages_from_dict(messages)
|
||||
|
||||
|
||||
@ -106,14 +105,12 @@ class LLMStringRunMapper(StringRunMapper):
|
||||
if run.run_type != "llm":
|
||||
msg = "LLM RunMapper only supports LLM runs."
|
||||
raise ValueError(msg)
|
||||
elif not run.outputs:
|
||||
if not run.outputs:
|
||||
if run.error:
|
||||
msg = f"Cannot evaluate errored LLM run {run.id}: {run.error}"
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
msg = f"Run {run.id} has no outputs. Cannot evaluate this run."
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
try:
|
||||
inputs = self.serialize_inputs(run.inputs)
|
||||
except Exception as e:
|
||||
@ -142,9 +139,8 @@ class ChainStringRunMapper(StringRunMapper):
|
||||
def _get_key(self, source: dict, key: Optional[str], which: str) -> str:
|
||||
if key is not None:
|
||||
return source[key]
|
||||
elif len(source) == 1:
|
||||
if len(source) == 1:
|
||||
return next(iter(source.values()))
|
||||
else:
|
||||
msg = (
|
||||
f"Could not map run {which} with multiple keys: "
|
||||
f"{source}\nPlease manually specify a {which}_key"
|
||||
@ -168,7 +164,7 @@ class ChainStringRunMapper(StringRunMapper):
|
||||
f" '{self.input_key}'."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
elif self.prediction_key is not None and self.prediction_key not in run.outputs:
|
||||
if self.prediction_key is not None and self.prediction_key not in run.outputs:
|
||||
available_keys = ", ".join(run.outputs.keys())
|
||||
msg = (
|
||||
f"Run with ID {run.id} doesn't have the expected prediction key"
|
||||
@ -178,7 +174,6 @@ class ChainStringRunMapper(StringRunMapper):
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
else:
|
||||
input_ = self._get_key(run.inputs, self.input_key, "input")
|
||||
prediction = self._get_key(run.outputs, self.prediction_key, "prediction")
|
||||
return {
|
||||
@ -224,7 +219,6 @@ class StringExampleMapper(Serializable):
|
||||
" specify a reference_key."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
output = list(example.outputs.values())[0]
|
||||
elif self.reference_key not in example.outputs:
|
||||
msg = (
|
||||
|
@ -65,9 +65,8 @@ def _import_python_tool_PythonREPLTool() -> Any:
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name == "PythonAstREPLTool":
|
||||
return _import_python_tool_PythonAstREPLTool()
|
||||
elif name == "PythonREPLTool":
|
||||
if name == "PythonREPLTool":
|
||||
return _import_python_tool_PythonREPLTool()
|
||||
else:
|
||||
from langchain_community import tools
|
||||
|
||||
# If not in interactive env, raise warning.
|
||||
|
@ -143,7 +143,7 @@ ignore-regex = ".*(Stati Uniti|Tense=Pres).*"
|
||||
ignore-words-list = "momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogyny,unsecure,damon,crate,aadd,symbl,precesses,accademia,nin"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["A", "C4", "E", "F", "I", "EM", "PGH003", "PIE", "SIM", "T201", "D", "UP", "S", "W"]
|
||||
select = ["A", "C4", "D", "E", "EM", "F", "I", "PGH003", "PIE", "RET", "S", "SIM", "T201", "UP", "W"]
|
||||
pydocstyle.convention = "google"
|
||||
pyupgrade.keep-runtime-typing = true
|
||||
|
||||
|
@ -139,7 +139,6 @@ async def get_state(
|
||||
async def ask_for_passphrase(said_please: bool) -> dict[str, Any]:
|
||||
if said_please:
|
||||
return {"passphrase": f"The passphrase is {PASS_PHRASE}"}
|
||||
else:
|
||||
return {"passphrase": "I won't share the passphrase without saying 'please'."}
|
||||
|
||||
|
||||
@ -153,7 +152,6 @@ async def recycle(password: SecretPassPhrase) -> dict[str, Any]:
|
||||
if password.pw == PASS_PHRASE:
|
||||
_ROBOT_STATE["destruct"] = True
|
||||
return {"status": "Self-destruct initiated", "state": _ROBOT_STATE}
|
||||
else:
|
||||
_ROBOT_STATE["destruct"] = False
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
|
@ -100,14 +100,13 @@ def _get_agent(**kwargs: Any) -> AgentExecutor:
|
||||
),
|
||||
]
|
||||
|
||||
agent = initialize_agent(
|
||||
return initialize_agent(
|
||||
tools,
|
||||
fake_llm,
|
||||
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
||||
verbose=True,
|
||||
**kwargs,
|
||||
)
|
||||
return agent
|
||||
|
||||
|
||||
def test_agent_bad_action() -> None:
|
||||
|
@ -71,14 +71,13 @@ def _get_agent(**kwargs: Any) -> AgentExecutor:
|
||||
),
|
||||
]
|
||||
|
||||
agent = initialize_agent(
|
||||
return initialize_agent(
|
||||
tools,
|
||||
fake_llm,
|
||||
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
||||
verbose=True,
|
||||
**kwargs,
|
||||
)
|
||||
return agent
|
||||
|
||||
|
||||
async def test_agent_bad_action() -> None:
|
||||
|
@ -11,7 +11,6 @@ def get_action_and_input(text: str) -> tuple[str, str]:
|
||||
output = output_parser.parse(text)
|
||||
if isinstance(output, AgentAction):
|
||||
return output.tool, str(output.tool_input)
|
||||
else:
|
||||
return "Final Answer", output.return_values["output"]
|
||||
|
||||
|
||||
|
@ -16,7 +16,6 @@ def get_action_and_input(text: str) -> tuple[str, str]:
|
||||
output = MRKLOutputParser().parse(text)
|
||||
if isinstance(output, AgentAction):
|
||||
return output.tool, str(output.tool_input)
|
||||
else:
|
||||
return "Final Answer", output.return_values["output"]
|
||||
|
||||
|
||||
|
@ -21,9 +21,8 @@ def get_action_and_input(text: str) -> tuple[str, str]:
|
||||
output = output_parser.parse(text)
|
||||
if isinstance(output, AgentAction):
|
||||
return output.tool, str(output.tool_input)
|
||||
elif isinstance(output, AgentFinish):
|
||||
if isinstance(output, AgentFinish):
|
||||
return output.return_values["output"], output.log
|
||||
else:
|
||||
msg = "Unexpected output type"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
@ -58,7 +58,6 @@ class FakeChain(Chain):
|
||||
) -> dict[str, str]:
|
||||
if self.be_correct:
|
||||
return {"bar": "baz"}
|
||||
else:
|
||||
return {"baz": "bar"}
|
||||
|
||||
|
||||
|
@ -54,7 +54,6 @@ class _FakeTrajectoryChatModel(FakeChatModel):
|
||||
response = self.queries[list(self.queries.keys())[self.response_index]]
|
||||
self.response_index = self.response_index + 1
|
||||
return response
|
||||
else:
|
||||
prompt = messages[0].content
|
||||
return self.queries[prompt]
|
||||
|
||||
|
@ -45,7 +45,6 @@ class FakeLLM(LLM):
|
||||
return self.queries[prompt]
|
||||
if stop is None:
|
||||
return "foo"
|
||||
else:
|
||||
return "bar"
|
||||
|
||||
@property
|
||||
|
@ -14,7 +14,6 @@ class SequentialRetriever(BaseRetriever):
|
||||
) -> list[Document]:
|
||||
if self.response_index >= len(self.sequential_responses):
|
||||
return []
|
||||
else:
|
||||
self.response_index += 1
|
||||
return self.sequential_responses[self.response_index - 1]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user