mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 00:23:25 +00:00
Fix add callbacks to spark_sql due to depreciation of callback_manager (#9831)
Description: Due to depreciation (regarding to line 109 in [langchain/libs/langchain/langchain/chains/base.py](https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/chains/base.py) of callback_manager i replaced several parts Issue: None Dependencies: Maintainer: @baskaryan --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
ffa5625134
commit
a05fed9369
@ -686,12 +686,15 @@ s
|
|||||||
cls,
|
cls,
|
||||||
agent: Union[BaseSingleActionAgent, BaseMultiActionAgent],
|
agent: Union[BaseSingleActionAgent, BaseMultiActionAgent],
|
||||||
tools: Sequence[BaseTool],
|
tools: Sequence[BaseTool],
|
||||||
callback_manager: Optional[BaseCallbackManager] = None,
|
callbacks: Callbacks = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AgentExecutor:
|
) -> AgentExecutor:
|
||||||
"""Create from agent and tools."""
|
"""Create from agent and tools."""
|
||||||
return cls(
|
return cls(
|
||||||
agent=agent, tools=tools, callback_manager=callback_manager, **kwargs
|
agent=agent,
|
||||||
|
tools=tools,
|
||||||
|
callbacks=callbacks,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
|
@ -6,7 +6,7 @@ from langchain.agents.agent_toolkits.spark_sql.prompt import SQL_PREFIX, SQL_SUF
|
|||||||
from langchain.agents.agent_toolkits.spark_sql.toolkit import SparkSQLToolkit
|
from langchain.agents.agent_toolkits.spark_sql.toolkit import SparkSQLToolkit
|
||||||
from langchain.agents.mrkl.base import ZeroShotAgent
|
from langchain.agents.mrkl.base import ZeroShotAgent
|
||||||
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
|
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
|
||||||
from langchain.callbacks.base import BaseCallbackManager
|
from langchain.callbacks.base import BaseCallbackManager, Callbacks
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.schema.language_model import BaseLanguageModel
|
from langchain.schema.language_model import BaseLanguageModel
|
||||||
|
|
||||||
@ -15,6 +15,7 @@ def create_spark_sql_agent(
|
|||||||
llm: BaseLanguageModel,
|
llm: BaseLanguageModel,
|
||||||
toolkit: SparkSQLToolkit,
|
toolkit: SparkSQLToolkit,
|
||||||
callback_manager: Optional[BaseCallbackManager] = None,
|
callback_manager: Optional[BaseCallbackManager] = None,
|
||||||
|
callbacks: Callbacks = None,
|
||||||
prefix: str = SQL_PREFIX,
|
prefix: str = SQL_PREFIX,
|
||||||
suffix: str = SQL_SUFFIX,
|
suffix: str = SQL_SUFFIX,
|
||||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||||
@ -41,6 +42,7 @@ def create_spark_sql_agent(
|
|||||||
llm=llm,
|
llm=llm,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
callback_manager=callback_manager,
|
callback_manager=callback_manager,
|
||||||
|
callbacks=callbacks,
|
||||||
)
|
)
|
||||||
tool_names = [tool.name for tool in tools]
|
tool_names = [tool.name for tool in tools]
|
||||||
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
|
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
|
||||||
@ -48,6 +50,7 @@ def create_spark_sql_agent(
|
|||||||
agent=agent,
|
agent=agent,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
callback_manager=callback_manager,
|
callback_manager=callback_manager,
|
||||||
|
callbacks=callbacks,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
max_iterations=max_iterations,
|
max_iterations=max_iterations,
|
||||||
max_execution_time=max_execution_time,
|
max_execution_time=max_execution_time,
|
||||||
|
@ -136,6 +136,12 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
def raise_callback_manager_deprecation(cls, values: Dict) -> Dict:
|
def raise_callback_manager_deprecation(cls, values: Dict) -> Dict:
|
||||||
"""Raise deprecation warning if callback_manager is used."""
|
"""Raise deprecation warning if callback_manager is used."""
|
||||||
if values.get("callback_manager") is not None:
|
if values.get("callback_manager") is not None:
|
||||||
|
if values.get("callbacks") is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot specify both callback_manager and callbacks. "
|
||||||
|
"callback_manager is deprecated, callbacks is the preferred "
|
||||||
|
"parameter to pass in."
|
||||||
|
)
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"callback_manager is deprecated. Please use callbacks instead.",
|
"callback_manager is deprecated. Please use callbacks instead.",
|
||||||
DeprecationWarning,
|
DeprecationWarning,
|
||||||
|
Loading…
Reference in New Issue
Block a user