mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-27 08:58:48 +00:00
StreamlitCallbackHandler (#6315)
A new implementation of `StreamlitCallbackHandler`. It formats Agent thoughts into Streamlit expanders. You can see the handler in action here: https://langchain-mrkl.streamlit.app/ Per a discussion with Harrison, we'll be adding a `StreamlitCallbackHandler` implementation to an upcoming [Streamlit](https://github.com/streamlit/streamlit) release as well, and will be updating it as we add new LLM- and LangChain-specific features to Streamlit. The idea with this PR is that the LangChain `StreamlitCallbackHandler` will "auto-update" in a way that keeps it forward- (and backward-) compatible with Streamlit. If the user has an older Streamlit version installed, the LangChain `StreamlitCallbackHandler` will be used; if they have a newer Streamlit version that has an updated `StreamlitCallbackHandler`, that implementation will be used instead. (I'm opening this as a draft to get the conversation going and make sure we're on the same page. We're really excited to land this into LangChain!) #### Who can review? @agola11, @hwchase17
This commit is contained in:
parent
74ac6fb6b9
commit
c28990d871
@ -21,9 +21,7 @@ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
|||||||
from langchain.callbacks.streaming_stdout_final_only import (
|
from langchain.callbacks.streaming_stdout_final_only import (
|
||||||
FinalStreamingStdOutCallbackHandler,
|
FinalStreamingStdOutCallbackHandler,
|
||||||
)
|
)
|
||||||
|
from langchain.callbacks.streamlit import LLMThoughtLabeler, StreamlitCallbackHandler
|
||||||
# now streamlit requires Python >=3.7, !=3.9.7 So, it is commented out here.
|
|
||||||
# from langchain.callbacks.streamlit import StreamlitCallbackHandler
|
|
||||||
from langchain.callbacks.wandb_callback import WandbCallbackHandler
|
from langchain.callbacks.wandb_callback import WandbCallbackHandler
|
||||||
from langchain.callbacks.whylabs_callback import WhyLabsCallbackHandler
|
from langchain.callbacks.whylabs_callback import WhyLabsCallbackHandler
|
||||||
|
|
||||||
@ -42,8 +40,8 @@ __all__ = [
|
|||||||
"OpenAICallbackHandler",
|
"OpenAICallbackHandler",
|
||||||
"StdOutCallbackHandler",
|
"StdOutCallbackHandler",
|
||||||
"StreamingStdOutCallbackHandler",
|
"StreamingStdOutCallbackHandler",
|
||||||
# now streamlit requires Python >=3.7, !=3.9.7 So, it is commented out here.
|
"StreamlitCallbackHandler",
|
||||||
# "StreamlitCallbackHandler",
|
"LLMThoughtLabeler",
|
||||||
"WandbCallbackHandler",
|
"WandbCallbackHandler",
|
||||||
"WhyLabsCallbackHandler",
|
"WhyLabsCallbackHandler",
|
||||||
"get_openai_callback",
|
"get_openai_callback",
|
||||||
|
@ -1,104 +0,0 @@
|
|||||||
"""Callback Handler that logs to streamlit."""
|
|
||||||
from typing import Any, Dict, List, Optional, Union
|
|
||||||
|
|
||||||
from langchain.callbacks.base import BaseCallbackHandler
|
|
||||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
|
||||||
|
|
||||||
|
|
||||||
class StreamlitCallbackHandler(BaseCallbackHandler):
|
|
||||||
"""Callback Handler that logs to streamlit."""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
try:
|
|
||||||
import streamlit as st
|
|
||||||
except ImportError as e:
|
|
||||||
raise ImportError(
|
|
||||||
"Could not import streamlit Python package. "
|
|
||||||
"Please install it with `pip install streamlit`."
|
|
||||||
) from e
|
|
||||||
|
|
||||||
self.tokens_area = st.empty()
|
|
||||||
self.tokens_stream = ""
|
|
||||||
self.st = st
|
|
||||||
|
|
||||||
def on_llm_start(
|
|
||||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Print out the prompts."""
|
|
||||||
self.st.write("Prompts after formatting:")
|
|
||||||
for prompt in prompts:
|
|
||||||
self.st.write(prompt)
|
|
||||||
|
|
||||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
|
||||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
|
||||||
self.tokens_stream += token
|
|
||||||
self.tokens_area.write(self.tokens_stream)
|
|
||||||
|
|
||||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
|
||||||
"""Do nothing."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_llm_error(
|
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Do nothing."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_chain_start(
|
|
||||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Print out that we are entering a chain."""
|
|
||||||
class_name = serialized["name"]
|
|
||||||
self.st.write(f"Entering new {class_name} chain...")
|
|
||||||
|
|
||||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
|
||||||
"""Print out that we finished a chain."""
|
|
||||||
self.st.write("Finished chain.")
|
|
||||||
|
|
||||||
def on_chain_error(
|
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Do nothing."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_tool_start(
|
|
||||||
self,
|
|
||||||
serialized: Dict[str, Any],
|
|
||||||
input_str: str,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> None:
|
|
||||||
"""Print out the log in specified color."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
|
||||||
"""Run on agent action."""
|
|
||||||
# st.write requires two spaces before a newline to render it
|
|
||||||
|
|
||||||
self.st.markdown(action.log.replace("\n", " \n"))
|
|
||||||
|
|
||||||
def on_tool_end(
|
|
||||||
self,
|
|
||||||
output: str,
|
|
||||||
observation_prefix: Optional[str] = None,
|
|
||||||
llm_prefix: Optional[str] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> None:
|
|
||||||
"""If not the final action, print out observation."""
|
|
||||||
self.st.write(f"{observation_prefix}{output}")
|
|
||||||
self.st.write(llm_prefix)
|
|
||||||
|
|
||||||
def on_tool_error(
|
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Do nothing."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
|
||||||
"""Run on text."""
|
|
||||||
# st.write requires two spaces before a newline to render it
|
|
||||||
self.st.write(text.replace("\n", " \n"))
|
|
||||||
|
|
||||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
|
||||||
"""Run on agent end."""
|
|
||||||
# st.write requires two spaces before a newline to render it
|
|
||||||
self.st.write(finish.log.replace("\n", " \n"))
|
|
79
langchain/callbacks/streamlit/__init__.py
Normal file
79
langchain/callbacks/streamlit/__init__.py
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
|
from langchain.callbacks.streamlit.streamlit_callback_handler import (
|
||||||
|
LLMThoughtLabeler as LLMThoughtLabeler,
|
||||||
|
)
|
||||||
|
from langchain.callbacks.streamlit.streamlit_callback_handler import (
|
||||||
|
StreamlitCallbackHandler as _InternalStreamlitCallbackHandler,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from streamlit.delta_generator import DeltaGenerator
|
||||||
|
|
||||||
|
|
||||||
|
def StreamlitCallbackHandler(
|
||||||
|
parent_container: DeltaGenerator,
|
||||||
|
*,
|
||||||
|
max_thought_containers: int = 4,
|
||||||
|
expand_new_thoughts: bool = True,
|
||||||
|
collapse_completed_thoughts: bool = True,
|
||||||
|
thought_labeler: Optional[LLMThoughtLabeler] = None,
|
||||||
|
) -> BaseCallbackHandler:
|
||||||
|
"""Construct a new StreamlitCallbackHandler. This CallbackHandler is geared towards
|
||||||
|
use with a LangChain Agent; it displays the Agent's LLM and tool-usage "thoughts"
|
||||||
|
inside a series of Streamlit expanders.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
parent_container
|
||||||
|
The `st.container` that will contain all the Streamlit elements that the
|
||||||
|
Handler creates.
|
||||||
|
max_thought_containers
|
||||||
|
The max number of completed LLM thought containers to show at once. When this
|
||||||
|
threshold is reached, a new thought will cause the oldest thoughts to be
|
||||||
|
collapsed into a "History" expander. Defaults to 4.
|
||||||
|
expand_new_thoughts
|
||||||
|
Each LLM "thought" gets its own `st.expander`. This param controls whether that
|
||||||
|
expander is expanded by default. Defaults to True.
|
||||||
|
collapse_completed_thoughts
|
||||||
|
If True, LLM thought expanders will be collapsed when completed.
|
||||||
|
Defaults to True.
|
||||||
|
thought_labeler
|
||||||
|
An optional custom LLMThoughtLabeler instance. If unspecified, the handler
|
||||||
|
will use the default thought labeling logic. Defaults to None.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
A new StreamlitCallbackHandler instance.
|
||||||
|
|
||||||
|
Note that this is an "auto-updating" API: if the installed version of Streamlit
|
||||||
|
has a more recent StreamlitCallbackHandler implementation, an instance of that class
|
||||||
|
will be used.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# If we're using a version of Streamlit that implements StreamlitCallbackHandler,
|
||||||
|
# delegate to it instead of using our built-in handler. The official handler is
|
||||||
|
# guaranteed to support the same set of kwargs.
|
||||||
|
try:
|
||||||
|
from streamlit.external.langchain import (
|
||||||
|
StreamlitCallbackHandler as OfficialStreamlitCallbackHandler, # type: ignore # noqa: 501
|
||||||
|
)
|
||||||
|
|
||||||
|
return OfficialStreamlitCallbackHandler(
|
||||||
|
parent_container,
|
||||||
|
max_thought_containers=max_thought_containers,
|
||||||
|
expand_new_thoughts=expand_new_thoughts,
|
||||||
|
collapse_completed_thoughts=collapse_completed_thoughts,
|
||||||
|
thought_labeler=thought_labeler,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
return _InternalStreamlitCallbackHandler(
|
||||||
|
parent_container,
|
||||||
|
max_thought_containers=max_thought_containers,
|
||||||
|
expand_new_thoughts=expand_new_thoughts,
|
||||||
|
collapse_completed_thoughts=collapse_completed_thoughts,
|
||||||
|
thought_labeler=thought_labeler,
|
||||||
|
)
|
152
langchain/callbacks/streamlit/mutable_expander.py
Normal file
152
langchain/callbacks/streamlit/mutable_expander.py
Normal file
@ -0,0 +1,152 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from streamlit.delta_generator import DeltaGenerator
|
||||||
|
from streamlit.type_util import SupportsStr
|
||||||
|
|
||||||
|
|
||||||
|
class ChildType(Enum):
|
||||||
|
MARKDOWN = "MARKDOWN"
|
||||||
|
EXCEPTION = "EXCEPTION"
|
||||||
|
|
||||||
|
|
||||||
|
class ChildRecord(NamedTuple):
|
||||||
|
type: ChildType
|
||||||
|
kwargs: Dict[str, Any]
|
||||||
|
dg: DeltaGenerator
|
||||||
|
|
||||||
|
|
||||||
|
class MutableExpander:
|
||||||
|
"""A Streamlit expander that can be renamed and dynamically expanded/collapsed."""
|
||||||
|
|
||||||
|
def __init__(self, parent_container: DeltaGenerator, label: str, expanded: bool):
|
||||||
|
"""Create a new MutableExpander.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
parent_container
|
||||||
|
The `st.container` that the expander will be created inside.
|
||||||
|
|
||||||
|
The expander transparently deletes and recreates its underlying
|
||||||
|
`st.expander` instance when its label changes, and it uses
|
||||||
|
`parent_container` to ensure it recreates this underlying expander in the
|
||||||
|
same location onscreen.
|
||||||
|
label
|
||||||
|
The expander's initial label.
|
||||||
|
expanded
|
||||||
|
The expander's initial `expanded` value.
|
||||||
|
"""
|
||||||
|
self._label = label
|
||||||
|
self._expanded = expanded
|
||||||
|
self._parent_cursor = parent_container.empty()
|
||||||
|
self._container = self._parent_cursor.expander(label, expanded)
|
||||||
|
self._child_records: List[ChildRecord] = []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def label(self) -> str:
|
||||||
|
"""The expander's label string."""
|
||||||
|
return self._label
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expanded(self) -> bool:
|
||||||
|
"""True if the expander was created with `expanded=True`."""
|
||||||
|
return self._expanded
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Remove the container and its contents entirely. A cleared container can't
|
||||||
|
be reused.
|
||||||
|
"""
|
||||||
|
self._container = self._parent_cursor.empty()
|
||||||
|
self._child_records.clear()
|
||||||
|
|
||||||
|
def append_copy(self, other: MutableExpander) -> None:
|
||||||
|
"""Append a copy of another MutableExpander's children to this
|
||||||
|
MutableExpander.
|
||||||
|
"""
|
||||||
|
other_records = other._child_records.copy()
|
||||||
|
for record in other_records:
|
||||||
|
self._create_child(record.type, record.kwargs)
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self, *, new_label: Optional[str] = None, new_expanded: Optional[bool] = None
|
||||||
|
) -> None:
|
||||||
|
"""Change the expander's label and expanded state"""
|
||||||
|
if new_label is None:
|
||||||
|
new_label = self._label
|
||||||
|
if new_expanded is None:
|
||||||
|
new_expanded = self._expanded
|
||||||
|
|
||||||
|
if self._label == new_label and self._expanded == new_expanded:
|
||||||
|
# No change!
|
||||||
|
return
|
||||||
|
|
||||||
|
self._label = new_label
|
||||||
|
self._expanded = new_expanded
|
||||||
|
self._container = self._parent_cursor.expander(new_label, new_expanded)
|
||||||
|
|
||||||
|
prev_records = self._child_records
|
||||||
|
self._child_records = []
|
||||||
|
|
||||||
|
# Replay all children into the new container
|
||||||
|
for record in prev_records:
|
||||||
|
self._create_child(record.type, record.kwargs)
|
||||||
|
|
||||||
|
def markdown(
|
||||||
|
self,
|
||||||
|
body: SupportsStr,
|
||||||
|
unsafe_allow_html: bool = False,
|
||||||
|
*,
|
||||||
|
help: Optional[str] = None,
|
||||||
|
index: Optional[int] = None,
|
||||||
|
) -> int:
|
||||||
|
"""Add a Markdown element to the container and return its index."""
|
||||||
|
kwargs = {"body": body, "unsafe_allow_html": unsafe_allow_html, "help": help}
|
||||||
|
new_dg = self._get_dg(index).markdown(**kwargs) # type: ignore[arg-type]
|
||||||
|
record = ChildRecord(ChildType.MARKDOWN, kwargs, new_dg)
|
||||||
|
return self._add_record(record, index)
|
||||||
|
|
||||||
|
def exception(
|
||||||
|
self, exception: BaseException, *, index: Optional[int] = None
|
||||||
|
) -> int:
|
||||||
|
"""Add an Exception element to the container and return its index."""
|
||||||
|
kwargs = {"exception": exception}
|
||||||
|
new_dg = self._get_dg(index).exception(**kwargs)
|
||||||
|
record = ChildRecord(ChildType.EXCEPTION, kwargs, new_dg)
|
||||||
|
return self._add_record(record, index)
|
||||||
|
|
||||||
|
def _create_child(self, type: ChildType, kwargs: Dict[str, Any]) -> None:
|
||||||
|
"""Create a new child with the given params"""
|
||||||
|
if type == ChildType.MARKDOWN:
|
||||||
|
self.markdown(**kwargs)
|
||||||
|
elif type == ChildType.EXCEPTION:
|
||||||
|
self.exception(**kwargs)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Unexpected child type {type}")
|
||||||
|
|
||||||
|
def _add_record(self, record: ChildRecord, index: Optional[int]) -> int:
|
||||||
|
"""Add a ChildRecord to self._children. If `index` is specified, replace
|
||||||
|
the existing record at that index. Otherwise, append the record to the
|
||||||
|
end of the list.
|
||||||
|
|
||||||
|
Return the index of the added record.
|
||||||
|
"""
|
||||||
|
if index is not None:
|
||||||
|
# Replace existing child
|
||||||
|
self._child_records[index] = record
|
||||||
|
return index
|
||||||
|
|
||||||
|
# Append new child
|
||||||
|
self._child_records.append(record)
|
||||||
|
return len(self._child_records) - 1
|
||||||
|
|
||||||
|
def _get_dg(self, index: Optional[int]) -> DeltaGenerator:
|
||||||
|
if index is not None:
|
||||||
|
# Existing index: reuse child's DeltaGenerator
|
||||||
|
assert 0 <= index < len(self._child_records), f"Bad index: {index}"
|
||||||
|
return self._child_records[index].dg
|
||||||
|
|
||||||
|
# No index: use container's DeltaGenerator
|
||||||
|
return self._container
|
406
langchain/callbacks/streamlit/streamlit_callback_handler.py
Normal file
406
langchain/callbacks/streamlit/streamlit_callback_handler.py
Normal file
@ -0,0 +1,406 @@
|
|||||||
|
"""Callback Handler that prints to streamlit."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Union
|
||||||
|
|
||||||
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
|
from langchain.callbacks.streamlit.mutable_expander import MutableExpander
|
||||||
|
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from streamlit.delta_generator import DeltaGenerator
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_newlines(text: str) -> str:
|
||||||
|
"""Convert newline characters to markdown newline sequences
|
||||||
|
(space, space, newline).
|
||||||
|
"""
|
||||||
|
return text.replace("\n", " \n")
|
||||||
|
|
||||||
|
|
||||||
|
CHECKMARK_EMOJI = "✅"
|
||||||
|
THINKING_EMOJI = ":thinking_face:"
|
||||||
|
HISTORY_EMOJI = ":books:"
|
||||||
|
EXCEPTION_EMOJI = "⚠️"
|
||||||
|
|
||||||
|
|
||||||
|
class LLMThoughtState(Enum):
|
||||||
|
# The LLM is thinking about what to do next. We don't know which tool we'll run.
|
||||||
|
THINKING = "THINKING"
|
||||||
|
# The LLM has decided to run a tool. We don't have results from the tool yet.
|
||||||
|
RUNNING_TOOL = "RUNNING_TOOL"
|
||||||
|
# We have results from the tool.
|
||||||
|
COMPLETE = "COMPLETE"
|
||||||
|
|
||||||
|
|
||||||
|
class ToolRecord(NamedTuple):
|
||||||
|
name: str
|
||||||
|
input_str: str
|
||||||
|
|
||||||
|
|
||||||
|
class LLMThoughtLabeler:
|
||||||
|
"""
|
||||||
|
Generates markdown labels for LLMThought containers. Pass a custom
|
||||||
|
subclass of this to StreamlitCallbackHandler to override its default
|
||||||
|
labeling logic.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_initial_label(self) -> str:
|
||||||
|
"""Return the markdown label for a new LLMThought that doesn't have
|
||||||
|
an associated tool yet.
|
||||||
|
"""
|
||||||
|
return f"{THINKING_EMOJI} **Thinking...**"
|
||||||
|
|
||||||
|
def get_tool_label(self, tool: ToolRecord, is_complete: bool) -> str:
|
||||||
|
"""Return the label for an LLMThought that has an associated
|
||||||
|
tool.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
tool
|
||||||
|
The tool's ToolRecord
|
||||||
|
|
||||||
|
is_complete
|
||||||
|
True if the thought is complete; False if the thought
|
||||||
|
is still receiving input.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
The markdown label for the thought's container.
|
||||||
|
|
||||||
|
"""
|
||||||
|
input = tool.input_str
|
||||||
|
name = tool.name
|
||||||
|
emoji = CHECKMARK_EMOJI if is_complete else THINKING_EMOJI
|
||||||
|
if name == "_Exception":
|
||||||
|
emoji = EXCEPTION_EMOJI
|
||||||
|
name = "Parsing error"
|
||||||
|
idx = min([60, len(input)])
|
||||||
|
input = input[0:idx]
|
||||||
|
if len(tool.input_str) > idx:
|
||||||
|
input = input + "..."
|
||||||
|
input = input.replace("\n", " ")
|
||||||
|
label = f"{emoji} **{name}:** {input}"
|
||||||
|
return label
|
||||||
|
|
||||||
|
def get_history_label(self) -> str:
|
||||||
|
"""Return a markdown label for the special 'history' container
|
||||||
|
that contains overflow thoughts.
|
||||||
|
"""
|
||||||
|
return f"{HISTORY_EMOJI} **History**"
|
||||||
|
|
||||||
|
def get_final_agent_thought_label(self) -> str:
|
||||||
|
"""Return the markdown label for the agent's final thought -
|
||||||
|
the "Now I have the answer" thought, that doesn't involve
|
||||||
|
a tool.
|
||||||
|
"""
|
||||||
|
return f"{CHECKMARK_EMOJI} **Complete!**"
|
||||||
|
|
||||||
|
|
||||||
|
class LLMThought:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parent_container: DeltaGenerator,
|
||||||
|
labeler: LLMThoughtLabeler,
|
||||||
|
expanded: bool,
|
||||||
|
collapse_on_complete: bool,
|
||||||
|
):
|
||||||
|
self._container = MutableExpander(
|
||||||
|
parent_container=parent_container,
|
||||||
|
label=labeler.get_initial_label(),
|
||||||
|
expanded=expanded,
|
||||||
|
)
|
||||||
|
self._state = LLMThoughtState.THINKING
|
||||||
|
self._llm_token_stream = ""
|
||||||
|
self._llm_token_writer_idx: Optional[int] = None
|
||||||
|
self._last_tool: Optional[ToolRecord] = None
|
||||||
|
self._collapse_on_complete = collapse_on_complete
|
||||||
|
self._labeler = labeler
|
||||||
|
|
||||||
|
@property
|
||||||
|
def container(self) -> MutableExpander:
|
||||||
|
"""The container we're writing into."""
|
||||||
|
return self._container
|
||||||
|
|
||||||
|
@property
|
||||||
|
def last_tool(self) -> Optional[ToolRecord]:
|
||||||
|
"""The last tool executed by this thought"""
|
||||||
|
return self._last_tool
|
||||||
|
|
||||||
|
def _reset_llm_token_stream(self) -> None:
|
||||||
|
self._llm_token_stream = ""
|
||||||
|
self._llm_token_writer_idx = None
|
||||||
|
|
||||||
|
def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str]) -> None:
|
||||||
|
self._reset_llm_token_stream()
|
||||||
|
|
||||||
|
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||||
|
# This is only called when the LLM is initialized with `streaming=True`
|
||||||
|
self._llm_token_stream += _convert_newlines(token)
|
||||||
|
self._llm_token_writer_idx = self._container.markdown(
|
||||||
|
self._llm_token_stream, index=self._llm_token_writer_idx
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||||
|
# `response` is the concatenation of all the tokens received by the LLM.
|
||||||
|
# If we're receiving streaming tokens from `on_llm_new_token`, this response
|
||||||
|
# data is redundant
|
||||||
|
self._reset_llm_token_stream()
|
||||||
|
|
||||||
|
def on_llm_error(
|
||||||
|
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
self._container.markdown("**LLM encountered an error...**")
|
||||||
|
self._container.exception(error)
|
||||||
|
|
||||||
|
def on_tool_start(
|
||||||
|
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
# Called with the name of the tool we're about to run (in `serialized[name]`),
|
||||||
|
# and its input. We change our container's label to be the tool name.
|
||||||
|
self._state = LLMThoughtState.RUNNING_TOOL
|
||||||
|
tool_name = serialized["name"]
|
||||||
|
self._last_tool = ToolRecord(name=tool_name, input_str=input_str)
|
||||||
|
self._container.update(
|
||||||
|
new_label=self._labeler.get_tool_label(self._last_tool, is_complete=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_tool_end(
|
||||||
|
self,
|
||||||
|
output: str,
|
||||||
|
color: Optional[str] = None,
|
||||||
|
observation_prefix: Optional[str] = None,
|
||||||
|
llm_prefix: Optional[str] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
self._container.markdown(f"**{output}**")
|
||||||
|
|
||||||
|
def on_tool_error(
|
||||||
|
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
self._container.markdown("**Tool encountered an error...**")
|
||||||
|
self._container.exception(error)
|
||||||
|
|
||||||
|
def on_agent_action(
|
||||||
|
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
|
||||||
|
) -> Any:
|
||||||
|
# Called when we're about to kick off a new tool. The `action` data
|
||||||
|
# tells us the tool we're about to use, and the input we'll give it.
|
||||||
|
# We don't output anything here, because we'll receive this same data
|
||||||
|
# when `on_tool_start` is called immediately after.
|
||||||
|
pass
|
||||||
|
|
||||||
|
def complete(self, final_label: Optional[str] = None) -> None:
|
||||||
|
"""Finish the thought."""
|
||||||
|
if final_label is None and self._state == LLMThoughtState.RUNNING_TOOL:
|
||||||
|
assert (
|
||||||
|
self._last_tool is not None
|
||||||
|
), "_last_tool should never be null when _state == RUNNING_TOOL"
|
||||||
|
final_label = self._labeler.get_tool_label(
|
||||||
|
self._last_tool, is_complete=True
|
||||||
|
)
|
||||||
|
self._state = LLMThoughtState.COMPLETE
|
||||||
|
if self._collapse_on_complete:
|
||||||
|
self._container.update(new_label=final_label, new_expanded=False)
|
||||||
|
else:
|
||||||
|
self._container.update(new_label=final_label)
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Remove the thought from the screen. A cleared thought can't be reused."""
|
||||||
|
self._container.clear()
|
||||||
|
|
||||||
|
|
||||||
|
class StreamlitCallbackHandler(BaseCallbackHandler):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parent_container: DeltaGenerator,
|
||||||
|
*,
|
||||||
|
max_thought_containers: int = 4,
|
||||||
|
expand_new_thoughts: bool = True,
|
||||||
|
collapse_completed_thoughts: bool = True,
|
||||||
|
thought_labeler: Optional[LLMThoughtLabeler] = None,
|
||||||
|
):
|
||||||
|
"""Create a StreamlitCallbackHandler instance.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
parent_container
|
||||||
|
The `st.container` that will contain all the Streamlit elements that the
|
||||||
|
Handler creates.
|
||||||
|
max_thought_containers
|
||||||
|
The max number of completed LLM thought containers to show at once. When
|
||||||
|
this threshold is reached, a new thought will cause the oldest thoughts to
|
||||||
|
be collapsed into a "History" expander. Defaults to 4.
|
||||||
|
expand_new_thoughts
|
||||||
|
Each LLM "thought" gets its own `st.expander`. This param controls whether
|
||||||
|
that expander is expanded by default. Defaults to True.
|
||||||
|
collapse_completed_thoughts
|
||||||
|
If True, LLM thought expanders will be collapsed when completed.
|
||||||
|
Defaults to True.
|
||||||
|
thought_labeler
|
||||||
|
An optional custom LLMThoughtLabeler instance. If unspecified, the handler
|
||||||
|
will use the default thought labeling logic. Defaults to None.
|
||||||
|
"""
|
||||||
|
self._parent_container = parent_container
|
||||||
|
self._history_parent = parent_container.container()
|
||||||
|
self._history_container: Optional[MutableExpander] = None
|
||||||
|
self._current_thought: Optional[LLMThought] = None
|
||||||
|
self._completed_thoughts: List[LLMThought] = []
|
||||||
|
self._max_thought_containers = max(max_thought_containers, 1)
|
||||||
|
self._expand_new_thoughts = expand_new_thoughts
|
||||||
|
self._collapse_completed_thoughts = collapse_completed_thoughts
|
||||||
|
self._thought_labeler = thought_labeler or LLMThoughtLabeler()
|
||||||
|
|
||||||
|
def _require_current_thought(self) -> LLMThought:
|
||||||
|
"""Return our current LLMThought. Raise an error if we have no current
|
||||||
|
thought.
|
||||||
|
"""
|
||||||
|
if self._current_thought is None:
|
||||||
|
raise RuntimeError("Current LLMThought is unexpectedly None!")
|
||||||
|
return self._current_thought
|
||||||
|
|
||||||
|
def _get_last_completed_thought(self) -> Optional[LLMThought]:
|
||||||
|
"""Return our most recent completed LLMThought, or None if we don't have one."""
|
||||||
|
if len(self._completed_thoughts) > 0:
|
||||||
|
return self._completed_thoughts[len(self._completed_thoughts) - 1]
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _num_thought_containers(self) -> int:
|
||||||
|
"""The number of 'thought containers' we're currently showing: the
|
||||||
|
number of completed thought containers, the history container (if it exists),
|
||||||
|
and the current thought container (if it exists).
|
||||||
|
"""
|
||||||
|
count = len(self._completed_thoughts)
|
||||||
|
if self._history_container is not None:
|
||||||
|
count += 1
|
||||||
|
if self._current_thought is not None:
|
||||||
|
count += 1
|
||||||
|
return count
|
||||||
|
|
||||||
|
def _complete_current_thought(self, final_label: Optional[str] = None) -> None:
|
||||||
|
"""Complete the current thought, optionally assigning it a new label.
|
||||||
|
Add it to our _completed_thoughts list.
|
||||||
|
"""
|
||||||
|
thought = self._require_current_thought()
|
||||||
|
thought.complete(final_label)
|
||||||
|
self._completed_thoughts.append(thought)
|
||||||
|
self._current_thought = None
|
||||||
|
|
||||||
|
def _prune_old_thought_containers(self) -> None:
|
||||||
|
"""If we have too many thoughts onscreen, move older thoughts to the
|
||||||
|
'history container.'
|
||||||
|
"""
|
||||||
|
while (
|
||||||
|
self._num_thought_containers > self._max_thought_containers
|
||||||
|
and len(self._completed_thoughts) > 0
|
||||||
|
):
|
||||||
|
# Create our history container if it doesn't exist, and if
|
||||||
|
# max_thought_containers is > 1. (if max_thought_containers is 1, we don't
|
||||||
|
# have room to show history.)
|
||||||
|
if self._history_container is None and self._max_thought_containers > 1:
|
||||||
|
self._history_container = MutableExpander(
|
||||||
|
self._history_parent,
|
||||||
|
label=self._thought_labeler.get_history_label(),
|
||||||
|
expanded=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
oldest_thought = self._completed_thoughts.pop(0)
|
||||||
|
if self._history_container is not None:
|
||||||
|
self._history_container.markdown(oldest_thought.container.label)
|
||||||
|
self._history_container.append_copy(oldest_thought.container)
|
||||||
|
oldest_thought.clear()
|
||||||
|
|
||||||
|
def on_llm_start(
|
||||||
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
if self._current_thought is None:
|
||||||
|
self._current_thought = LLMThought(
|
||||||
|
parent_container=self._parent_container,
|
||||||
|
expanded=self._expand_new_thoughts,
|
||||||
|
collapse_on_complete=self._collapse_completed_thoughts,
|
||||||
|
labeler=self._thought_labeler,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._current_thought.on_llm_start(serialized, prompts)
|
||||||
|
|
||||||
|
# We don't prune_old_thought_containers here, because our container won't
|
||||||
|
# be visible until it has a child.
|
||||||
|
|
||||||
|
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||||
|
self._require_current_thought().on_llm_new_token(token, **kwargs)
|
||||||
|
self._prune_old_thought_containers()
|
||||||
|
|
||||||
|
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||||
|
self._require_current_thought().on_llm_end(response, **kwargs)
|
||||||
|
self._prune_old_thought_containers()
|
||||||
|
|
||||||
|
def on_llm_error(
|
||||||
|
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
self._require_current_thought().on_llm_error(error, **kwargs)
|
||||||
|
self._prune_old_thought_containers()
|
||||||
|
|
||||||
|
def on_tool_start(
|
||||||
|
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
self._require_current_thought().on_tool_start(serialized, input_str, **kwargs)
|
||||||
|
self._prune_old_thought_containers()
|
||||||
|
|
||||||
|
def on_tool_end(
|
||||||
|
self,
|
||||||
|
output: str,
|
||||||
|
color: Optional[str] = None,
|
||||||
|
observation_prefix: Optional[str] = None,
|
||||||
|
llm_prefix: Optional[str] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
self._require_current_thought().on_tool_end(
|
||||||
|
output, color, observation_prefix, llm_prefix, **kwargs
|
||||||
|
)
|
||||||
|
self._complete_current_thought()
|
||||||
|
|
||||||
|
def on_tool_error(
|
||||||
|
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
self._require_current_thought().on_tool_error(error, **kwargs)
|
||||||
|
self._prune_old_thought_containers()
|
||||||
|
|
||||||
|
def on_text(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
color: Optional[str] = None,
|
||||||
|
end: str = "",
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_chain_start(
|
||||||
|
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_chain_error(
|
||||||
|
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_agent_action(
|
||||||
|
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
|
||||||
|
) -> Any:
|
||||||
|
self._require_current_thought().on_agent_action(action, color, **kwargs)
|
||||||
|
self._prune_old_thought_containers()
|
||||||
|
|
||||||
|
def on_agent_finish(
|
||||||
|
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
if self._current_thought is not None:
|
||||||
|
self._current_thought.complete(
|
||||||
|
self._thought_labeler.get_final_agent_thought_label()
|
||||||
|
)
|
||||||
|
self._current_thought = None
|
703
poetry.lock
generated
703
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -110,8 +110,7 @@ langchainplus-sdk = ">=0.0.13"
|
|||||||
awadb = {version = "^0.3.3", optional = true}
|
awadb = {version = "^0.3.3", optional = true}
|
||||||
azure-search-documents = {version = "11.4.0a20230509004", source = "azure-sdk-dev", optional = true}
|
azure-search-documents = {version = "11.4.0a20230509004", source = "azure-sdk-dev", optional = true}
|
||||||
openllm = {version = ">=0.1.6", optional = true}
|
openllm = {version = ">=0.1.6", optional = true}
|
||||||
# now streamlit requires Python >=3.7, !=3.9.7 So, it is commented out.
|
streamlit = {version = "^1.18.0", optional = true, python = ">=3.8.1,<3.9.7 || >3.9.7,<4.0"}
|
||||||
#streamlit = {version = "^1.18.0", optional = true}
|
|
||||||
|
|
||||||
[tool.poetry.group.docs.dependencies]
|
[tool.poetry.group.docs.dependencies]
|
||||||
autodoc_pydantic = "^1.8.0"
|
autodoc_pydantic = "^1.8.0"
|
||||||
@ -332,8 +331,7 @@ extended_testing = [
|
|||||||
"html2text",
|
"html2text",
|
||||||
"py-trello",
|
"py-trello",
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
# now streamlit requires Python >=3.7, !=3.9.7 So, it is commented out.
|
"streamlit",
|
||||||
# "streamlit",
|
|
||||||
"pyspark",
|
"pyspark",
|
||||||
"openai"
|
"openai"
|
||||||
]
|
]
|
||||||
|
31
tests/integration_tests/callbacks/test_streamlit_callback.py
Normal file
31
tests/integration_tests/callbacks/test_streamlit_callback.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
"""Integration tests for the StreamlitCallbackHandler module."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain.agents import AgentType, initialize_agent, load_tools
|
||||||
|
|
||||||
|
# Import the internal StreamlitCallbackHandler from its module - and not from
|
||||||
|
# the `langchain.callbacks.streamlit` package - so that we don't end up using
|
||||||
|
# Streamlit's externally-provided callback handler.
|
||||||
|
from langchain.callbacks.streamlit.streamlit_callback_handler import (
|
||||||
|
StreamlitCallbackHandler,
|
||||||
|
)
|
||||||
|
from langchain.llms import OpenAI
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("streamlit")
|
||||||
|
def test_streamlit_callback_agent() -> None:
|
||||||
|
import streamlit as st
|
||||||
|
|
||||||
|
streamlit_callback = StreamlitCallbackHandler(st.container())
|
||||||
|
|
||||||
|
llm = OpenAI(temperature=0)
|
||||||
|
tools = load_tools(["serpapi", "llm-math"], llm=llm)
|
||||||
|
agent = initialize_agent(
|
||||||
|
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
||||||
|
)
|
||||||
|
agent.run(
|
||||||
|
"Who is Olivia Wilde's boyfriend? "
|
||||||
|
"What is his current age raised to the 0.23 power?",
|
||||||
|
callbacks=[streamlit_callback],
|
||||||
|
)
|
86
tests/unit_tests/callbacks/test_streamlit_callback.py
Normal file
86
tests/unit_tests/callbacks/test_streamlit_callback.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
import builtins
|
||||||
|
import unittest
|
||||||
|
from typing import Any
|
||||||
|
from unittest import mock
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
from langchain.callbacks.streamlit import StreamlitCallbackHandler
|
||||||
|
|
||||||
|
|
||||||
|
class TestImport(unittest.TestCase):
|
||||||
|
"""Test the StreamlitCallbackHandler 'auto-updating' API"""
|
||||||
|
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.builtins_import = builtins.__import__
|
||||||
|
|
||||||
|
def tearDown(self) -> None:
|
||||||
|
builtins.__import__ = self.builtins_import
|
||||||
|
|
||||||
|
@mock.patch("langchain.callbacks.streamlit._InternalStreamlitCallbackHandler")
|
||||||
|
def test_create_internal_handler(self, mock_internal_handler: Any) -> None:
|
||||||
|
"""If we're using a Streamlit that does not expose its own
|
||||||
|
StreamlitCallbackHandler, use our own implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def external_import_error(
|
||||||
|
name: str, globals: Any, locals: Any, fromlist: Any, level: int
|
||||||
|
) -> Any:
|
||||||
|
if name == "streamlit.external.langchain":
|
||||||
|
raise ImportError
|
||||||
|
return self.builtins_import(name, globals, locals, fromlist, level)
|
||||||
|
|
||||||
|
builtins.__import__ = external_import_error # type: ignore[assignment]
|
||||||
|
|
||||||
|
parent_container = MagicMock()
|
||||||
|
thought_labeler = MagicMock()
|
||||||
|
StreamlitCallbackHandler(
|
||||||
|
parent_container,
|
||||||
|
max_thought_containers=1,
|
||||||
|
expand_new_thoughts=True,
|
||||||
|
collapse_completed_thoughts=False,
|
||||||
|
thought_labeler=thought_labeler,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Our internal handler should be created
|
||||||
|
mock_internal_handler.assert_called_once_with(
|
||||||
|
parent_container,
|
||||||
|
max_thought_containers=1,
|
||||||
|
expand_new_thoughts=True,
|
||||||
|
collapse_completed_thoughts=False,
|
||||||
|
thought_labeler=thought_labeler,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_create_external_handler(self) -> None:
|
||||||
|
"""If we're using a Streamlit that *does* expose its own callback handler,
|
||||||
|
delegate to that implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
mock_streamlit_module = MagicMock()
|
||||||
|
|
||||||
|
def external_import_success(
|
||||||
|
name: str, globals: Any, locals: Any, fromlist: Any, level: int
|
||||||
|
) -> Any:
|
||||||
|
if name == "streamlit.external.langchain":
|
||||||
|
return mock_streamlit_module
|
||||||
|
return self.builtins_import(name, globals, locals, fromlist, level)
|
||||||
|
|
||||||
|
builtins.__import__ = external_import_success # type: ignore[assignment]
|
||||||
|
|
||||||
|
parent_container = MagicMock()
|
||||||
|
thought_labeler = MagicMock()
|
||||||
|
StreamlitCallbackHandler(
|
||||||
|
parent_container,
|
||||||
|
max_thought_containers=1,
|
||||||
|
expand_new_thoughts=True,
|
||||||
|
collapse_completed_thoughts=False,
|
||||||
|
thought_labeler=thought_labeler,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Streamlit's handler should be created
|
||||||
|
mock_streamlit_module.StreamlitCallbackHandler.assert_called_once_with(
|
||||||
|
parent_container,
|
||||||
|
max_thought_containers=1,
|
||||||
|
expand_new_thoughts=True,
|
||||||
|
collapse_completed_thoughts=False,
|
||||||
|
thought_labeler=thought_labeler,
|
||||||
|
)
|
Loading…
Reference in New Issue
Block a user