support PP training

This commit is contained in:
Tong Li
2025-04-01 11:24:09 +08:00
parent 6604654b2f
commit d961a5f725
5 changed files with 121 additions and 69 deletions

View File

@@ -54,7 +54,6 @@ class BaseConsumer:
self.model_config = model_config
self.plugin_config = plugin_config
assert self.plugin_config.get("pp_size", 1) == 1, "pp_size > 1 is not supported now"
self.device = get_current_device()
self.lr_scheduler = None

View File

@@ -96,7 +96,7 @@ class GRPOConsumer(BaseConsumer):
self.global_step = 0
if use_wandb and self.rank == 0:
name = f"{generate_config['backend']}_bs_{self.batch_size*self.world_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}"
self.wandb_run = wandb.init(project="GRPO-V1", sync_tensorboard=True, dir="./wandb", name=name)
self.wandb_run = wandb.init(project="GRPO-V1-PP", sync_tensorboard=True, dir="./wandb", name=name)
self.lr_scheduler = CosineAnnealingWarmupLR(
optimizer=self.optimizer,
@@ -168,72 +168,120 @@ class GRPOConsumer(BaseConsumer):
).repeat_interleave(self.num_generations, dim=0)
)
mean_kl, mean_loss = [], []
for forward_micro_batch_start in range(0, data["input_ids"].size(0), forward_batch_size):
input_ids_forward_micro_batch = data["input_ids"][
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
]
attention_mask_forward_micro_batch = data["attention_mask"][
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
]
action_mask_forward_micro_batch = action_mask[
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
]
loss_mask_forward_micro_batch = (
loss_mask[forward_micro_batch_start : forward_micro_batch_start + forward_batch_size]
if loss_mask is not None
else None
)
advantages_forward_micro_batch = advantages[
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
]
policy_model_logits = self.policy_model(
input_ids=input_ids_forward_micro_batch,
attention_mask=attention_mask_forward_micro_batch,
).logits
action_log_probs = calc_action_log_probs(
policy_model_logits / self.generate_config["temperature"],
input_ids_forward_micro_batch,
num_action,
self.plugin.shard_config,
)
if self.plugin.pp_size > 1:
# Support training with PP.
data_iter = iter([data])
with torch.no_grad():
reference_model_logits = self.reference_model(
reference_model_outputs = self.booster.execute_pipeline(
data_iter,
self.reference_model,
criterion=lambda outputs, inputs: outputs.logits.mean(), # dummy criterion
optimizer=None,
return_loss=False,
return_outputs=True,
)
if self.booster.plugin.stage_manager.is_last_stage():
reference_model_logits = reference_model_outputs["outputs"]["logits"]
reference_action_log_probs = calc_action_log_probs(
reference_model_logits / self.generate_config["temperature"],
data["input_ids"],
num_action,
self.plugin.shard_config,
)
else:
# Dummy reference logprobs for data iterator.
reference_action_log_probs = torch.zeros(
(old_action_log_probs.size(0), old_action_log_probs.size(1))
)
data["reference_action_log_probs"] = reference_action_log_probs
data_iter = iter([data])
def _criterion(outputs, inputs):
pass
outputs = self.booster.execute_pipeline(
data_iter,
self.policy_model,
criterion=_criterion,
optimizer=self.optimizer,
return_loss=True,
)
loss = outputs["loss"]
if self.booster.plugin.stage_manager.is_last_stage():
loss = all_reduce_mean(loss, self.plugin)
mean_loss.append(loss.data)
else:
for forward_micro_batch_start in range(0, data["input_ids"].size(0), forward_batch_size):
input_ids_forward_micro_batch = data["input_ids"][
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
]
attention_mask_forward_micro_batch = data["attention_mask"][
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
]
action_mask_forward_micro_batch = action_mask[
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
]
loss_mask_forward_micro_batch = (
loss_mask[forward_micro_batch_start : forward_micro_batch_start + forward_batch_size]
if loss_mask is not None
else None
)
advantages_forward_micro_batch = advantages[
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
]
policy_model_logits = self.policy_model(
input_ids=input_ids_forward_micro_batch,
attention_mask=attention_mask_forward_micro_batch,
).logits
reference_action_log_probs = calc_action_log_probs(
reference_model_logits / self.generate_config["temperature"],
input_ids_forward_micro_batch,
num_action,
self.plugin.shard_config,
)
action_log_probs = calc_action_log_probs(
policy_model_logits / self.generate_config["temperature"],
input_ids_forward_micro_batch,
num_action,
self.plugin.shard_config,
)
per_token_kl = (
torch.exp(reference_action_log_probs - action_log_probs)
- (reference_action_log_probs - action_log_probs)
- 1
)
kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum(
action_mask_forward_micro_batch, dim=-1
)
with torch.no_grad():
reference_model_logits = self.reference_model(
input_ids=input_ids_forward_micro_batch,
attention_mask=attention_mask_forward_micro_batch,
).logits
reference_action_log_probs = calc_action_log_probs(
reference_model_logits / self.generate_config["temperature"],
input_ids_forward_micro_batch,
num_action,
self.plugin.shard_config,
)
loss, skip_update, _ = self.policy_loss_fn(
action_log_probs,
old_action_log_probs,
advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1),
per_token_kl,
action_mask_forward_micro_batch,
loss_mask=loss_mask_forward_micro_batch,
)
per_token_kl = (
torch.exp(reference_action_log_probs - action_log_probs)
- (reference_action_log_probs - action_log_probs)
- 1
)
kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum(
action_mask_forward_micro_batch, dim=-1
)
if not skip_update:
self.booster.backward(loss, self.optimizer)
loss = all_reduce_mean(loss, self.plugin)
kl = all_reduce_mean(kl.mean(), self.plugin)
# Calculate accumulate value.
mean_kl.append(kl.data)
mean_loss.append(loss.data)
loss, skip_update, _ = self.policy_loss_fn(
action_log_probs,
old_action_log_probs,
advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1),
per_token_kl,
action_mask_forward_micro_batch,
loss_mask=loss_mask_forward_micro_batch,
)
if not skip_update:
self.booster.backward(loss, self.optimizer)
loss = all_reduce_mean(loss, self.plugin)
kl = all_reduce_mean(kl.mean(), self.plugin)
# Calculate accumulate value.
mean_kl.append(kl.data)
mean_loss.append(loss.data)
reward = all_reduce_mean(reward.mean(), self.plugin)
format_reward = all_reduce_mean(format_reward.mean(), self.plugin)

