From f4c0e3cc15305b390ace862d820d54d3895a64f1 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Sat, 18 Nov 2023 12:24:49 -0500 Subject: [PATCH] move streaming stdout (#13559) --- .../langchain/callbacks/streaming_stdout.py | 67 +------------------ .../schema/callbacks/streaming_stdout.py | 67 +++++++++++++++++++ 2 files changed, 69 insertions(+), 65 deletions(-) create mode 100644 libs/langchain/langchain/schema/callbacks/streaming_stdout.py diff --git a/libs/langchain/langchain/callbacks/streaming_stdout.py b/libs/langchain/langchain/callbacks/streaming_stdout.py index 38d1850dd0e..26be79bd192 100644 --- a/libs/langchain/langchain/callbacks/streaming_stdout.py +++ b/libs/langchain/langchain/callbacks/streaming_stdout.py @@ -1,67 +1,4 @@ """Callback Handler streams to stdout on new llm token.""" -import sys -from typing import Any, Dict, List +from langchain.schema.callbacks.streaming_stdout import StreamingStdOutCallbackHandler -from langchain.callbacks.base import BaseCallbackHandler -from langchain.schema import AgentAction, AgentFinish, LLMResult -from langchain.schema.messages import BaseMessage - - -class StreamingStdOutCallbackHandler(BaseCallbackHandler): - """Callback handler for streaming. Only works with LLMs that support streaming.""" - - def on_llm_start( - self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any - ) -> None: - """Run when LLM starts running.""" - - def on_chat_model_start( - self, - serialized: Dict[str, Any], - messages: List[List[BaseMessage]], - **kwargs: Any, - ) -> None: - """Run when LLM starts running.""" - - def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - """Run on new LLM token. Only available when streaming is enabled.""" - sys.stdout.write(token) - sys.stdout.flush() - - def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: - """Run when LLM ends running.""" - - def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: - """Run when LLM errors.""" - - def on_chain_start( - self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any - ) -> None: - """Run when chain starts running.""" - - def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: - """Run when chain ends running.""" - - def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: - """Run when chain errors.""" - - def on_tool_start( - self, serialized: Dict[str, Any], input_str: str, **kwargs: Any - ) -> None: - """Run when tool starts running.""" - - def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: - """Run on agent action.""" - pass - - def on_tool_end(self, output: str, **kwargs: Any) -> None: - """Run when tool ends running.""" - - def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: - """Run when tool errors.""" - - def on_text(self, text: str, **kwargs: Any) -> None: - """Run on arbitrary text.""" - - def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: - """Run on agent end.""" +__all__ = ["StreamingStdOutCallbackHandler"] diff --git a/libs/langchain/langchain/schema/callbacks/streaming_stdout.py b/libs/langchain/langchain/schema/callbacks/streaming_stdout.py new file mode 100644 index 00000000000..dd0896801a8 --- /dev/null +++ b/libs/langchain/langchain/schema/callbacks/streaming_stdout.py @@ -0,0 +1,67 @@ +"""Callback Handler streams to stdout on new llm token.""" +import sys +from typing import Any, Dict, List + +from langchain.schema import AgentAction, AgentFinish, LLMResult +from langchain.schema.callbacks.base import BaseCallbackHandler +from langchain.schema.messages import BaseMessage + + +class StreamingStdOutCallbackHandler(BaseCallbackHandler): + """Callback handler for streaming. Only works with LLMs that support streaming.""" + + def on_llm_start( + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + ) -> None: + """Run when LLM starts running.""" + + def on_chat_model_start( + self, + serialized: Dict[str, Any], + messages: List[List[BaseMessage]], + **kwargs: Any, + ) -> None: + """Run when LLM starts running.""" + + def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + """Run on new LLM token. Only available when streaming is enabled.""" + sys.stdout.write(token) + sys.stdout.flush() + + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + """Run when LLM ends running.""" + + def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: + """Run when LLM errors.""" + + def on_chain_start( + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any + ) -> None: + """Run when chain starts running.""" + + def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: + """Run when chain ends running.""" + + def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: + """Run when chain errors.""" + + def on_tool_start( + self, serialized: Dict[str, Any], input_str: str, **kwargs: Any + ) -> None: + """Run when tool starts running.""" + + def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: + """Run on agent action.""" + pass + + def on_tool_end(self, output: str, **kwargs: Any) -> None: + """Run when tool ends running.""" + + def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: + """Run when tool errors.""" + + def on_text(self, text: str, **kwargs: Any) -> None: + """Run on arbitrary text.""" + + def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: + """Run on agent end."""