mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 05:13:46 +00:00
core[patch]: On Chain Start Fix for Chain
Class (#26593)
- **Issue:** #26588 --------- Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
parent
bba7af903b
commit
154a5ff7ca
@ -32,8 +32,14 @@ class StdOutCallbackHandler(BaseCallbackHandler):
|
||||
inputs (Dict[str, Any]): The inputs to the chain.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
"""
|
||||
class_name = serialized.get("name", serialized.get("id", ["<unknown>"])[-1])
|
||||
print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m") # noqa: T201
|
||||
if "name" in kwargs:
|
||||
name = kwargs["name"]
|
||||
else:
|
||||
if serialized:
|
||||
name = serialized.get("name", serialized.get("id", ["<unknown>"])[-1])
|
||||
else:
|
||||
name = "<unknown>"
|
||||
print(f"\n\n\033[1m> Entering new {name} chain...\033[0m") # noqa: T201
|
||||
|
||||
def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Print out that we finished a chain.
|
||||
|
44
libs/langchain/tests/unit_tests/callbacks/test_stdout.py
Normal file
44
libs/langchain/tests/unit_tests/callbacks/test_stdout.py
Normal file
@ -0,0 +1,44 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.callbacks import StdOutCallbackHandler
|
||||
from langchain.chains.base import CallbackManagerForChainRun, Chain
|
||||
|
||||
|
||||
class FakeChain(Chain):
|
||||
"""Fake chain class for testing purposes."""
|
||||
|
||||
be_correct: bool = True
|
||||
the_input_keys: List[str] = ["foo"]
|
||||
the_output_keys: List[str] = ["bar"]
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Input keys."""
|
||||
return self.the_input_keys
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Output key of bar."""
|
||||
return self.the_output_keys
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
return {"bar": "bar"}
|
||||
|
||||
|
||||
def test_stdoutcallback(capsys: pytest.CaptureFixture) -> Any:
|
||||
"""Test the stdout callback handler."""
|
||||
chain_test = FakeChain(callbacks=[StdOutCallbackHandler(color="red")])
|
||||
chain_test.invoke({"foo": "bar"})
|
||||
# Capture the output
|
||||
captured = capsys.readouterr()
|
||||
# Assert the output is as expected
|
||||
assert captured.out == (
|
||||
"\n\n\x1b[1m> Entering new FakeChain "
|
||||
"chain...\x1b[0m\n\n\x1b[1m> Finished chain.\x1b[0m\n"
|
||||
)
|
Loading…
Reference in New Issue
Block a user