Compare commits

...

1 Commits

Author SHA1 Message Date
vowelparrot
6b879cebd0 Migrate to callbacks 2023-04-30 16:04:25 -07:00
8 changed files with 41 additions and 21 deletions

View File

@@ -117,7 +117,7 @@ class BaseSingleActionAgent(BaseModel):
cls,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> BaseSingleActionAgent:
raise NotImplementedError
@@ -513,7 +513,7 @@ class Agent(BaseSingleActionAgent):
cls,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
callbacks: Callbacks = None,
output_parser: Optional[AgentOutputParser] = None,
**kwargs: Any,
) -> Agent:
@@ -522,7 +522,8 @@ class Agent(BaseSingleActionAgent):
llm_chain = LLMChain(
llm=llm,
prompt=cls.create_prompt(tools),
callback_manager=callback_manager,
callbacks=callbacks,
callback_manager=kwargs.get("callback_manager"),
)
tool_names = [tool.name for tool in tools]
_output_parser = output_parser or cls._get_default_output_parser()

View File

@@ -6,7 +6,7 @@ from langchain.agents.agent import Agent, AgentOutputParser
from langchain.agents.chat.output_parser import ChatOutputParser
from langchain.agents.chat.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.chains.llm import LLMChain
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.chat import (
@@ -80,7 +80,7 @@ class ChatAgent(Agent):
cls,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
callbacks: Callbacks = None,
output_parser: Optional[AgentOutputParser] = None,
prefix: str = PREFIX,
suffix: str = SUFFIX,
@@ -100,7 +100,8 @@ class ChatAgent(Agent):
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
callback_manager=callback_manager,
callbacks=callbacks,
callback_manager=kwargs.get("callback_manager"),
)
tool_names = [tool.name for tool in tools]
_output_parser = output_parser or cls._get_default_output_parser()

View File

@@ -10,7 +10,7 @@ from langchain.agents.agent_types import AgentType
from langchain.agents.conversational.output_parser import ConvoOutputParser
from langchain.agents.conversational.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.tools.base import BaseTool
@@ -85,7 +85,7 @@ class ConversationalAgent(Agent):
cls,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
callbacks: Callbacks = None,
output_parser: Optional[AgentOutputParser] = None,
prefix: str = PREFIX,
suffix: str = SUFFIX,
@@ -109,7 +109,8 @@ class ConversationalAgent(Agent):
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
callback_manager=callback_manager,
callbacks=callbacks,
callback_manager=kwargs.get("callback_manager"),
)
tool_names = [tool.name for tool in tools]
_output_parser = output_parser or cls._get_default_output_parser(

View File

@@ -13,7 +13,7 @@ from langchain.agents.conversational_chat.prompt import (
TEMPLATE_TOOL_RESPONSE,
)
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.chains import LLMChain
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.chat import (
@@ -103,7 +103,7 @@ class ConversationalChatAgent(Agent):
cls,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
callbacks: Callbacks = None,
output_parser: Optional[AgentOutputParser] = None,
system_message: str = PREFIX,
human_message: str = SUFFIX,
@@ -123,7 +123,8 @@ class ConversationalChatAgent(Agent):
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
callback_manager=callback_manager,
callbacks=callbacks,
callback_manager=kwargs.get("callback_manager"),
)
tool_names = [tool.name for tool in tools]
return cls(

View File

@@ -5,7 +5,7 @@ from langchain.agents.agent import AgentExecutor
from langchain.agents.agent_types import AgentType
from langchain.agents.loading import AGENT_TO_CLASS, load_agent
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import CallbackManager, Callbacks
from langchain.tools.base import BaseTool
@@ -13,7 +13,7 @@ def initialize_agent(
tools: Sequence[BaseTool],
llm: BaseLanguageModel,
agent: Optional[AgentType] = None,
callback_manager: Optional[BaseCallbackManager] = None,
callbacks: Callbacks = None,
agent_path: Optional[str] = None,
agent_kwargs: Optional[dict] = None,
**kwargs: Any,
@@ -25,7 +25,7 @@ def initialize_agent(
llm: Language model to use as the agent.
agent: Agent type to use. If None and agent_path is also None, will default to
AgentType.ZERO_SHOT_REACT_DESCRIPTION.
callback_manager: CallbackManager to use. Global callback manager is used if
callbacks: CallbackManager to use. Global callback manager is used if
not provided. Defaults to None.
agent_path: Path to serialized agent to use.
agent_kwargs: Additional key word arguments to pass to the underlying agent
@@ -34,6 +34,11 @@ def initialize_agent(
Returns:
An agent executor
"""
verbose: bool = kwargs.get("verbose", False)
callback_manager = CallbackManager.configure(
callbacks or kwargs.get("callback_manager"), verbose=verbose
)
if agent is None and agent_path is None:
agent = AgentType.ZERO_SHOT_REACT_DESCRIPTION
if agent is not None and agent_path is not None:
@@ -50,11 +55,17 @@ def initialize_agent(
agent_cls = AGENT_TO_CLASS[agent]
agent_kwargs = agent_kwargs or {}
agent_obj = agent_cls.from_llm_and_tools(
llm, tools, callback_manager=callback_manager, **agent_kwargs
llm,
tools,
callbacks=callback_manager,
**agent_kwargs,
)
elif agent_path is not None:
agent_obj = load_agent(
agent_path, llm=llm, tools=tools, callback_manager=callback_manager
agent_path,
llm=llm,
tools=tools,
callbacks=callback_manager,
)
else:
raise ValueError(
@@ -64,6 +75,6 @@ def initialize_agent(
return AgentExecutor.from_agent_and_tools(
agent=agent_obj,
tools=tools,
callback_manager=callback_manager,
callbacks=callback_manager,
**kwargs,
)

View File

@@ -11,7 +11,7 @@ from langchain.agents.mrkl.output_parser import MRKLOutputParser
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
from langchain.agents.tools import Tool
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.tools.base import BaseTool
@@ -89,7 +89,7 @@ class ZeroShotAgent(Agent):
cls,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
callbacks: Callbacks = None,
output_parser: Optional[AgentOutputParser] = None,
prefix: str = PREFIX,
suffix: str = SUFFIX,
@@ -109,7 +109,8 @@ class ZeroShotAgent(Agent):
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
callback_manager=callback_manager,
callbacks=callbacks,
callback_manager=kwargs.get("callback_manager"),
)
tool_names = [tool.name for tool in tools]
_output_parser = output_parser or cls._get_default_output_parser()

View File

@@ -30,7 +30,9 @@ class Chain(BaseModel, ABC):
memory: Optional[BaseMemory] = None
callbacks: Callbacks = None
"""Callbacks to be used during chain execution."""
callback_manager: Optional[BaseCallbackManager] = None
"""Deprecated. Use 'callbacks' instead."""
verbose: bool = Field(
default_factory=_get_verbosity
) # Whether to print the response text

View File

@@ -36,7 +36,9 @@ class BaseChatModel(BaseLanguageModel, ABC):
verbose: bool = Field(default_factory=_get_verbosity)
"""Whether to print out response text."""
callbacks: Callbacks = None
"""Callbacks to be used during chat."""
callback_manager: Optional[BaseCallbackManager] = None
"""Deprecated. Use 'callbacks' instead."""
@root_validator()
def raise_deprecation(cls, values: Dict) -> Dict: