From a3d2a2ec2a1abfceb511364c19381253f48afac6 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Fri, 30 Dec 2022 14:43:28 -0500 Subject: [PATCH] Harrison/streamlit handler (#488) also add a set handler method usage is: ``` from langchain.callbacks.streamlit import StreamlitCallbackHandler import langchain langchain.set_handler(StreamlitCallbackHandler()) ``` produces the following output ![Screen Shot 2022-12-29 at 10 50 33 PM](https://user-images.githubusercontent.com/11986836/210032762-7f53fffa-cb2f-4dac-af39-7d4cf81e55dd.png) only works for agent stuff currently --- langchain/__init__.py | 3 +- langchain/callbacks/__init__.py | 8 +++- langchain/callbacks/base.py | 8 ++++ langchain/callbacks/shared.py | 5 +++ langchain/callbacks/streamlit.py | 72 ++++++++++++++++++++++++++++++++ 5 files changed, 94 insertions(+), 2 deletions(-) create mode 100644 langchain/callbacks/streamlit.py diff --git a/langchain/__init__.py b/langchain/__init__.py index 4a0c81555fd..02ab91ad9eb 100644 --- a/langchain/__init__.py +++ b/langchain/__init__.py @@ -4,7 +4,7 @@ from typing import Optional from langchain.agents import MRKLChain, ReActChain, SelfAskWithSearchChain from langchain.cache import BaseCache -from langchain.callbacks import set_default_callback_manager +from langchain.callbacks import set_default_callback_manager, set_handler from langchain.chains import ( ConversationChain, LLMBashChain, @@ -63,4 +63,5 @@ __all__ = [ "VectorDBQAWithSourcesChain", "QAWithSourcesChain", "PALChain", + "set_handler", ] diff --git a/langchain/callbacks/__init__.py b/langchain/callbacks/__init__.py index 9c9fb210232..959bb1c40a7 100644 --- a/langchain/callbacks/__init__.py +++ b/langchain/callbacks/__init__.py @@ -1,5 +1,5 @@ """Callback handlers that allow listening to events in LangChain.""" -from langchain.callbacks.base import BaseCallbackManager +from langchain.callbacks.base import BaseCallbackHandler, BaseCallbackManager from langchain.callbacks.shared import SharedCallbackManager from langchain.callbacks.stdout import StdOutCallbackHandler @@ -13,3 +13,9 @@ def set_default_callback_manager() -> None: """Set default callback manager.""" callback = get_callback_manager() callback.add_handler(StdOutCallbackHandler()) + + +def set_handler(handler: BaseCallbackHandler) -> None: + """Set handler.""" + callback = get_callback_manager() + callback.set_handler(handler) diff --git a/langchain/callbacks/base.py b/langchain/callbacks/base.py index b8709cce194..bcf48eb8693 100644 --- a/langchain/callbacks/base.py +++ b/langchain/callbacks/base.py @@ -70,6 +70,10 @@ class BaseCallbackManager(BaseCallbackHandler, ABC): def remove_handler(self, handler: BaseCallbackHandler) -> None: """Remove a handler from the callback manager.""" + @abstractmethod + def set_handler(self, handler: BaseCallbackHandler) -> None: + """Set handler as the only handler on the callback manager.""" + class CallbackManager(BaseCallbackManager): """Callback manager that can be used to handle callbacks from langchain.""" @@ -144,3 +148,7 @@ class CallbackManager(BaseCallbackManager): def remove_handler(self, handler: BaseCallbackHandler) -> None: """Remove a handler from the callback manager.""" self.handlers.remove(handler) + + def set_handler(self, handler: BaseCallbackHandler) -> None: + """Set handler as the only handler on the callback manager.""" + self.handlers = [handler] diff --git a/langchain/callbacks/shared.py b/langchain/callbacks/shared.py index e255bd34417..763b1de15bb 100644 --- a/langchain/callbacks/shared.py +++ b/langchain/callbacks/shared.py @@ -102,3 +102,8 @@ class SharedCallbackManager(Singleton, BaseCallbackManager): """Remove a callback from the callback manager.""" with self._lock: self._callback_manager.remove_handler(callback) + + def set_handler(self, handler: BaseCallbackHandler) -> None: + """Set handler as the only handler on the callback manager.""" + with self._lock: + self._callback_manager.handlers = [handler] diff --git a/langchain/callbacks/streamlit.py b/langchain/callbacks/streamlit.py new file mode 100644 index 00000000000..9079b45cace --- /dev/null +++ b/langchain/callbacks/streamlit.py @@ -0,0 +1,72 @@ +"""Callback Handler that logs to streamlit.""" +from typing import Any, Dict, List, Optional + +import streamlit as st + +from langchain.callbacks.base import BaseCallbackHandler +from langchain.schema import AgentAction, LLMResult + + +class StreamlitCallbackHandler(BaseCallbackHandler): + """Callback Handler that logs to streamlit.""" + + def on_llm_start( + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + ) -> None: + """Print out the prompts.""" + st.write("Prompts after formatting:") + for prompt in prompts: + st.write(prompt) + + def on_llm_end(self, response: LLMResult) -> None: + """Do nothing.""" + pass + + def on_llm_error(self, error: Exception) -> None: + """Do nothing.""" + pass + + def on_chain_start( + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any + ) -> None: + """Print out that we are entering a chain.""" + class_name = serialized["name"] + st.write(f"Entering new {class_name} chain...") + + def on_chain_end(self, outputs: Dict[str, Any]) -> None: + """Print out that we finished a chain.""" + st.write("Finished chain.") + + def on_chain_error(self, error: Exception) -> None: + """Do nothing.""" + pass + + def on_tool_start( + self, + serialized: Dict[str, Any], + action: AgentAction, + **kwargs: Any, + ) -> None: + """Print out the log in specified color.""" + # st.write requires two spaces before a newline to render it + st.markdown(action.log.replace("\n", " \n")) + + def on_tool_end( + self, + output: str, + observation_prefix: Optional[str] = None, + llm_prefix: Optional[str] = None, + **kwargs: Any, + ) -> None: + """If not the final action, print out observation.""" + st.write(f"{observation_prefix}{output}") + st.write(llm_prefix) + + def on_tool_error(self, error: Exception) -> None: + """Do nothing.""" + pass + + def on_text(self, text: str, **kwargs: Any) -> None: + """Run on text.""" + # st.write requires two spaces before a newline to render it + st.write(text.replace("\n", " \n"))