mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 23:54:14 +00:00
StreamlitCallbackHandler (#6315)
A new implementation of `StreamlitCallbackHandler`. It formats Agent thoughts into Streamlit expanders. You can see the handler in action here: https://langchain-mrkl.streamlit.app/ Per a discussion with Harrison, we'll be adding a `StreamlitCallbackHandler` implementation to an upcoming [Streamlit](https://github.com/streamlit/streamlit) release as well, and will be updating it as we add new LLM- and LangChain-specific features to Streamlit. The idea with this PR is that the LangChain `StreamlitCallbackHandler` will "auto-update" in a way that keeps it forward- (and backward-) compatible with Streamlit. If the user has an older Streamlit version installed, the LangChain `StreamlitCallbackHandler` will be used; if they have a newer Streamlit version that has an updated `StreamlitCallbackHandler`, that implementation will be used instead. (I'm opening this as a draft to get the conversation going and make sure we're on the same page. We're really excited to land this into LangChain!) #### Who can review? @agola11, @hwchase17
This commit is contained in:
parent
74ac6fb6b9
commit
c28990d871
@ -21,9 +21,7 @@ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
from langchain.callbacks.streaming_stdout_final_only import (
|
||||
FinalStreamingStdOutCallbackHandler,
|
||||
)
|
||||
|
||||
# now streamlit requires Python >=3.7, !=3.9.7 So, it is commented out here.
|
||||
# from langchain.callbacks.streamlit import StreamlitCallbackHandler
|
||||
from langchain.callbacks.streamlit import LLMThoughtLabeler, StreamlitCallbackHandler
|
||||
from langchain.callbacks.wandb_callback import WandbCallbackHandler
|
||||
from langchain.callbacks.whylabs_callback import WhyLabsCallbackHandler
|
||||
|
||||
@ -42,8 +40,8 @@ __all__ = [
|
||||
"OpenAICallbackHandler",
|
||||
"StdOutCallbackHandler",
|
||||
"StreamingStdOutCallbackHandler",
|
||||
# now streamlit requires Python >=3.7, !=3.9.7 So, it is commented out here.
|
||||
# "StreamlitCallbackHandler",
|
||||
"StreamlitCallbackHandler",
|
||||
"LLMThoughtLabeler",
|
||||
"WandbCallbackHandler",
|
||||
"WhyLabsCallbackHandler",
|
||||
"get_openai_callback",
|
||||
|
@ -1,104 +0,0 @@
|
||||
"""Callback Handler that logs to streamlit."""
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
|
||||
class StreamlitCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler that logs to streamlit."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
try:
|
||||
import streamlit as st
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Could not import streamlit Python package. "
|
||||
"Please install it with `pip install streamlit`."
|
||||
) from e
|
||||
|
||||
self.tokens_area = st.empty()
|
||||
self.tokens_stream = ""
|
||||
self.st = st
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out the prompts."""
|
||||
self.st.write("Prompts after formatting:")
|
||||
for prompt in prompts:
|
||||
self.st.write(prompt)
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||
self.tokens_stream += token
|
||||
self.tokens_area.write(self.tokens_stream)
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out that we are entering a chain."""
|
||||
class_name = serialized["name"]
|
||||
self.st.write(f"Entering new {class_name} chain...")
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Print out that we finished a chain."""
|
||||
self.st.write("Finished chain.")
|
||||
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Print out the log in specified color."""
|
||||
pass
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Run on agent action."""
|
||||
# st.write requires two spaces before a newline to render it
|
||||
|
||||
self.st.markdown(action.log.replace("\n", " \n"))
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""If not the final action, print out observation."""
|
||||
self.st.write(f"{observation_prefix}{output}")
|
||||
self.st.write(llm_prefix)
|
||||
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""Run on text."""
|
||||
# st.write requires two spaces before a newline to render it
|
||||
self.st.write(text.replace("\n", " \n"))
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Run on agent end."""
|
||||
# st.write requires two spaces before a newline to render it
|
||||
self.st.write(finish.log.replace("\n", " \n"))
|
79
langchain/callbacks/streamlit/__init__.py
Normal file
79
langchain/callbacks/streamlit/__init__.py
Normal file
@ -0,0 +1,79 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.callbacks.streamlit.streamlit_callback_handler import (
|
||||
LLMThoughtLabeler as LLMThoughtLabeler,
|
||||
)
|
||||
from langchain.callbacks.streamlit.streamlit_callback_handler import (
|
||||
StreamlitCallbackHandler as _InternalStreamlitCallbackHandler,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from streamlit.delta_generator import DeltaGenerator
|
||||
|
||||
|
||||
def StreamlitCallbackHandler(
|
||||
parent_container: DeltaGenerator,
|
||||
*,
|
||||
max_thought_containers: int = 4,
|
||||
expand_new_thoughts: bool = True,
|
||||
collapse_completed_thoughts: bool = True,
|
||||
thought_labeler: Optional[LLMThoughtLabeler] = None,
|
||||
) -> BaseCallbackHandler:
|
||||
"""Construct a new StreamlitCallbackHandler. This CallbackHandler is geared towards
|
||||
use with a LangChain Agent; it displays the Agent's LLM and tool-usage "thoughts"
|
||||
inside a series of Streamlit expanders.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
parent_container
|
||||
The `st.container` that will contain all the Streamlit elements that the
|
||||
Handler creates.
|
||||
max_thought_containers
|
||||
The max number of completed LLM thought containers to show at once. When this
|
||||
threshold is reached, a new thought will cause the oldest thoughts to be
|
||||
collapsed into a "History" expander. Defaults to 4.
|
||||
expand_new_thoughts
|
||||
Each LLM "thought" gets its own `st.expander`. This param controls whether that
|
||||
expander is expanded by default. Defaults to True.
|
||||
collapse_completed_thoughts
|
||||
If True, LLM thought expanders will be collapsed when completed.
|
||||
Defaults to True.
|
||||
thought_labeler
|
||||
An optional custom LLMThoughtLabeler instance. If unspecified, the handler
|
||||
will use the default thought labeling logic. Defaults to None.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A new StreamlitCallbackHandler instance.
|
||||
|
||||
Note that this is an "auto-updating" API: if the installed version of Streamlit
|
||||
has a more recent StreamlitCallbackHandler implementation, an instance of that class
|
||||
will be used.
|
||||
|
||||
"""
|
||||
# If we're using a version of Streamlit that implements StreamlitCallbackHandler,
|
||||
# delegate to it instead of using our built-in handler. The official handler is
|
||||
# guaranteed to support the same set of kwargs.
|
||||
try:
|
||||
from streamlit.external.langchain import (
|
||||
StreamlitCallbackHandler as OfficialStreamlitCallbackHandler, # type: ignore # noqa: 501
|
||||
)
|
||||
|
||||
return OfficialStreamlitCallbackHandler(
|
||||
parent_container,
|
||||
max_thought_containers=max_thought_containers,
|
||||
expand_new_thoughts=expand_new_thoughts,
|
||||
collapse_completed_thoughts=collapse_completed_thoughts,
|
||||
thought_labeler=thought_labeler,
|
||||
)
|
||||
except ImportError:
|
||||
return _InternalStreamlitCallbackHandler(
|
||||
parent_container,
|
||||
max_thought_containers=max_thought_containers,
|
||||
expand_new_thoughts=expand_new_thoughts,
|
||||
collapse_completed_thoughts=collapse_completed_thoughts,
|
||||
thought_labeler=thought_labeler,
|
||||
)
|
152
langchain/callbacks/streamlit/mutable_expander.py
Normal file
152
langchain/callbacks/streamlit/mutable_expander.py
Normal file
@ -0,0 +1,152 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from streamlit.delta_generator import DeltaGenerator
|
||||
from streamlit.type_util import SupportsStr
|
||||
|
||||
|
||||
class ChildType(Enum):
|
||||
MARKDOWN = "MARKDOWN"
|
||||
EXCEPTION = "EXCEPTION"
|
||||
|
||||
|
||||
class ChildRecord(NamedTuple):
|
||||
type: ChildType
|
||||
kwargs: Dict[str, Any]
|
||||
dg: DeltaGenerator
|
||||
|
||||
|
||||
class MutableExpander:
|
||||
"""A Streamlit expander that can be renamed and dynamically expanded/collapsed."""
|
||||
|
||||
def __init__(self, parent_container: DeltaGenerator, label: str, expanded: bool):
|
||||
"""Create a new MutableExpander.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
parent_container
|
||||
The `st.container` that the expander will be created inside.
|
||||
|
||||
The expander transparently deletes and recreates its underlying
|
||||
`st.expander` instance when its label changes, and it uses
|
||||
`parent_container` to ensure it recreates this underlying expander in the
|
||||
same location onscreen.
|
||||
label
|
||||
The expander's initial label.
|
||||
expanded
|
||||
The expander's initial `expanded` value.
|
||||
"""
|
||||
self._label = label
|
||||
self._expanded = expanded
|
||||
self._parent_cursor = parent_container.empty()
|
||||
self._container = self._parent_cursor.expander(label, expanded)
|
||||
self._child_records: List[ChildRecord] = []
|
||||
|
||||
@property
|
||||
def label(self) -> str:
|
||||
"""The expander's label string."""
|
||||
return self._label
|
||||
|
||||
@property
|
||||
def expanded(self) -> bool:
|
||||
"""True if the expander was created with `expanded=True`."""
|
||||
return self._expanded
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Remove the container and its contents entirely. A cleared container can't
|
||||
be reused.
|
||||
"""
|
||||
self._container = self._parent_cursor.empty()
|
||||
self._child_records.clear()
|
||||
|
||||
def append_copy(self, other: MutableExpander) -> None:
|
||||
"""Append a copy of another MutableExpander's children to this
|
||||
MutableExpander.
|
||||
"""
|
||||
other_records = other._child_records.copy()
|
||||
for record in other_records:
|
||||
self._create_child(record.type, record.kwargs)
|
||||
|
||||
def update(
|
||||
self, *, new_label: Optional[str] = None, new_expanded: Optional[bool] = None
|
||||
) -> None:
|
||||
"""Change the expander's label and expanded state"""
|
||||
if new_label is None:
|
||||
new_label = self._label
|
||||
if new_expanded is None:
|
||||
new_expanded = self._expanded
|
||||
|
||||
if self._label == new_label and self._expanded == new_expanded:
|
||||
# No change!
|
||||
return
|
||||
|
||||
self._label = new_label
|
||||
self._expanded = new_expanded
|
||||
self._container = self._parent_cursor.expander(new_label, new_expanded)
|
||||
|
||||
prev_records = self._child_records
|
||||
self._child_records = []
|
||||
|
||||
# Replay all children into the new container
|
||||
for record in prev_records:
|
||||
self._create_child(record.type, record.kwargs)
|
||||
|
||||
def markdown(
|
||||
self,
|
||||
body: SupportsStr,
|
||||
unsafe_allow_html: bool = False,
|
||||
*,
|
||||
help: Optional[str] = None,
|
||||
index: Optional[int] = None,
|
||||
) -> int:
|
||||
"""Add a Markdown element to the container and return its index."""
|
||||
kwargs = {"body": body, "unsafe_allow_html": unsafe_allow_html, "help": help}
|
||||
new_dg = self._get_dg(index).markdown(**kwargs) # type: ignore[arg-type]
|
||||
record = ChildRecord(ChildType.MARKDOWN, kwargs, new_dg)
|
||||
return self._add_record(record, index)
|
||||
|
||||
def exception(
|
||||
self, exception: BaseException, *, index: Optional[int] = None
|
||||
) -> int:
|
||||
"""Add an Exception element to the container and return its index."""
|
||||
kwargs = {"exception": exception}
|
||||
new_dg = self._get_dg(index).exception(**kwargs)
|
||||
record = ChildRecord(ChildType.EXCEPTION, kwargs, new_dg)
|
||||
return self._add_record(record, index)
|
||||
|
||||
def _create_child(self, type: ChildType, kwargs: Dict[str, Any]) -> None:
|
||||
"""Create a new child with the given params"""
|
||||
if type == ChildType.MARKDOWN:
|
||||
self.markdown(**kwargs)
|
||||
elif type == ChildType.EXCEPTION:
|
||||
self.exception(**kwargs)
|
||||
else:
|
||||
raise RuntimeError(f"Unexpected child type {type}")
|
||||
|
||||
def _add_record(self, record: ChildRecord, index: Optional[int]) -> int:
|
||||
"""Add a ChildRecord to self._children. If `index` is specified, replace
|
||||
the existing record at that index. Otherwise, append the record to the
|
||||
end of the list.
|
||||
|
||||
Return the index of the added record.
|
||||
"""
|
||||
if index is not None:
|
||||
# Replace existing child
|
||||
self._child_records[index] = record
|
||||
return index
|
||||
|
||||
# Append new child
|
||||
self._child_records.append(record)
|
||||
return len(self._child_records) - 1
|
||||
|
||||
def _get_dg(self, index: Optional[int]) -> DeltaGenerator:
|
||||
if index is not None:
|
||||
# Existing index: reuse child's DeltaGenerator
|
||||
assert 0 <= index < len(self._child_records), f"Bad index: {index}"
|
||||
return self._child_records[index].dg
|
||||
|
||||
# No index: use container's DeltaGenerator
|
||||
return self._container
|
406
langchain/callbacks/streamlit/streamlit_callback_handler.py
Normal file
406
langchain/callbacks/streamlit/streamlit_callback_handler.py
Normal file
@ -0,0 +1,406 @@
|
||||
"""Callback Handler that prints to streamlit."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.callbacks.streamlit.mutable_expander import MutableExpander
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from streamlit.delta_generator import DeltaGenerator
|
||||
|
||||
|
||||
def _convert_newlines(text: str) -> str:
|
||||
"""Convert newline characters to markdown newline sequences
|
||||
(space, space, newline).
|
||||
"""
|
||||
return text.replace("\n", " \n")
|
||||
|
||||
|
||||
CHECKMARK_EMOJI = "✅"
|
||||
THINKING_EMOJI = ":thinking_face:"
|
||||
HISTORY_EMOJI = ":books:"
|
||||
EXCEPTION_EMOJI = "⚠️"
|
||||
|
||||
|
||||
class LLMThoughtState(Enum):
|
||||
# The LLM is thinking about what to do next. We don't know which tool we'll run.
|
||||
THINKING = "THINKING"
|
||||
# The LLM has decided to run a tool. We don't have results from the tool yet.
|
||||
RUNNING_TOOL = "RUNNING_TOOL"
|
||||
# We have results from the tool.
|
||||
COMPLETE = "COMPLETE"
|
||||
|
||||
|
||||
class ToolRecord(NamedTuple):
|
||||
name: str
|
||||
input_str: str
|
||||
|
||||
|
||||
class LLMThoughtLabeler:
|
||||
"""
|
||||
Generates markdown labels for LLMThought containers. Pass a custom
|
||||
subclass of this to StreamlitCallbackHandler to override its default
|
||||
labeling logic.
|
||||
"""
|
||||
|
||||
def get_initial_label(self) -> str:
|
||||
"""Return the markdown label for a new LLMThought that doesn't have
|
||||
an associated tool yet.
|
||||
"""
|
||||
return f"{THINKING_EMOJI} **Thinking...**"
|
||||
|
||||
def get_tool_label(self, tool: ToolRecord, is_complete: bool) -> str:
|
||||
"""Return the label for an LLMThought that has an associated
|
||||
tool.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tool
|
||||
The tool's ToolRecord
|
||||
|
||||
is_complete
|
||||
True if the thought is complete; False if the thought
|
||||
is still receiving input.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The markdown label for the thought's container.
|
||||
|
||||
"""
|
||||
input = tool.input_str
|
||||
name = tool.name
|
||||
emoji = CHECKMARK_EMOJI if is_complete else THINKING_EMOJI
|
||||
if name == "_Exception":
|
||||
emoji = EXCEPTION_EMOJI
|
||||
name = "Parsing error"
|
||||
idx = min([60, len(input)])
|
||||
input = input[0:idx]
|
||||
if len(tool.input_str) > idx:
|
||||
input = input + "..."
|
||||
input = input.replace("\n", " ")
|
||||
label = f"{emoji} **{name}:** {input}"
|
||||
return label
|
||||
|
||||
def get_history_label(self) -> str:
|
||||
"""Return a markdown label for the special 'history' container
|
||||
that contains overflow thoughts.
|
||||
"""
|
||||
return f"{HISTORY_EMOJI} **History**"
|
||||
|
||||
def get_final_agent_thought_label(self) -> str:
|
||||
"""Return the markdown label for the agent's final thought -
|
||||
the "Now I have the answer" thought, that doesn't involve
|
||||
a tool.
|
||||
"""
|
||||
return f"{CHECKMARK_EMOJI} **Complete!**"
|
||||
|
||||
|
||||
class LLMThought:
|
||||
def __init__(
|
||||
self,
|
||||
parent_container: DeltaGenerator,
|
||||
labeler: LLMThoughtLabeler,
|
||||
expanded: bool,
|
||||
collapse_on_complete: bool,
|
||||
):
|
||||
self._container = MutableExpander(
|
||||
parent_container=parent_container,
|
||||
label=labeler.get_initial_label(),
|
||||
expanded=expanded,
|
||||
)
|
||||
self._state = LLMThoughtState.THINKING
|
||||
self._llm_token_stream = ""
|
||||
self._llm_token_writer_idx: Optional[int] = None
|
||||
self._last_tool: Optional[ToolRecord] = None
|
||||
self._collapse_on_complete = collapse_on_complete
|
||||
self._labeler = labeler
|
||||
|
||||
@property
|
||||
def container(self) -> MutableExpander:
|
||||
"""The container we're writing into."""
|
||||
return self._container
|
||||
|
||||
@property
|
||||
def last_tool(self) -> Optional[ToolRecord]:
|
||||
"""The last tool executed by this thought"""
|
||||
return self._last_tool
|
||||
|
||||
def _reset_llm_token_stream(self) -> None:
|
||||
self._llm_token_stream = ""
|
||||
self._llm_token_writer_idx = None
|
||||
|
||||
def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str]) -> None:
|
||||
self._reset_llm_token_stream()
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
# This is only called when the LLM is initialized with `streaming=True`
|
||||
self._llm_token_stream += _convert_newlines(token)
|
||||
self._llm_token_writer_idx = self._container.markdown(
|
||||
self._llm_token_stream, index=self._llm_token_writer_idx
|
||||
)
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
# `response` is the concatenation of all the tokens received by the LLM.
|
||||
# If we're receiving streaming tokens from `on_llm_new_token`, this response
|
||||
# data is redundant
|
||||
self._reset_llm_token_stream()
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
self._container.markdown("**LLM encountered an error...**")
|
||||
self._container.exception(error)
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> None:
|
||||
# Called with the name of the tool we're about to run (in `serialized[name]`),
|
||||
# and its input. We change our container's label to be the tool name.
|
||||
self._state = LLMThoughtState.RUNNING_TOOL
|
||||
tool_name = serialized["name"]
|
||||
self._last_tool = ToolRecord(name=tool_name, input_str=input_str)
|
||||
self._container.update(
|
||||
new_label=self._labeler.get_tool_label(self._last_tool, is_complete=False)
|
||||
)
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
color: Optional[str] = None,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self._container.markdown(f"**{output}**")
|
||||
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
self._container.markdown("**Tool encountered an error...**")
|
||||
self._container.exception(error)
|
||||
|
||||
def on_agent_action(
|
||||
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
# Called when we're about to kick off a new tool. The `action` data
|
||||
# tells us the tool we're about to use, and the input we'll give it.
|
||||
# We don't output anything here, because we'll receive this same data
|
||||
# when `on_tool_start` is called immediately after.
|
||||
pass
|
||||
|
||||
def complete(self, final_label: Optional[str] = None) -> None:
|
||||
"""Finish the thought."""
|
||||
if final_label is None and self._state == LLMThoughtState.RUNNING_TOOL:
|
||||
assert (
|
||||
self._last_tool is not None
|
||||
), "_last_tool should never be null when _state == RUNNING_TOOL"
|
||||
final_label = self._labeler.get_tool_label(
|
||||
self._last_tool, is_complete=True
|
||||
)
|
||||
self._state = LLMThoughtState.COMPLETE
|
||||
if self._collapse_on_complete:
|
||||
self._container.update(new_label=final_label, new_expanded=False)
|
||||
else:
|
||||
self._container.update(new_label=final_label)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Remove the thought from the screen. A cleared thought can't be reused."""
|
||||
self._container.clear()
|
||||
|
||||
|
||||
class StreamlitCallbackHandler(BaseCallbackHandler):
|
||||
def __init__(
|
||||
self,
|
||||
parent_container: DeltaGenerator,
|
||||
*,
|
||||
max_thought_containers: int = 4,
|
||||
expand_new_thoughts: bool = True,
|
||||
collapse_completed_thoughts: bool = True,
|
||||
thought_labeler: Optional[LLMThoughtLabeler] = None,
|
||||
):
|
||||
"""Create a StreamlitCallbackHandler instance.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
parent_container
|
||||
The `st.container` that will contain all the Streamlit elements that the
|
||||
Handler creates.
|
||||
max_thought_containers
|
||||
The max number of completed LLM thought containers to show at once. When
|
||||
this threshold is reached, a new thought will cause the oldest thoughts to
|
||||
be collapsed into a "History" expander. Defaults to 4.
|
||||
expand_new_thoughts
|
||||
Each LLM "thought" gets its own `st.expander`. This param controls whether
|
||||
that expander is expanded by default. Defaults to True.
|
||||
collapse_completed_thoughts
|
||||
If True, LLM thought expanders will be collapsed when completed.
|
||||
Defaults to True.
|
||||
thought_labeler
|
||||
An optional custom LLMThoughtLabeler instance. If unspecified, the handler
|
||||
will use the default thought labeling logic. Defaults to None.
|
||||
"""
|
||||
self._parent_container = parent_container
|
||||
self._history_parent = parent_container.container()
|
||||
self._history_container: Optional[MutableExpander] = None
|
||||
self._current_thought: Optional[LLMThought] = None
|
||||
self._completed_thoughts: List[LLMThought] = []
|
||||
self._max_thought_containers = max(max_thought_containers, 1)
|
||||
self._expand_new_thoughts = expand_new_thoughts
|
||||
self._collapse_completed_thoughts = collapse_completed_thoughts
|
||||
self._thought_labeler = thought_labeler or LLMThoughtLabeler()
|
||||
|
||||
def _require_current_thought(self) -> LLMThought:
|
||||
"""Return our current LLMThought. Raise an error if we have no current
|
||||
thought.
|
||||
"""
|
||||
if self._current_thought is None:
|
||||
raise RuntimeError("Current LLMThought is unexpectedly None!")
|
||||
return self._current_thought
|
||||
|
||||
def _get_last_completed_thought(self) -> Optional[LLMThought]:
|
||||
"""Return our most recent completed LLMThought, or None if we don't have one."""
|
||||
if len(self._completed_thoughts) > 0:
|
||||
return self._completed_thoughts[len(self._completed_thoughts) - 1]
|
||||
return None
|
||||
|
||||
@property
|
||||
def _num_thought_containers(self) -> int:
|
||||
"""The number of 'thought containers' we're currently showing: the
|
||||
number of completed thought containers, the history container (if it exists),
|
||||
and the current thought container (if it exists).
|
||||
"""
|
||||
count = len(self._completed_thoughts)
|
||||
if self._history_container is not None:
|
||||
count += 1
|
||||
if self._current_thought is not None:
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def _complete_current_thought(self, final_label: Optional[str] = None) -> None:
|
||||
"""Complete the current thought, optionally assigning it a new label.
|
||||
Add it to our _completed_thoughts list.
|
||||
"""
|
||||
thought = self._require_current_thought()
|
||||
thought.complete(final_label)
|
||||
self._completed_thoughts.append(thought)
|
||||
self._current_thought = None
|
||||
|
||||
def _prune_old_thought_containers(self) -> None:
|
||||
"""If we have too many thoughts onscreen, move older thoughts to the
|
||||
'history container.'
|
||||
"""
|
||||
while (
|
||||
self._num_thought_containers > self._max_thought_containers
|
||||
and len(self._completed_thoughts) > 0
|
||||
):
|
||||
# Create our history container if it doesn't exist, and if
|
||||
# max_thought_containers is > 1. (if max_thought_containers is 1, we don't
|
||||
# have room to show history.)
|
||||
if self._history_container is None and self._max_thought_containers > 1:
|
||||
self._history_container = MutableExpander(
|
||||
self._history_parent,
|
||||
label=self._thought_labeler.get_history_label(),
|
||||
expanded=False,
|
||||
)
|
||||
|
||||
oldest_thought = self._completed_thoughts.pop(0)
|
||||
if self._history_container is not None:
|
||||
self._history_container.markdown(oldest_thought.container.label)
|
||||
self._history_container.append_copy(oldest_thought.container)
|
||||
oldest_thought.clear()
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
if self._current_thought is None:
|
||||
self._current_thought = LLMThought(
|
||||
parent_container=self._parent_container,
|
||||
expanded=self._expand_new_thoughts,
|
||||
collapse_on_complete=self._collapse_completed_thoughts,
|
||||
labeler=self._thought_labeler,
|
||||
)
|
||||
|
||||
self._current_thought.on_llm_start(serialized, prompts)
|
||||
|
||||
# We don't prune_old_thought_containers here, because our container won't
|
||||
# be visible until it has a child.
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
self._require_current_thought().on_llm_new_token(token, **kwargs)
|
||||
self._prune_old_thought_containers()
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
self._require_current_thought().on_llm_end(response, **kwargs)
|
||||
self._prune_old_thought_containers()
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
self._require_current_thought().on_llm_error(error, **kwargs)
|
||||
self._prune_old_thought_containers()
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> None:
|
||||
self._require_current_thought().on_tool_start(serialized, input_str, **kwargs)
|
||||
self._prune_old_thought_containers()
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
color: Optional[str] = None,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self._require_current_thought().on_tool_end(
|
||||
output, color, observation_prefix, llm_prefix, **kwargs
|
||||
)
|
||||
self._complete_current_thought()
|
||||
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
self._require_current_thought().on_tool_error(error, **kwargs)
|
||||
self._prune_old_thought_containers()
|
||||
|
||||
def on_text(
|
||||
self,
|
||||
text: str,
|
||||
color: Optional[str] = None,
|
||||
end: str = "",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_agent_action(
|
||||
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
self._require_current_thought().on_agent_action(action, color, **kwargs)
|
||||
self._prune_old_thought_containers()
|
||||
|
||||
def on_agent_finish(
|
||||
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
|
||||
) -> None:
|
||||
if self._current_thought is not None:
|
||||
self._current_thought.complete(
|
||||
self._thought_labeler.get_final_agent_thought_label()
|
||||
)
|
||||
self._current_thought = None
|
703
poetry.lock
generated
703
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -110,8 +110,7 @@ langchainplus-sdk = ">=0.0.13"
|
||||
awadb = {version = "^0.3.3", optional = true}
|
||||
azure-search-documents = {version = "11.4.0a20230509004", source = "azure-sdk-dev", optional = true}
|
||||
openllm = {version = ">=0.1.6", optional = true}
|
||||
# now streamlit requires Python >=3.7, !=3.9.7 So, it is commented out.
|
||||
#streamlit = {version = "^1.18.0", optional = true}
|
||||
streamlit = {version = "^1.18.0", optional = true, python = ">=3.8.1,<3.9.7 || >3.9.7,<4.0"}
|
||||
|
||||
[tool.poetry.group.docs.dependencies]
|
||||
autodoc_pydantic = "^1.8.0"
|
||||
@ -157,7 +156,7 @@ optional = true
|
||||
# 2. Add the package name to the extended_testing extra (find it below)
|
||||
# 3. Relock the poetry file
|
||||
# poetry lock --no-update
|
||||
# 4. Favor unit tests not integration tests.
|
||||
# 4. Favor unit tests not integration tests.
|
||||
# Use the @pytest.mark.requires(pkg_name) decorator in unit_tests.
|
||||
# Your tests should not rely on network access, as it prevents other
|
||||
# developers from being able to easily run them.
|
||||
@ -332,8 +331,7 @@ extended_testing = [
|
||||
"html2text",
|
||||
"py-trello",
|
||||
"scikit-learn",
|
||||
# now streamlit requires Python >=3.7, !=3.9.7 So, it is commented out.
|
||||
# "streamlit",
|
||||
"streamlit",
|
||||
"pyspark",
|
||||
"openai"
|
||||
]
|
||||
|
31
tests/integration_tests/callbacks/test_streamlit_callback.py
Normal file
31
tests/integration_tests/callbacks/test_streamlit_callback.py
Normal file
@ -0,0 +1,31 @@
|
||||
"""Integration tests for the StreamlitCallbackHandler module."""
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.agents import AgentType, initialize_agent, load_tools
|
||||
|
||||
# Import the internal StreamlitCallbackHandler from its module - and not from
|
||||
# the `langchain.callbacks.streamlit` package - so that we don't end up using
|
||||
# Streamlit's externally-provided callback handler.
|
||||
from langchain.callbacks.streamlit.streamlit_callback_handler import (
|
||||
StreamlitCallbackHandler,
|
||||
)
|
||||
from langchain.llms import OpenAI
|
||||
|
||||
|
||||
@pytest.mark.requires("streamlit")
|
||||
def test_streamlit_callback_agent() -> None:
|
||||
import streamlit as st
|
||||
|
||||
streamlit_callback = StreamlitCallbackHandler(st.container())
|
||||
|
||||
llm = OpenAI(temperature=0)
|
||||
tools = load_tools(["serpapi", "llm-math"], llm=llm)
|
||||
agent = initialize_agent(
|
||||
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
||||
)
|
||||
agent.run(
|
||||
"Who is Olivia Wilde's boyfriend? "
|
||||
"What is his current age raised to the 0.23 power?",
|
||||
callbacks=[streamlit_callback],
|
||||
)
|
86
tests/unit_tests/callbacks/test_streamlit_callback.py
Normal file
86
tests/unit_tests/callbacks/test_streamlit_callback.py
Normal file
@ -0,0 +1,86 @@
|
||||
import builtins
|
||||
import unittest
|
||||
from typing import Any
|
||||
from unittest import mock
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from langchain.callbacks.streamlit import StreamlitCallbackHandler
|
||||
|
||||
|
||||
class TestImport(unittest.TestCase):
|
||||
"""Test the StreamlitCallbackHandler 'auto-updating' API"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.builtins_import = builtins.__import__
|
||||
|
||||
def tearDown(self) -> None:
|
||||
builtins.__import__ = self.builtins_import
|
||||
|
||||
@mock.patch("langchain.callbacks.streamlit._InternalStreamlitCallbackHandler")
|
||||
def test_create_internal_handler(self, mock_internal_handler: Any) -> None:
|
||||
"""If we're using a Streamlit that does not expose its own
|
||||
StreamlitCallbackHandler, use our own implementation.
|
||||
"""
|
||||
|
||||
def external_import_error(
|
||||
name: str, globals: Any, locals: Any, fromlist: Any, level: int
|
||||
) -> Any:
|
||||
if name == "streamlit.external.langchain":
|
||||
raise ImportError
|
||||
return self.builtins_import(name, globals, locals, fromlist, level)
|
||||
|
||||
builtins.__import__ = external_import_error # type: ignore[assignment]
|
||||
|
||||
parent_container = MagicMock()
|
||||
thought_labeler = MagicMock()
|
||||
StreamlitCallbackHandler(
|
||||
parent_container,
|
||||
max_thought_containers=1,
|
||||
expand_new_thoughts=True,
|
||||
collapse_completed_thoughts=False,
|
||||
thought_labeler=thought_labeler,
|
||||
)
|
||||
|
||||
# Our internal handler should be created
|
||||
mock_internal_handler.assert_called_once_with(
|
||||
parent_container,
|
||||
max_thought_containers=1,
|
||||
expand_new_thoughts=True,
|
||||
collapse_completed_thoughts=False,
|
||||
thought_labeler=thought_labeler,
|
||||
)
|
||||
|
||||
def test_create_external_handler(self) -> None:
|
||||
"""If we're using a Streamlit that *does* expose its own callback handler,
|
||||
delegate to that implementation.
|
||||
"""
|
||||
|
||||
mock_streamlit_module = MagicMock()
|
||||
|
||||
def external_import_success(
|
||||
name: str, globals: Any, locals: Any, fromlist: Any, level: int
|
||||
) -> Any:
|
||||
if name == "streamlit.external.langchain":
|
||||
return mock_streamlit_module
|
||||
return self.builtins_import(name, globals, locals, fromlist, level)
|
||||
|
||||
builtins.__import__ = external_import_success # type: ignore[assignment]
|
||||
|
||||
parent_container = MagicMock()
|
||||
thought_labeler = MagicMock()
|
||||
StreamlitCallbackHandler(
|
||||
parent_container,
|
||||
max_thought_containers=1,
|
||||
expand_new_thoughts=True,
|
||||
collapse_completed_thoughts=False,
|
||||
thought_labeler=thought_labeler,
|
||||
)
|
||||
|
||||
# Streamlit's handler should be created
|
||||
mock_streamlit_module.StreamlitCallbackHandler.assert_called_once_with(
|
||||
parent_container,
|
||||
max_thought_containers=1,
|
||||
expand_new_thoughts=True,
|
||||
collapse_completed_thoughts=False,
|
||||
thought_labeler=thought_labeler,
|
||||
)
|
Loading…
Reference in New Issue
Block a user