update all reduce

This commit is contained in:
Tong Li 2024-08-15 03:46:43 +00:00
parent e87cd8bcfb
commit 4191f21f70
2 changed files with 11 additions and 4 deletions

View File

@ -114,7 +114,8 @@ class SFTTrainer(SLTrainer):
)
loss = outputs["loss"]
if dist.get_rank() == dist.get_world_size() - 1:
step_bar.set_postfix({"train/loss": loss.item()})
global_loss = all_reduce_mean(loss, self.booster)
step_bar.set_postfix({"train/loss": global_loss.item()})
step_bar.update()
self.optimizer.step()
self.optimizer.zero_grad()

View File

@ -9,6 +9,8 @@ import torch.distributed as dist
from torch.utils._pytree import tree_map
from torch.utils.data import DataLoader
from colossalai.booster import Booster
class CycledDataLoader:
"""
@ -85,7 +87,7 @@ def to_device(x: Any, device: torch.device) -> Any:
return tree_map(_to, x)
def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
def all_reduce_mean(tensor: torch.Tensor, booster: Booster) -> torch.Tensor:
"""
Perform all-reduce operation on the given tensor and compute the mean across all processes.
@ -95,8 +97,12 @@ def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
Returns:
torch.Tensor: The reduced tensor with mean computed across all processes.
"""
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
tensor.div_(dist.get_world_size())
if booster is not None:
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM, group=booster.plugin.dp_group)
tensor.div_(booster.plugin.dp_size)
else:
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
tensor.div_(dist.get_world_size())
return tensor