mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 15:19:33 +00:00
langchain[patch]: smith.evaluation.progress.ProgressBarCallback: Make output after progress bar ends configurable (#31583)
This commit is contained in:
parent
6105a5841b
commit
9de4f22205
@ -13,15 +13,19 @@ from langchain_core.outputs import LLMResult
|
||||
class ProgressBarCallback(base_callbacks.BaseCallbackHandler):
|
||||
"""A simple progress bar for the console."""
|
||||
|
||||
def __init__(self, total: int, ncols: int = 50, **kwargs: Any):
|
||||
def __init__(
|
||||
self, total: int, ncols: int = 50, end_with: str = "\n", **kwargs: Any
|
||||
):
|
||||
"""Initialize the progress bar.
|
||||
|
||||
Args:
|
||||
total: int, the total number of items to be processed.
|
||||
ncols: int, the character width of the progress bar.
|
||||
end_with: str, last string to print after progress bar reaches end.
|
||||
"""
|
||||
self.total = total
|
||||
self.ncols = ncols
|
||||
self.end_with = end_with
|
||||
self.counter = 0
|
||||
self.lock = threading.Lock()
|
||||
self._print_bar()
|
||||
@ -37,7 +41,8 @@ class ProgressBarCallback(base_callbacks.BaseCallbackHandler):
|
||||
progress = self.counter / self.total
|
||||
arrow = "-" * int(round(progress * self.ncols) - 1) + ">"
|
||||
spaces = " " * (self.ncols - len(arrow))
|
||||
print(f"\r[{arrow + spaces}] {self.counter}/{self.total}", end="") # noqa: T201
|
||||
end = "" if self.counter < self.total else self.end_with
|
||||
print(f"\r[{arrow + spaces}] {self.counter}/{self.total}", end=end) # noqa: T201
|
||||
|
||||
def on_chain_error(
|
||||
self,
|
||||
|
Loading…
Reference in New Issue
Block a user