mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-20 01:54:14 +00:00
WIP: stdout callback (#479)
first pass at stdout callback for the most part, went pretty smoothly. aside from the code here, here are some comments/observations. 1. should somehow default to stdouthandler so i dont have to do ``` from langchain.callbacks import get_callback_manager from langchain.callbacks.stdout import StdOutCallbackHandler get_callback_manager().add_handler(StdOutCallbackHandler()) ``` 2. I kept around the verbosity flag. 1) this is pretty important for getting the stdout to look good for agents (and other things). 2) I actually added this for LLM class since it didn't have it. 3. The only part that isn't basically perfectly moved over is the end of the agent run. Here's a screenshot of the new stdout tracing  Noticing it is missing logging of the final thought, eg before this is what it looked like  The reason its missing is that this was previously logged as part of agent end (lines 205 and 206) this is probably only relevant for the std out logger? any thoughts for how to get it back in?
This commit is contained in:
@@ -4,6 +4,7 @@ from typing import Optional
|
|||||||
|
|
||||||
from langchain.agents import MRKLChain, ReActChain, SelfAskWithSearchChain
|
from langchain.agents import MRKLChain, ReActChain, SelfAskWithSearchChain
|
||||||
from langchain.cache import BaseCache
|
from langchain.cache import BaseCache
|
||||||
|
from langchain.callbacks import set_default_callback_manager
|
||||||
from langchain.chains import (
|
from langchain.chains import (
|
||||||
ConversationChain,
|
ConversationChain,
|
||||||
LLMBashChain,
|
LLMBashChain,
|
||||||
@@ -33,6 +34,7 @@ from langchain.vectorstores import FAISS, ElasticVectorSearch
|
|||||||
logger: BaseLogger = StdOutLogger()
|
logger: BaseLogger = StdOutLogger()
|
||||||
verbose: bool = False
|
verbose: bool = False
|
||||||
llm_cache: Optional[BaseCache] = None
|
llm_cache: Optional[BaseCache] = None
|
||||||
|
set_default_callback_manager()
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"LLMChain",
|
"LLMChain",
|
||||||
|
@@ -3,11 +3,10 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from pydantic import BaseModel, root_validator
|
from pydantic import BaseModel, root_validator
|
||||||
|
|
||||||
import langchain
|
|
||||||
from langchain.agents.tools import Tool
|
from langchain.agents.tools import Tool
|
||||||
from langchain.callbacks.base import BaseCallbackManager
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
@@ -17,7 +16,7 @@ from langchain.llms.base import BaseLLM
|
|||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
from langchain.prompts.few_shot import FewShotPromptTemplate
|
from langchain.prompts.few_shot import FewShotPromptTemplate
|
||||||
from langchain.prompts.prompt import PromptTemplate
|
from langchain.prompts.prompt import PromptTemplate
|
||||||
from langchain.schema import AgentAction, AgentFinish
|
from langchain.schema import AGENT_FINISH_OBSERVATION, AgentAction, AgentFinish
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
@@ -47,7 +46,7 @@ class Agent(BaseModel):
|
|||||||
|
|
||||||
def plan(
|
def plan(
|
||||||
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
|
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
|
||||||
) -> Union[AgentFinish, AgentAction]:
|
) -> AgentAction:
|
||||||
"""Given input, decided what to do.
|
"""Given input, decided what to do.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -74,7 +73,7 @@ class Agent(BaseModel):
|
|||||||
parsed_output = self._extract_tool_and_input(full_output)
|
parsed_output = self._extract_tool_and_input(full_output)
|
||||||
tool, tool_input = parsed_output
|
tool, tool_input = parsed_output
|
||||||
if tool == self.finish_tool_name:
|
if tool == self.finish_tool_name:
|
||||||
return AgentFinish({"output": tool_input}, full_output)
|
return AgentFinish(tool, tool_input, full_output, {"output": tool_input})
|
||||||
return AgentAction(tool, tool_input, full_output)
|
return AgentAction(tool, tool_input, full_output)
|
||||||
|
|
||||||
def prepare_for_new_call(self) -> None:
|
def prepare_for_new_call(self) -> None:
|
||||||
@@ -138,11 +137,15 @@ class Agent(BaseModel):
|
|||||||
llm: BaseLLM,
|
llm: BaseLLM,
|
||||||
tools: List[Tool],
|
tools: List[Tool],
|
||||||
callback_manager: Optional[BaseCallbackManager] = None,
|
callback_manager: Optional[BaseCallbackManager] = None,
|
||||||
|
verbose: bool = False,
|
||||||
) -> Agent:
|
) -> Agent:
|
||||||
"""Construct an agent from an LLM and tools."""
|
"""Construct an agent from an LLM and tools."""
|
||||||
cls._validate_tools(tools)
|
cls._validate_tools(tools)
|
||||||
llm_chain = LLMChain(
|
llm_chain = LLMChain(
|
||||||
llm=llm, prompt=cls.create_prompt(tools), callback_manager=callback_manager
|
llm=llm,
|
||||||
|
prompt=cls.create_prompt(tools),
|
||||||
|
callback_manager=callback_manager,
|
||||||
|
verbose=verbose,
|
||||||
)
|
)
|
||||||
return cls(llm_chain=llm_chain)
|
return cls(llm_chain=llm_chain)
|
||||||
|
|
||||||
@@ -217,28 +220,34 @@ class AgentExecutor(Chain, BaseModel):
|
|||||||
# If the tool chosen is the finishing tool, then we end and return.
|
# If the tool chosen is the finishing tool, then we end and return.
|
||||||
if isinstance(output, AgentFinish):
|
if isinstance(output, AgentFinish):
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
langchain.logger.log_agent_end(output, color="green")
|
self._get_callback_manager().on_tool_start(
|
||||||
|
{"name": "Finish"}, output, color="green"
|
||||||
|
)
|
||||||
|
self._get_callback_manager().on_tool_end(AGENT_FINISH_OBSERVATION)
|
||||||
final_output = output.return_values
|
final_output = output.return_values
|
||||||
if self.return_intermediate_steps:
|
if self.return_intermediate_steps:
|
||||||
final_output["intermediate_steps"] = intermediate_steps
|
final_output["intermediate_steps"] = intermediate_steps
|
||||||
return final_output
|
return final_output
|
||||||
if self.verbose:
|
|
||||||
langchain.logger.log_agent_action(output, color="green")
|
|
||||||
# And then we lookup the tool
|
# And then we lookup the tool
|
||||||
if output.tool in name_to_tool_map:
|
if output.tool in name_to_tool_map:
|
||||||
chain = name_to_tool_map[output.tool]
|
chain = name_to_tool_map[output.tool]
|
||||||
|
if self.verbose:
|
||||||
self._get_callback_manager().on_tool_start(
|
self._get_callback_manager().on_tool_start(
|
||||||
{"name": str(chain)[:60] + "..."}, output.tool, output.tool_input
|
{"name": str(chain)[:60] + "..."}, output, color="green"
|
||||||
)
|
)
|
||||||
# We then call the tool on the tool input to get an observation
|
# We then call the tool on the tool input to get an observation
|
||||||
observation = chain(output.tool_input)
|
observation = chain(output.tool_input)
|
||||||
self._get_callback_manager().on_tool_end(observation)
|
|
||||||
color = color_mapping[output.tool]
|
color = color_mapping[output.tool]
|
||||||
else:
|
else:
|
||||||
|
if self.verbose:
|
||||||
|
self._get_callback_manager().on_tool_start(
|
||||||
|
{"name": "N/A"}, output, color="green"
|
||||||
|
)
|
||||||
observation = f"{output.tool} is not a valid tool, try another one."
|
observation = f"{output.tool} is not a valid tool, try another one."
|
||||||
color = None
|
color = None
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
langchain.logger.log_agent_observation(
|
self._get_callback_manager().on_tool_end(
|
||||||
observation,
|
observation,
|
||||||
color=color,
|
color=color,
|
||||||
observation_prefix=self.agent.observation_prefix,
|
observation_prefix=self.agent.observation_prefix,
|
||||||
|
@@ -21,6 +21,7 @@ def initialize_agent(
|
|||||||
llm: BaseLLM,
|
llm: BaseLLM,
|
||||||
agent: str = "zero-shot-react-description",
|
agent: str = "zero-shot-react-description",
|
||||||
callback_manager: Optional[BaseCallbackManager] = None,
|
callback_manager: Optional[BaseCallbackManager] = None,
|
||||||
|
verbose: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AgentExecutor:
|
) -> AgentExecutor:
|
||||||
"""Load agent given tools and LLM.
|
"""Load agent given tools and LLM.
|
||||||
@@ -30,6 +31,7 @@ def initialize_agent(
|
|||||||
llm: Language model to use as the agent.
|
llm: Language model to use as the agent.
|
||||||
agent: The agent to use. Valid options are:
|
agent: The agent to use. Valid options are:
|
||||||
`zero-shot-react-description`, `react-docstore`, `self-ask-with-search`.
|
`zero-shot-react-description`, `react-docstore`, `self-ask-with-search`.
|
||||||
|
verbose: Whether to use the callback manager for this particular agent.
|
||||||
callback_manager: CallbackManager to use. Global callback manager is used if
|
callback_manager: CallbackManager to use. Global callback manager is used if
|
||||||
not provided. Defaults to None.
|
not provided. Defaults to None.
|
||||||
**kwargs: Additional key word arguments to pass to the agent.
|
**kwargs: Additional key word arguments to pass to the agent.
|
||||||
@@ -44,8 +46,12 @@ def initialize_agent(
|
|||||||
)
|
)
|
||||||
agent_cls = AGENT_TO_CLASS[agent]
|
agent_cls = AGENT_TO_CLASS[agent]
|
||||||
agent_obj = agent_cls.from_llm_and_tools(
|
agent_obj = agent_cls.from_llm_and_tools(
|
||||||
llm, tools, callback_manager=callback_manager
|
llm, tools, callback_manager=callback_manager, verbose=verbose
|
||||||
)
|
)
|
||||||
return AgentExecutor.from_agent_and_tools(
|
return AgentExecutor.from_agent_and_tools(
|
||||||
agent=agent_obj, tools=tools, callback_manager=callback_manager, **kwargs
|
agent=agent_obj,
|
||||||
|
tools=tools,
|
||||||
|
callback_manager=callback_manager,
|
||||||
|
verbose=verbose,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
@@ -1,8 +1,15 @@
|
|||||||
"""Callback handlers that allow listening to events in LangChain."""
|
"""Callback handlers that allow listening to events in LangChain."""
|
||||||
from langchain.callbacks.base import BaseCallbackManager
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
from langchain.callbacks.shared import SharedCallbackManager
|
from langchain.callbacks.shared import SharedCallbackManager
|
||||||
|
from langchain.callbacks.stdout import StdOutCallbackHandler
|
||||||
|
|
||||||
|
|
||||||
def get_callback_manager() -> BaseCallbackManager:
|
def get_callback_manager() -> BaseCallbackManager:
|
||||||
"""Return the shared callback manager."""
|
"""Return the shared callback manager."""
|
||||||
return SharedCallbackManager()
|
return SharedCallbackManager()
|
||||||
|
|
||||||
|
|
||||||
|
def set_default_callback_manager() -> None:
|
||||||
|
"""Set default callback manager."""
|
||||||
|
callback = get_callback_manager()
|
||||||
|
callback.add_handler(StdOutCallbackHandler())
|
||||||
|
@@ -3,7 +3,7 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from langchain.schema import LLMResult
|
from langchain.schema import AgentAction, LLMResult
|
||||||
|
|
||||||
|
|
||||||
class BaseCallbackHandler(ABC):
|
class BaseCallbackHandler(ABC):
|
||||||
@@ -11,7 +11,7 @@ class BaseCallbackHandler(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def on_llm_start(
|
def on_llm_start(
|
||||||
self, serialized: Dict[str, Any], prompts: List[str], **extra: str
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when LLM starts running."""
|
"""Run when LLM starts running."""
|
||||||
|
|
||||||
@@ -28,7 +28,7 @@ class BaseCallbackHandler(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def on_chain_start(
|
def on_chain_start(
|
||||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **extra: str
|
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when chain starts running."""
|
"""Run when chain starts running."""
|
||||||
|
|
||||||
@@ -42,12 +42,12 @@ class BaseCallbackHandler(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def on_tool_start(
|
def on_tool_start(
|
||||||
self, serialized: Dict[str, Any], action: str, tool_input: str, **extra: str
|
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when tool starts running."""
|
"""Run when tool starts running."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def on_tool_end(self, output: str) -> None:
|
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||||
"""Run when tool ends running."""
|
"""Run when tool ends running."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@@ -75,11 +75,11 @@ class CallbackManager(BaseCallbackManager):
|
|||||||
self.handlers = handlers
|
self.handlers = handlers
|
||||||
|
|
||||||
def on_llm_start(
|
def on_llm_start(
|
||||||
self, serialized: Dict[str, Any], prompts: List[str], **extra: str
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when LLM starts running."""
|
"""Run when LLM starts running."""
|
||||||
for handler in self.handlers:
|
for handler in self.handlers:
|
||||||
handler.on_llm_start(serialized, prompts, **extra)
|
handler.on_llm_start(serialized, prompts, **kwargs)
|
||||||
|
|
||||||
def on_llm_end(
|
def on_llm_end(
|
||||||
self,
|
self,
|
||||||
@@ -95,11 +95,11 @@ class CallbackManager(BaseCallbackManager):
|
|||||||
handler.on_llm_error(error)
|
handler.on_llm_error(error)
|
||||||
|
|
||||||
def on_chain_start(
|
def on_chain_start(
|
||||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **extra: str
|
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when chain starts running."""
|
"""Run when chain starts running."""
|
||||||
for handler in self.handlers:
|
for handler in self.handlers:
|
||||||
handler.on_chain_start(serialized, inputs, **extra)
|
handler.on_chain_start(serialized, inputs, **kwargs)
|
||||||
|
|
||||||
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
|
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
|
||||||
"""Run when chain ends running."""
|
"""Run when chain ends running."""
|
||||||
@@ -112,16 +112,16 @@ class CallbackManager(BaseCallbackManager):
|
|||||||
handler.on_chain_error(error)
|
handler.on_chain_error(error)
|
||||||
|
|
||||||
def on_tool_start(
|
def on_tool_start(
|
||||||
self, serialized: Dict[str, Any], action: str, tool_input: str, **extra: str
|
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when tool starts running."""
|
"""Run when tool starts running."""
|
||||||
for handler in self.handlers:
|
for handler in self.handlers:
|
||||||
handler.on_tool_start(serialized, action, tool_input, **extra)
|
handler.on_tool_start(serialized, action, **kwargs)
|
||||||
|
|
||||||
def on_tool_end(self, output: str) -> None:
|
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||||
"""Run when tool ends running."""
|
"""Run when tool ends running."""
|
||||||
for handler in self.handlers:
|
for handler in self.handlers:
|
||||||
handler.on_tool_end(output)
|
handler.on_tool_end(output, **kwargs)
|
||||||
|
|
||||||
def on_tool_error(self, error: Exception) -> None:
|
def on_tool_error(self, error: Exception) -> None:
|
||||||
"""Run when tool errors."""
|
"""Run when tool errors."""
|
||||||
|
@@ -8,7 +8,7 @@ from langchain.callbacks.base import (
|
|||||||
BaseCallbackManager,
|
BaseCallbackManager,
|
||||||
CallbackManager,
|
CallbackManager,
|
||||||
)
|
)
|
||||||
from langchain.schema import LLMResult
|
from langchain.schema import AgentAction, LLMResult
|
||||||
|
|
||||||
|
|
||||||
class Singleton:
|
class Singleton:
|
||||||
@@ -35,11 +35,11 @@ class SharedCallbackManager(Singleton, BaseCallbackManager):
|
|||||||
_callback_manager: CallbackManager = CallbackManager([])
|
_callback_manager: CallbackManager = CallbackManager([])
|
||||||
|
|
||||||
def on_llm_start(
|
def on_llm_start(
|
||||||
self, serialized: Dict[str, Any], prompts: List[str], **extra: str
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when LLM starts running."""
|
"""Run when LLM starts running."""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._callback_manager.on_llm_start(serialized, prompts, **extra)
|
self._callback_manager.on_llm_start(serialized, prompts, **kwargs)
|
||||||
|
|
||||||
def on_llm_end(
|
def on_llm_end(
|
||||||
self,
|
self,
|
||||||
@@ -55,11 +55,11 @@ class SharedCallbackManager(Singleton, BaseCallbackManager):
|
|||||||
self._callback_manager.on_llm_error(error)
|
self._callback_manager.on_llm_error(error)
|
||||||
|
|
||||||
def on_chain_start(
|
def on_chain_start(
|
||||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **extra: str
|
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when chain starts running."""
|
"""Run when chain starts running."""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._callback_manager.on_chain_start(serialized, inputs, **extra)
|
self._callback_manager.on_chain_start(serialized, inputs, **kwargs)
|
||||||
|
|
||||||
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
|
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
|
||||||
"""Run when chain ends running."""
|
"""Run when chain ends running."""
|
||||||
@@ -72,18 +72,16 @@ class SharedCallbackManager(Singleton, BaseCallbackManager):
|
|||||||
self._callback_manager.on_chain_error(error)
|
self._callback_manager.on_chain_error(error)
|
||||||
|
|
||||||
def on_tool_start(
|
def on_tool_start(
|
||||||
self, serialized: Dict[str, Any], action: str, tool_input: str, **extra: str
|
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when tool starts running."""
|
"""Run when tool starts running."""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._callback_manager.on_tool_start(
|
self._callback_manager.on_tool_start(serialized, action, **kwargs)
|
||||||
serialized, action, tool_input, **extra
|
|
||||||
)
|
|
||||||
|
|
||||||
def on_tool_end(self, output: str) -> None:
|
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||||
"""Run when tool ends running."""
|
"""Run when tool ends running."""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._callback_manager.on_tool_end(output)
|
self._callback_manager.on_tool_end(output, **kwargs)
|
||||||
|
|
||||||
def on_tool_error(self, error: Exception) -> None:
|
def on_tool_error(self, error: Exception) -> None:
|
||||||
"""Run when tool errors."""
|
"""Run when tool errors."""
|
||||||
|
69
langchain/callbacks/stdout.py
Normal file
69
langchain/callbacks/stdout.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
"""Callback Handler that prints to std out."""
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
|
from langchain.input import print_text
|
||||||
|
from langchain.schema import AGENT_FINISH_OBSERVATION, AgentAction, LLMResult
|
||||||
|
|
||||||
|
|
||||||
|
class StdOutCallbackHandler(BaseCallbackHandler):
|
||||||
|
"""Callback Handler that prints to std out."""
|
||||||
|
|
||||||
|
def on_llm_start(
|
||||||
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Print out the prompts."""
|
||||||
|
print("Prompts after formatting:")
|
||||||
|
for prompt in prompts:
|
||||||
|
print_text(prompt, color="green", end="\n")
|
||||||
|
|
||||||
|
def on_llm_end(self, response: LLMResult) -> None:
|
||||||
|
"""Do nothing."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_llm_error(self, error: Exception) -> None:
|
||||||
|
"""Do nothing."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_chain_start(
|
||||||
|
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Print out that we are entering a chain."""
|
||||||
|
class_name = serialized["name"]
|
||||||
|
print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m")
|
||||||
|
|
||||||
|
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
|
||||||
|
"""Print out that we finished a chain."""
|
||||||
|
print("\n\033[1m> Finished chain.\033[0m")
|
||||||
|
|
||||||
|
def on_chain_error(self, error: Exception) -> None:
|
||||||
|
"""Do nothing."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_tool_start(
|
||||||
|
self,
|
||||||
|
serialized: Dict[str, Any],
|
||||||
|
action: AgentAction,
|
||||||
|
color: Optional[str] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Print out the log in specified color."""
|
||||||
|
print_text(action.log, color=color)
|
||||||
|
|
||||||
|
def on_tool_end(
|
||||||
|
self,
|
||||||
|
output: str,
|
||||||
|
color: Optional[str] = None,
|
||||||
|
observation_prefix: Optional[str] = None,
|
||||||
|
llm_prefix: Optional[str] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
"""If not the final action, print out observation."""
|
||||||
|
if output != AGENT_FINISH_OBSERVATION:
|
||||||
|
print_text(f"\n{observation_prefix}")
|
||||||
|
print_text(output, color=color)
|
||||||
|
print_text(f"\n{llm_prefix}")
|
||||||
|
|
||||||
|
def on_tool_error(self, error: Exception) -> None:
|
||||||
|
"""Do nothing."""
|
||||||
|
pass
|
@@ -120,16 +120,12 @@ class Chain(BaseModel, ABC):
|
|||||||
inputs = dict(inputs, **external_context)
|
inputs = dict(inputs, **external_context)
|
||||||
self._validate_inputs(inputs)
|
self._validate_inputs(inputs)
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(
|
|
||||||
f"\n\n\033[1m> Entering new {self.__class__.__name__} chain...\033[0m"
|
|
||||||
)
|
|
||||||
self._get_callback_manager().on_chain_start(
|
self._get_callback_manager().on_chain_start(
|
||||||
{"name": self.__class__.__name__}, inputs
|
{"name": self.__class__.__name__}, inputs
|
||||||
)
|
)
|
||||||
outputs = self._call(inputs)
|
outputs = self._call(inputs)
|
||||||
self._get_callback_manager().on_chain_end(outputs)
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(f"\n\033[1m> Finished {self.__class__.__name__} chain.\033[0m")
|
self._get_callback_manager().on_chain_end(outputs)
|
||||||
self._validate_outputs(outputs)
|
self._validate_outputs(outputs)
|
||||||
if self.memory is not None:
|
if self.memory is not None:
|
||||||
self.memory.save_context(inputs, outputs)
|
self.memory.save_context(inputs, outputs)
|
||||||
|
@@ -3,7 +3,6 @@ from typing import Any, Dict, List, Sequence, Union
|
|||||||
|
|
||||||
from pydantic import BaseModel, Extra
|
from pydantic import BaseModel, Extra
|
||||||
|
|
||||||
import langchain
|
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.llms.base import BaseLLM
|
from langchain.llms.base import BaseLLM
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
@@ -54,6 +53,7 @@ class LLMChain(Chain, BaseModel):
|
|||||||
|
|
||||||
def generate(self, input_list: List[Dict[str, Any]]) -> LLMResult:
|
def generate(self, input_list: List[Dict[str, Any]]) -> LLMResult:
|
||||||
"""Generate LLM result from inputs."""
|
"""Generate LLM result from inputs."""
|
||||||
|
self.llm.verbose = self.verbose
|
||||||
stop = None
|
stop = None
|
||||||
if "stop" in input_list[0]:
|
if "stop" in input_list[0]:
|
||||||
stop = input_list[0]["stop"]
|
stop = input_list[0]["stop"]
|
||||||
@@ -61,8 +61,6 @@ class LLMChain(Chain, BaseModel):
|
|||||||
for inputs in input_list:
|
for inputs in input_list:
|
||||||
selected_inputs = {k: inputs[k] for k in self.prompt.input_variables}
|
selected_inputs = {k: inputs[k] for k in self.prompt.input_variables}
|
||||||
prompt = self.prompt.format(**selected_inputs)
|
prompt = self.prompt.format(**selected_inputs)
|
||||||
if self.verbose:
|
|
||||||
langchain.logger.log_llm_inputs(selected_inputs, prompt)
|
|
||||||
if "stop" in inputs and inputs["stop"] != stop:
|
if "stop" in inputs and inputs["stop"] != stop:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"If `stop` is present in any inputs, should be present in all."
|
"If `stop` is present in any inputs, should be present in all."
|
||||||
@@ -78,8 +76,6 @@ class LLMChain(Chain, BaseModel):
|
|||||||
for generation in response.generations:
|
for generation in response.generations:
|
||||||
# Get the text of the top generated string.
|
# Get the text of the top generated string.
|
||||||
response_str = generation[0].text
|
response_str = generation[0].text
|
||||||
if self.verbose:
|
|
||||||
langchain.logger.log_llm_response(response_str)
|
|
||||||
outputs.append({self.output_key: response_str})
|
outputs.append({self.output_key: response_str})
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
@@ -5,7 +5,7 @@ from pathlib import Path
|
|||||||
from typing import Any, Dict, List, Mapping, Optional, Union
|
from typing import Any, Dict, List, Mapping, Optional, Union
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from pydantic import BaseModel, Extra
|
from pydantic import BaseModel, Extra, Field
|
||||||
|
|
||||||
import langchain
|
import langchain
|
||||||
from langchain.callbacks import get_callback_manager
|
from langchain.callbacks import get_callback_manager
|
||||||
@@ -13,10 +13,16 @@ from langchain.callbacks.base import BaseCallbackManager
|
|||||||
from langchain.schema import Generation, LLMResult
|
from langchain.schema import Generation, LLMResult
|
||||||
|
|
||||||
|
|
||||||
|
def _get_verbosity() -> bool:
|
||||||
|
return langchain.verbose
|
||||||
|
|
||||||
|
|
||||||
class BaseLLM(BaseModel, ABC):
|
class BaseLLM(BaseModel, ABC):
|
||||||
"""LLM wrapper should take in a prompt and return a string."""
|
"""LLM wrapper should take in a prompt and return a string."""
|
||||||
|
|
||||||
cache: Optional[bool] = None
|
cache: Optional[bool] = None
|
||||||
|
verbose: bool = Field(default_factory=_get_verbosity)
|
||||||
|
"""Whether to print out response text."""
|
||||||
callback_manager: Optional[BaseCallbackManager] = None
|
callback_manager: Optional[BaseCallbackManager] = None
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
@@ -48,10 +54,12 @@ class BaseLLM(BaseModel, ABC):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Asked to cache, but no cache found at `langchain.cache`."
|
"Asked to cache, but no cache found at `langchain.cache`."
|
||||||
)
|
)
|
||||||
|
if self.verbose:
|
||||||
self._get_callback_manager().on_llm_start(
|
self._get_callback_manager().on_llm_start(
|
||||||
{"name": self.__class__.__name__}, prompts
|
{"name": self.__class__.__name__}, prompts
|
||||||
)
|
)
|
||||||
output = self._generate(prompts, stop=stop)
|
output = self._generate(prompts, stop=stop)
|
||||||
|
if self.verbose:
|
||||||
self._get_callback_manager().on_llm_end(output)
|
self._get_callback_manager().on_llm_end(output)
|
||||||
return output
|
return output
|
||||||
params = self._llm_dict()
|
params = self._llm_dict()
|
||||||
|
@@ -1,9 +1,13 @@
|
|||||||
"""Common schema objects."""
|
"""Common schema objects."""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
from typing import List, NamedTuple, Optional
|
from typing import List, NamedTuple, Optional
|
||||||
|
|
||||||
|
AGENT_FINISH_OBSERVATION = "__agent_finish__"
|
||||||
|
|
||||||
class AgentAction(NamedTuple):
|
|
||||||
|
@dataclass
|
||||||
|
class AgentAction:
|
||||||
"""Agent's action to take."""
|
"""Agent's action to take."""
|
||||||
|
|
||||||
tool: str
|
tool: str
|
||||||
@@ -11,11 +15,11 @@ class AgentAction(NamedTuple):
|
|||||||
log: str
|
log: str
|
||||||
|
|
||||||
|
|
||||||
class AgentFinish(NamedTuple):
|
@dataclass
|
||||||
|
class AgentFinish(AgentAction):
|
||||||
"""Agent's return value."""
|
"""Agent's return value."""
|
||||||
|
|
||||||
return_values: dict
|
return_values: dict
|
||||||
log: str
|
|
||||||
|
|
||||||
|
|
||||||
class Generation(NamedTuple):
|
class Generation(NamedTuple):
|
||||||
|
@@ -74,7 +74,7 @@ def test_agent_with_callbacks() -> None:
|
|||||||
f"FooBarBaz\nAction: {tool}\nAction Input: misalignment",
|
f"FooBarBaz\nAction: {tool}\nAction Input: misalignment",
|
||||||
"Oh well\nAction: Final Answer\nAction Input: curses foiled again",
|
"Oh well\nAction: Final Answer\nAction Input: curses foiled again",
|
||||||
]
|
]
|
||||||
fake_llm = FakeListLLM(responses=responses, callback_manager=manager)
|
fake_llm = FakeListLLM(responses=responses, callback_manager=manager, verbose=True)
|
||||||
tools = [
|
tools = [
|
||||||
Tool("Search", lambda x: x, "Useful for searching"),
|
Tool("Search", lambda x: x, "Useful for searching"),
|
||||||
]
|
]
|
||||||
@@ -89,7 +89,36 @@ def test_agent_with_callbacks() -> None:
|
|||||||
output = agent.run("when was langchain made")
|
output = agent.run("when was langchain made")
|
||||||
assert output == "curses foiled again"
|
assert output == "curses foiled again"
|
||||||
|
|
||||||
# 1 top level chain run, 2 LLMChain runs, 2 LLM runs, 1 tool run
|
# 1 top level chain run, 2 LLMChain runs, 2 LLM runs, 1 tool run, 1 ending
|
||||||
assert handler.starts == 6
|
assert handler.starts == 7
|
||||||
assert handler.ends == 6
|
assert handler.ends == 7
|
||||||
|
assert handler.errors == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_with_callbacks_not_verbose() -> None:
|
||||||
|
"""Test react chain with callbacks but not verbose."""
|
||||||
|
handler = FakeCallbackHandler()
|
||||||
|
manager = CallbackManager([handler])
|
||||||
|
tool = "Search"
|
||||||
|
responses = [
|
||||||
|
f"FooBarBaz\nAction: {tool}\nAction Input: misalignment",
|
||||||
|
"Oh well\nAction: Final Answer\nAction Input: curses foiled again",
|
||||||
|
]
|
||||||
|
fake_llm = FakeListLLM(responses=responses, callback_manager=manager)
|
||||||
|
tools = [
|
||||||
|
Tool("Search", lambda x: x, "Useful for searching"),
|
||||||
|
]
|
||||||
|
agent = initialize_agent(
|
||||||
|
tools,
|
||||||
|
fake_llm,
|
||||||
|
agent="zero-shot-react-description",
|
||||||
|
callback_manager=manager,
|
||||||
|
)
|
||||||
|
|
||||||
|
output = agent.run("when was langchain made")
|
||||||
|
assert output == "curses foiled again"
|
||||||
|
|
||||||
|
# 1 top level chain run, 2 LLMChain runs, 2 LLM runs, 1 tool run
|
||||||
|
assert handler.starts == 0
|
||||||
|
assert handler.ends == 0
|
||||||
assert handler.errors == 0
|
assert handler.errors == 0
|
||||||
|
@@ -2,7 +2,7 @@
|
|||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from langchain.callbacks.base import BaseCallbackHandler
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
from langchain.schema import LLMResult
|
from langchain.schema import AgentAction, LLMResult
|
||||||
|
|
||||||
|
|
||||||
class FakeCallbackHandler(BaseCallbackHandler):
|
class FakeCallbackHandler(BaseCallbackHandler):
|
||||||
@@ -15,7 +15,7 @@ class FakeCallbackHandler(BaseCallbackHandler):
|
|||||||
self.errors = 0
|
self.errors = 0
|
||||||
|
|
||||||
def on_llm_start(
|
def on_llm_start(
|
||||||
self, serialized: Dict[str, Any], prompts: List[str], **extra: str
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when LLM starts running."""
|
"""Run when LLM starts running."""
|
||||||
self.starts += 1
|
self.starts += 1
|
||||||
@@ -32,7 +32,7 @@ class FakeCallbackHandler(BaseCallbackHandler):
|
|||||||
self.errors += 1
|
self.errors += 1
|
||||||
|
|
||||||
def on_chain_start(
|
def on_chain_start(
|
||||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **extra: str
|
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when chain starts running."""
|
"""Run when chain starts running."""
|
||||||
self.starts += 1
|
self.starts += 1
|
||||||
@@ -46,12 +46,12 @@ class FakeCallbackHandler(BaseCallbackHandler):
|
|||||||
self.errors += 1
|
self.errors += 1
|
||||||
|
|
||||||
def on_tool_start(
|
def on_tool_start(
|
||||||
self, serialized: Dict[str, Any], action: str, tool_input: str, **extra: str
|
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when tool starts running."""
|
"""Run when tool starts running."""
|
||||||
self.starts += 1
|
self.starts += 1
|
||||||
|
|
||||||
def on_tool_end(self, output: str) -> None:
|
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||||
"""Run when tool ends running."""
|
"""Run when tool ends running."""
|
||||||
self.ends += 1
|
self.ends += 1
|
||||||
|
|
||||||
|
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from langchain.callbacks.base import BaseCallbackManager, CallbackManager
|
from langchain.callbacks.base import BaseCallbackManager, CallbackManager
|
||||||
from langchain.callbacks.shared import SharedCallbackManager
|
from langchain.callbacks.shared import SharedCallbackManager
|
||||||
from langchain.schema import LLMResult
|
from langchain.schema import AgentAction, LLMResult
|
||||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||||
|
|
||||||
|
|
||||||
@@ -13,10 +13,10 @@ def _test_callback_manager(
|
|||||||
manager.on_llm_start({}, [])
|
manager.on_llm_start({}, [])
|
||||||
manager.on_llm_end(LLMResult(generations=[]))
|
manager.on_llm_end(LLMResult(generations=[]))
|
||||||
manager.on_llm_error(Exception())
|
manager.on_llm_error(Exception())
|
||||||
manager.on_chain_start({}, {})
|
manager.on_chain_start({"name": "foo"}, {})
|
||||||
manager.on_chain_end({})
|
manager.on_chain_end({})
|
||||||
manager.on_chain_error(Exception())
|
manager.on_chain_error(Exception())
|
||||||
manager.on_tool_start({}, "", "")
|
manager.on_tool_start({}, AgentAction("", "", ""))
|
||||||
manager.on_tool_end("")
|
manager.on_tool_end("")
|
||||||
manager.on_tool_error(Exception())
|
manager.on_tool_error(Exception())
|
||||||
for handler in handlers:
|
for handler in handlers:
|
||||||
|
@@ -140,9 +140,20 @@ def test_run_arg_with_memory() -> None:
|
|||||||
def test_run_with_callback() -> None:
|
def test_run_with_callback() -> None:
|
||||||
"""Test run method works when callback manager is passed."""
|
"""Test run method works when callback manager is passed."""
|
||||||
handler = FakeCallbackHandler()
|
handler = FakeCallbackHandler()
|
||||||
chain = FakeChain(callback_manager=CallbackManager([handler]))
|
chain = FakeChain(callback_manager=CallbackManager([handler]), verbose=True)
|
||||||
output = chain.run("bar")
|
output = chain.run("bar")
|
||||||
assert output == "baz"
|
assert output == "baz"
|
||||||
assert handler.starts == 1
|
assert handler.starts == 1
|
||||||
assert handler.ends == 1
|
assert handler.ends == 1
|
||||||
assert handler.errors == 0
|
assert handler.errors == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_with_callback_not_verbose() -> None:
|
||||||
|
"""Test run method works when callback manager is passed and not verbose."""
|
||||||
|
handler = FakeCallbackHandler()
|
||||||
|
chain = FakeChain(callback_manager=CallbackManager([handler]))
|
||||||
|
output = chain.run("bar")
|
||||||
|
assert output == "baz"
|
||||||
|
assert handler.starts == 0
|
||||||
|
assert handler.ends == 0
|
||||||
|
assert handler.errors == 0
|
||||||
|
@@ -7,9 +7,20 @@ from tests.unit_tests.llms.fake_llm import FakeLLM
|
|||||||
def test_llm_with_callbacks() -> None:
|
def test_llm_with_callbacks() -> None:
|
||||||
"""Test LLM callbacks."""
|
"""Test LLM callbacks."""
|
||||||
handler = FakeCallbackHandler()
|
handler = FakeCallbackHandler()
|
||||||
llm = FakeLLM(callback_manager=CallbackManager([handler]))
|
llm = FakeLLM(callback_manager=CallbackManager([handler]), verbose=True)
|
||||||
output = llm("foo")
|
output = llm("foo")
|
||||||
assert output == "foo"
|
assert output == "foo"
|
||||||
assert handler.starts == 1
|
assert handler.starts == 1
|
||||||
assert handler.ends == 1
|
assert handler.ends == 1
|
||||||
assert handler.errors == 0
|
assert handler.errors == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_llm_with_callbacks_not_verbose() -> None:
|
||||||
|
"""Test LLM callbacks but not verbose."""
|
||||||
|
handler = FakeCallbackHandler()
|
||||||
|
llm = FakeLLM(callback_manager=CallbackManager([handler]))
|
||||||
|
output = llm("foo")
|
||||||
|
assert output == "foo"
|
||||||
|
assert handler.starts == 0
|
||||||
|
assert handler.ends == 0
|
||||||
|
assert handler.errors == 0
|
||||||
|
Reference in New Issue
Block a user