View File

@@ -31,7 +31,13 @@ if __name__ == "__main__":
default=1,
help="Number of prompts per device. Number of samples = tMbs * num of generation per prompt.",
)
parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=2, help="Number of samples per device.")
parser.add_argument(
"-tmbs",
"--train-microbatch-size",
type=int,
default=2,
help="Number of samples per device. PP micro batchsize when PP is activated.",
)
parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm"])
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"])
args = parser.parse_args()
@@ -45,11 +51,7 @@ if __name__ == "__main__":
ray.init(address="local", namespace="ray-example")
inference_model_config = dict(path=args.model)
train_model_config = dict(
path=args.model,
# use_flash_attention_2=True,
# use_cache=False
)
train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False)
generate_config = dict(top_k=50, top_p=0.75, temperature=0.9)
if args.backend == "transformers":
@@ -107,7 +109,7 @@ if __name__ == "__main__":
generate_config=generate_config,
num_generations=args.num_generations,
train_model_config=train_model_config,
plugin_config={},
plugin_config={"pp_size": 2, "tp_size": 1, "microbatch_size": 2, "zero_stage": 0},
inference_backend=args.backend,
master_addr="localhost",
master_port=29505,

View File

@@ -1411,8 +1411,10 @@ class HybridParallelPlugin(PipelinePluginBase):
)
# run with gradients accumulation
if model.require_grad_sync == False or (
isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False
if (
not torch.is_grad_enabled()
or model.require_grad_sync == False
or (isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False)
):
return outputs

View File

@@ -284,6 +284,7 @@ class Qwen2PipelineForwards:
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
**kwargs,
):
r"""
Args: