diff --git a/applications/ColossalChat/.gitignore b/applications/ColossalChat/.gitignore index 7b361d38e..5a4bb905f 100755 --- a/applications/ColossalChat/.gitignore +++ b/applications/ColossalChat/.gitignore @@ -161,3 +161,9 @@ applications/ColossalChat/sft_data applications/ColossalChat/prompt_data applications/ColossalChat/preference_data applications/ColossalChat/temp + +# Testing data +/kto_data/ +/preference_data/ +/prompt_data/ +/sft_data/ diff --git a/applications/ColossalChat/coati/trainer/sft.py b/applications/ColossalChat/coati/trainer/sft.py index 298fb30ee..33b241c05 100755 --- a/applications/ColossalChat/coati/trainer/sft.py +++ b/applications/ColossalChat/coati/trainer/sft.py @@ -114,9 +114,7 @@ class SFTTrainer(SLTrainer): ) loss = outputs["loss"] if dist.get_rank() == dist.get_world_size() - 1: - global_loss = all_reduce_mean(loss, self.booster) - step_bar.set_postfix({"train/loss": global_loss.item()}) - step_bar.update() + step_bar.set_postfix({"train/loss": loss.item()}) self.optimizer.step() self.optimizer.zero_grad() else: @@ -200,9 +198,8 @@ class SFTTrainer(SLTrainer): ) loss = outputs["loss"] if dist.get_rank() == dist.get_world_size() - 1: - global_loss = all_reduce_mean(loss, self.booster) - step_bar.set_postfix({"eval/loss": global_loss.item()}) - self.accumulative_meter.add("loss", global_loss.item()) + step_bar.set_postfix({"eval/loss": loss.item()}) + self.accumulative_meter.add("loss", loss.item()) step_bar.update() if dist.get_rank() == dist.get_world_size() - 1: diff --git a/applications/ColossalChat/coati/trainer/utils.py b/applications/ColossalChat/coati/trainer/utils.py index c15c291b4..e87993c38 100755 --- a/applications/ColossalChat/coati/trainer/utils.py +++ b/applications/ColossalChat/coati/trainer/utils.py @@ -87,7 +87,7 @@ def to_device(x: Any, device: torch.device) -> Any: return tree_map(_to, x) -def all_reduce_mean(tensor: torch.Tensor, booster: Booster) -> torch.Tensor: +def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor: """ Perform all-reduce operation on the given tensor and compute the mean across all processes. @@ -97,12 +97,8 @@ def all_reduce_mean(tensor: torch.Tensor, booster: Booster) -> torch.Tensor: Returns: torch.Tensor: The reduced tensor with mean computed across all processes. """ - if booster is not None: - dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM, group=booster.plugin.dp_group) - tensor.div_(booster.plugin.dp_size) - else: - dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM) - tensor.div_(dist.get_world_size()) + dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM) + tensor.div_(dist.get_world_size()) return tensor