mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 21:33:51 +00:00
- **Description:** Same changes as #26593 but for FileCallbackHandler - **Issue:** Fixes #29941 - **Dependencies:** None - **Twitter handle:** None - [x] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/
This commit is contained in:
parent
b3885c124f
commit
d98c3f76c2
@ -47,9 +47,15 @@ class FileCallbackHandler(BaseCallbackHandler):
|
|||||||
inputs (Dict[str, Any]): The inputs to the chain.
|
inputs (Dict[str, Any]): The inputs to the chain.
|
||||||
**kwargs (Any): Additional keyword arguments.
|
**kwargs (Any): Additional keyword arguments.
|
||||||
"""
|
"""
|
||||||
class_name = serialized.get("name", serialized.get("id", ["<unknown>"])[-1])
|
if "name" in kwargs:
|
||||||
|
name = kwargs["name"]
|
||||||
|
else:
|
||||||
|
if serialized:
|
||||||
|
name = serialized.get("name", serialized.get("id", ["<unknown>"])[-1])
|
||||||
|
else:
|
||||||
|
name = "<unknown>"
|
||||||
print_text(
|
print_text(
|
||||||
f"\n\n\033[1m> Entering new {class_name} chain...\033[0m",
|
f"\n\n\033[1m> Entering new {name} chain...\033[0m",
|
||||||
end="\n",
|
end="\n",
|
||||||
file=self.file,
|
file=self.file,
|
||||||
)
|
)
|
||||||
|
45
libs/langchain/tests/unit_tests/callbacks/test_file.py
Normal file
45
libs/langchain/tests/unit_tests/callbacks/test_file.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
import pathlib
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain.callbacks import FileCallbackHandler
|
||||||
|
from langchain.chains.base import CallbackManagerForChainRun, Chain
|
||||||
|
|
||||||
|
|
||||||
|
class FakeChain(Chain):
|
||||||
|
"""Fake chain class for testing purposes."""
|
||||||
|
|
||||||
|
be_correct: bool = True
|
||||||
|
the_input_keys: List[str] = ["foo"]
|
||||||
|
the_output_keys: List[str] = ["bar"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_keys(self) -> List[str]:
|
||||||
|
"""Input keys."""
|
||||||
|
return self.the_input_keys
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_keys(self) -> List[str]:
|
||||||
|
"""Output key of bar."""
|
||||||
|
return self.the_output_keys
|
||||||
|
|
||||||
|
def _call(
|
||||||
|
self,
|
||||||
|
inputs: Dict[str, str],
|
||||||
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
|
) -> Dict[str, str]:
|
||||||
|
return {"bar": "bar"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_filecallback(capsys: pytest.CaptureFixture, tmp_path: pathlib.Path) -> Any:
|
||||||
|
"""Test the file callback handler."""
|
||||||
|
p = tmp_path / "output.log"
|
||||||
|
handler = FileCallbackHandler(str(p))
|
||||||
|
chain_test = FakeChain(callbacks=[handler])
|
||||||
|
chain_test.invoke({"foo": "bar"})
|
||||||
|
# Assert the output is as expected
|
||||||
|
assert p.read_text() == (
|
||||||
|
"\n\n\x1b[1m> Entering new FakeChain "
|
||||||
|
"chain...\x1b[0m\n\n\x1b[1m> Finished chain.\x1b[0m\n"
|
||||||
|
)
|
Loading…
Reference in New Issue
Block a user