mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
[feat] Support DAPO (#6263)
* update help information * update style * fix * minor fix * support PP training * add pp support * remove unused code * address conversation * fix memory leakage support tp+pp * move empty cache * move empty cache * add DAPO support * remove format reward * fix filtering, still buggy * small fix * add DAPO support * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * tested multi-node training; fix bind_batch bug * fix conversation; support sleep mode * support reusing excessive samples * add dynamic batching control flag * add dynamic batching control flag * refactored * fix logging --------- Co-authored-by: Tong Li <tong.li35271158@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import torch
|
||||
@@ -26,6 +27,27 @@ def bind_batch(batches: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor
|
||||
return batch
|
||||
|
||||
|
||||
def pad_batch(batches: List[Dict[str, torch.Tensor]], tokenizer: Any = None) -> List[Dict[str, torch.Tensor]]:
|
||||
max_len = defaultdict(int)
|
||||
for sample in batches:
|
||||
for k in sample:
|
||||
if k in ["input_ids", "attention_mask", "action_log_probs", "action_mask"]:
|
||||
max_len[k] = max(max_len[k], sample[k].size(-1))
|
||||
for idx, sample in enumerate(batches):
|
||||
for k in sample:
|
||||
if k in ["input_ids", "attention_mask", "action_log_probs", "action_mask"]:
|
||||
# right pad with 0s
|
||||
if k in ["attention_mask", "action_mask"]:
|
||||
batches[idx][k] = torch.nn.functional.pad(
|
||||
batches[idx][k], (0, max_len[k] - batches[idx][k].size(-1)), "constant", False
|
||||
)
|
||||
else:
|
||||
batches[idx][k] = torch.nn.functional.pad(
|
||||
batches[idx][k], (0, max_len[k] - batches[idx][k].size(-1)), "constant", 0
|
||||
)
|
||||
return batches
|
||||
|
||||
|
||||
def pre_send(batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
# compress mask to save bandwidth
|
||||
if "attention_mask" in batch:
|
||||
@@ -113,3 +135,20 @@ def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch
|
||||
mask_sum = mask.sum(dim=dim)
|
||||
mean = tensor / (mask_sum + 1e-8)
|
||||
return mean
|
||||
|
||||
|
||||
def masked_sum(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor:
|
||||
"""
|
||||
Compute the masked sum of a tensor along a specified dimension.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): The input tensor.
|
||||
mask (torch.Tensor): The mask tensor with the same shape as the input tensor.
|
||||
dim (int, optional): The dimension along which to compute the sum. Default is 1.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The masked sum tensor.
|
||||
|
||||
"""
|
||||
tensor = tensor * mask
|
||||
return tensor.sum(dim=dim)
|
||||
|
Reference in New Issue
Block a user