mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 02:26:51 +00:00
[ColossalChat] Add PP support (#6001)
* support pp training
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* update rm
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* refactor
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* update test case
* fix
* change to 4
* fix eval
* test
* add pp
* hotfix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* support pp training
* update rm
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* refactor
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* update test case
* fix
* change to 4
* fix eval
* test
* add pp
* hotfix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* update
* skip pp eval
* update all reduce
* update sft
* update ignore
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* update no cache
* add eval
* remove fi
* remove debug
* remove parentheses to avoid warning
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Revert "add eval"
This reverts commit 3ab2f6fa32
.
* add all reduce
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -9,6 +9,8 @@ import torch.distributed as dist
|
||||
from torch.utils._pytree import tree_map
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from colossalai.booster import Plugin
|
||||
|
||||
|
||||
class CycledDataLoader:
|
||||
"""
|
||||
@@ -85,7 +87,7 @@ def to_device(x: Any, device: torch.device) -> Any:
|
||||
return tree_map(_to, x)
|
||||
|
||||
|
||||
def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
|
||||
def all_reduce_mean(tensor: torch.Tensor, plugin: Plugin = None) -> torch.Tensor:
|
||||
"""
|
||||
Perform all-reduce operation on the given tensor and compute the mean across all processes.
|
||||
|
||||
@@ -95,8 +97,13 @@ def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
|
||||
Returns:
|
||||
torch.Tensor: The reduced tensor with mean computed across all processes.
|
||||
"""
|
||||
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
|
||||
tensor.div_(dist.get_world_size())
|
||||
# All reduce mean across DP group
|
||||
if plugin is not None:
|
||||
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM, group=plugin.dp_group)
|
||||
tensor.div_(plugin.dp_size)
|
||||
else:
|
||||
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
|
||||
tensor.div_(dist.get_world_size())
|
||||
return tensor
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user