[Feature] Support Distributed LogProb for GRPO Training (#6247)

* [fix] fix qwen VocabParallelLMHead1D and gather output

* fix tp bug

* fix consumer

* [feat] Support Distributed LogProb for GRPO Training

* [fix] fix loss func

* [fix] fix log prob plugin

* [fix] fix qwen modeling param

* [fix] rm comments

* [fix] rm hard-code;fix non-dist version

* [fix] fix test file param name and benchmark tp gather output=True/False

* [fix] rm non-dist version in dist log prob

* [fix] fix comments

* [fix] fix dis log prob plugin

* [fix] fix test case

* [fix] fix qwen VocabParallelLMHead1D and gather output

* [fix] fix DistLogProb comments

* [fix] restore tp size

* [fix] fix comments

* [fix] fix comment; fix LogSoftmax usage

---------

Co-authored-by: Tong Li <tong.li35271158@gmail.com>
This commit is contained in:
duanjunwen
2025-03-18 17:47:55 +08:00
committed by GitHub
parent bc0171d392
commit 7795d4c50d
8 changed files with 233 additions and 12 deletions

View File

@@ -2,6 +2,8 @@ from typing import Any, Dict, List
import torch
from colossalai.shardformer.layer.loss import dist_log_prob
def unbind_batch(batch: Dict[str, torch.Tensor]) -> List[Dict[str, torch.Tensor]]:
batches = []
@@ -66,18 +68,30 @@ def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.T
return per_label_logps.squeeze(-1)
def calc_action_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor:
def calc_action_log_probs(
logits: torch.Tensor,
sequences: torch.LongTensor,
num_actions: int,
shard_config,
vocab_size: int = None,
) -> torch.Tensor:
"""Calculate action log probs.
Args:
output (torch.Tensor): Output tensor of Actor.forward.logits.
logits (torch.Tensor): Output tensor of Actor.forward.logits.
sequences (torch.LongTensor): Input sequences.
num_actions (int): Number of actions.
shard_config
vocab_size
Returns:
torch.Tensor: Action log probs.
"""
log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
# labels: torch.Tensor, # [B, S] or [B, S, Vocab_size]
# logits: torch.Tensor, # [B, S, Vocab_size]
log_probs = dist_log_prob(sequences, logits, shard_config, vocab_size, logits.dtype)
log_probs = log_probs.squeeze(-1)
return log_probs[:, -num_actions:]