From d98c3f76c21c1a66ee9c379bdd98529d4b806e62 Mon Sep 17 00:00:00 2001 From: Dan Mirsky Date: Wed, 26 Feb 2025 11:54:24 -0800 Subject: [PATCH] core[patch]: Fix FileCallbackHandler name resolution, Fixes #29941 (#29942) - **Description:** Same changes as #26593 but for FileCallbackHandler - **Issue:** Fixes #29941 - **Dependencies:** None - **Twitter handle:** None - [x] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ --- libs/core/langchain_core/callbacks/file.py | 10 ++++- .../tests/unit_tests/callbacks/test_file.py | 45 +++++++++++++++++++ 2 files changed, 53 insertions(+), 2 deletions(-) create mode 100644 libs/langchain/tests/unit_tests/callbacks/test_file.py diff --git a/libs/core/langchain_core/callbacks/file.py b/libs/core/langchain_core/callbacks/file.py index f7fe7e09595..cd20fbe4f71 100644 --- a/libs/core/langchain_core/callbacks/file.py +++ b/libs/core/langchain_core/callbacks/file.py @@ -47,9 +47,15 @@ class FileCallbackHandler(BaseCallbackHandler): inputs (Dict[str, Any]): The inputs to the chain. **kwargs (Any): Additional keyword arguments. """ - class_name = serialized.get("name", serialized.get("id", [""])[-1]) + if "name" in kwargs: + name = kwargs["name"] + else: + if serialized: + name = serialized.get("name", serialized.get("id", [""])[-1]) + else: + name = "" print_text( - f"\n\n\033[1m> Entering new {class_name} chain...\033[0m", + f"\n\n\033[1m> Entering new {name} chain...\033[0m", end="\n", file=self.file, ) diff --git a/libs/langchain/tests/unit_tests/callbacks/test_file.py b/libs/langchain/tests/unit_tests/callbacks/test_file.py new file mode 100644 index 00000000000..7d739af8a65 --- /dev/null +++ b/libs/langchain/tests/unit_tests/callbacks/test_file.py @@ -0,0 +1,45 @@ +import pathlib +from typing import Any, Dict, List, Optional + +import pytest + +from langchain.callbacks import FileCallbackHandler +from langchain.chains.base import CallbackManagerForChainRun, Chain + + +class FakeChain(Chain): + """Fake chain class for testing purposes.""" + + be_correct: bool = True + the_input_keys: List[str] = ["foo"] + the_output_keys: List[str] = ["bar"] + + @property + def input_keys(self) -> List[str]: + """Input keys.""" + return self.the_input_keys + + @property + def output_keys(self) -> List[str]: + """Output key of bar.""" + return self.the_output_keys + + def _call( + self, + inputs: Dict[str, str], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + return {"bar": "bar"} + + +def test_filecallback(capsys: pytest.CaptureFixture, tmp_path: pathlib.Path) -> Any: + """Test the file callback handler.""" + p = tmp_path / "output.log" + handler = FileCallbackHandler(str(p)) + chain_test = FakeChain(callbacks=[handler]) + chain_test.invoke({"foo": "bar"}) + # Assert the output is as expected + assert p.read_text() == ( + "\n\n\x1b[1m> Entering new FakeChain " + "chain...\x1b[0m\n\n\x1b[1m> Finished chain.\x1b[0m\n" + )