[ColossalChat] Add PP support (#6001)

* support pp training

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update rm

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refactor

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update test case

* fix

* change to 4

* fix eval

* test

* add pp

* hotfix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* support pp training

* update rm

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refactor

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update test case

* fix

* change to 4

* fix eval

* test

* add pp

* hotfix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update

* skip pp eval

* update all reduce

* update sft

* update ignore

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update no cache

* add eval

* remove fi

* remove debug

* remove parentheses to avoid warning

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Revert "add eval"

This reverts commit 3ab2f6fa32.

* add all reduce

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Tong Li
2024-08-21 10:47:39 +08:00
committed by GitHub
parent 0d3b0bd864
commit 39e2597426
16 changed files with 241 additions and 115 deletions

View File

@@ -9,6 +9,8 @@ import torch.distributed as dist
from torch.utils._pytree import tree_map
from torch.utils.data import DataLoader
from colossalai.booster import Plugin
class CycledDataLoader:
"""
@@ -85,7 +87,7 @@ def to_device(x: Any, device: torch.device) -> Any:
return tree_map(_to, x)
def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
def all_reduce_mean(tensor: torch.Tensor, plugin: Plugin = None) -> torch.Tensor:
"""
Perform all-reduce operation on the given tensor and compute the mean across all processes.
@@ -95,8 +97,13 @@ def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
Returns:
torch.Tensor: The reduced tensor with mean computed across all processes.
"""
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
tensor.div_(dist.get_world_size())
# All reduce mean across DP group
if plugin is not None:
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM, group=plugin.dp_group)
tensor.div_(plugin.dp_size)
else:
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
tensor.div_(dist.get_world_size())
return tensor