From 4191f21f70cb23ead05075a0f24321cdc89bfeaa Mon Sep 17 00:00:00 2001 From: Tong Li Date: Thu, 15 Aug 2024 03:46:43 +0000 Subject: [PATCH] update all reduce --- applications/ColossalChat/coati/trainer/sft.py | 3 ++- applications/ColossalChat/coati/trainer/utils.py | 12 +++++++++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/applications/ColossalChat/coati/trainer/sft.py b/applications/ColossalChat/coati/trainer/sft.py index 6322cb8df..fb2f9a765 100755 --- a/applications/ColossalChat/coati/trainer/sft.py +++ b/applications/ColossalChat/coati/trainer/sft.py @@ -114,7 +114,8 @@ class SFTTrainer(SLTrainer): ) loss = outputs["loss"] if dist.get_rank() == dist.get_world_size() - 1: - step_bar.set_postfix({"train/loss": loss.item()}) + global_loss = all_reduce_mean(loss, self.booster) + step_bar.set_postfix({"train/loss": global_loss.item()}) step_bar.update() self.optimizer.step() self.optimizer.zero_grad() diff --git a/applications/ColossalChat/coati/trainer/utils.py b/applications/ColossalChat/coati/trainer/utils.py index 3c836b4b4..c15c291b4 100755 --- a/applications/ColossalChat/coati/trainer/utils.py +++ b/applications/ColossalChat/coati/trainer/utils.py @@ -9,6 +9,8 @@ import torch.distributed as dist from torch.utils._pytree import tree_map from torch.utils.data import DataLoader +from colossalai.booster import Booster + class CycledDataLoader: """ @@ -85,7 +87,7 @@ def to_device(x: Any, device: torch.device) -> Any: return tree_map(_to, x) -def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor: +def all_reduce_mean(tensor: torch.Tensor, booster: Booster) -> torch.Tensor: """ Perform all-reduce operation on the given tensor and compute the mean across all processes. @@ -95,8 +97,12 @@ def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: The reduced tensor with mean computed across all processes. """ - dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM) - tensor.div_(dist.get_world_size()) + if booster is not None: + dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM, group=booster.plugin.dp_group) + tensor.div_(booster.plugin.dp_size) + else: + dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM) + tensor.div_(dist.get_world_size()) return tensor