Harrison/streamlit handler (#488)

also add a set handler method

usage is:
```
from langchain.callbacks.streamlit import StreamlitCallbackHandler
import langchain
langchain.set_handler(StreamlitCallbackHandler())
```

produces the following output


![Screen Shot 2022-12-29 at 10 50 33
PM](https://user-images.githubusercontent.com/11986836/210032762-7f53fffa-cb2f-4dac-af39-7d4cf81e55dd.png)

only works for agent stuff currently
This commit is contained in:
Harrison Chase 2022-12-30 14:43:28 -05:00 committed by GitHub
parent 45d6de177e
commit a3d2a2ec2a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 94 additions and 2 deletions

View File

@ -4,7 +4,7 @@ from typing import Optional
from langchain.agents import MRKLChain, ReActChain, SelfAskWithSearchChain from langchain.agents import MRKLChain, ReActChain, SelfAskWithSearchChain
from langchain.cache import BaseCache from langchain.cache import BaseCache
from langchain.callbacks import set_default_callback_manager from langchain.callbacks import set_default_callback_manager, set_handler
from langchain.chains import ( from langchain.chains import (
ConversationChain, ConversationChain,
LLMBashChain, LLMBashChain,
@ -63,4 +63,5 @@ __all__ = [
"VectorDBQAWithSourcesChain", "VectorDBQAWithSourcesChain",
"QAWithSourcesChain", "QAWithSourcesChain",
"PALChain", "PALChain",
"set_handler",
] ]

View File

@ -1,5 +1,5 @@
"""Callback handlers that allow listening to events in LangChain.""" """Callback handlers that allow listening to events in LangChain."""
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackHandler, BaseCallbackManager
from langchain.callbacks.shared import SharedCallbackManager from langchain.callbacks.shared import SharedCallbackManager
from langchain.callbacks.stdout import StdOutCallbackHandler from langchain.callbacks.stdout import StdOutCallbackHandler
@ -13,3 +13,9 @@ def set_default_callback_manager() -> None:
"""Set default callback manager.""" """Set default callback manager."""
callback = get_callback_manager() callback = get_callback_manager()
callback.add_handler(StdOutCallbackHandler()) callback.add_handler(StdOutCallbackHandler())
def set_handler(handler: BaseCallbackHandler) -> None:
"""Set handler."""
callback = get_callback_manager()
callback.set_handler(handler)

View File

@ -70,6 +70,10 @@ class BaseCallbackManager(BaseCallbackHandler, ABC):
def remove_handler(self, handler: BaseCallbackHandler) -> None: def remove_handler(self, handler: BaseCallbackHandler) -> None:
"""Remove a handler from the callback manager.""" """Remove a handler from the callback manager."""
@abstractmethod
def set_handler(self, handler: BaseCallbackHandler) -> None:
"""Set handler as the only handler on the callback manager."""
class CallbackManager(BaseCallbackManager): class CallbackManager(BaseCallbackManager):
"""Callback manager that can be used to handle callbacks from langchain.""" """Callback manager that can be used to handle callbacks from langchain."""
@ -144,3 +148,7 @@ class CallbackManager(BaseCallbackManager):
def remove_handler(self, handler: BaseCallbackHandler) -> None: def remove_handler(self, handler: BaseCallbackHandler) -> None:
"""Remove a handler from the callback manager.""" """Remove a handler from the callback manager."""
self.handlers.remove(handler) self.handlers.remove(handler)
def set_handler(self, handler: BaseCallbackHandler) -> None:
"""Set handler as the only handler on the callback manager."""
self.handlers = [handler]

View File

@ -102,3 +102,8 @@ class SharedCallbackManager(Singleton, BaseCallbackManager):
"""Remove a callback from the callback manager.""" """Remove a callback from the callback manager."""
with self._lock: with self._lock:
self._callback_manager.remove_handler(callback) self._callback_manager.remove_handler(callback)
def set_handler(self, handler: BaseCallbackHandler) -> None:
"""Set handler as the only handler on the callback manager."""
with self._lock:
self._callback_manager.handlers = [handler]

View File

@ -0,0 +1,72 @@
"""Callback Handler that logs to streamlit."""
from typing import Any, Dict, List, Optional
import streamlit as st
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, LLMResult
class StreamlitCallbackHandler(BaseCallbackHandler):
"""Callback Handler that logs to streamlit."""
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Print out the prompts."""
st.write("Prompts after formatting:")
for prompt in prompts:
st.write(prompt)
def on_llm_end(self, response: LLMResult) -> None:
"""Do nothing."""
pass
def on_llm_error(self, error: Exception) -> None:
"""Do nothing."""
pass
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Print out that we are entering a chain."""
class_name = serialized["name"]
st.write(f"Entering new {class_name} chain...")
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
"""Print out that we finished a chain."""
st.write("Finished chain.")
def on_chain_error(self, error: Exception) -> None:
"""Do nothing."""
pass
def on_tool_start(
self,
serialized: Dict[str, Any],
action: AgentAction,
**kwargs: Any,
) -> None:
"""Print out the log in specified color."""
# st.write requires two spaces before a newline to render it
st.markdown(action.log.replace("\n", " \n"))
def on_tool_end(
self,
output: str,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
"""If not the final action, print out observation."""
st.write(f"{observation_prefix}{output}")
st.write(llm_prefix)
def on_tool_error(self, error: Exception) -> None:
"""Do nothing."""
pass
def on_text(self, text: str, **kwargs: Any) -> None:
"""Run on text."""
# st.write requires two spaces before a newline to render it
st.write(text.replace("\n", " \n"))