diff --git a/libs/langchain/langchain/agents/agent.py b/libs/langchain/langchain/agents/agent.py index 1294b1718f8..bc266d7c1e5 100644 --- a/libs/langchain/langchain/agents/agent.py +++ b/libs/langchain/langchain/agents/agent.py @@ -686,12 +686,15 @@ s cls, agent: Union[BaseSingleActionAgent, BaseMultiActionAgent], tools: Sequence[BaseTool], - callback_manager: Optional[BaseCallbackManager] = None, + callbacks: Callbacks = None, **kwargs: Any, ) -> AgentExecutor: """Create from agent and tools.""" return cls( - agent=agent, tools=tools, callback_manager=callback_manager, **kwargs + agent=agent, + tools=tools, + callbacks=callbacks, + **kwargs, ) @root_validator() diff --git a/libs/langchain/langchain/agents/agent_toolkits/spark_sql/base.py b/libs/langchain/langchain/agents/agent_toolkits/spark_sql/base.py index f6e2e209ea6..e70dc7f0509 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/spark_sql/base.py +++ b/libs/langchain/langchain/agents/agent_toolkits/spark_sql/base.py @@ -6,7 +6,7 @@ from langchain.agents.agent_toolkits.spark_sql.prompt import SQL_PREFIX, SQL_SUF from langchain.agents.agent_toolkits.spark_sql.toolkit import SparkSQLToolkit from langchain.agents.mrkl.base import ZeroShotAgent from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS -from langchain.callbacks.base import BaseCallbackManager +from langchain.callbacks.base import BaseCallbackManager, Callbacks from langchain.chains.llm import LLMChain from langchain.schema.language_model import BaseLanguageModel @@ -15,6 +15,7 @@ def create_spark_sql_agent( llm: BaseLanguageModel, toolkit: SparkSQLToolkit, callback_manager: Optional[BaseCallbackManager] = None, + callbacks: Callbacks = None, prefix: str = SQL_PREFIX, suffix: str = SQL_SUFFIX, format_instructions: str = FORMAT_INSTRUCTIONS, @@ -41,6 +42,7 @@ def create_spark_sql_agent( llm=llm, prompt=prompt, callback_manager=callback_manager, + callbacks=callbacks, ) tool_names = [tool.name for tool in tools] agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) @@ -48,6 +50,7 @@ def create_spark_sql_agent( agent=agent, tools=tools, callback_manager=callback_manager, + callbacks=callbacks, verbose=verbose, max_iterations=max_iterations, max_execution_time=max_execution_time, diff --git a/libs/langchain/langchain/chains/base.py b/libs/langchain/langchain/chains/base.py index 5a21dc6a661..701da6c375b 100644 --- a/libs/langchain/langchain/chains/base.py +++ b/libs/langchain/langchain/chains/base.py @@ -136,6 +136,12 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC): def raise_callback_manager_deprecation(cls, values: Dict) -> Dict: """Raise deprecation warning if callback_manager is used.""" if values.get("callback_manager") is not None: + if values.get("callbacks") is not None: + raise ValueError( + "Cannot specify both callback_manager and callbacks. " + "callback_manager is deprecated, callbacks is the preferred " + "parameter to pass in." + ) warnings.warn( "callback_manager is deprecated. Please use callbacks instead.", DeprecationWarning,