[Distributed RLHF] Integration of PP (#6257)

* update help information

* update style

* fix

* minor fix

* support PP training

* add pp support

* remove unused code

* address conversation

---------

Co-authored-by: Tong Li <tong.li35271158@gmail.com>
This commit is contained in:
YeAnbang
2025-04-09 13:23:24 +08:00
committed by YeAnbang
parent 12da4d14aa
commit 5d79b9e692
7 changed files with 263 additions and 116 deletions

View File

@@ -1412,8 +1412,10 @@ class HybridParallelPlugin(PipelinePluginBase):
)
# run with gradients accumulation
if model.require_grad_sync == False or (
isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False
if (
not torch.is_grad_enabled()
or model.require_grad_sync == False
or (isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False)
):
return outputs

View File

@@ -275,6 +275,7 @@ class Qwen2PipelineForwards:
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
**kwargs,
):
r"""
Args: