[Feature] Support Distributed LogProb for GRPO Training (#6247)

* [fix] fix qwen VocabParallelLMHead1D and gather output

* fix tp bug

* fix consumer

* [feat] Support Distributed LogProb for GRPO Training

* [fix] fix loss func

* [fix] fix log prob plugin

* [fix] fix qwen modeling param

* [fix] rm comments

* [fix] rm hard-code;fix non-dist version

* [fix] fix test file param name and benchmark tp gather output=True/False

* [fix] rm non-dist version in dist log prob

* [fix] fix comments

* [fix] fix dis log prob plugin

* [fix] fix test case

* [fix] fix qwen VocabParallelLMHead1D and gather output

* [fix] fix DistLogProb comments

* [fix] restore tp size

* [fix] fix comments

* [fix] fix comment; fix LogSoftmax usage

---------

Co-authored-by: Tong Li <tong.li35271158@gmail.com>
This commit is contained in:
duanjunwen
2025-03-18 17:47:55 +08:00
committed by YeAnbang
parent 35dabd718e
commit 455185345e
8 changed files with 233 additions and 12 deletions

View File

@@ -823,7 +823,6 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
loss = None
if labels is not None:
loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, logits.dtype)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output