mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-24 10:41:07 +00:00
update all reduce
This commit is contained in:
parent
e87cd8bcfb
commit
4191f21f70
@ -114,7 +114,8 @@ class SFTTrainer(SLTrainer):
|
|||||||
)
|
)
|
||||||
loss = outputs["loss"]
|
loss = outputs["loss"]
|
||||||
if dist.get_rank() == dist.get_world_size() - 1:
|
if dist.get_rank() == dist.get_world_size() - 1:
|
||||||
step_bar.set_postfix({"train/loss": loss.item()})
|
global_loss = all_reduce_mean(loss, self.booster)
|
||||||
|
step_bar.set_postfix({"train/loss": global_loss.item()})
|
||||||
step_bar.update()
|
step_bar.update()
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
|
@ -9,6 +9,8 @@ import torch.distributed as dist
|
|||||||
from torch.utils._pytree import tree_map
|
from torch.utils._pytree import tree_map
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from colossalai.booster import Booster
|
||||||
|
|
||||||
|
|
||||||
class CycledDataLoader:
|
class CycledDataLoader:
|
||||||
"""
|
"""
|
||||||
@ -85,7 +87,7 @@ def to_device(x: Any, device: torch.device) -> Any:
|
|||||||
return tree_map(_to, x)
|
return tree_map(_to, x)
|
||||||
|
|
||||||
|
|
||||||
def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
|
def all_reduce_mean(tensor: torch.Tensor, booster: Booster) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Perform all-reduce operation on the given tensor and compute the mean across all processes.
|
Perform all-reduce operation on the given tensor and compute the mean across all processes.
|
||||||
|
|
||||||
@ -95,8 +97,12 @@ def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
|
|||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: The reduced tensor with mean computed across all processes.
|
torch.Tensor: The reduced tensor with mean computed across all processes.
|
||||||
"""
|
"""
|
||||||
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
|
if booster is not None:
|
||||||
tensor.div_(dist.get_world_size())
|
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM, group=booster.plugin.dp_group)
|
||||||
|
tensor.div_(booster.plugin.dp_size)
|
||||||
|
else:
|
||||||
|
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
|
||||||
|
tensor.div_(dist.get_world_size())
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user