diff --git a/libs/langchain/langchain/smith/evaluation/progress.py b/libs/langchain/langchain/smith/evaluation/progress.py index af94ebb511e..4c1d1101662 100644 --- a/libs/langchain/langchain/smith/evaluation/progress.py +++ b/libs/langchain/langchain/smith/evaluation/progress.py @@ -13,15 +13,19 @@ from langchain_core.outputs import LLMResult class ProgressBarCallback(base_callbacks.BaseCallbackHandler): """A simple progress bar for the console.""" - def __init__(self, total: int, ncols: int = 50, **kwargs: Any): + def __init__( + self, total: int, ncols: int = 50, end_with: str = "\n", **kwargs: Any + ): """Initialize the progress bar. Args: total: int, the total number of items to be processed. ncols: int, the character width of the progress bar. + end_with: str, last string to print after progress bar reaches end. """ self.total = total self.ncols = ncols + self.end_with = end_with self.counter = 0 self.lock = threading.Lock() self._print_bar() @@ -37,7 +41,8 @@ class ProgressBarCallback(base_callbacks.BaseCallbackHandler): progress = self.counter / self.total arrow = "-" * int(round(progress * self.ncols) - 1) + ">" spaces = " " * (self.ncols - len(arrow)) - print(f"\r[{arrow + spaces}] {self.counter}/{self.total}", end="") # noqa: T201 + end = "" if self.counter < self.total else self.end_with + print(f"\r[{arrow + spaces}] {self.counter}/{self.total}", end=end) # noqa: T201 def on_chain_error( self,