Fix missing chain classname in StdOutCallbackHandler.on_chain_start (#6124)

Retrieves the name of the class from new location as of commit
18af149e91


Co-authored-by: Zander Chase <130414180+vowelparrot@users.noreply.github.com>
This commit is contained in:
Sam Coward
2023-07-13 03:05:36 -04:00
committed by GitHub
parent af3f401015
commit 224199083b

View File

@@ -37,7 +37,7 @@ class StdOutCallbackHandler(BaseCallbackHandler):
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None: ) -> None:
"""Print out that we are entering a chain.""" """Print out that we are entering a chain."""
class_name = serialized.get("name", "") class_name = serialized.get("name", serialized.get("id", ["<unknown>"])[-1])
print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m") print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m")
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: