mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-27 19:36:13 +00:00
* update help information * update style * fix * minor fix * support PP training * add pp support * remove unused code * address conversation * fix memory leakage support tp+pp * move empty cache * move empty cache * add DAPO support * remove format reward * fix filtering, still buggy * small fix * add DAPO support * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * tested multi-node training; fix bind_batch bug * fix conversation; support sleep mode * support reusing excessive samples * add dynamic batching control flag * add dynamic batching control flag * refactored * fix logging --------- Co-authored-by: Tong Li <tong.li35271158@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
147 lines
4.2 KiB
Python
Executable File
147 lines
4.2 KiB
Python
Executable File
"""
|
|
Training utilities for Coati.
|
|
"""
|
|
|
|
from typing import Any
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch.utils._pytree import tree_map
|
|
from torch.utils.data import DataLoader
|
|
|
|
from colossalai.booster import Plugin
|
|
|
|
|
|
class AnnealingScheduler:
|
|
def __init__(self, start, end, warmup_steps=100, annealing_step=2000):
|
|
self.start = start
|
|
self.end = end
|
|
self.warmup_steps = warmup_steps
|
|
self.step = 0
|
|
self.annealing_step = annealing_step
|
|
|
|
def get_temperature(self):
|
|
if self.step <= self.warmup_steps:
|
|
return self.start # Stop annealing after warm-up steps
|
|
elif self.step >= self.annealing_step:
|
|
return self.end
|
|
# Linear annealing
|
|
temp = self.start - (self.step / self.annealing_step) * (self.start - self.end)
|
|
return temp
|
|
|
|
def step_forward(self):
|
|
self.step += 1
|
|
|
|
|
|
class CycledDataLoader:
|
|
"""
|
|
A data loader that cycles through the data when it reaches the end.
|
|
|
|
Args:
|
|
dataloader (DataLoader): The original data loader.
|
|
|
|
Attributes:
|
|
dataloader (DataLoader): The original data loader.
|
|
count (int): The number of times the data loader has been cycled.
|
|
dataloader_iter (iterable): The iterator for the data loader.
|
|
|
|
Methods:
|
|
next(): Returns the next batch of data from the data loader, cycling through the data if necessary.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dataloader: DataLoader,
|
|
) -> None:
|
|
self.dataloader = dataloader
|
|
|
|
self.count = 0
|
|
self.dataloader_iter = None
|
|
|
|
def next(self):
|
|
"""
|
|
Returns the next batch of data from the data loader, cycling through the data if necessary.
|
|
|
|
Returns:
|
|
Any: The next batch of data from the data loader.
|
|
"""
|
|
# defer initialization
|
|
if self.dataloader_iter is None:
|
|
self.dataloader_iter = iter(self.dataloader)
|
|
|
|
self.count += 1
|
|
try:
|
|
return next(self.dataloader_iter)
|
|
except StopIteration:
|
|
self.count = 0
|
|
self.dataloader_iter = iter(self.dataloader)
|
|
return next(self.dataloader_iter)
|
|
|
|
|
|
def is_rank_0() -> bool:
|
|
"""
|
|
Check if the current process is the rank 0 process in a distributed training setup.
|
|
|
|
Returns:
|
|
bool: True if the current process is the rank 0 process, False otherwise.
|
|
"""
|
|
return not dist.is_initialized() or dist.get_rank() == 0
|
|
|
|
|
|
def to_device(x: Any, device: torch.device) -> Any:
|
|
"""
|
|
Move the input tensor or nested structure of tensors to the specified device.
|
|
|
|
Args:
|
|
x (Any): The input tensor or nested structure of tensors.
|
|
device (torch.device): The target device to move the tensors to.
|
|
|
|
Returns:
|
|
Any: The tensor or nested structure of tensors moved to the target device.
|
|
"""
|
|
|
|
def _to(t: Any):
|
|
if isinstance(t, torch.Tensor):
|
|
return t.to(device)
|
|
return t
|
|
|
|
return tree_map(_to, x)
|
|
|
|
|
|
def all_reduce_mean(tensor: torch.Tensor, plugin: Plugin = None) -> torch.Tensor:
|
|
"""
|
|
Perform all-reduce operation on the given tensor and compute the mean across all processes.
|
|
|
|
Args:
|
|
tensor (torch.Tensor): The input tensor to be reduced.
|
|
|
|
Returns:
|
|
torch.Tensor: The reduced tensor with mean computed across all processes.
|
|
"""
|
|
# All reduce mean across DP group
|
|
if plugin is not None:
|
|
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM, group=plugin.dp_group)
|
|
tensor.div_(plugin.dp_size)
|
|
else:
|
|
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
|
|
tensor.div_(dist.get_world_size())
|
|
return tensor
|
|
|
|
|
|
def all_reduce_sum(tensor: torch.Tensor, plugin: Plugin = None) -> torch.Tensor:
|
|
"""
|
|
Performs an all-reduce operation to sum the values of the given tensor across all processes.
|
|
|
|
Args:
|
|
tensor (torch.Tensor): The input tensor to be reduced.
|
|
|
|
Returns:
|
|
torch.Tensor: The reduced tensor with the sum of values across all processes.
|
|
"""
|
|
# All reduce sum across DP group
|
|
if plugin is not None:
|
|
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM, group=plugin.dp_group)
|
|
else:
|
|
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
|
|
return tensor
|